Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major upgrade and refactor #15

Merged
merged 21 commits into from
Jun 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Bare-bones integration with core
masongup-mdsol committed May 31, 2024
commit 3e8a347001705a38d956ed04f09f4fa179f90e0e
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -36,6 +36,8 @@ axum = { version = ">= 0.7.2", optional = true }
futures-core = { version = ">= 0.3.25", optional = true }
thiserror = ">= 1.0.37"

mauth-core = "0.4"

[dev-dependencies]
tokio = { version = ">= 1.0.1", features = ["rt-multi-thread", "macros"] }

10 changes: 0 additions & 10 deletions build.rs
Original file line number Diff line number Diff line change
@@ -25,16 +25,6 @@ fn main() {
let formatted_name = name.replace('-', "_");
code_str.push_str(&format!(
r#"
#[tokio::test]
async fn {formatted_name}_string_to_sign() {{
test_string_to_sign("{name}".to_string()).await;
}}
#[tokio::test]
async fn {formatted_name}_sign_string() {{
test_sign_string("{name}".to_string()).await;
}}
#[tokio::test]
async fn {formatted_name}_generate_headers() {{
test_generate_headers("{name}".to_string()).await;
219 changes: 67 additions & 152 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -40,22 +40,22 @@ use std::sync::{Arc, RwLock};

use base64::Engine;
use chrono::prelude::*;
use percent_encoding::{percent_decode_str, percent_encode, AsciiSet, NON_ALPHANUMERIC};
use regex::{Captures, Regex};
use reqwest::{header::HeaderValue, Body, Client, Method, Request, Response, Url};
use ring::rand::SystemRandom;
use ring::signature::{
RsaKeyPair, UnparsedPublicKey, RSA_PKCS1_2048_8192_SHA512, RSA_PKCS1_SHA512,
};
use ring::signature::{UnparsedPublicKey, RSA_PKCS1_2048_8192_SHA512};
use serde::Deserialize;
use sha2::{Digest, Sha512};
use thiserror::Error;
use tokio::io;
use uuid::Uuid;

use openssl::pkey::{PKey, Private, Public};
use openssl::pkey::Public;
use openssl::rsa::{Padding, Rsa};

use mauth_core::signer::Signer;

#[cfg(feature = "axum-service")]
use mauth_core::verifier::Verifier;

const CONFIG_FILE: &str = ".mauth_config.yml";

/// This is the primary struct of this class. It contains all of the information
@@ -65,12 +65,11 @@ const CONFIG_FILE: &str = ".mauth_config.yml";
/// makes the struct non-Sync.
pub struct MAuthInfo {
app_id: Uuid,
private_key: RsaKeyPair,
openssl_private_key: Rsa<Private>,
remote_key_store: Arc<RwLock<HashMap<Uuid, Rsa<Public>>>>,
mauth_uri_base: Url,
sign_with_v1_also: bool,
allow_v1_auth: bool,
signer: Signer,
}

/// This struct holds the digest information required to perform the signing operation. It is a
@@ -143,18 +142,15 @@ impl MAuthInfo {
.parse()?;

let pk_data = std::fs::read_to_string(&section.private_key_file)?;
let openssl_key = PKey::private_key_from_pem(&pk_data.into_bytes())?;
let der_key_data = openssl_key.private_key_to_der()?;

Ok(MAuthInfo {
app_id: Uuid::parse_str(&section.app_uuid)?,
mauth_uri_base: full_uri,
remote_key_store: input_keystore
.unwrap_or_else(|| Arc::new(RwLock::new(HashMap::new()))),
private_key: RsaKeyPair::from_der(&der_key_data)?,
openssl_private_key: openssl_key.rsa()?,
sign_with_v1_also: !section.v2_only_sign_requests.unwrap_or(false),
allow_v1_auth: !section.v2_only_authenticate.unwrap_or(false),
signer: Signer::new(section.app_uuid.clone(), pk_data).unwrap(),
})
}

@@ -298,9 +294,18 @@ impl MAuthInfo {
/// shortly after the signature takes place.
pub fn sign_request_v2(&self, req: &mut Request, body_digest: &BodyDigest) {
let timestamp_str = Utc::now().timestamp().to_string();
let string_to_sign = self.get_signing_string_v2(req, body_digest, &timestamp_str);
let signature = self.sign_string_v2(string_to_sign);
self.set_headers_v2(req, signature, &timestamp_str);
let some_string = self
.signer
.sign_string(
2,
req.method().as_str(),
req.url().path(),
req.url().query().unwrap_or(""),
&body_digest.body_data,
timestamp_str.clone(),
)
.unwrap();
self.set_headers_v2(req, some_string, &timestamp_str);
}

#[cfg(feature = "axum-service")]
@@ -309,8 +314,8 @@ impl MAuthInfo {
req: &http::request::Parts,
body_bytes: &bytes::Bytes,
) -> Result<Uuid, MAuthValidationError> {
let mut hasher = Sha512::default();
hasher.update(body_bytes);
// let mut hasher = Sha512::default();
// hasher.update(body_bytes);

//retrieve and parse auth string
let sig_header = req
@@ -330,29 +335,26 @@ impl MAuthInfo {
.map_err(|_| MAuthValidationError::InvalidTime)?;
Self::validate_timestamp(ts_str)?;

//Compute response signing string
let encoded_query: String = req.uri.query().map_or("".to_string(), Self::encode_query);

let string_to_sign = format!(
"{}\n{}\n{}\n{}\n{}\n{}",
req.method,
Self::normalize_url(req.uri.path()),
&hex::encode(hasher.finalize()),
&host_app_uuid,
&ts_str,
&encoded_query
);

match self.get_app_pub_key(&host_app_uuid).await {
None => Err(MAuthValidationError::KeyUnavailable),
Some(pub_key) => {
let ring_key = UnparsedPublicKey::new(
&RSA_PKCS1_2048_8192_SHA512,
bytes::Bytes::from(pub_key.public_key_to_der_pkcs1().unwrap()),
);
match ring_key.verify(&string_to_sign.into_bytes(), &raw_signature) {
let verifier = Verifier::new(
host_app_uuid,
String::from_utf8(pub_key.public_key_to_der_pkcs1().unwrap()).unwrap(),
)
.unwrap();
match verifier.verify_signature(
2,
req.method.as_str(),
req.uri.path(),
req.uri.query().unwrap_or(""),
body_bytes,
ts_str,
String::from_utf8(raw_signature).unwrap(),
) {
Ok(true) => Ok(host_app_uuid),
Ok(false) => Err(MAuthValidationError::SignatureVerifyFailure),
Err(_) => Err(MAuthValidationError::SignatureVerifyFailure),
Ok(()) => Ok(host_app_uuid),
}
}
}
@@ -382,63 +384,31 @@ impl MAuthInfo {
.map_err(|_| MAuthValidationError::InvalidTime)?;
Self::validate_timestamp(ts_str)?;

//Compute response signing string
let mut hasher = Sha512::default();
let string_to_sign1 = format!("{}\n{}\n", req.method, req.uri.path());
hasher.update(string_to_sign1.into_bytes());
hasher.update(body_bytes);
let string_to_sign2 = format!("\n{}\n{}", &host_app_uuid, &ts_str);
hasher.update(string_to_sign2.into_bytes());
let sign_input: Vec<u8> = hex::encode(hasher.finalize()).into_bytes();

match self.get_app_pub_key(&host_app_uuid).await {
None => Err(MAuthValidationError::KeyUnavailable),
Some(pub_key) => {
let mut sign_output: Vec<u8> = vec![0; pub_key.size() as usize];
let len = pub_key
.public_decrypt(&raw_signature, &mut sign_output, Padding::PKCS1)
.unwrap();
if *sign_input.as_slice() == sign_output[0..len] {
Ok(host_app_uuid)
} else {
Err(MAuthValidationError::SignatureVerifyFailure)
let verifier = Verifier::new(
host_app_uuid,
String::from_utf8(pub_key.public_key_to_der_pkcs1().unwrap()).unwrap(),
)
.unwrap();
match verifier.verify_signature(
1,
req.method.as_str(),
req.uri.path(),
req.uri.query().unwrap_or(""),
body_bytes,
ts_str,
String::from_utf8(raw_signature).unwrap(),
) {
Ok(true) => Ok(host_app_uuid),
Ok(false) => Err(MAuthValidationError::SignatureVerifyFailure),
Err(_) => Err(MAuthValidationError::SignatureVerifyFailure),
}
}
}
}

fn get_signing_string_v2(
&self,
req: &Request,
body_digest: &BodyDigest,
timestamp_str: &str,
) -> String {
let encoded_query: String = req.url().query().map_or("".to_string(), Self::encode_query);
format!(
"{}\n{}\n{}\n{}\n{}\n{}",
req.method(),
Self::normalize_url(req.url().path()),
&body_digest.digest_str,
&self.app_id,
&timestamp_str,
&encoded_query
)
}

fn sign_string_v2(&self, string: String) -> String {
let mut signature = vec![0; self.private_key.public().modulus_len()];
self.private_key
.sign(
&RSA_PKCS1_SHA512,
&SystemRandom::new(),
&string.into_bytes(),
&mut signature,
)
.unwrap();
let b64 = base64::engine::general_purpose::STANDARD;
b64.encode(&signature)
}

fn set_headers_v2(&self, req: &mut Request, signature: String, timestamp_str: &str) {
let sig_head_str = format!("MWSV2 {}:{};", self.app_id, &signature);
let headers = req.headers_mut();
@@ -449,53 +419,6 @@ impl MAuthInfo {
);
}

const MAUTH_ENCODE_CHARS: &'static AsciiSet = &NON_ALPHANUMERIC
.remove(b'-')
.remove(b'_')
.remove(b'%')
.remove(b'.')
.remove(b'~');

fn encode_query(qstr: &str) -> String {
let mut temp_param_list: Vec<Vec<Vec<u8>>> = qstr
.split('&')
.map(|p| {
p.split('=')
.map(|x| percent_decode_str(&x.replace('+', " ")).collect())
.collect()
})
.collect();

temp_param_list.sort();
temp_param_list
.iter()
.map(|p| {
p.iter()
.map(|x| percent_encode(x, Self::MAUTH_ENCODE_CHARS).to_string())
.collect::<Vec<String>>()
.join("=")
})
.collect::<Vec<String>>()
.join("&")
}

fn normalize_url(urlstr: &str) -> String {
let squeeze_regex = Regex::new(r"/+").unwrap();
let url = squeeze_regex.replace_all(urlstr, "/");
let percent_case_regex = Regex::new(r"%[a-f0-9]{2}").unwrap();
let url = percent_case_regex.replace_all(&url, |c: &Captures| c[0].to_uppercase());
let mut url = url.replace("/./", "/");
let path_regex2 = Regex::new(r"/[^/]+/\.\./?").unwrap();
loop {
let new_url = path_regex2.replace_all(&url, "/").to_string();
if new_url == url {
return new_url;
} else {
url = new_url;
}
}
}

/// Sign a provided request using the MAuth V1 protocol. The signature consists of 2 headers
/// containing both a timestamp and a signature string, and will be added to the headers of the
/// request. It is required to pass a `body`, even if the request is an empty-body GET.
@@ -504,30 +427,22 @@ impl MAuthInfo {
/// shortly after the signature takes place.
pub fn sign_request_v1(&self, req: &mut Request, body: &BodyDigest) {
let timestamp_str = Utc::now().timestamp().to_string();
let mut hasher = Sha512::default();
let string_to_sign1 = format!("{}\n{}\n", req.method(), req.url().path());
hasher.update(string_to_sign1.into_bytes());
hasher.update(body.body_data.clone());
let string_to_sign2 = format!("\n{}\n{}", &self.app_id, &timestamp_str);
hasher.update(string_to_sign2.into_bytes());

let mut sign_output = vec![0; self.openssl_private_key.size() as usize];
self.openssl_private_key
.private_encrypt(
&hex::encode(hasher.finalize()).into_bytes(),
&mut sign_output,
Padding::PKCS1,

let sig = self
.signer
.sign_string(
1,
req.method().as_str(),
req.url().path(),
req.url().query().unwrap_or(""),
&body.body_data,
timestamp_str.clone(),
)
.unwrap();
let b64 = base64::engine::general_purpose::STANDARD;
let signature = format!("MWS {}:{}", self.app_id, b64.encode(&sign_output));

let headers = req.headers_mut();
headers.insert("X-MWS-Time", HeaderValue::from_str(&timestamp_str).unwrap());
headers.insert(
"X-MWS-Authentication",
HeaderValue::from_str(&signature).unwrap(),
);
headers.insert("X-MWS-Authentication", HeaderValue::from_str(&sig).unwrap());
}

fn validate_timestamp(timestamp_str: &str) -> Result<(), MAuthValidationError> {
57 changes: 0 additions & 57 deletions src/protocol_test_suite.rs
Original file line number Diff line number Diff line change
@@ -5,14 +5,6 @@ use tokio::fs;

use std::path::{Path, PathBuf};

#[derive(Deserialize)]
struct RequestShape {
verb: String,
url: String,
body: Option<String>,
body_filepath: Option<String>,
}

#[derive(Deserialize)]
struct TestSignConfig {
app_uuid: String,
@@ -43,55 +35,6 @@ async fn setup_mauth_info() -> (MAuthInfo, u64) {
)
}

async fn test_string_to_sign(file_name: String) {
let (mauth_info, req_time) = setup_mauth_info().await;
let mut req_file_path = PathBuf::from(&BASE_PATH);
req_file_path.push(format!("{name}/{name}.req", name = &file_name));
let request_shape: RequestShape =
serde_json::from_slice(&fs::read(req_file_path).await.unwrap()).unwrap();

let mut sts_file_path = PathBuf::from(&BASE_PATH);
sts_file_path.push(format!("{name}/{name}.sts", name = &file_name));
let expected_string_to_sign =
String::from_utf8(fs::read(sts_file_path).await.unwrap()).unwrap();

let body_data = match (request_shape.body, request_shape.body_filepath) {
(Some(direct_str), None) => direct_str.as_bytes().to_vec(),
(None, Some(filename_str)) => {
let mut body_file_path = PathBuf::from(&BASE_PATH);
body_file_path.push(&file_name);
body_file_path.push(filename_str);
fs::read(body_file_path).await.unwrap()
}
_ => vec![],
};

let (body, digest) = MAuthInfo::build_body_with_digest_from_bytes(body_data);
// It seems the Url class really doesn't like relative URLs
let fixed_url = format!("http://a.com{}", request_shape.url.replace(" ", "%20"));
let method = Method::from_bytes(request_shape.verb.as_bytes()).unwrap();
let mut req = Request::new(method, fixed_url.parse().unwrap());
*req.body_mut() = Some(body);
let sts = mauth_info.get_signing_string_v2(&req, &digest, &req_time.to_string());

assert_eq!(expected_string_to_sign, sts);
}

async fn test_sign_string(file_name: String) {
let (mauth_info, _) = setup_mauth_info().await;
let mut sts_file_path = PathBuf::from(&BASE_PATH);
sts_file_path.push(format!("{name}/{name}.sts", name = &file_name));
let string_to_sign = String::from_utf8(fs::read(sts_file_path).await.unwrap()).unwrap();

let mut sig_file_path = PathBuf::from(&BASE_PATH);
sig_file_path.push(format!("{name}/{name}.sig", name = &file_name));
let expected_sig = String::from_utf8(fs::read(sig_file_path).await.unwrap()).unwrap();

let signed = mauth_info.sign_string_v2(string_to_sign);

assert_eq!(expected_sig, signed);
}

async fn test_generate_headers(file_name: String) {
let (mauth_info, req_time) = setup_mauth_info().await;