Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support JWT auth #13

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion akasa-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ num_cpus = "1.14.0"
parking_lot = "0.12.1"
cfg-if = "1.0.0"
# rhai = { version = "1.11.0", features = ["decimal"] }
# jsonwebtoken = "8.2.0"
jsonwebtoken = "9.3.1"
# jwt-simple = "0.11.2"
serde = { version = "1.0.147", features = ["derive"] }
thiserror = "1.0.38"
Expand Down
113 changes: 113 additions & 0 deletions akasa-core/src/auth/jwt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::collections::HashMap;

use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;

use crate::config::JwtSecret;

#[derive(Error, Debug, PartialEq, Eq)]
pub enum JwtDecodeError {
#[error("InitError")]
InitError,
#[error("ValidationError")]
ValidationError(#[from] jsonwebtoken::errors::Error),
}

#[derive(Error, Debug, PartialEq, Eq)]
pub enum JwtEncodeError {
#[error("InitError")]
InitError,
#[error("EncodeError")]
EncodeError(#[from] jsonwebtoken::errors::Error),
}

#[derive(Clone)]
struct Secret {
decoding_key: DecodingKey,
}

#[derive(Clone)]
pub struct JWT {
validation: Validation,
secrets: HashMap<String, Secret>,
header: Header,
encoding_key: Option<EncodingKey>,
}

impl Default for JWT {
fn default() -> Self {
let mut validation = Validation::default();
let required_spec: Vec<String> = vec![];
validation.set_required_spec_claims(&required_spec);
Self {
validation,
secrets: Default::default(),
header: Default::default(),
encoding_key: None,
}
}
}

impl JWT {
pub fn update_from(&mut self, m: &HashMap<String, JwtSecret>) {
for (name, secret) in m.iter() {
let decoding_key = match secret {
JwtSecret::HS256 { secret } => {
let b = secret.as_bytes();
self.header.alg = Algorithm::HS256;
let encoder = EncodingKey::from_secret(b);
self.encoding_key = Some(encoder);
DecodingKey::from_secret(b)
}
JwtSecret::HS384 { secret } => {
let b = secret.as_bytes();
self.header.alg = Algorithm::HS384;
let encoder = EncodingKey::from_secret(b);
self.encoding_key = Some(encoder);
DecodingKey::from_secret(b)
}
JwtSecret::HS512 { secret } => {
let b = secret.as_bytes();
self.header.alg = Algorithm::HS512;
let encoder = EncodingKey::from_secret(b);
self.encoding_key = Some(encoder);
DecodingKey::from_secret(b)
}
};
let s = Secret { decoding_key };
if let Some(_secret) = self.secrets.insert(name.to_string(), s) {
log::warn!("JWT secret replaced by name {name}");
}
}
}

pub fn encode<T>(&self, claims: T) -> Result<String, JwtEncodeError>
where
T: Serialize,
{
self.encoding_key
.as_ref()
.map(|encoder| {
encode(&self.header, &claims, encoder).map_err(JwtEncodeError::EncodeError)
})
.unwrap_or(Err(JwtEncodeError::InitError))
}

pub fn decode<T>(&self, token: &[u8]) -> Result<T, JwtDecodeError>
where
T: DeserializeOwned + std::fmt::Debug,
{
let token = String::from_utf8_lossy(token);
let mut e = JwtDecodeError::InitError;
for (_name, secret) in self.secrets.iter() {
match decode(&token, &secret.decoding_key, &self.validation) {
Ok(token) => {
return Ok(token.claims);
}
Err(err) => e = JwtDecodeError::ValidationError(err),
}
}
Err(e)
}
}
78 changes: 78 additions & 0 deletions akasa-core/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
pub mod jwt;
pub mod user;

use std::collections::HashMap;

use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::auth::user::User;
use crate::{config::JwtSecret, protocols::mqtt::check_password, AuthPassword};

#[derive(Default, Serialize, Deserialize, PartialEq, Eq, Debug)]
pub struct Claims {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub sub: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub exp: Option<usize>,
#[serde(default, skip_serializing_if = "skip_false")]
pub superuser: bool,
}

fn skip_false(b: &bool) -> bool {
!b
}

impl From<Claims> for User {
fn from(value: Claims) -> Self {
Self {
username: value.sub,
superuser: value.superuser,
}
}
}

#[derive(Default)]
pub struct Auth {
pub allow_anonymous: bool,
passwords: DashMap<String, AuthPassword>,
pub jwt: jwt::JWT,
}

#[derive(Error, Debug)]
pub enum AuthError {
#[error("NotAuthorized")]
NotAuthorized,
}

impl Auth {
pub fn authorize(&self, username: &str, password: &[u8]) -> Result<User, AuthError> {
if check_password(&self.passwords, username, password) {
Ok(User::new(username))
} else if let Ok(mut claims) = self.jwt.decode::<Claims>(password) {
if claims.sub.is_empty() {
claims.sub = username.to_string();
}
Ok(claims.into())
} else if self.allow_anonymous {
Ok(User::super_user(username))
} else {
Err(AuthError::NotAuthorized)
}
}

pub fn update_passwords(&self, m: DashMap<String, AuthPassword>) {
for (k, v) in m {
self.passwords.insert(k, v);
}
}

pub fn update_password(&mut self, username: String, pswd: AuthPassword) {
self.passwords.insert(username, pswd);
}

pub fn update_jwt(&mut self, m: &HashMap<String, JwtSecret>) {
self.jwt.update_from(m);
}
}
32 changes: 32 additions & 0 deletions akasa-core/src/auth/user.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use mqtt_proto::{TopicFilter, TopicName};
use serde::{Deserialize, Serialize};

#[derive(Default, Serialize, Deserialize, Debug)]
pub struct User {
pub username: String,
pub superuser: bool,
}

impl User {
pub fn new(username: &str) -> Self {
Self {
username: username.to_string(),
..Default::default()
}
}
pub fn super_user(username: &str) -> Self {
Self {
username: username.to_string(),
superuser: true,
}
}

pub fn allow_read(&self, _tf: TopicFilter) -> bool {
// TODO
true
}
pub fn allow_write(&self, _topic: TopicName) -> bool {
// TODO
true
}
}
20 changes: 20 additions & 0 deletions akasa-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,26 @@ pub struct ScramPasswordInfo {
pub salt: Vec<u8>,
}

#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(tag = "alg")]
pub enum JwtSecret {
HS256 { secret: String },
HS384 { secret: String },
HS512 { secret: String },
}

#[derive(Default, Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct JwtConfig {
#[serde(default)]
pub secrets_file: Option<PathBuf>,
}

#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct AuthConfig {
pub enable: bool,
pub password_file: Option<PathBuf>,
#[serde(default)]
pub jwt: JwtConfig,
}

