diff --git a/Cargo.toml b/Cargo.toml index 0cdde4ab..b3d95485 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,13 +35,17 @@ axum = { version = "0.7", default-features = false, features = [ "json", "matched-path", ], optional = true } -axum-extra = { version = "0.9", features = ["typed-header"] } +axum-extra = { version = "0.9", features = ["cookie", "typed-header"] } axum-server = { version = "0.6", features = ["tls-rustls"] } tower = { version = "0.4.13", features = ["util"] } tower-http = { version = "0.5", features = [ "add-extension", "cors", + "follow-redirect", + "map-request-body", "sensitive-headers", + "set-header", + "timeout", "trace", "util", ] } @@ -172,6 +176,16 @@ optional = true package = "rust-rocksdb" version = "0.25" +[dependencies.mas-http] +features = ["client"] +git = "https://github.com/matrix-org/matrix-authentication-service" +rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8" + +[dependencies.mas-oidc-client] +features = [] +git = "https://github.com/matrix-org/matrix-authentication-service" +rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8" + [target.'cfg(unix)'.dependencies] nix = { version = "0.28", features = ["resource"] } diff --git a/docs/configuration.md b/docs/configuration.md index 9687ead1..0fee0d03 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -13,6 +13,7 @@ Conduit's configuration file is divided into the following sections: - [Global](#global) - [TLS](#tls) - [Proxy](#proxy) + - [SSO (Single Sign-On)](#sso) ## Global @@ -111,3 +112,20 @@ exclude = ["*.clearnet.onion"] [global] {{#include ../conduit-example.toml:22:}} ``` + +### SSO (Single Sign-On) + +Authentication through SSO instead of a password can be enabled by configuring OIDC (OpenID Connect) identity providers. +Identity providers using OAuth such as Github are not supported yet. + +> **Note:** The `*` symbol indicates that the field is required, and the values in **parentheses** are the possible values + +| Field | Type | Description | Default | +| --- | --- | --- | --- | +| `issuer`* | `Url` | The issuer URL. | N/A | +| `name` | `string` | The name displayed on fallback pages. | `issuer` | +| `icon` | `Url` OR `MxcUri` | The icon displayed on fallback pages. | N/A | +| `scopes` | `array` | The scopes used to obtain extra claims which can be used for templates. | `["openid"]` | +| `client_id`* | `string` | The provider-supplied, unique ID for the client. | N/A | +| `client_secret`* | `string` | The provider-supplied, unique ID for the client. | N/A | +| `authentication_method`* | `"basic" OR "post"` | The method used for client authentication. | N/A | diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 4af8890d..05110248 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -100,6 +100,12 @@ pub async fn upload_signing_keys_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let master_key = services() + .users + .get_master_key(Some(sender_user), sender_user, &|other| { + sender_user == other + })?; + // UIAA let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -111,7 +117,15 @@ pub async fn upload_signing_keys_route( auth_error: None, }; - if let Some(auth) = &body.auth { + if let (Some(master_key), None) = (&body.master_key, master_key) { + services().users.add_cross_signing_keys( + sender_user, + master_key, + &body.self_signing_key, + &body.user_signing_key, + true, + )?; + } else if let Some(auth) = &body.auth { let (worked, uiaainfo) = services() .uiaa @@ -130,16 +144,6 @@ pub async fn upload_signing_keys_route( return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - if let Some(master_key) = &body.master_key { - services().users.add_cross_signing_keys( - sender_user, - master_key, - &body.self_signing_key, - &body.user_signing_key, - true, // notify so that other users see the new keys - )?; - } - Ok(upload_signing_keys::v3::Response {}) } diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index a35d7a98..07ee7a17 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -23,6 +23,7 @@ mod room; mod search; mod session; mod space; +mod sso; mod state; mod sync; mod tag; @@ -60,6 +61,7 @@ pub use room::*; pub use search::*; pub use session::*; pub use space::*; +pub use sso::*; pub use state::*; pub use sync::*; pub use tag::*; @@ -76,3 +78,5 @@ pub const DEVICE_ID_LENGTH: usize = 10; pub const TOKEN_LENGTH: usize = 32; pub const SESSION_ID_LENGTH: usize = 32; pub const AUTO_GEN_PASSWORD_LENGTH: usize = 15; +pub const AUTH_SESSION_EXPIRATION_SECS: u64 = 60 * 5; +pub const LOGIN_TOKEN_EXPIRATION_SECS: u64 = 15; diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 07078328..0c1189ae 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,5 +1,6 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{services, utils, Error, Result, Ruma}; +use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma}; +use jsonwebtoken::{Algorithm, Validation}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,10 +25,19 @@ struct Claims { pub async fn get_login_types_route( _body: Ruma, ) -> Result { - Ok(get_login_types::v3::Response::new(vec![ + let identity_providers: Vec<_> = services().sso.login_type().collect(); + let mut flows = vec![ get_login_types::v3::LoginType::Password(Default::default()), get_login_types::v3::LoginType::ApplicationService(Default::default()), - ])) + ]; + + if !identity_providers.is_empty() { + flows.push(get_login_types::v3::LoginType::Sso( + get_login_types::v3::SsoLoginType { identity_providers }, + )); + } + + Ok(get_login_types::v3::Response::new(flows)) } /// # `POST /_matrix/client/r0/login` @@ -101,35 +111,64 @@ pub async fn login_route(body: Ruma) -> Result { - if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { - let token = jsonwebtoken::decode::( - token, - jwt_decoding_key, - &jsonwebtoken::Validation::default(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; - let username = token.claims.sub.to_lowercase(); - let user_id = - UserId::parse_with_server_name(username, services().globals.server_name()) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - })?; + match ( + services().globals.jwt_decoding_key(), + services().globals.config.idps.is_empty(), + ) { + (_, false) => { + let mut v = Validation::new(Algorithm::HS256); - if services().appservice.is_exclusive_user_id(&user_id).await { + v.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); + v.validate_aud = false; + v.validate_nbf = false; + + services() + .globals + .validate_claims::(token, Some(&v)) + .map(LoginToken::audience) + .map_err(|e| { + tracing::warn!("Invalid token: {}", e); + + Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.") + })? + } + (Some(jwt_decoding_key), _) => { + let token = jsonwebtoken::decode::( + token, + jwt_decoding_key, + &Validation::default(), + ) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") + })?; + let username = token.claims.sub.to_lowercase(); + let user_id = + UserId::parse_with_server_name(username, services().globals.server_name()) + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ) + })?; + + if services().appservice.is_exclusive_user_id(&user_id).await { + return Err(Error::BadRequest( + ErrorKind::Exclusive, + "User id reserved by appservice.", + )); + } + + user_id + } + (None, _) => { return Err(Error::BadRequest( - ErrorKind::Exclusive, - "User id reserved by appservice.", + ErrorKind::Unknown, + "Token login is not supported (server has no jwt decoding key).", )); } - - user_id - } else { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Token login is not supported (server has no jwt decoding key).", - )); } } + login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService { identifier, user, diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs new file mode 100644 index 00000000..a35439c9 --- /dev/null +++ b/src/api/client_server/sso.rs @@ -0,0 +1,473 @@ +use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime}; + +use crate::{ + config::IdpConfig, + service::sso::{ + LoginToken, ValidationData, SSO_AUTH_EXPIRATION_SECS, SSO_SESSION_COOKIE, SUBJECT_CLAIM_KEY, + }, + services, utils, Error, Result, Ruma, +}; +use futures_util::TryFutureExt; +use mas_oidc_client::{ + requests::{ + authorization_code::{self, AuthorizationRequestData}, + jose::{self, JwtVerificationData}, + userinfo, + }, + types::{ + client_credentials::ClientCredentials, + iana::jose::JsonWebSignatureAlg, + requests::{AccessTokenResponse, AuthorizationResponse}, + }, +}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use ruma::{ + api::client::{ + error::ErrorKind, + media::create_content, + session::{sso_login, sso_login_with_provider}, + }, + events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, + push, UserId, +}; +use serde_json::Value; +use tracing::{error, info, warn}; +use url::Url; + +pub const CALLBACK_PATH: &str = "/_matrix/client/unstable/conduit/callback"; + +/// # `GET /_matrix/client/v3/login/sso/redirect` +/// +/// Redirect the user to the SSO interfa. +/// TODO: this should be removed once Ruma supports trailing slashes. +pub async fn get_sso_redirect_route( + Ruma { + body, + sender_user, + sender_device, + sender_servername, + json_body, + .. + }: Ruma, +) -> Result { + let sso_login_with_provider::v3::Response { location, cookie } = + get_sso_redirect_with_provider_route( + Ruma { + body: sso_login_with_provider::v3::Request::new( + Default::default(), + body.redirect_url, + ), + sender_user, + sender_device, + sender_servername, + json_body, + appservice_info: None, + } + .into(), + ) + .await?; + + Ok(sso_login::v3::Response { location, cookie }) +} + +/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}` +/// +/// Redirects the user to the SSO interface. +pub async fn get_sso_redirect_with_provider_route( + body: Ruma, +) -> Result { + let idp_ids: Vec<&str> = services() + .globals + .config + .idps + .iter() + .map(Borrow::borrow) + .collect(); + + let provider = match &*idp_ids { + [] => { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Single Sign-On is disabled.", + )); + } + [idp_id] => services().sso.get(idp_id).expect("we know it exists"), + [_, ..] => services().sso.get(&body.idp_id).ok_or_else(|| { + Error::BadRequest(ErrorKind::InvalidParam, "Unknown identity provider.") + })?, + }; + + let redirect_url = body + .redirect_url + .parse::() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid redirect_url."))?; + + let mut callback = services() + .globals + .well_known_client() + .parse::() + .map_err(|_| Error::bad_config("Invalid well_known_client url."))?; + callback.set_path(CALLBACK_PATH); + + let (auth_url, validation_data) = authorization_code::build_authorization_url( + provider.metadata.authorization_endpoint().clone(), + AuthorizationRequestData::new( + provider.config.client_id.clone(), + provider.config.scopes.clone(), + callback, + ), + &mut StdRng::from_entropy(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?; + + let signed = services().globals.sign_claims(&ValidationData::new( + Borrow::::borrow(provider).to_owned(), + redirect_url.to_string(), + validation_data, + )); + + Ok(sso_login_with_provider::v3::Response { + location: auth_url.to_string(), + cookie: Some( + utils::build_cookie( + SSO_SESSION_COOKIE, + &signed, + CALLBACK_PATH, + Some(SSO_AUTH_EXPIRATION_SECS), + ) + .to_string(), + ), + }) +} + +/// # `GET /_conduit/client/sso/callback` +/// +/// Validate the authorization response received from the identity provider. +/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. +/// If this is the first login, register the user, possibly interactively through a fallback page. +pub async fn handle_callback_route( + body: Ruma, +) -> Result { + let sso_callback::Request { + response: + AuthorizationResponse { + code, + access_token: _, + token_type: _, + id_token: _, + expires_in: _, + }, + cookie, + } = body.body; + + let ValidationData { + provider, + redirect_url, + inner: validation_data, + } = services() + .globals + .validate_claims(&cookie, None) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.") + })?; + + let provider = services().sso.get(&provider).ok_or_else(|| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Unknown provider for session cookie.", + ) + })?; + + let IdpConfig { + client_id, + client_secret, + auth_method, + .. + } = provider.config.clone(); + + let credentials = match &*auth_method { + "basic" => ClientCredentials::ClientSecretBasic { + client_id, + client_secret, + }, + "post" => ClientCredentials::ClientSecretPost { + client_id, + client_secret, + }, + _ => todo!(), + }; + let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri()) + .await + .map_err(|_| Error::bad_config("Failed to fetch signing keys for token endpoint."))?; + let idt_verification_data = Some(JwtVerificationData { + jwks, + issuer: &provider.config.issuer, + client_id: &provider.config.client_id, + signing_algorithm: &JsonWebSignatureAlg::Rs256, + }); + + let ( + AccessTokenResponse { + access_token, + refresh_token: _, + token_type: _, + expires_in: _, + scope: _, + .. + }, + Some(id_token), + ) = authorization_code::access_token_with_authorization_code( + services().sso.service(), + credentials, + provider.metadata.token_endpoint(), + code.unwrap_or_default(), + validation_data, + idt_verification_data, + SystemTime::now().into(), + &mut StdRng::from_entropy(), + ) + .await + .map_err(|_| Error::bad_config("Failed to fetch access token."))? + else { + unreachable!("ID token should never be empty") + }; + + let mut userinfo = HashMap::default(); + if let Some(endpoint) = provider.metadata.userinfo_endpoint.as_ref() { + userinfo = userinfo::fetch_userinfo( + services().sso.service(), + endpoint, + &access_token, + None, + &id_token, + ) + .await + .map_err(|e| { + tracing::error!("Failed to fetch claims for userinfo endpoint: {:?}", e); + + Error::bad_config("Failed to fetch claims for userinfo endpoint.") + })?; + } + + let (_, id_token) = id_token.into_parts(); + + info!("userinfo: {:?}", &userinfo); + info!("id_token: {:?}", &id_token); + + let subject = match id_token.get(SUBJECT_CLAIM_KEY) { + Some(Value::String(s)) => s.to_owned(), + Some(Value::Number(n)) => n.to_string(), + value => { + return Err(Error::BadRequest( + ErrorKind::Unknown, + value + .map(|_| { + error!("Subject claim is missing from ID token: {id_token:?}"); + + "Subject claim is missing from ID token." + }) + .unwrap_or("Subject claim should be a string or number."), + )); + } + }; + + let user_id = match services() + .sso + .user_from_subject(Borrow::::borrow(provider), &subject)? + { + Some(user_id) => user_id, + None => { + let mut localpart = subject.clone(); + + let user_id = loop { + match UserId::parse_with_server_name(&*localpart, services().globals.server_name()) + .map(|user_id| { + ( + user_id.clone(), + services().users.exists(&user_id).unwrap_or(true), + ) + }) { + Ok((user_id, false)) => break user_id, + _ => { + let n: u8 = rand::thread_rng().gen(); + + localpart = format!("{}{}", localpart, n % 10); + } + } + }; + + services().users.set_placeholder_password(&user_id)?; + let displayname = id_token + .get("preferred_username") + .or(id_token.get("nickname")); + let mut displayname = displayname + .as_deref() + .map(Value::as_str) + .flatten() + .unwrap_or(user_id.localpart()) + .to_owned(); + + // If enabled append lightning bolt to display name (default true) + if services().globals.enable_lightning_bolt() { + displayname.push_str(" ⚡️"); + } + + services() + .users + .set_displayname(&user_id, Some(displayname.clone()))?; + + if let Some(Value::String(url)) = userinfo.get("picture").or(id_token.get("picture")) { + let req = services() + .globals + .default_client() + .get(url) + .send() + .and_then(reqwest::Response::bytes); + + if let Ok(file) = req.await { + let _ = crate::api::client_server::create_content_route(Ruma { + body: create_content::v3::Request::new(file.to_vec()), + sender_user: None, + sender_device: None, + sender_servername: None, + json_body: None, + appservice_info: None, + }) + .await + .and_then(|res| { + tracing::info!("successfully imported avatar for {}", &user_id); + + services() + .users + .set_avatar_url(&user_id, Some(res.content_uri)) + }); + } + } + + // Initial account data + services().account_data.update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json always works"), + )?; + + info!("New user {} registered on this server.", user_id); + services() + .admin + .send_message(RoomMessageEventContent::notice_plain(format!( + "New user {user_id} registered on this server." + ))); + + if let Some(admin_room) = services().admin.get_admin_room()? { + if services() + .rooms + .state_cache + .room_joined_count(&admin_room)? + == Some(1) + { + services() + .admin + .make_user_admin(&user_id, displayname.to_owned()) + .await?; + + warn!("Granting {} admin privileges as the first user", user_id); + } + } + + user_id + } + }; + + let signed = services().globals.sign_claims(&LoginToken::new( + Borrow::::borrow(provider).to_owned(), + user_id, + )); + + let mut redirect_url: Url = redirect_url.parse().expect(""); + redirect_url + .query_pairs_mut() + .append_pair("loginToken", &signed); + + Ok(sso_login_with_provider::v3::Response { + location: redirect_url.to_string(), + cookie: Some(utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string()), + }) +} + +mod sso_callback { + use axum_extra::headers::{self, HeaderMapExt}; + use http::Method; + use mas_oidc_client::types::requests::AuthorizationResponse; + use ruma::{ + api::{ + client::{session::sso_login_with_provider, Error}, + error::{FromHttpRequestError, HeaderDeserializationError}, + IncomingRequest, Metadata, + }, + metadata, + }; + + use crate::service::sso::SSO_SESSION_COOKIE; + + pub const METADATA: Metadata = metadata! { + method: GET, + rate_limited: false, + authentication: None, + history: { + 1.0 => "/_matrix/client/unstable/conduit/callback", + } + }; + + pub struct Request { + pub response: AuthorizationResponse, + pub cookie: String, + } + + impl IncomingRequest for Request { + type EndpointError = Error; + type OutgoingResponse = sso_login_with_provider::v3::Response; + + const METADATA: Metadata = METADATA; + + fn try_from_http_request( + req: http::Request, + _path_args: &[S], + ) -> Result + where + B: AsRef<[u8]>, + S: AsRef, + { + if !(req.method() == METADATA.method + || req.method() == Method::HEAD && METADATA.method == Method::GET) + { + return Err(FromHttpRequestError::MethodMismatch { + expected: METADATA.method, + received: req.method().clone(), + }); + } + + let response: AuthorizationResponse = + serde_html_form::from_str(req.uri().query().unwrap_or(""))?; + + let Some(cookie) = req + .headers() + .typed_get() + .and_then(|cookie: headers::Cookie| { + cookie.get(SSO_SESSION_COOKIE).map(str::to_owned) + }) + else { + return Err(HeaderDeserializationError::MissingHeader( + "Cookie".to_owned(), + ))?; + }; + + Ok(Self { response, cookie }) + } + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 378ab929..93198cf3 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,16 +1,27 @@ use std::{ - collections::BTreeMap, + borrow::Borrow, + collections::{BTreeMap, HashSet}, fmt, + hash::{Hash, Hasher}, net::{IpAddr, Ipv4Addr}, }; -use ruma::{OwnedServerName, RoomVersionId}; -use serde::{de::IgnoredAny, Deserialize}; +use figment::value::{Dict, Value}; +use mas_oidc_client::types::{client_credentials::ClientCredentials, scope::Scope}; +use ruma::{ + api::client::session::get_login_types::v3::IdentityProvider, OwnedServerName, RoomVersionId, +}; +use serde::{ + de::{self, IgnoredAny}, + Deserialize, Deserializer, Serialize, +}; use tracing::warn; use url::Url; mod proxy; +use crate::{Error, Result}; + use self::proxy::ProxyConfig; #[derive(Clone, Debug, Deserialize)] @@ -67,6 +78,8 @@ pub struct Config { pub tracing_flame: bool, #[serde(default)] pub proxy: ProxyConfig, + #[serde(default, deserialize_with = "deserialize_providers")] + pub idps: HashSet, pub jwt_secret: Option, #[serde(default = "default_trusted_servers")] pub trusted_servers: Vec, @@ -101,6 +114,27 @@ pub struct WellKnownConfig { pub server: Option, } +#[derive(Clone, Debug, Deserialize)] +pub struct IdpConfig { + pub issuer: String, + #[serde(flatten)] + pub inner: IdentityProvider, + #[serde(deserialize_with = "deserialize_scopes")] + pub scopes: Scope, + + pub client_id: String, + pub client_secret: String, + pub auth_method: String, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct Template { + pub localpart: Option, + pub displayname: Option, + pub avatar_url: Option, + pub email: Option, +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { @@ -244,6 +278,49 @@ impl fmt::Display for Config { } } +impl Borrow for IdpConfig { + fn borrow(&self) -> &str { + &self.inner.id + } +} + +impl PartialEq for IdpConfig { + fn eq(&self, other: &Self) -> bool { + self.inner.id == other.inner.id + } +} + +impl Eq for IdpConfig {} + +impl Hash for IdpConfig { + fn hash(&self, hasher: &mut H) { + self.inner.id.hash(hasher) + } +} + +impl Into for IdpConfig { + fn into(self) -> ClientCredentials { + let IdpConfig { + client_id, + client_secret, + auth_method, + .. + } = self; + + match &*auth_method { + "basic" => ClientCredentials::ClientSecretBasic { + client_id, + client_secret, + }, + "post" => ClientCredentials::ClientSecretPost { + client_id, + client_secret, + }, + _ => unimplemented!(), + } + } +} + fn false_fn() -> bool { false } @@ -312,3 +389,46 @@ fn default_openid_token_ttl() -> u64 { pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } + +fn deserialize_scopes<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let scopes = >::deserialize(deserializer)?; + + scopes.join(" ").parse().map_err(de::Error::custom) +} + +fn deserialize_providers<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let mut result = HashSet::new(); + let dict = Dict::deserialize(deserializer) + .map(Dict::into_iter) + .map_err(de::Error::custom)?; + warn!(?dict); + + for (name, value) in dict { + let tag = value.tag(); + + let Some(dict) = value.into_dict() else { + return Err(de::Error::custom(Error::bad_config( + "Invalid SSO configuration. ", + ))); + }; + + let id = String::from("id"); + let name = name.parse().map_err(de::Error::custom)?; + + let dict = Some((id, name)).into_iter().chain(dict).collect(); + + result.insert( + Value::Dict(tag, dict) + .deserialize() + .map_err(de::Error::custom)?, + ); + } + + Ok(result) +} diff --git a/src/database/key_value/mod.rs b/src/database/key_value/mod.rs index c4496af8..5027c367 100644 --- a/src/database/key_value/mod.rs +++ b/src/database/key_value/mod.rs @@ -8,6 +8,7 @@ mod media; mod pusher; mod rooms; mod sending; +mod sso; mod transaction_ids; mod uiaa; mod users; diff --git a/src/database/key_value/sso.rs b/src/database/key_value/sso.rs new file mode 100644 index 00000000..9aa85ada --- /dev/null +++ b/src/database/key_value/sso.rs @@ -0,0 +1,29 @@ +use ruma::{OwnedUserId, UserId}; + +use crate::{service, utils, Error, KeyValueDatabase, Result}; + +impl service::sso::Data for KeyValueDatabase { + fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()> { + let mut key = provider.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(subject.as_bytes()); + + self.providersubjectid_userid.insert(&key, user_id.as_bytes()) + } + + fn user_from_subject(&self, provider: &str, subject: &str) -> Result> { + let mut key = provider.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(subject.as_bytes()); + + self.providersubjectid_userid.get(&key)?.map_or(Ok(None), |bytes| { + Some( + UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in claim_userid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("User ID in claim_userid is invalid.")), + ) + .transpose() + }) + } +} diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 63321a40..fca0328c 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -119,6 +119,10 @@ impl service::users::Data for KeyValueDatabase { } } + fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> { + self.userid_password.insert(user_id.as_bytes(), b"0xff") + } + /// Returns the displayname of a user on this homeserver. fn displayname(&self, user_id: &UserId) -> Result> { self.userid_displayname diff --git a/src/database/mod.rs b/src/database/mod.rs index 2317f7a8..bb5cb2ca 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -50,7 +50,6 @@ pub struct KeyValueDatabase { pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 pub(super) token_userdeviceid: Arc, - pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count @@ -64,6 +63,9 @@ pub struct KeyValueDatabase { pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count + pub(super) userid_providersubjectid: Arc, + pub(super) providersubjectid_userid: Arc, + //pub uiaa: uiaa::Uiaa, pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication pub(super) userdevicesessionid_uiaarequest: @@ -298,6 +300,9 @@ impl KeyValueDatabase { userfilterid_filter: builder.open_tree("userfilterid_filter")?, todeviceid_events: builder.open_tree("todeviceid_events")?, + userid_providersubjectid: builder.open_tree("userid_providersubjectid")?, + providersubjectid_userid: builder.open_tree("providersubjectid_userid")?, + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, @@ -1050,6 +1055,8 @@ impl KeyValueDatabase { services().admin.start_handler(); + services().sso.start_handler().await?; + // Set emergency access for the conduit user match set_emergency_access() { Ok(pwd_set) => { diff --git a/src/main.rs b/src/main.rs index 2776c200..cfc3756d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -292,6 +292,11 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::third_party_route) .ruma_route(client_server::request_3pid_management_token_via_email_route) .ruma_route(client_server::request_3pid_management_token_via_msisdn_route) + .ruma_route(client_server::get_sso_redirect_route) + .ruma_route(client_server::get_sso_redirect_with_provider_route) + // The specification will likely never introduce any endpoint for handling authorization callbacks. + // As a workaround, we use custom path that redirects the user to the default login handler. + .ruma_route(client_server::handle_callback_route) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 3325e518..2a4b76ff 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,9 +1,10 @@ mod data; pub use data::{Data, SigningKeys}; use ruma::{ - serde::Base64, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, OwnedRoomAliasId, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, + serde::Base64, signatures::KeyPair, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, + OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, }; +use serde::{de::DeserializeOwned, Serialize}; use crate::api::server_server::DestinationResponse; @@ -17,7 +18,7 @@ use ruma::{ DeviceId, RoomVersionId, ServerName, UserId, }; use std::{ - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, HashMap, HashSet}, error::Error as StdError, fs, future::{self, Future}, @@ -37,6 +38,9 @@ use tracing::{error, info}; use base64::{engine::general_purpose, Engine as _}; +// https://github.com/rust-lang/rust/issues/104699 +const PROBLEMATIC_CONST: &[u8] = b"0xCAFEBABE"; + type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries @@ -505,6 +509,36 @@ impl Service { self.config.well_known_client() } + pub fn sign_claims(&self, claims: &S) -> String { + let key = jsonwebtoken::EncodingKey::from_secret( + self.keypair().sign(PROBLEMATIC_CONST).as_bytes(), + ); + + jsonwebtoken::encode(&jsonwebtoken::Header::default(), claims, &key) + .expect("signing JWTs always works") + } + + /// Decode and validate a macaroon with this server's macaroon key. + pub fn validate_claims( + &self, + token: &str, + validation_data: Option<&jsonwebtoken::Validation>, + ) -> jsonwebtoken::errors::Result { + let key = jsonwebtoken::DecodingKey::from_secret( + self.keypair().sign(PROBLEMATIC_CONST).as_bytes(), + ); + + let mut v = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); + + // these validations are redundant as all JWTs are stored in cookies + v.validate_exp = false; + v.validate_nbf = false; + v.required_spec_claims = HashSet::new(); + + jsonwebtoken::decode::(token, &key, validation_data.unwrap_or(&v)) + .map(|data| data.claims) + } + pub fn shutdown(&self) { self.shutdown.store(true, atomic::Ordering::Relaxed); // On shutdown diff --git a/src/service/mod.rs b/src/service/mod.rs index 552c71af..c7d75fcd 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -19,6 +19,7 @@ pub mod pdu; pub mod pusher; pub mod rooms; pub mod sending; +pub mod sso; pub mod transaction_ids; pub mod uiaa; pub mod users; @@ -35,6 +36,7 @@ pub struct Services { pub globals: globals::Service, pub key_backups: key_backups::Service, pub media: media::Service, + pub sso: Arc, pub sending: Arc, } @@ -51,6 +53,7 @@ impl Services { + key_backups::Data + media::Data + sending::Data + + sso::Data + 'static, >( db: &'static D, @@ -120,6 +123,7 @@ impl Services { key_backups: key_backups::Service { db }, media: media::Service { db }, sending: sending::Service::build(db, &config), + sso: sso::Service::build(db)?, globals: globals::Service::load(db, config)?, }) diff --git a/src/service/sso/data.rs b/src/service/sso/data.rs new file mode 100644 index 00000000..75d45bf2 --- /dev/null +++ b/src/service/sso/data.rs @@ -0,0 +1,9 @@ +use ruma::{OwnedUserId, UserId}; + +use crate::Result; + +pub trait Data: Send + Sync { + fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()>; + + fn user_from_subject(&self, provider: &str, subject: &str) -> Result>; +} diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs new file mode 100644 index 00000000..ac14edbf --- /dev/null +++ b/src/service/sso/mod.rs @@ -0,0 +1,213 @@ +use std::{ + borrow::Borrow, + collections::HashSet, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use crate::{ + api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH}, + config::IdpConfig, + utils, Error, Result, +}; +use futures_util::future::{self}; +use http::HeaderValue; +use mas_oidc_client::{ + http_service::HttpService, + requests::{authorization_code::AuthorizationValidationData, discovery}, + types::oidc::VerifiedProviderMetadata, +}; +use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId}; +use serde::{Deserialize, Serialize}; +use tokio::sync::OnceCell; +use tower::BoxError; +use tower_http::{set_header::SetRequestHeaderLayer, ServiceBuilderExt}; +use tracing::error; +use url::Url; + +use crate::services; + +mod data; +pub use data::Data; + +pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60; +pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2; +pub const SSO_SESSION_COOKIE: &str = "sso-auth"; +pub const SUBJECT_CLAIM_KEY: &str = "sub"; + +pub struct Service { + db: &'static dyn Data, + service: HttpService, + providers: OnceCell>, +} + +impl Service { + pub fn build(db: &'static dyn Data) -> Result> { + let client = tower::ServiceBuilder::new() + .map_err(BoxError::from) + .layer(tower_http::timeout::TimeoutLayer::new( + std::time::Duration::from_secs(10), + )) + .layer(mas_http::BytesToBodyRequestLayer) + .layer(mas_http::BodyToBytesResponseLayer) + .layer(SetRequestHeaderLayer::overriding( + http::header::USER_AGENT, + HeaderValue::from_static("conduit/0.9-alpha"), + )) + .concurrency_limit(10) + .follow_redirects() + .service(mas_http::make_untraced_client()); + + Ok(Arc::new(Self { + db, + service: HttpService::new(client), + providers: OnceCell::new(), + })) + } + + pub fn service(&self) -> &HttpService { + &self.service + } + + pub async fn start_handler(&self) -> Result<()> { + let providers = services().globals.config.idps.iter(); + + self.providers + .get_or_try_init(|| async move { + future::try_join_all(providers.map(Provider::fetch_metadata)) + .await + .map(Vec::into_iter) + .map(HashSet::from_iter) + }) + .await?; + + Ok(()) + } + + pub fn get(&self, provider: &str) -> Option<&Provider> { + let providers = self.providers.get().expect(""); + + providers.get(provider) + } + + pub fn login_type(&self) -> impl Iterator + '_ { + let providers = self.providers.get().expect(""); + + providers.iter().map(|p| p.config.inner.clone()) + } + + pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result> { + self.db.user_from_subject(provider, subject) + } +} + +#[derive(Clone, Debug)] +pub struct Provider { + pub config: &'static IdpConfig, + pub metadata: VerifiedProviderMetadata, +} + +impl Provider { + pub async fn fetch_metadata(config: &'static IdpConfig) -> Result { + discovery::discover(services().sso.service(), &config.issuer) + .await + .map(|metadata| Provider { config, metadata }) + .map_err(|e| { + error!( + "Failed to fetch identity provider metadata ({}): {}", + &config.inner.id, e + ); + + Error::bad_config("Failed to fetch identity provider metadata.") + }) + } +} + +impl Borrow for Provider { + fn borrow(&self) -> &str { + self.config.borrow() + } +} + +impl PartialEq for Provider { + fn eq(&self, other: &Self) -> bool { + self.config == other.config + } +} + +impl Eq for Provider {} + +impl Hash for Provider { + fn hash(&self, hasher: &mut H) { + self.config.hash(hasher) + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct LoginToken { + pub iss: String, + pub aud: OwnedUserId, + pub sub: String, + pub exp: u64, +} + +impl LoginToken { + pub fn new(provider: String, user_id: OwnedUserId) -> Self { + Self { + iss: provider, + aud: user_id, + sub: utils::random_string(TOKEN_LENGTH), + exp: utils::millis_since_unix_epoch() + .checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000) + .expect("time overflow"), + } + } + pub fn audience(self) -> OwnedUserId { + self.aud + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ValidationData { + pub provider: String, + pub redirect_url: String, + #[serde(flatten, with = "AuthorizationValidationDataDef")] + pub inner: AuthorizationValidationData, +} + +impl ValidationData { + pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self { + Self { + provider, + redirect_url, + inner, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(remote = "AuthorizationValidationData")] +pub struct AuthorizationValidationDataDef { + pub state: String, + pub nonce: String, + pub redirect_uri: Url, + pub code_challenge_verifier: Option, +} + +impl From for AuthorizationValidationDataDef { + fn from( + AuthorizationValidationData { + state, + nonce, + redirect_uri, + code_challenge_verifier, + }: AuthorizationValidationData, + ) -> Self { + Self { + state, + nonce, + redirect_uri, + code_challenge_verifier, + } + } +} diff --git a/src/service/sso/templates.rs b/src/service/sso/templates.rs new file mode 100644 index 00000000..01512cc4 --- /dev/null +++ b/src/service/sso/templates.rs @@ -0,0 +1,34 @@ +pub fn base(title: &str, body: maud::Markup) -> maud::Markup { + maud::html! { + (maud::DOCTYPE) + html lang="en" { + head { + meta charset="utf-8"; + meta name="viewport" content="width=device-width, initial-scale=1.0"; + link rel="icon" type="image/png" sizes="32x32" href="https://conduit.rs/conduit.svg"; + style { (FONT_FACE) } + title { (title) } + } + body { (body) } + } + } +} + +pub fn footer() -> maud::Markup { + let info = "An open network for secure, decentralized communication."; + + maud::html! { + footer { p { (info) } } + } +} + +const FONT_FACE: &str = r#" + @font-face { + font-family: 'Source Sans 3 Variable'; + font-style: normal; + font-display: swap; + font-weight: 200 900; + src: url(https://cdn.jsdelivr.net/fontsource/fonts/source-sans-3:vf@latest/latin-wght-normal.woff2) format('woff2-variations'); + unicode-range: U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0304,U+0308,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD; + } +"#; diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 4566c36d..eff94b6f 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -217,4 +217,6 @@ pub trait Data: Send + Sync { /// Find out which user an OpenID access token belongs to. fn find_from_openid_token(&self, token: &str) -> Result>; + + fn set_placeholder_password(&self, user_id: &UserId) -> Result<()>; } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index a5694a10..15756ff4 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -602,6 +602,10 @@ impl Service { pub fn find_from_openid_token(&self, token: &str) -> Result> { self.db.find_from_openid_token(token) } + + pub fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> { + self.db.set_placeholder_password(user_id) + } } /// Ensure that a user only sees signatures from themselves and the target user diff --git a/src/utils/error.rs b/src/utils/error.rs index 1d811106..30568001 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -175,6 +175,22 @@ impl Error { } } +impl From for Error { + fn from(e: mas_oidc_client::types::errors::ClientError) -> Self { + error!( + "Failed to complete authorization callback: {} {}", + e.error, + e.error_description.as_deref().unwrap_or_default() + ); + + // TODO: error conversion + Self::BadRequest( + ErrorKind::Unknown, + "Failed to complete authorization callback.", + ) + } +} + #[cfg(feature = "persy")] impl> From> for Error { fn from(err: persy::PE) -> Self { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d09a1033..4ff6fd6f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,7 @@ pub mod error; use argon2::{Config, Variant}; +use axum_extra::extract::cookie::{Cookie, SameSite}; use cmp::Ordering; use rand::prelude::*; use ring::digest; @@ -8,7 +9,7 @@ use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonO use std::{ cmp, fmt, str::FromStr, - time::{SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime, UNIX_EPOCH}, }; pub fn millis_since_unix_epoch() -> u64 { @@ -142,6 +143,29 @@ pub fn deserialize_from_str< deserializer.deserialize_str(Visitor(std::marker::PhantomData)) } +pub fn build_cookie<'c>( + name: &'c str, + value: &'c str, + path: &'c str, + max_age: Option, +) -> Cookie<'c> { + let mut cookie = Cookie::new(name, value); + + cookie.set_path(path); + cookie.set_secure(true); + cookie.set_http_only(true); + cookie.set_same_site(SameSite::None); + cookie.set_max_age( + max_age + .map(Duration::from_secs) + .map(TryInto::try_into) + .transpose() + .expect("time overflow"), + ); + + cookie +} + // Copied from librustdoc: // https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs