diff --git a/Cargo.toml b/Cargo.toml index bc079115..044eefea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,18 +35,15 @@ axum = { version = "0.7", default-features = false, features = [ "json", "matched-path", ], optional = true } -axum-extra = { version = "0.9", features = ["cookie", "typed-header"] } +axum-extra = { version = "0.9", features = ["typed-header", "cookie"] } axum-server = { version = "0.6", features = ["tls-rustls"] } tower = { version = "0.4.13", features = ["util"] } -# tower-http = { version = "0.5", features = [ -# "add-extension", -# "cors", -# "sensitive-headers", -# "trace", -# "util", -# ] } tower-http = { version = "0.5", features = [ - "full", + "add-extension", + "cors", + "sensitive-headers", + "trace", + "util", ] } tower-service = "0.3" @@ -143,6 +140,7 @@ figment = { version = "0.10.8", features = ["env", "toml"] } # Validating urls in config url = { version = "2", features = ["serde"] } +mas-oidc-client = { version = "0.9", default-features = false, features = ["hyper"] } # HTML maud = { version = "0.26.0", default-features = false, features = ["axum"] } @@ -152,19 +150,14 @@ tikv-jemallocator = { version = "0.5.0", features = [ ], optional = true } sd-notify = { version = "0.4.1", optional = true } -http-body-util = "0.1.2" -hyper-rustls = { version = "0.27.2", default-features = false, features = ["http1", "http2", "ring", "rustls-native-certs", "rustls-platform-verifier"] } -mas-http = "0.9.0" # Used for matrix spec type definitions and helpers [dependencies.ruma] features = [ "appservice-api-c", - "client", "client-api", "compat", "federation-api", - "client-hyper", "push-gateway-api-c", "rand", "ring-compat", @@ -183,16 +176,11 @@ optional = true package = "rust-rocksdb" version = "0.25" -# Used for Single Sign-On -[dependencies.mas-oidc-client] -git = "https://github.com/matrix-org/matrix-authentication-service.git" -default-features = false - [target.'cfg(unix)'.dependencies] nix = { version = "0.28", features = ["resource"] } [features] -default = ["backend_rocksdb", "backend_sqlite", "conduit_bin", "systemd"] +default = ["backend_sqlite", "conduit_bin"] #backend_sled = ["sled"] backend_persy = ["parking_lot", "persy"] backend_sqlite = ["sqlite"] diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index aea8d3ac..148c67f5 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,5 +1,6 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma}; +use jsonwebtoken::{Algorithm, Validation}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,17 +25,16 @@ struct Claims { pub async fn get_login_types_route( _body: Ruma, ) -> Result { + let identity_providers: Vec<_> = services().sso.login_type().collect(); let mut flows = vec![ get_login_types::v3::LoginType::Password(Default::default()), get_login_types::v3::LoginType::ApplicationService(Default::default()), ]; - if let v @ [_, ..] = &*services().sso.flows() { - let flow = get_login_types::v3::SsoLoginType { - identity_providers: v.to_owned(), - }; - - flows.push(get_login_types::v3::LoginType::Sso(flow)); + if !identity_providers.is_empty() { + flows.push(get_login_types::v3::LoginType::Sso( + get_login_types::v3::SsoLoginType { identity_providers }, + )); } Ok(get_login_types::v3::Response::new(flows)) @@ -113,30 +113,26 @@ pub async fn login_route(body: Ruma) -> Result { match ( services().globals.jwt_decoding_key(), - &services().sso.providers().is_empty(), + services().sso.login_type().next().is_some(), ) { (_, false) => { - let mut validation = - jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); + let mut validation = Validation::new(Algorithm::HS256); validation.validate_nbf = false; validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); - let login_token = services() + services() .globals .validate_claims::(token, Some(validation)) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.") - })?; - - login_token.audience().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid token audience.") - })? + .as_ref() + .map(LoginToken::audience) + .map(ToOwned::to_owned) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid token."))? } (Some(jwt_decoding_key), _) => { let token = jsonwebtoken::decode::( token, jwt_decoding_key, - &jsonwebtoken::Validation::default(), + &Validation::default(), ) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs index c6d26d27..c5e3b0e3 100644 --- a/src/api/client_server/sso.rs +++ b/src/api/client_server/sso.rs @@ -1,30 +1,24 @@ use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime}; use crate::{ - config::{ - sso::{Registration, Template}, - IdpConfig, - }, + config::IdpConfig, service::sso::{ - templates, LoginToken, RegistrationInfo, RegistrationToken, ValidationData, - REGISTRATION_EXPIRATION_SECS, SESSION_EXPIRATION_SECS, SSO_AUTH_EXPIRATION_SECS, - SSO_SESSION_COOKIE, + LoginToken, ValidationData, SSO_AUTH_EXPIRATION_SECS, SSO_SESSION_COOKIE, SUBJECT_CLAIM_KEY, }, services, utils, Error, Result, Ruma, }; use axum::{ - extract::RawQuery, response::{AppendHeaders, IntoResponse, Redirect}, RequestExt, }; use axum_extra::{ - headers::{self, HeaderMapExt}, + headers::{self}, TypedHeader, }; use http::header; use mas_oidc_client::{ requests::{ - authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData}, + authorization_code::{self, AuthorizationRequestData}, jose::{self, JwtVerificationData}, userinfo, }, @@ -35,17 +29,17 @@ use mas_oidc_client::{ requests::{AccessTokenResponse, AuthorizationResponse}, }, }; -use rand::{rngs::StdRng, SeedableRng}; +use rand::{rngs::StdRng, Rng, SeedableRng}; use ruma::{ api::client::{ error::ErrorKind, - session::{self, sso_login, sso_login_with_provider}, + session::{sso_login, sso_login_with_provider}, }, - events::GlobalAccountDataEventType, - push, OwnedMxcUri, UserId, + events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, + push, UserId, }; -use serde_json::Number; -use tracing::error; +use serde_json::Value; +use tracing::{error, info, warn}; use url::Url; pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; @@ -54,17 +48,28 @@ pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; /// /// Redirect the user to the SSO interface. /// TODO: this should be removed once Ruma supports trailing slashes. -pub async fn get_sso_redirect( - body: Ruma, +pub async fn get_sso_redirect_route( + Ruma { + body, + sender_user, + sender_device, + sender_servername, + json_body, + .. + }: Ruma, ) -> Result { let sso_login_with_provider::v3::Response { location, cookie } = - get_sso_redirect_with_provider( + get_sso_redirect_with_provider_route( Ruma { body: sso_login_with_provider::v3::Request::new( Default::default(), - body.redirect_url.clone(), + body.redirect_url, ), - ..body + sender_user, + sender_device, + sender_servername, + json_body, + appservice_info: None, } .into(), ) @@ -76,7 +81,7 @@ pub async fn get_sso_redirect( /// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}` /// /// Redirects the user to the SSO interface. -pub async fn get_sso_redirect_with_provider( +pub async fn get_sso_redirect_with_provider_route( body: Ruma, ) -> Result { let idp_ids: Vec<&str> = services() @@ -124,7 +129,7 @@ pub async fn get_sso_redirect_with_provider( .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?; let signed = services().globals.sign_claims(&ValidationData::new( - provider.borrow().to_string(), + Borrow::::borrow(provider).to_owned(), validation_data, )); @@ -142,12 +147,7 @@ pub async fn get_sso_redirect_with_provider( }) } -/// # `GET /_conduit/client/sso/callback` -/// -/// Validate the authorization response received from the identity provider. -/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. -/// If this is the first login, register the user, possibly interactively through a fallback page. -pub async fn get_sso_callback(req: axum::extract::Request) -> Result { +async fn handle_callback_helper(req: axum::extract::Request) -> Result { let query = req.uri().query().ok_or_else(|| { Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.") })?; @@ -158,22 +158,26 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result(query).map_err(|_| { - serde_html_form::from_str::(query).unwrap_or_else(|_| { - error!("Failed to deserialize authorization callback: {}", callback); + } = serde_html_form::from_str(query).map_err(|_| { + serde_html_form::from_str(query) + .map(ClientError::into) + .unwrap_or_else(|_| { + error!("Failed to deserialize authorization callback: {}", query); - Error::BadRequest( - ErrorKind::Unknown, - "Failed to deserialize authorization callback.", - ) - }) + Error::BadRequest( + ErrorKind::Unknown, + "Failed to deserialize authorization callback.", + ) + }) })?; - let cookie = req - .extract::>>() - .await - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid session cookie."))? - .ok_or_else(|_| Error::BadRequest(ErrorKind::MissingParam, "Missing session cookie."))?; + let Ok(Some(cookie)): Result>, _> = req.extract().await + else { + return Err(Error::BadRequest( + ErrorKind::MissingParam, + "Missing session cookie.", + )); + }; let ValidationData { provider, @@ -186,11 +190,11 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result Result ClientCredentials::ClientSecretBasic { client_id, client_secret, @@ -215,6 +219,16 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result todo!(), }; + let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri()) + .await + .map_err(|_| Error::bad_config("Failed to fetch signing keys for token endpoint."))?; + let jwt_verification_data = Some(JwtVerificationData { + jwks, + issuer: &provider.config.issuer, + client_id: &provider.config.client_id, + signing_algorithm: &JsonWebSignatureAlg::Rs256, + }); + let ( AccessTokenResponse { access_token, @@ -227,34 +241,22 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result Result { + let subject = match id_token.get(SUBJECT_CLAIM_KEY) { + Some(Value::String(s)) => s.to_owned(), + Some(Value::Number(n)) => n.to_string(), + value => { return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Single Sign-On registration is disabled.", - )) + ErrorKind::Unknown, + value + .map(|_| { + error!("Subject claim is missing from ID token: {id_token:?}"); + + "Subject claim is missing from ID token." + }) + .unwrap_or("Subject claim should be a string or number."), + )); } - Registration::Automated => todo!(), - Registration::Interactive => {} }; - let Template { - username, - displayname, - avatar_url, - email, - } = &provider.config.template; - let registration_info = - RegistrationInfo::new(&claims, username, displayname, avatar_url, email); - - let signed = services() - .globals - .sign_macaroon(&RegistrationToken::new( - validation_data.provider_id.clone(), - subject.to_owned(), - redirect_uri.to_owned(), - registration_info, - )) - .expect("signing macaroons always works"); - - let cookie = utils::build_cookie( - "sso-registration", - &signed, - "/_conduit/client/sso/register", - REGISTRATION_EXPIRATION_SECS, - ); - - Ok(( - AppendHeaders(vec![ - (header::SET_COOKIE, cookie.to_string()), - ( - header::SET_COOKIE, - utils::reset_cookie("sso-session").to_string(), - ), - ]), - Redirect::temporary("/_conduit/client/sso/register"), - ) - .into_response()) -} - -/// # `GET /_conduit/client/sso/pick_idp` -pub async fn pick_idp(RawQuery(query): RawQuery) -> impl IntoResponse { - let providers: Vec<_> = services() - .globals - .config + let user_id = match services() .sso - .iter() - .map(|p| p.inner.to_owned()) - .collect(); + .user_from_subject(Borrow::::borrow(provider), &subject)? + { + Some(user_id) => user_id, + None => { + let mut localpart = subject.clone(); - let body = maud::html! { - header { - h1 { "Log in to " (services().globals.server_name()) } - p { "Choose an identity provider to continue" } - } - main { - ul .providers { - @for provider in providers { - li { - a href={ "/_matrix/client/v3/login/sso/redirect/" (provider.id) "?" (query.as_deref().unwrap_or_default()) } { - @if let Some(url) = provider.icon.as_deref().and_then(utils::mxc_to_http) { - img src=(url); - } - } - span { - (provider.name) - } + let user_id = loop { + match UserId::parse_with_server_name(&*localpart, services().globals.server_name()) + { + Ok(user_id) if services().users.exists(&user_id)? => break user_id, + _ => { + let n: u8 = rand::thread_rng().gen(); + + localpart = format!("{}{}", localpart, n % 10); } } + }; + + services().users.set_placeholder_password(&user_id)?; + let mut displayname = id_token + .get("preferred_username") + .or(id_token.get("nickname")) + .as_deref() + .map(Value::to_string) + .unwrap_or(user_id.localpart().to_owned()); + + // If enabled append lightning bolt to display name (default true) + if services().globals.enable_lightning_bolt() { + displayname.push_str(" ⚡️"); } + + services() + .users + .set_displayname(&user_id, Some(displayname.clone()))?; + + // Initial account data + services().account_data.update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json always works"), + )?; + + info!("New user {} registered on this server.", user_id); + services() + .admin + .send_message(RoomMessageEventContent::notice_plain(format!( + "New user {user_id} registered on this server." + ))); + + if let Some(admin_room) = services().admin.get_admin_room()? { + if services() + .rooms + .state_cache + .room_joined_count(&admin_room)? + == Some(1) + { + services() + .admin + .make_user_admin(&user_id, displayname) + .await?; + + warn!("Granting {} admin privileges as the first user", user_id); + } + } + + user_id } }; - ( - [(header::CONTENT_TYPE, "text/html; charset=utf-8")], - maud::html! { - (templates::base("Pick Identity Provider", body)) + let signed = services().globals.sign_claims(&LoginToken::new( + Borrow::::borrow(provider).to_owned(), + user_id, + )); - (templates::footer()) - }, - ) -} - -/// # `GET /_conduit/client/sso/register` -/// -/// Serve a registration form with defaults based on the retrieved claims. -/// This endpoint is only available when interactive registration is enabled. -pub async fn get_sso_registration( - cookie: TypedHeader, -) -> Result { - let token = cookie.get("sso-registration").ok_or_else(|| { - Error::BadRequest( - ErrorKind::MissingParam, - "Missing registration token cookie.", - ) - })?; - - let registration_token: RegistrationToken = services() - .globals - .validate_macaroon(token, None) - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid registration token cookie.", - ) - })?; - - let provider = services() - .sso - .get(®istration_token.provider_id) - .map(|p| p.config.inner.to_owned())?; - let server_name = services().globals.server_name(); - - let RegistrationInfo { - username, - displayname, - avatar_url, - email, - } = registration_token.info; - - let additional_info = (&displayname, &avatar_url, &email) != (&None, &None, &None); - - fn detail(title: &str, body: maud::Markup) -> maud::Markup { - maud::html! { - label .detail for=(title) { - div .check-row { - span .name { (title) } " " - span .use { "use" } - input #(title) type="checkbox" name={(title)"-checkbox"} value=(true) checked; - } - (body) - } - } - } - - let body = maud::html! { - header { - h1 { "Complete your registration at " (server_name) } - p { "Confirm your details to finish creating your account." } - } - main { - form .form #form method="post" { - div .username-div #username-div { - label for="username-input" { "Username (required)" } - div .prefix { "@" } - input .username-input type="text" name="username" - value=(username) autofocus autocorrect="off" autocapitalize="none"; - div .postfix { ":" (server_name) } - } - output .username-output for="username-input" { } - - @if additional_info { - section .additional-info { - h2 { - @if let Some(icon) = provider.icon.as_deref().and_then(utils::mxc_to_http) { - img src=(icon.to_string()); - } - "Optional data from " (provider.name) - } - @if let Some(avatar_url) = avatar_url.as_ref() { - (detail("avatar", maud::html!{ - img .avatar src=(avatar_url); - })) - } - @if let Some(displayname) = displayname.as_ref() { - (detail("displayname", maud::html!{ - p .value { (displayname) }; - })) - } - @if let Some(email) = email.as_ref() { - (detail("email", maud::html!{ - p .value { (email) }; - })) - } - } - } - - input type="submit" value="Submit" .primary-button {} - } - } - }; + let mut redirect_uri = validation_data.redirect_uri; + redirect_uri + .query_pairs_mut() + .append_pair("loginToken", &signed); Ok(( - [(header::CONTENT_TYPE, "text/html; charset=utf-8")], - maud::html! { - (templates::base("Register Account", body)) - - (templates::footer()) - }, - ) - .into_response()) -} - -/// # `POST /_conduit/client/sso/register` -/// -/// Submit the registration form. -pub async fn submit_sso_registration( - cookie: TypedHeader, - axum::extract::Form(registration_info): axum::extract::Form, -) -> Result { - let token = cookie.get("sso-registration").ok_or_else(|| { - Error::BadRequest( - ErrorKind::MissingParam, - "Missing registration token cookie.", - ) - })?; - - let registration_token: RegistrationToken = services() - .globals - .validate_macaroon(token, None) - .map_err(|_| { - Error::BadRequest( - ErrorKind::MissingParam, - "Invalid registration token cookie.", - ) - })?; - - let RegistrationInfo { - username, - mut displayname, - avatar_url, - email: _, - } = registration_info; - - let user_id = - UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Invalid username."))?; - - if services().users.exists(&user_id)? { - return Err(Error::BadRequest( - ErrorKind::UserInUse, - "Desired UserId is already taken.", - )); - } - - if services().appservice.is_exclusive_user_id(&user_id).await { - return Err(Error::BadRequest( - ErrorKind::Exclusive, - "Desired UserId reserved by appservice.", - )); - } - - services().users.create(&user_id, None)?; - services().users.set_password_placeholder(&user_id)?; - - if let Some(avatar_url) = avatar_url { - let request = services().globals.default_client().get(avatar_url.as_ref()); - - let res = request.send().await.map_err(|_| { - Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.") - })?; - - let filename = avatar_url.path_segments().and_then(Iterator::last); - - let (content_type, body): (Option, Vec) = ( - res.headers().typed_get(), - res.bytes().await.map(Into::into).map_err(|_| { - Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.") - })?, - ); - - let mxc = format!( - "mxc://{}/{}", - services().globals.server_name(), - utils::random_string(crate::api::client_server::MXC_LENGTH) - ); - - services() - .media - .create( - mxc.clone(), - filename - .map(|filename| "inline; filename=".to_owned() + filename) - .as_deref(), - content_type.map(|header| header.to_string()).as_deref(), - &body, - ) - .await?; - - services() - .users - .set_avatar_url(&user_id, Some(OwnedMxcUri::from(mxc)))?; - }; - - if let (Some(displayname), true) = ( - displayname.as_mut(), - services().globals.config.enable_lightning_bolt, - ) { - displayname.push_str(" ⚡️"); - } - - services().users.set_displayname(&user_id, displayname)?; - - services().sso.save_claim( - ®istration_token.provider_id, - &user_id, - ®istration_token.unique_claim, - )?; - - services().account_data.update( - None, - &user_id, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: push::Ruleset::server_default(&user_id), - }, - }) - .expect("PushRulesEvent should always serialize"), - )?; - - let login_token = LoginToken::new(registration_token.provider_id, user_id); - let redirect_uri = redirect_with_login_token(registration_token.redirect_uri, &login_token); - - Ok(( - AppendHeaders([( + AppendHeaders(vec![( header::SET_COOKIE, - utils::reset_cookie("sso-registration").to_string(), + utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string(), )]), Redirect::temporary(redirect_uri.as_str()), ) .into_response()) } -fn redirect_with_login_token(mut redirect_uri: Url, login_token: &LoginToken) -> Url { - let signed = services() - .globals - .sign_macaroon(login_token) - .expect("signing macaroons should always works"); - - redirect_uri - .query_pairs_mut() - .append_pair("loginToken", &signed); - - redirect_uri +/// # `GET /_conduit/client/sso/callback` +/// +/// Validate the authorization response received from the identity provider. +/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. +/// If this is the first login, register the user, possibly interactively through a fallback page. +pub async fn handle_callback_route(req: axum::extract::Request) -> axum::response::Response { + match handle_callback_helper(req).await { + Ok(res) => res, + Err(e) => e.into_response(), + } } diff --git a/src/database/key_value/sso.rs b/src/database/key_value/sso.rs new file mode 100644 index 00000000..1f6eab28 --- /dev/null +++ b/src/database/key_value/sso.rs @@ -0,0 +1,29 @@ +use ruma::{OwnedUserId, UserId}; + +use crate::{service, utils, Error, KeyValueDatabase, Result}; + +impl service::sso::Data for KeyValueDatabase { + fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()> { + let mut key = provider.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(subject.as_bytes()); + + self.subject_userid.insert(&key, user_id.as_bytes()) + } + + fn user_from_subject(&self, provider: &str, subject: &str) -> Result> { + let mut key = provider.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(subject.as_bytes()); + + self.subject_userid.get(&key)?.map_or(Ok(None), |bytes| { + Some( + UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in claim_userid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("User ID in claim_userid is invalid.")), + ) + .transpose() + }) + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 16a5e60a..39550a93 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -49,6 +49,7 @@ pub struct KeyValueDatabase { pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 pub(super) token_userdeviceid: Arc, + pub(super) subject_userid: Arc, pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count @@ -289,6 +290,8 @@ impl KeyValueDatabase { userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, token_userdeviceid: builder.open_tree("token_userdeviceid")?, + subject_userid: builder.open_tree("subject_userid")?, + onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, keychangeid_userid: builder.open_tree("keychangeid_userid")?, diff --git a/src/main.rs b/src/main.rs index 0b07fe2b..c3ad4c1e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use axum::{ }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use conduit::api::{ - client_server::{self, SSO_CALLBACK_PATH}, + client_server::{self, CALLBACK_PATH}, server_server, }; use figment::{ @@ -283,7 +283,7 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::get_sso_redirect_with_provider_route) // The specification will likely never introduce any endpoint for handling authorization callbacks. // As a workaround, we use custom path that redirects the user to the default login handler. - .route(SSO_CALLBACK_PATH, get(client_server::sso_login_route)) + .route(CALLBACK_PATH, get(client_server::handle_callback_route)) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) diff --git a/src/service/mod.rs b/src/service/mod.rs index fae5a726..6d8c34d2 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -123,7 +123,7 @@ impl Services { key_backups: key_backups::Service { db }, media: media::Service { db }, sending: sending::Service::build(db, &config), - sso: sso::Service::build(db), + sso: sso::Service::build(db)?, globals: globals::Service::load(db, config)?, }) diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs index 1206ed34..31c9c3ab 100644 --- a/src/service/sso/mod.rs +++ b/src/service/sso/mod.rs @@ -1,50 +1,36 @@ -mod data; use std::{ borrow::Borrow, - collections::{HashMap, HashSet}, + collections::HashSet, hash::{Hash, Hasher}, - str::FromStr, - sync::{Arc, RwLock}, - time::{Duration, SystemTime, UNIX_EPOCH}, + sync::Arc, }; use crate::{ - api::client_server::TOKEN_LENGTH, - config::{sso::ProviderConfig as Config, IdpConfig}, + api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH}, + config::IdpConfig, utils, Error, Result, }; -pub use data::Data; -use email_address::EmailAddress; use futures_util::future::{self}; use mas_oidc_client::{ http_service::{hyper, HttpService}, - jose::jwk::PublicJsonWebKeySet, - requests::{ - authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData}, - discovery, - jose::{self, JwtVerificationData}, - userinfo, - }, - types::{ - iana::jose::JsonWebSignatureAlg, oidc::VerifiedProviderMetadata, - requests::AccessTokenResponse, IdToken, - }, + requests::{authorization_code::AuthorizationValidationData, discovery}, + types::oidc::VerifiedProviderMetadata, }; -use rand::SeedableRng; -use ruma::{api::client::error::ErrorKind, MilliSecondsSinceUnixEpoch, OwnedUserId, UserId}; +use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId}; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tokio::sync::{oneshot, OnceCell}; +use tokio::sync::OnceCell; use tracing::error; use url::Url; use crate::services; +mod data; pub use data::Data; pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60; pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2; pub const SSO_SESSION_COOKIE: &str = "sso-auth"; +pub const SUBJECT_CLAIM_KEY: &str = "sub"; pub struct Service { db: &'static dyn Data, @@ -69,7 +55,7 @@ impl Service { let providers = services().globals.config.idps.iter(); self.providers - .get_or_try_init(|| { + .get_or_try_init(|| async move { future::try_join_all(providers.map(Provider::fetch_metadata)) .await .map(Vec::into_iter) @@ -86,6 +72,12 @@ impl Service { providers.get(provider) } + pub fn login_type(&self) -> impl Iterator + '_ { + let providers = self.providers.get().expect(""); + + providers.iter().map(|p| p.config.inner.clone()) + } + pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result> { self.db.user_from_subject(provider, subject) } @@ -111,30 +103,6 @@ impl Provider { Error::bad_config("Failed to fetch identity provider metadata.") }) } - - async fn fetch_signing_keys(&self) -> Result { - jose::fetch_jwks(&services().sso.service, self.metadata.jwks_uri()) - .await - .map_err(|e| { - error!("Failed to fetch signing keys for token endpoint: {}", e); - - Error::bad_config("Failed to fetch signing keys for token endpoint.") - }) - } - - pub async fn fetch_access_token( - &self, - auth_code: String, - validation_data: AuthorizationValidationData, - ) -> Result<(AccessTokenResponse, Option>)> { - } - - pub async fn fetch_userinfo( - &self, - access_token: &str, - id_token: &IdToken<'_>, - ) -> Result>> { - } } impl Borrow for Provider { @@ -157,105 +125,28 @@ impl Hash for Provider { } } -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct RegistrationToken { - pub info: RegistrationInfo, - pub provider_id: String, - pub unique_claim: String, - pub redirect_uri: Url, - pub expires_at: MilliSecondsSinceUnixEpoch, -} - -impl RegistrationToken { - pub fn new( - provider_id: String, - unique_claim: String, - redirect_uri: Url, - info: RegistrationInfo, - ) -> Self { - let expires_at = MilliSecondsSinceUnixEpoch::from_system_time( - UNIX_EPOCH - .checked_add(Duration::from_secs(REGISTRATION_EXPIRATION_SECS)) - .expect("SystemTime should not overflow"), - ) - .expect("MilliSecondsSinceUnixEpoch is not too large"); - - Self { - info, - provider_id, - unique_claim, - redirect_uri, - expires_at, - } - } -} - -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -pub struct RegistrationInfo { - pub username: String, - pub displayname: Option, - pub avatar_url: Option, - pub email: Option, -} - -impl RegistrationInfo { - pub fn new( - claims: &HashMap, - username: &str, - displayname: &str, - avatar_url: &str, - email: &str, - ) -> Self { - Self { - username: claims - .get(username) - .and_then(|v| v.as_str()) - .map(ToOwned::to_owned) - .unwrap_or_default(), - displayname: claims - .get(displayname) - .and_then(|v| v.as_str()) - .map(ToOwned::to_owned), - avatar_url: claims - .get(avatar_url) - .and_then(|v| v.as_str()) - .map(Url::parse) - .and_then(Result::ok), - email: claims - .get(email) - .and_then(|v| v.as_str()) - .map(EmailAddress::from_str) - .and_then(Result::ok), - } - } -} - #[derive(Clone, Deserialize, Serialize)] pub struct LoginToken { - pub inner: String, - pub provider_id: String, - pub user_id: OwnedUserId, - - #[serde(rename = "exp")] - expires_at: u64, + pub iss: String, + pub aud: OwnedUserId, + pub sub: String, + pub exp: u64, } impl LoginToken { - pub fn new(provider_id: String, user_id: OwnedUserId) -> Self { - let expires_at = SystemTime::now() - .checked_add(Duration::from_secs(LOGIN_TOKEN_EXPIRATION_SECS)) - .expect("SystemTime should not overflow") - .duration_since(UNIX_EPOCH) - .expect("SystemTime went backwards") - .as_secs(); - + pub fn new(provider: String, user_id: OwnedUserId) -> Self { Self { - inner: utils::random_string(TOKEN_LENGTH), - provider_id, - user_id, - expires_at, + iss: provider, + aud: user_id, + sub: utils::random_string(TOKEN_LENGTH), + exp: utils::millis_since_unix_epoch() + .checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000) + .expect("time overflow"), } } + pub fn audience(&self) -> &UserId { + &self.aud + } } #[derive(Clone, Debug, Deserialize, Serialize)]