#[derive(Serialize, Deserialize, Clone, Copy, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -174,6 +190,9 @@ impl Default for Config {
auth: AuthConfig {
enable: true,
password_file: Some(PathBuf::from("/path/to/passwords/file")),
jwt: JwtConfig {
secrets_file: Some(PathBuf::from("/path/to/secrets/jwt.yaml")),
},
},
scram_users: vec![("user", (b"***", 4096, b"salt"))]
.into_iter()
Expand Down Expand Up @@ -222,6 +241,7 @@ impl Config {
auth: AuthConfig {
enable: false,
password_file: None,
jwt: JwtConfig::default(),
},
..Default::default()
}
Expand Down
1 change: 1 addition & 0 deletions akasa-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod auth;
mod config;
mod hook;
mod protocols;
Expand Down
2 changes: 1 addition & 1 deletion akasa-core/src/protocols/mqtt/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mod auth;
pub mod auth;
mod common;
mod online_loop;
mod pending;
Expand Down
7 changes: 4 additions & 3 deletions akasa-core/src/protocols/mqtt/v3/packet/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use mqtt_proto::{
};
use tokio::io::AsyncWrite;

use crate::protocols::mqtt::{check_password, start_keep_alive_timer};
use crate::protocols::mqtt::start_keep_alive_timer;
use crate::state::{AddClientReceipt, ClientReceiver, GlobalState};

