From 67c23d6dd4d57ee74a12fa2a4ab3cc140e277487 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 15 Jul 2024 12:24:06 +0200 Subject: [PATCH] feat: base support --- Cargo.toml | 28 ++--- docs/configuration.md | 21 +--- src/api/client_server/account.rs | 2 - src/api/client_server/keys.rs | 32 +++--- src/api/client_server/session.rs | 20 ++-- src/api/client_server/sso.rs | 184 +++++++++++++++---------------- src/main.rs | 10 +- src/service/globals/mod.rs | 8 +- src/service/sso/mod.rs | 21 ++-- 9 files changed, 148 insertions(+), 178 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c52ed191..93abd753 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ axum = { version = "0.7", default-features = false, features = [ "json", "matched-path", ], optional = true } -axum-extra = { version = "0.9", features = ["typed-header", "cookie"] } +axum-extra = { version = "0.9", features = ["cookie", "typed-header"] } axum-server = { version = "0.6", features = ["tls-rustls"] } tower = { version = "0.4.13", features = ["util"] } tower-http = { version = "0.5", features = [ @@ -49,15 +49,6 @@ tower-http = { version = "0.5", features = [ "trace", "util", ] } -# tower-http = { version = "0.5", features = [ -# "add-extension", -# "cors", -# "decompression-full", -# "sensitive-headers", -# "set-header", -# "trace", -# "util", -# ] } tower-service = "0.3" # Async runtime and utilities @@ -153,11 +144,6 @@ figment = { version = "0.10.8", features = ["env", "toml"] } # Validating urls in config url = { version = "2", features = ["serde"] } -# HTML -mas-oidc-client = { git = "https://github.com/matrix-org/matrix-authentication-service", default-features = false } -mas-http = { git = "https://github.com/matrix-org/matrix-authentication-service", features = ["client"] } -maud = { version = "0.26.0", default-features = false, features = ["axum"] } - async-trait = "0.1.68" tikv-jemallocator = { version = "0.5.0", features = [ "unprefixed_malloc_on_supported_platforms", @@ -190,11 +176,21 @@ optional = true package = "rust-rocksdb" version = "0.25" +[dependencies.mas-http] +features = ["client"] +git = "https://github.com/matrix-org/matrix-authentication-service" +rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8" + +[dependencies.mas-oidc-client] +features = [] +git = "https://github.com/matrix-org/matrix-authentication-service" +rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8" + [target.'cfg(unix)'.dependencies] nix = { version = "0.28", features = ["resource"] } [features] -default = ["backend_sqlite", "conduit_bin"] +default = ["backend_rocksdb", "backend_sqlite", "conduit_bin", "systemd"] #backend_sled = ["sled"] backend_persy = ["parking_lot", "persy"] backend_sqlite = ["sqlite"] diff --git a/docs/configuration.md b/docs/configuration.md index a8fa07de..21375004 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -124,25 +124,6 @@ Identity providers using OAuth such as Github are not supported yet. | `name` | `string` | The name displayed on fallback pages. | `issuer` | | `icon` | `Url` OR `MxcUri` | The icon displayed on fallback pages. | N/A | | `scopes` | `array` | The scopes used to obtain extra claims which can be used for templates. | `["openid"]` | - - - - | `client_id`* | `string` | The provider-supplied, unique ID for the client. | N/A | | `client_secret`* | `string` | The provider-supplied, unique ID for the client. | N/A | -| `authentication_method`* | `"basic" | "post"` | The method used for client authentication. | N/A | - - - - - - - - - - - - - - - +| `authentication_method`* | `"basic" OR "post"` | The method used for client authentication. | N/A | diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index f688ff68..47ccdc83 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -322,8 +322,6 @@ pub async fn change_password_route( .ok_or_else(|| Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))?; let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - // if services().users.password_hash(sender_user)? == Some(""); - let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { stages: vec![AuthType::Password], diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 5dcea4fa..05110248 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -100,6 +100,12 @@ pub async fn upload_signing_keys_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let master_key = services() + .users + .get_master_key(Some(sender_user), sender_user, &|other| { + sender_user == other + })?; + // UIAA let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -111,11 +117,15 @@ pub async fn upload_signing_keys_route( auth_error: None, }; - let master_key = services() - .users - .get_master_key(None, sender_user, &|user_id| user_id == sender_user)?; - - if let Some(auth) = &body.auth { + if let (Some(master_key), None) = (&body.master_key, master_key) { + services().users.add_cross_signing_keys( + sender_user, + master_key, + &body.self_signing_key, + &body.user_signing_key, + true, + )?; + } else if let Some(auth) = &body.auth { let (worked, uiaainfo) = services() .uiaa @@ -130,20 +140,10 @@ pub async fn upload_signing_keys_route( .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); - } else if master_key.is_some() { + } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - if let Some(master_key) = &body.master_key { - services().users.add_cross_signing_keys( - sender_user, - master_key, - &body.self_signing_key, - &body.user_signing_key, - true, // notify so that other users see the new keys - )?; - } - Ok(upload_signing_keys::v3::Response {}) } diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 148c67f5..0c1189ae 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -113,20 +113,24 @@ pub async fn login_route(body: Ruma) -> Result { match ( services().globals.jwt_decoding_key(), - services().sso.login_type().next().is_some(), + services().globals.config.idps.is_empty(), ) { (_, false) => { - let mut validation = Validation::new(Algorithm::HS256); - validation.validate_nbf = false; - validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); + let mut v = Validation::new(Algorithm::HS256); + + v.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); + v.validate_aud = false; + v.validate_nbf = false; services() .globals - .validate_claims::(token, Some(validation)) - .as_ref() + .validate_claims::(token, Some(&v)) .map(LoginToken::audience) - .map(ToOwned::to_owned) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid token."))? + .map_err(|e| { + tracing::warn!("Invalid token: {}", e); + + Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.") + })? } (Some(jwt_decoding_key), _) => { let token = jsonwebtoken::decode::( diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs index e0b7f9a2..a35439c9 100644 --- a/src/api/client_server/sso.rs +++ b/src/api/client_server/sso.rs @@ -7,15 +7,7 @@ use crate::{ }, services, utils, Error, Result, Ruma, }; -use axum::{ - response::{AppendHeaders, IntoResponse, Redirect}, - RequestExt, -}; -use axum_extra::{ - headers::{self}, - TypedHeader, -}; -use http::header::{self}; +use futures_util::TryFutureExt; use mas_oidc_client::{ requests::{ authorization_code::{self, AuthorizationRequestData}, @@ -24,7 +16,6 @@ use mas_oidc_client::{ }, types::{ client_credentials::ClientCredentials, - errors::ClientError, iana::jose::JsonWebSignatureAlg, requests::{AccessTokenResponse, AuthorizationResponse}, }, @@ -33,6 +24,7 @@ use rand::{rngs::StdRng, Rng, SeedableRng}; use ruma::{ api::client::{ error::ErrorKind, + media::create_content, session::{sso_login, sso_login_with_provider}, }, events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, @@ -46,7 +38,7 @@ pub const CALLBACK_PATH: &str = "/_matrix/client/unstable/conduit/callback"; /// # `GET /_matrix/client/v3/login/sso/redirect` /// -/// Redirect the user to the SSO interface. +/// Redirect the user to the SSO interfa. /// TODO: this should be removed once Ruma supports trailing slashes. pub async fn get_sso_redirect_route( Ruma { @@ -148,37 +140,25 @@ pub async fn get_sso_redirect_with_provider_route( }) } -async fn handle_callback_helper(req: axum::extract::Request) -> Result { - let query = req.uri().query().ok_or_else(|| { - Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.") - })?; - - let AuthorizationResponse { - code, - access_token: _, - token_type: _, - id_token: _, - expires_in: _, - } = serde_html_form::from_str(query).map_err(|_| { - serde_html_form::from_str(query) - .map(ClientError::into) - .unwrap_or_else(|_| { - error!("Failed to deserialize authorization callback: {}", query); - - Error::BadRequest( - ErrorKind::Unknown, - "Failed to deserialize authorization callback.", - ) - }) - })?; - - let Ok(Some(cookie)): Result>, _> = req.extract().await - else { - return Err(Error::BadRequest( - ErrorKind::MissingParam, - "Missing session cookie.", - )); - }; +/// # `GET /_conduit/client/sso/callback` +/// +/// Validate the authorization response received from the identity provider. +/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. +/// If this is the first login, register the user, possibly interactively through a fallback page. +pub async fn handle_callback_route( + body: Ruma, +) -> Result { + let sso_callback::Request { + response: + AuthorizationResponse { + code, + access_token: _, + token_type: _, + id_token: _, + expires_in: _, + }, + cookie, + } = body.body; let ValidationData { provider, @@ -186,12 +166,7 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result Result Result Result s.to_owned(), Some(Value::Number(n)) => n.to_string(), @@ -299,8 +281,13 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result break user_id, + .map(|user_id| { + ( + user_id.clone(), + services().users.exists(&user_id).unwrap_or(true), + ) + }) { + Ok((user_id, false)) => break user_id, _ => { let n: u8 = rand::thread_rng().gen(); @@ -310,12 +297,15 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result Result Result Result axum::response::Response { - match handle_callback_helper(req).await { - Ok(res) => res, - Err(e) => e.into_response(), - } + Ok(sso_login_with_provider::v3::Response { + location: redirect_url.to_string(), + cookie: Some(utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string()), + }) } mod sso_callback { @@ -404,9 +406,9 @@ mod sso_callback { use mas_oidc_client::types::requests::AuthorizationResponse; use ruma::{ api::{ - client::Error, + client::{session::sso_login_with_provider, Error}, error::{FromHttpRequestError, HeaderDeserializationError}, - IncomingRequest, Metadata, OutgoingResponse, + IncomingRequest, Metadata, }, metadata, }; @@ -423,15 +425,13 @@ mod sso_callback { }; pub struct Request { - response: AuthorizationResponse, - cookie: String, + pub response: AuthorizationResponse, + pub cookie: String, } - pub struct Response {} - impl IncomingRequest for Request { type EndpointError = Error; - type OutgoingResponse = Response; + type OutgoingResponse = sso_login_with_provider::v3::Response; const METADATA: Metadata = METADATA; @@ -470,12 +470,4 @@ mod sso_callback { Ok(Self { response, cookie }) } } - - impl OutgoingResponse for Response { - fn try_into_http_response( - self, - ) -> Result, ruma::api::error::IntoHttpError> { - todo!() - } - } } diff --git a/src/main.rs b/src/main.rs index 15b59be4..34887460 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,10 +9,7 @@ use axum::{ Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; -use conduit::api::{ - client_server::{self, CALLBACK_PATH}, - server_server, -}; +use conduit::api::{client_server, server_server}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -283,10 +280,7 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::get_sso_redirect_with_provider_route) // The specification will likely never introduce any endpoint for handling authorization callbacks. // As a workaround, we use custom path that redirects the user to the default login handler. - .route( - &format!("/{CALLBACK_PATH}"), - get(client_server::handle_callback_route), - ) + .ruma_route(client_server::handle_callback_route) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 9a3c7d6a..2a4b76ff 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -18,7 +18,7 @@ use ruma::{ DeviceId, RoomVersionId, ServerName, UserId, }; use std::{ - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, HashMap, HashSet}, error::Error as StdError, fs, future::{self, Future}, @@ -522,7 +522,7 @@ impl Service { pub fn validate_claims( &self, token: &str, - validation_data: Option, + validation_data: Option<&jsonwebtoken::Validation>, ) -> jsonwebtoken::errors::Result { let key = jsonwebtoken::DecodingKey::from_secret( self.keypair().sign(PROBLEMATIC_CONST).as_bytes(), @@ -533,9 +533,9 @@ impl Service { // these validations are redundant as all JWTs are stored in cookies v.validate_exp = false; v.validate_nbf = false; - v.required_spec_claims = Default::default(); + v.required_spec_claims = HashSet::new(); - jsonwebtoken::decode::(token, &key, &validation_data.unwrap_or(v)) + jsonwebtoken::decode::(token, &key, validation_data.unwrap_or(&v)) .map(|data| data.claims) } diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs index 6ffef5ab..ac14edbf 100644 --- a/src/service/sso/mod.rs +++ b/src/service/sso/mod.rs @@ -11,6 +11,7 @@ use crate::{ utils, Error, Result, }; use futures_util::future::{self}; +use http::HeaderValue; use mas_oidc_client::{ http_service::HttpService, requests::{authorization_code::AuthorizationValidationData, discovery}, @@ -20,6 +21,7 @@ use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUse use serde::{Deserialize, Serialize}; use tokio::sync::OnceCell; use tower::BoxError; +use tower_http::{set_header::SetRequestHeaderLayer, ServiceBuilderExt}; use tracing::error; use url::Url; @@ -43,14 +45,17 @@ impl Service { pub fn build(db: &'static dyn Data) -> Result> { let client = tower::ServiceBuilder::new() .map_err(BoxError::from) + .layer(tower_http::timeout::TimeoutLayer::new( + std::time::Duration::from_secs(10), + )) .layer(mas_http::BytesToBodyRequestLayer) .layer(mas_http::BodyToBytesResponseLayer) - // .override_request_header(http::header::USER_AGENT, "conduit".to_owned()) - // .concurrency_limit(10) - // .follow_redirects() - // .layer(tower_http::timeout::TimeoutLayer::new( - // std::time::Duration::from_secs(10), - // )) + .layer(SetRequestHeaderLayer::overriding( + http::header::USER_AGENT, + HeaderValue::from_static("conduit/0.9-alpha"), + )) + .concurrency_limit(10) + .follow_redirects() .service(mas_http::make_untraced_client()); Ok(Arc::new(Self { @@ -157,8 +162,8 @@ impl LoginToken { .expect("time overflow"), } } - pub fn audience(&self) -> &UserId { - &self.aud + pub fn audience(self) -> OwnedUserId { + self.aud } }