use super::super::Session;
Expand Down Expand Up @@ -71,7 +71,9 @@ clean session : {}
} else {
let username = packet.username.as_ref().unwrap();
let password = packet.password.as_ref().unwrap();
if !check_password(&global.auth_passwords, username, password) {
if let Ok(user) = global.auth.authorize(username, password) {
session.user = Some(Arc::new(user));
} else {
log::debug!("incorrect password for user: {}", username);
return_code = ConnectReturnCode::BadUserNameOrPassword;
}
Expand All @@ -95,7 +97,6 @@ clean session : {}
} else {
Arc::clone(&packet.client_id)
};
session.username = packet.username.map(|name| Arc::clone(&name));
session.keep_alive = packet.keep_alive;

if let Some(last_will) = packet.last_will {
Expand Down
5 changes: 3 additions & 2 deletions akasa-core/src/protocols/mqtt/v3/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use hashbrown::HashMap;
use mqtt_proto::{v3::LastWill, Pid, Protocol, QoS, TopicFilter, TopicName};
use parking_lot::RwLock;

use crate::auth::user::User;
use crate::config::Config;
use crate::state::{ClientId, ClientReceiver};

Expand All @@ -28,7 +29,7 @@ pub struct Session {
pub(super) client_id: ClientId,
pub client_identifier: Arc<String>,
pub assigned_client_id: bool,
pub username: Option<Arc<String>>,
pub user: Option<Arc<User>>,
pub keep_alive: u16,
pub clean_session: bool,
pub last_will: Option<LastWill>,
Expand Down Expand Up @@ -73,7 +74,7 @@ impl Session {
client_id: ClientId::max_value(),
client_identifier: Arc::new(String::new()),
assigned_client_id: false,
username: None,
user: None,
keep_alive: 0,
clean_session: true,
last_will: None,
Expand Down
7 changes: 4 additions & 3 deletions akasa-core/src/protocols/mqtt/v5/packet/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use scram::server::{AuthenticationStatus, ScramServer};
use tokio::io::AsyncWrite;

use crate::config::SaslMechanism;
use crate::protocols::mqtt::{check_password, start_keep_alive_timer};
use crate::protocols::mqtt::start_keep_alive_timer;
use crate::state::{AddClientReceipt, ClientReceiver, GlobalState};

use super::super::{ScramStage, Session, TracedRng};
Expand Down Expand Up @@ -57,7 +57,9 @@ pub(crate) async fn handle_connect<T: AsyncWrite + Unpin>(
} else {
let username = packet.username.as_ref().unwrap();
let password = packet.password.as_ref().unwrap();
if !check_password(&global.auth_passwords, username, password) {
if let Ok(user) = global.auth.authorize(username, password) {
session.user = Some(Arc::new(user));
} else {
log::debug!("incorrect password for user: {}", username);
reason_code = ConnectReasonCode::BadUserNameOrPassword;
}
Expand All @@ -80,7 +82,6 @@ pub(crate) async fn handle_connect<T: AsyncWrite + Unpin>(
} else {
Arc::clone(&packet.client_id)
};
session.username = packet.username;
session.keep_alive = if packet.keep_alive > global.config.max_keep_alive {
global.config.max_keep_alive
} else if packet.keep_alive < global.config.min_keep_alive {
Expand Down
5 changes: 3 additions & 2 deletions akasa-core/src/protocols/mqtt/v5/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use rand::{rngs::OsRng, RngCore};

use parking_lot::RwLock;

use crate::auth::user::User;
use crate::config::Config;
use crate::state::{ClientId, ClientReceiver};

Expand Down Expand Up @@ -46,7 +47,7 @@ pub struct Session {
pub(super) server_keep_alive: bool,
// (username, Option<role>)
pub scram_auth_result: Option<(String, Option<String>)>,
pub username: Option<Arc<String>>,
pub user: Option<Arc<User>>,
pub keep_alive: u16,
pub clean_start: bool,
pub last_will: Option<LastWill>,
Expand Down Expand Up @@ -112,7 +113,7 @@ impl Session {
assigned_client_id: false,
server_keep_alive: false,
scram_auth_result: None,
username: None,
user: None,
keep_alive: 0,
clean_start: true,
last_will: None,
Expand Down
Loading