From 269455d93ac82e515b33d859a19a8be4c2c43b38 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Thu, 4 Jul 2024 22:19:25 +0200 Subject: [PATCH 1/7] WIP: docs --- docs/configuration.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index d903a21e..a8fa07de 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -11,6 +11,7 @@ Conduit's configuration file is divided into the following sections: - [Global](#global) - [TLS](#tls) - [Proxy](#proxy) + - [SSO (Single Sign-On)](#sso) ## Global @@ -109,3 +110,39 @@ exclude = ["*.clearnet.onion"] [global] {{#include ../conduit-example.toml:22:}} ``` + +### SSO (Single Sign-On) + +Authentication through SSO instead of a password can be enabled by configuring OIDC (OpenID Connect) identity providers. +Identity providers using OAuth such as Github are not supported yet. + +> **Note:** The `*` symbol indicates that the field is required, and the values in **parentheses** are the possible values + +| Field | Type | Description | Default | +| --- | --- | --- | --- | +| `issuer`* | `Url` | The issuer URL. | N/A | +| `name` | `string` | The name displayed on fallback pages. | `issuer` | +| `icon` | `Url` OR `MxcUri` | The icon displayed on fallback pages. | N/A | +| `scopes` | `array` | The scopes used to obtain extra claims which can be used for templates. | `["openid"]` | + + + + +| `client_id`* | `string` | The provider-supplied, unique ID for the client. | N/A | +| `client_secret`* | `string` | The provider-supplied, unique ID for the client. | N/A | +| `authentication_method`* | `"basic" | "post"` | The method used for client authentication. | N/A | + + + + + + + + + + + + + + + From 895b66fa502537077ebebf83ae571d0c22ed68d7 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Wed, 10 Jul 2024 08:19:39 +0200 Subject: [PATCH 2/7] where did my code go??? --- Cargo.toml | 28 +++++-- src/api/client_server/account.rs | 2 + src/api/client_server/keys.rs | 6 +- src/api/client_server/mod.rs | 4 + src/api/client_server/session.rs | 87 +++++++++++++++------ src/config/mod.rs | 126 ++++++++++++++++++++++++++++++- src/database/key_value/mod.rs | 1 + src/database/key_value/users.rs | 4 + src/database/mod.rs | 8 ++ src/main.rs | 10 ++- src/service/globals/mod.rs | 38 +++++++++- src/service/mod.rs | 4 + src/service/uiaa/mod.rs | 3 + src/service/users/data.rs | 2 + src/service/users/mod.rs | 4 + src/utils/error.rs | 16 ++++ src/utils/mod.rs | 26 ++++++- 17 files changed, 331 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 67128f07..bc079115 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,15 +35,18 @@ axum = { version = "0.7", default-features = false, features = [ "json", "matched-path", ], optional = true } -axum-extra = { version = "0.9", features = ["typed-header"] } +axum-extra = { version = "0.9", features = ["cookie", "typed-header"] } axum-server = { version = "0.6", features = ["tls-rustls"] } tower = { version = "0.4.13", features = ["util"] } +# tower-http = { version = "0.5", features = [ +# "add-extension", +# "cors", +# "sensitive-headers", +# "trace", +# "util", +# ] } tower-http = { version = "0.5", features = [ - "add-extension", - "cors", - "sensitive-headers", - "trace", - "util", + "full", ] } tower-service = "0.3" @@ -140,20 +143,28 @@ figment = { version = "0.10.8", features = ["env", "toml"] } # Validating urls in config url = { version = "2", features = ["serde"] } +# HTML +maud = { version = "0.26.0", default-features = false, features = ["axum"] } + async-trait = "0.1.68" tikv-jemallocator = { version = "0.5.0", features = [ "unprefixed_malloc_on_supported_platforms", ], optional = true } sd-notify = { version = "0.4.1", optional = true } +http-body-util = "0.1.2" +hyper-rustls = { version = "0.27.2", default-features = false, features = ["http1", "http2", "ring", "rustls-native-certs", "rustls-platform-verifier"] } +mas-http = "0.9.0" # Used for matrix spec type definitions and helpers [dependencies.ruma] features = [ "appservice-api-c", + "client", "client-api", "compat", "federation-api", + "client-hyper", "push-gateway-api-c", "rand", "ring-compat", @@ -172,6 +183,11 @@ optional = true package = "rust-rocksdb" version = "0.25" +# Used for Single Sign-On +[dependencies.mas-oidc-client] +git = "https://github.com/matrix-org/matrix-authentication-service.git" +default-features = false + [target.'cfg(unix)'.dependencies] nix = { version = "0.28", features = ["resource"] } diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 47ccdc83..f688ff68 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -322,6 +322,8 @@ pub async fn change_password_route( .ok_or_else(|| Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))?; let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + // if services().users.password_hash(sender_user)? == Some(""); + let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { stages: vec![AuthType::Password], diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 4af8890d..5dcea4fa 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -111,6 +111,10 @@ pub async fn upload_signing_keys_route( auth_error: None, }; + let master_key = services() + .users + .get_master_key(None, sender_user, &|user_id| user_id == sender_user)?; + if let Some(auth) = &body.auth { let (worked, uiaainfo) = services() @@ -126,7 +130,7 @@ pub async fn upload_signing_keys_route( .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); - } else { + } else if master_key.is_some() { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index a35d7a98..07ee7a17 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -23,6 +23,7 @@ mod room; mod search; mod session; mod space; +mod sso; mod state; mod sync; mod tag; @@ -60,6 +61,7 @@ pub use room::*; pub use search::*; pub use session::*; pub use space::*; +pub use sso::*; pub use state::*; pub use sync::*; pub use tag::*; @@ -76,3 +78,5 @@ pub const DEVICE_ID_LENGTH: usize = 10; pub const TOKEN_LENGTH: usize = 32; pub const SESSION_ID_LENGTH: usize = 32; pub const AUTO_GEN_PASSWORD_LENGTH: usize = 15; +pub const AUTH_SESSION_EXPIRATION_SECS: u64 = 60 * 5; +pub const LOGIN_TOKEN_EXPIRATION_SECS: u64 = 15; diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 07078328..aea8d3ac 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,5 +1,5 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{services, utils, Error, Result, Ruma}; +use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,10 +24,20 @@ struct Claims { pub async fn get_login_types_route( _body: Ruma, ) -> Result { - Ok(get_login_types::v3::Response::new(vec![ + let mut flows = vec![ get_login_types::v3::LoginType::Password(Default::default()), get_login_types::v3::LoginType::ApplicationService(Default::default()), - ])) + ]; + + if let v @ [_, ..] = &*services().sso.flows() { + let flow = get_login_types::v3::SsoLoginType { + identity_providers: v.to_owned(), + }; + + flows.push(get_login_types::v3::LoginType::Sso(flow)); + } + + Ok(get_login_types::v3::Response::new(flows)) } /// # `POST /_matrix/client/r0/login` @@ -101,35 +111,64 @@ pub async fn login_route(body: Ruma) -> Result { - if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { - let token = jsonwebtoken::decode::( - token, - jwt_decoding_key, - &jsonwebtoken::Validation::default(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; - let username = token.claims.sub.to_lowercase(); - let user_id = - UserId::parse_with_server_name(username, services().globals.server_name()) + match ( + services().globals.jwt_decoding_key(), + &services().sso.providers().is_empty(), + ) { + (_, false) => { + let mut validation = + jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); + validation.validate_nbf = false; + validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); + + let login_token = services() + .globals + .validate_claims::(token, Some(validation)) .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") + Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.") })?; - if services().appservice.is_exclusive_user_id(&user_id).await { + login_token.audience().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid token audience.") + })? + } + (Some(jwt_decoding_key), _) => { + let token = jsonwebtoken::decode::( + token, + jwt_decoding_key, + &jsonwebtoken::Validation::default(), + ) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") + })?; + let username = token.claims.sub.to_lowercase(); + let user_id = + UserId::parse_with_server_name(username, services().globals.server_name()) + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ) + })?; + + if services().appservice.is_exclusive_user_id(&user_id).await { + return Err(Error::BadRequest( + ErrorKind::Exclusive, + "User id reserved by appservice.", + )); + } + + user_id + } + (None, _) => { return Err(Error::BadRequest( - ErrorKind::Exclusive, - "User id reserved by appservice.", + ErrorKind::Unknown, + "Token login is not supported (server has no jwt decoding key).", )); } - - user_id - } else { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Token login is not supported (server has no jwt decoding key).", - )); } } + login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService { identifier, user, diff --git a/src/config/mod.rs b/src/config/mod.rs index 378ab929..93198cf3 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,16 +1,27 @@ use std::{ - collections::BTreeMap, + borrow::Borrow, + collections::{BTreeMap, HashSet}, fmt, + hash::{Hash, Hasher}, net::{IpAddr, Ipv4Addr}, }; -use ruma::{OwnedServerName, RoomVersionId}; -use serde::{de::IgnoredAny, Deserialize}; +use figment::value::{Dict, Value}; +use mas_oidc_client::types::{client_credentials::ClientCredentials, scope::Scope}; +use ruma::{ + api::client::session::get_login_types::v3::IdentityProvider, OwnedServerName, RoomVersionId, +}; +use serde::{ + de::{self, IgnoredAny}, + Deserialize, Deserializer, Serialize, +}; use tracing::warn; use url::Url; mod proxy; +use crate::{Error, Result}; + use self::proxy::ProxyConfig; #[derive(Clone, Debug, Deserialize)] @@ -67,6 +78,8 @@ pub struct Config { pub tracing_flame: bool, #[serde(default)] pub proxy: ProxyConfig, + #[serde(default, deserialize_with = "deserialize_providers")] + pub idps: HashSet, pub jwt_secret: Option, #[serde(default = "default_trusted_servers")] pub trusted_servers: Vec, @@ -101,6 +114,27 @@ pub struct WellKnownConfig { pub server: Option, } +#[derive(Clone, Debug, Deserialize)] +pub struct IdpConfig { + pub issuer: String, + #[serde(flatten)] + pub inner: IdentityProvider, + #[serde(deserialize_with = "deserialize_scopes")] + pub scopes: Scope, + + pub client_id: String, + pub client_secret: String, + pub auth_method: String, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct Template { + pub localpart: Option, + pub displayname: Option, + pub avatar_url: Option, + pub email: Option, +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { @@ -244,6 +278,49 @@ impl fmt::Display for Config { } } +impl Borrow for IdpConfig { + fn borrow(&self) -> &str { + &self.inner.id + } +} + +impl PartialEq for IdpConfig { + fn eq(&self, other: &Self) -> bool { + self.inner.id == other.inner.id + } +} + +impl Eq for IdpConfig {} + +impl Hash for IdpConfig { + fn hash(&self, hasher: &mut H) { + self.inner.id.hash(hasher) + } +} + +impl Into for IdpConfig { + fn into(self) -> ClientCredentials { + let IdpConfig { + client_id, + client_secret, + auth_method, + .. + } = self; + + match &*auth_method { + "basic" => ClientCredentials::ClientSecretBasic { + client_id, + client_secret, + }, + "post" => ClientCredentials::ClientSecretPost { + client_id, + client_secret, + }, + _ => unimplemented!(), + } + } +} + fn false_fn() -> bool { false } @@ -312,3 +389,46 @@ fn default_openid_token_ttl() -> u64 { pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } + +fn deserialize_scopes<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let scopes = >::deserialize(deserializer)?; + + scopes.join(" ").parse().map_err(de::Error::custom) +} + +fn deserialize_providers<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let mut result = HashSet::new(); + let dict = Dict::deserialize(deserializer) + .map(Dict::into_iter) + .map_err(de::Error::custom)?; + warn!(?dict); + + for (name, value) in dict { + let tag = value.tag(); + + let Some(dict) = value.into_dict() else { + return Err(de::Error::custom(Error::bad_config( + "Invalid SSO configuration. ", + ))); + }; + + let id = String::from("id"); + let name = name.parse().map_err(de::Error::custom)?; + + let dict = Some((id, name)).into_iter().chain(dict).collect(); + + result.insert( + Value::Dict(tag, dict) + .deserialize() + .map_err(de::Error::custom)?, + ); + } + + Ok(result) +} diff --git a/src/database/key_value/mod.rs b/src/database/key_value/mod.rs index c4496af8..5027c367 100644 --- a/src/database/key_value/mod.rs +++ b/src/database/key_value/mod.rs @@ -8,6 +8,7 @@ mod media; mod pusher; mod rooms; mod sending; +mod sso; mod transaction_ids; mod uiaa; mod users; diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 63321a40..fca0328c 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -119,6 +119,10 @@ impl service::users::Data for KeyValueDatabase { } } + fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> { + self.userid_password.insert(user_id.as_bytes(), b"0xff") + } + /// Returns the displayname of a user on this homeserver. fn displayname(&self, user_id: &UserId) -> Result> { self.userid_displayname diff --git a/src/database/mod.rs b/src/database/mod.rs index 5171d4bb..16a5e60a 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -63,6 +63,9 @@ pub struct KeyValueDatabase { pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count + pub(super) userid_providersubjectid: Arc, + pub(super) providersubjectid_userid: Arc, + //pub uiaa: uiaa::Uiaa, pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication pub(super) userdevicesessionid_uiaarequest: @@ -297,6 +300,9 @@ impl KeyValueDatabase { userfilterid_filter: builder.open_tree("userfilterid_filter")?, todeviceid_events: builder.open_tree("todeviceid_events")?, + userid_providersubjectid: builder.open_tree("userid_providersubjectid")?, + providersubjectid_userid: builder.open_tree("providersubjectid_userid")?, + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, @@ -971,6 +977,8 @@ impl KeyValueDatabase { services().admin.start_handler(); + services().sso.start_handler().await?; + // Set emergency access for the conduit user match set_emergency_access() { Ok(pwd_set) => { diff --git a/src/main.rs b/src/main.rs index 8d242c53..0b07fe2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,10 @@ use axum::{ Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; -use conduit::api::{client_server, server_server}; +use conduit::api::{ + client_server::{self, SSO_CALLBACK_PATH}, + server_server, +}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -276,6 +279,11 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::third_party_route) .ruma_route(client_server::request_3pid_management_token_via_email_route) .ruma_route(client_server::request_3pid_management_token_via_msisdn_route) + .ruma_route(client_server::get_sso_redirect_route) + .ruma_route(client_server::get_sso_redirect_with_provider_route) + // The specification will likely never introduce any endpoint for handling authorization callbacks. + // As a workaround, we use custom path that redirects the user to the default login handler. + .route(SSO_CALLBACK_PATH, get(client_server::sso_login_route)) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 3325e518..9a3c7d6a 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,9 +1,10 @@ mod data; pub use data::{Data, SigningKeys}; use ruma::{ - serde::Base64, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, OwnedRoomAliasId, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, + serde::Base64, signatures::KeyPair, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, + OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, }; +use serde::{de::DeserializeOwned, Serialize}; use crate::api::server_server::DestinationResponse; @@ -37,6 +38,9 @@ use tracing::{error, info}; use base64::{engine::general_purpose, Engine as _}; +// https://github.com/rust-lang/rust/issues/104699 +const PROBLEMATIC_CONST: &[u8] = b"0xCAFEBABE"; + type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries @@ -505,6 +509,36 @@ impl Service { self.config.well_known_client() } + pub fn sign_claims(&self, claims: &S) -> String { + let key = jsonwebtoken::EncodingKey::from_secret( + self.keypair().sign(PROBLEMATIC_CONST).as_bytes(), + ); + + jsonwebtoken::encode(&jsonwebtoken::Header::default(), claims, &key) + .expect("signing JWTs always works") + } + + /// Decode and validate a macaroon with this server's macaroon key. + pub fn validate_claims( + &self, + token: &str, + validation_data: Option, + ) -> jsonwebtoken::errors::Result { + let key = jsonwebtoken::DecodingKey::from_secret( + self.keypair().sign(PROBLEMATIC_CONST).as_bytes(), + ); + + let mut v = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); + + // these validations are redundant as all JWTs are stored in cookies + v.validate_exp = false; + v.validate_nbf = false; + v.required_spec_claims = Default::default(); + + jsonwebtoken::decode::(token, &key, &validation_data.unwrap_or(v)) + .map(|data| data.claims) + } + pub fn shutdown(&self) { self.shutdown.store(true, atomic::Ordering::Relaxed); // On shutdown diff --git a/src/service/mod.rs b/src/service/mod.rs index 4c11bc18..fae5a726 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -19,6 +19,7 @@ pub mod pdu; pub mod pusher; pub mod rooms; pub mod sending; +pub mod sso; pub mod transaction_ids; pub mod uiaa; pub mod users; @@ -35,6 +36,7 @@ pub struct Services { pub globals: globals::Service, pub key_backups: key_backups::Service, pub media: media::Service, + pub sso: Arc, pub sending: Arc, } @@ -51,6 +53,7 @@ impl Services { + key_backups::Data + media::Data + sending::Data + + sso::Data + 'static, >( db: &'static D, @@ -120,6 +123,7 @@ impl Services { key_backups: key_backups::Service { db }, media: media::Service { db }, sending: sending::Service::build(db, &config), + sso: sso::Service::build(db), globals: globals::Service::load(db, config)?, }) diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 696be958..677d49f0 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -110,6 +110,9 @@ impl Service { AuthData::Dummy(_) => { uiaainfo.completed.push(AuthType::Dummy); } + AuthData::FallbackAcknowledgement(fallback) => { + todo!() + } k => error!("type not supported: {:?}", k), } diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 4566c36d..eff94b6f 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -217,4 +217,6 @@ pub trait Data: Send + Sync { /// Find out which user an OpenID access token belongs to. fn find_from_openid_token(&self, token: &str) -> Result>; + + fn set_placeholder_password(&self, user_id: &UserId) -> Result<()>; } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index a5694a10..15756ff4 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -602,6 +602,10 @@ impl Service { pub fn find_from_openid_token(&self, token: &str) -> Result> { self.db.find_from_openid_token(token) } + + pub fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> { + self.db.set_placeholder_password(user_id) + } } /// Ensure that a user only sees signatures from themselves and the target user diff --git a/src/utils/error.rs b/src/utils/error.rs index 1d811106..30568001 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -175,6 +175,22 @@ impl Error { } } +impl From for Error { + fn from(e: mas_oidc_client::types::errors::ClientError) -> Self { + error!( + "Failed to complete authorization callback: {} {}", + e.error, + e.error_description.as_deref().unwrap_or_default() + ); + + // TODO: error conversion + Self::BadRequest( + ErrorKind::Unknown, + "Failed to complete authorization callback.", + ) + } +} + #[cfg(feature = "persy")] impl> From> for Error { fn from(err: persy::PE) -> Self { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d09a1033..4ff6fd6f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,7 @@ pub mod error; use argon2::{Config, Variant}; +use axum_extra::extract::cookie::{Cookie, SameSite}; use cmp::Ordering; use rand::prelude::*; use ring::digest; @@ -8,7 +9,7 @@ use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonO use std::{ cmp, fmt, str::FromStr, - time::{SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime, UNIX_EPOCH}, }; pub fn millis_since_unix_epoch() -> u64 { @@ -142,6 +143,29 @@ pub fn deserialize_from_str< deserializer.deserialize_str(Visitor(std::marker::PhantomData)) } +pub fn build_cookie<'c>( + name: &'c str, + value: &'c str, + path: &'c str, + max_age: Option, +) -> Cookie<'c> { + let mut cookie = Cookie::new(name, value); + + cookie.set_path(path); + cookie.set_secure(true); + cookie.set_http_only(true); + cookie.set_same_site(SameSite::None); + cookie.set_max_age( + max_age + .map(Duration::from_secs) + .map(TryInto::try_into) + .transpose() + .expect("time overflow"), + ); + + cookie +} + // Copied from librustdoc: // https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs From 10ce7ea3a95a83c353516aa53c87b73d6fa906a4 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Thu, 11 Jul 2024 21:55:52 +0200 Subject: [PATCH 3/7] initial commit --- src/api/client_server/sso.rs | 648 +++++++++++++++++++++++++++++++++++ src/service/sso/data.rs | 9 + src/service/sso/mod.rs | 299 ++++++++++++++++ src/service/sso/templates.rs | 34 ++ 4 files changed, 990 insertions(+) create mode 100644 src/api/client_server/sso.rs create mode 100644 src/service/sso/data.rs create mode 100644 src/service/sso/mod.rs create mode 100644 src/service/sso/templates.rs diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs new file mode 100644 index 00000000..c6d26d27 --- /dev/null +++ b/src/api/client_server/sso.rs @@ -0,0 +1,648 @@ +use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime}; + +use crate::{ + config::{ + sso::{Registration, Template}, + IdpConfig, + }, + service::sso::{ + templates, LoginToken, RegistrationInfo, RegistrationToken, ValidationData, + REGISTRATION_EXPIRATION_SECS, SESSION_EXPIRATION_SECS, SSO_AUTH_EXPIRATION_SECS, + SSO_SESSION_COOKIE, + }, + services, utils, Error, Result, Ruma, +}; +use axum::{ + extract::RawQuery, + response::{AppendHeaders, IntoResponse, Redirect}, + RequestExt, +}; +use axum_extra::{ + headers::{self, HeaderMapExt}, + TypedHeader, +}; +use http::header; +use mas_oidc_client::{ + requests::{ + authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData}, + jose::{self, JwtVerificationData}, + userinfo, + }, + types::{ + client_credentials::ClientCredentials, + errors::ClientError, + iana::jose::JsonWebSignatureAlg, + requests::{AccessTokenResponse, AuthorizationResponse}, + }, +}; +use rand::{rngs::StdRng, SeedableRng}; +use ruma::{ + api::client::{ + error::ErrorKind, + session::{self, sso_login, sso_login_with_provider}, + }, + events::GlobalAccountDataEventType, + push, OwnedMxcUri, UserId, +}; +use serde_json::Number; +use tracing::error; +use url::Url; + +pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; + +/// # `GET /_matrix/client/v3/login/sso/redirect` +/// +/// Redirect the user to the SSO interface. +/// TODO: this should be removed once Ruma supports trailing slashes. +pub async fn get_sso_redirect( + body: Ruma, +) -> Result { + let sso_login_with_provider::v3::Response { location, cookie } = + get_sso_redirect_with_provider( + Ruma { + body: sso_login_with_provider::v3::Request::new( + Default::default(), + body.redirect_url.clone(), + ), + ..body + } + .into(), + ) + .await?; + + Ok(sso_login::v3::Response { location, cookie }) +} + +/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}` +/// +/// Redirects the user to the SSO interface. +pub async fn get_sso_redirect_with_provider( + body: Ruma, +) -> Result { + let idp_ids: Vec<&str> = services() + .globals + .config + .idps + .iter() + .map(Borrow::borrow) + .collect(); + + let provider = match &*idp_ids { + [] => { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Single Sign-On is disabled.", + )); + } + [idp_id] => services().sso.get(idp_id).expect("we know it exists"), + [_, ..] => services().sso.get(&body.idp_id).ok_or_else(|| { + Error::BadRequest(ErrorKind::InvalidParam, "Unknown identity provider.") + })?, + }; + + let redirect_url = body + .redirect_url + .parse::() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid redirect_url."))?; + + let mut callback = services() + .globals + .well_known_client() + .parse::() + .map_err(|_| Error::bad_config("Invalid well_known_client url."))?; + callback.set_path(CALLBACK_PATH); + + let (auth_url, validation_data) = authorization_code::build_authorization_url( + provider.metadata.authorization_endpoint().clone(), + AuthorizationRequestData::new( + provider.config.client_id.clone(), + provider.config.scopes.clone(), + redirect_url, + ), + &mut StdRng::from_entropy(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?; + + let signed = services().globals.sign_claims(&ValidationData::new( + provider.borrow().to_string(), + validation_data, + )); + + Ok(sso_login_with_provider::v3::Response { + location: auth_url.to_string(), + cookie: Some( + utils::build_cookie( + SSO_SESSION_COOKIE, + &signed, + "/_conduit/client/sso/callback", + Some(SSO_AUTH_EXPIRATION_SECS), + ) + .to_string(), + ), + }) +} + +/// # `GET /_conduit/client/sso/callback` +/// +/// Validate the authorization response received from the identity provider. +/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. +/// If this is the first login, register the user, possibly interactively through a fallback page. +pub async fn get_sso_callback(req: axum::extract::Request) -> Result { + let query = req.uri().query().ok_or_else(|| { + Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.") + })?; + + let AuthorizationResponse { + code, + access_token, + token_type, + id_token, + expires_in, + } = serde_html_form::from_str::(query).map_err(|_| { + serde_html_form::from_str::(query).unwrap_or_else(|_| { + error!("Failed to deserialize authorization callback: {}", callback); + + Error::BadRequest( + ErrorKind::Unknown, + "Failed to deserialize authorization callback.", + ) + }) + })?; + + let cookie = req + .extract::>>() + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid session cookie."))? + .ok_or_else(|_| Error::BadRequest(ErrorKind::MissingParam, "Missing session cookie."))?; + + let ValidationData { + provider, + inner: validation_data, + } = services() + .globals + .validate_claims( + cookie.get(SSO_SESSION_COOKIE).ok_or_else(|| { + Error::BadRequest(ErrorKind::MissingParam, "Missing value for session cookie.") + })?, + None, + ) + .map_err(|e| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.") + })?; + + let provider = services().sso.get(&provider).ok_or_else(|e| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Unknown provider for session cookie.", + ) + })?; + + let IdpConfig { + client_id, + client_secret, + auth_method, + .. + } = provider.config.clone(); + + let credentials = match &auth_method { + "basic" => ClientCredentials::ClientSecretBasic { + client_id, + client_secret, + }, + "post" => ClientCredentials::ClientSecretPost { + client_id, + client_secret, + }, + _ => todo!(), + }; + let ( + AccessTokenResponse { + access_token, + refresh_token, + token_type, + expires_in, + scope, + .. + }, + Some(id_token), + ) = authorization_code::access_token_with_authorization_code( + services().sso.service(), + method, + provider.metadata.token_endpoint(), + code, + validation_data, + jwt_verification_data, + SystemTime::now().into(), + &mut StdRng::from_entropy(), + ) + .await + .map_err(|e| Error::bad_config("Failed to fetch access token."))? + else { + unreachable!("ID token should never be empty") + }; + + // let userinfo = provider.fetch_userinfo(&access_token, &id_token).await?; + + let mut userinfo = HashMap::default(); + if let Some(endpoint) = &provider.metadata.userinfo_endpoint { + let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri()) + .await + .map_err(|e| Error::bad_config("Failed to fetch signing keys for token endpoint."))?; + let jwt_verification_data = Some(JwtVerificationData { + jwks, + issuer: &provider.config.issuer, + client_id: credentials.client_id(), + signing_algorithm: &JsonWebSignatureAlg::Rs256, + }); + + userinfo = userinfo::fetch_userinfo( + services().sso.service(), + endpoint, + &access_token, + jwt_verification_data, + &id_token, + ) + .await + .map_err(|e| Error::bad_config("Failed to fetch claims for userinfo endpoint."))?; + }; + + let (_, mut claims) = id_token.into_parts(); + + let subject = claims.get("sub").ok_or_else(|| { + error!("Unique \"sub\" claim is missing from ID token: {claims:?}"); + + Error::bad_config("Unique \"sub\" claim is missing from ID token.") + })?; + + let subject = &subject + .as_str() + .map(str::to_owned) + .or_else(|| subject.as_number().map(Number::to_string)) + .expect("unique claim should be a string or number"); + + let redirect_uri = &validation_data.redirect_uri; + + if let Some(user_id) = services() + .sso + .user_from_claim(&validation_data.provider_id, subject)? + { + let login_token = LoginToken::new(validation_data.provider_id.to_owned(), user_id); + + let redirect_uri = redirect_with_login_token(redirect_uri.to_owned(), &login_token); + + return Ok(( + AppendHeaders(vec![( + header::SET_COOKIE, + utils::reset_cookie("sso-session").to_string(), + )]), + Redirect::temporary(redirect_uri.as_str()), + ) + .into_response()); + } + + match provider.config.registration { + Registration::Disabled => { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Single Sign-On registration is disabled.", + )) + } + Registration::Automated => todo!(), + Registration::Interactive => {} + }; + + let Template { + username, + displayname, + avatar_url, + email, + } = &provider.config.template; + let registration_info = + RegistrationInfo::new(&claims, username, displayname, avatar_url, email); + + let signed = services() + .globals + .sign_macaroon(&RegistrationToken::new( + validation_data.provider_id.clone(), + subject.to_owned(), + redirect_uri.to_owned(), + registration_info, + )) + .expect("signing macaroons always works"); + + let cookie = utils::build_cookie( + "sso-registration", + &signed, + "/_conduit/client/sso/register", + REGISTRATION_EXPIRATION_SECS, + ); + + Ok(( + AppendHeaders(vec![ + (header::SET_COOKIE, cookie.to_string()), + ( + header::SET_COOKIE, + utils::reset_cookie("sso-session").to_string(), + ), + ]), + Redirect::temporary("/_conduit/client/sso/register"), + ) + .into_response()) +} + +/// # `GET /_conduit/client/sso/pick_idp` +pub async fn pick_idp(RawQuery(query): RawQuery) -> impl IntoResponse { + let providers: Vec<_> = services() + .globals + .config + .sso + .iter() + .map(|p| p.inner.to_owned()) + .collect(); + + let body = maud::html! { + header { + h1 { "Log in to " (services().globals.server_name()) } + p { "Choose an identity provider to continue" } + } + main { + ul .providers { + @for provider in providers { + li { + a href={ "/_matrix/client/v3/login/sso/redirect/" (provider.id) "?" (query.as_deref().unwrap_or_default()) } { + @if let Some(url) = provider.icon.as_deref().and_then(utils::mxc_to_http) { + img src=(url); + } + } + span { + (provider.name) + } + } + } + } + } + }; + + ( + [(header::CONTENT_TYPE, "text/html; charset=utf-8")], + maud::html! { + (templates::base("Pick Identity Provider", body)) + + (templates::footer()) + }, + ) +} + +/// # `GET /_conduit/client/sso/register` +/// +/// Serve a registration form with defaults based on the retrieved claims. +/// This endpoint is only available when interactive registration is enabled. +pub async fn get_sso_registration( + cookie: TypedHeader, +) -> Result { + let token = cookie.get("sso-registration").ok_or_else(|| { + Error::BadRequest( + ErrorKind::MissingParam, + "Missing registration token cookie.", + ) + })?; + + let registration_token: RegistrationToken = services() + .globals + .validate_macaroon(token, None) + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid registration token cookie.", + ) + })?; + + let provider = services() + .sso + .get(®istration_token.provider_id) + .map(|p| p.config.inner.to_owned())?; + let server_name = services().globals.server_name(); + + let RegistrationInfo { + username, + displayname, + avatar_url, + email, + } = registration_token.info; + + let additional_info = (&displayname, &avatar_url, &email) != (&None, &None, &None); + + fn detail(title: &str, body: maud::Markup) -> maud::Markup { + maud::html! { + label .detail for=(title) { + div .check-row { + span .name { (title) } " " + span .use { "use" } + input #(title) type="checkbox" name={(title)"-checkbox"} value=(true) checked; + } + (body) + } + } + } + + let body = maud::html! { + header { + h1 { "Complete your registration at " (server_name) } + p { "Confirm your details to finish creating your account." } + } + main { + form .form #form method="post" { + div .username-div #username-div { + label for="username-input" { "Username (required)" } + div .prefix { "@" } + input .username-input type="text" name="username" + value=(username) autofocus autocorrect="off" autocapitalize="none"; + div .postfix { ":" (server_name) } + } + output .username-output for="username-input" { } + + @if additional_info { + section .additional-info { + h2 { + @if let Some(icon) = provider.icon.as_deref().and_then(utils::mxc_to_http) { + img src=(icon.to_string()); + } + "Optional data from " (provider.name) + } + @if let Some(avatar_url) = avatar_url.as_ref() { + (detail("avatar", maud::html!{ + img .avatar src=(avatar_url); + })) + } + @if let Some(displayname) = displayname.as_ref() { + (detail("displayname", maud::html!{ + p .value { (displayname) }; + })) + } + @if let Some(email) = email.as_ref() { + (detail("email", maud::html!{ + p .value { (email) }; + })) + } + } + } + + input type="submit" value="Submit" .primary-button {} + } + } + }; + + Ok(( + [(header::CONTENT_TYPE, "text/html; charset=utf-8")], + maud::html! { + (templates::base("Register Account", body)) + + (templates::footer()) + }, + ) + .into_response()) +} + +/// # `POST /_conduit/client/sso/register` +/// +/// Submit the registration form. +pub async fn submit_sso_registration( + cookie: TypedHeader, + axum::extract::Form(registration_info): axum::extract::Form, +) -> Result { + let token = cookie.get("sso-registration").ok_or_else(|| { + Error::BadRequest( + ErrorKind::MissingParam, + "Missing registration token cookie.", + ) + })?; + + let registration_token: RegistrationToken = services() + .globals + .validate_macaroon(token, None) + .map_err(|_| { + Error::BadRequest( + ErrorKind::MissingParam, + "Invalid registration token cookie.", + ) + })?; + + let RegistrationInfo { + username, + mut displayname, + avatar_url, + email: _, + } = registration_info; + + let user_id = + UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Invalid username."))?; + + if services().users.exists(&user_id)? { + return Err(Error::BadRequest( + ErrorKind::UserInUse, + "Desired UserId is already taken.", + )); + } + + if services().appservice.is_exclusive_user_id(&user_id).await { + return Err(Error::BadRequest( + ErrorKind::Exclusive, + "Desired UserId reserved by appservice.", + )); + } + + services().users.create(&user_id, None)?; + services().users.set_password_placeholder(&user_id)?; + + if let Some(avatar_url) = avatar_url { + let request = services().globals.default_client().get(avatar_url.as_ref()); + + let res = request.send().await.map_err(|_| { + Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.") + })?; + + let filename = avatar_url.path_segments().and_then(Iterator::last); + + let (content_type, body): (Option, Vec) = ( + res.headers().typed_get(), + res.bytes().await.map(Into::into).map_err(|_| { + Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.") + })?, + ); + + let mxc = format!( + "mxc://{}/{}", + services().globals.server_name(), + utils::random_string(crate::api::client_server::MXC_LENGTH) + ); + + services() + .media + .create( + mxc.clone(), + filename + .map(|filename| "inline; filename=".to_owned() + filename) + .as_deref(), + content_type.map(|header| header.to_string()).as_deref(), + &body, + ) + .await?; + + services() + .users + .set_avatar_url(&user_id, Some(OwnedMxcUri::from(mxc)))?; + }; + + if let (Some(displayname), true) = ( + displayname.as_mut(), + services().globals.config.enable_lightning_bolt, + ) { + displayname.push_str(" ⚡️"); + } + + services().users.set_displayname(&user_id, displayname)?; + + services().sso.save_claim( + ®istration_token.provider_id, + &user_id, + ®istration_token.unique_claim, + )?; + + services().account_data.update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("PushRulesEvent should always serialize"), + )?; + + let login_token = LoginToken::new(registration_token.provider_id, user_id); + let redirect_uri = redirect_with_login_token(registration_token.redirect_uri, &login_token); + + Ok(( + AppendHeaders([( + header::SET_COOKIE, + utils::reset_cookie("sso-registration").to_string(), + )]), + Redirect::temporary(redirect_uri.as_str()), + ) + .into_response()) +} + +fn redirect_with_login_token(mut redirect_uri: Url, login_token: &LoginToken) -> Url { + let signed = services() + .globals + .sign_macaroon(login_token) + .expect("signing macaroons should always works"); + + redirect_uri + .query_pairs_mut() + .append_pair("loginToken", &signed); + + redirect_uri +} diff --git a/src/service/sso/data.rs b/src/service/sso/data.rs new file mode 100644 index 00000000..75d45bf2 --- /dev/null +++ b/src/service/sso/data.rs @@ -0,0 +1,9 @@ +use ruma::{OwnedUserId, UserId}; + +use crate::Result; + +pub trait Data: Send + Sync { + fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()>; + + fn user_from_subject(&self, provider: &str, subject: &str) -> Result>; +} diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs new file mode 100644 index 00000000..1206ed34 --- /dev/null +++ b/src/service/sso/mod.rs @@ -0,0 +1,299 @@ +mod data; +use std::{ + borrow::Borrow, + collections::{HashMap, HashSet}, + hash::{Hash, Hasher}, + str::FromStr, + sync::{Arc, RwLock}, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use crate::{ + api::client_server::TOKEN_LENGTH, + config::{sso::ProviderConfig as Config, IdpConfig}, + utils, Error, Result, +}; +pub use data::Data; +use email_address::EmailAddress; +use futures_util::future::{self}; +use mas_oidc_client::{ + http_service::{hyper, HttpService}, + jose::jwk::PublicJsonWebKeySet, + requests::{ + authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData}, + discovery, + jose::{self, JwtVerificationData}, + userinfo, + }, + types::{ + iana::jose::JsonWebSignatureAlg, oidc::VerifiedProviderMetadata, + requests::AccessTokenResponse, IdToken, + }, +}; +use rand::SeedableRng; +use ruma::{api::client::error::ErrorKind, MilliSecondsSinceUnixEpoch, OwnedUserId, UserId}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::sync::{oneshot, OnceCell}; +use tracing::error; +use url::Url; + +use crate::services; + +pub use data::Data; + +pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60; +pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2; +pub const SSO_SESSION_COOKIE: &str = "sso-auth"; + +pub struct Service { + db: &'static dyn Data, + service: HttpService, + providers: OnceCell>, +} + +impl Service { + pub fn build(db: &'static dyn Data) -> Result> { + Ok(Arc::new(Self { + db, + service: HttpService::new(hyper::hyper_service()), + providers: OnceCell::new(), + })) + } + + pub fn service(&self) -> &HttpService { + &self.service + } + + pub async fn start_handler(&self) -> Result<()> { + let providers = services().globals.config.idps.iter(); + + self.providers + .get_or_try_init(|| { + future::try_join_all(providers.map(Provider::fetch_metadata)) + .await + .map(Vec::into_iter) + .map(HashSet::from_iter) + }) + .await?; + + Ok(()) + } + + pub fn get(&self, provider: &str) -> Option<&Provider> { + let providers = self.providers.get().expect(""); + + providers.get(provider) + } + + pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result> { + self.db.user_from_subject(provider, subject) + } +} + +#[derive(Clone, Debug)] +pub struct Provider { + pub config: &'static IdpConfig, + pub metadata: VerifiedProviderMetadata, +} + +impl Provider { + pub async fn fetch_metadata(config: &'static IdpConfig) -> Result { + discovery::discover(services().sso.service(), &config.issuer) + .await + .map(|metadata| Provider { config, metadata }) + .map_err(|e| { + error!( + "Failed to fetch identity provider metadata ({}): {}", + &config.inner.id, e + ); + + Error::bad_config("Failed to fetch identity provider metadata.") + }) + } + + async fn fetch_signing_keys(&self) -> Result { + jose::fetch_jwks(&services().sso.service, self.metadata.jwks_uri()) + .await + .map_err(|e| { + error!("Failed to fetch signing keys for token endpoint: {}", e); + + Error::bad_config("Failed to fetch signing keys for token endpoint.") + }) + } + + pub async fn fetch_access_token( + &self, + auth_code: String, + validation_data: AuthorizationValidationData, + ) -> Result<(AccessTokenResponse, Option>)> { + } + + pub async fn fetch_userinfo( + &self, + access_token: &str, + id_token: &IdToken<'_>, + ) -> Result>> { + } +} + +impl Borrow for Provider { + fn borrow(&self) -> &str { + self.config.borrow() + } +} + +impl PartialEq for Provider { + fn eq(&self, other: &Self) -> bool { + self.config == other.config + } +} + +impl Eq for Provider {} + +impl Hash for Provider { + fn hash(&self, hasher: &mut H) { + self.config.hash(hasher) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct RegistrationToken { + pub info: RegistrationInfo, + pub provider_id: String, + pub unique_claim: String, + pub redirect_uri: Url, + pub expires_at: MilliSecondsSinceUnixEpoch, +} + +impl RegistrationToken { + pub fn new( + provider_id: String, + unique_claim: String, + redirect_uri: Url, + info: RegistrationInfo, + ) -> Self { + let expires_at = MilliSecondsSinceUnixEpoch::from_system_time( + UNIX_EPOCH + .checked_add(Duration::from_secs(REGISTRATION_EXPIRATION_SECS)) + .expect("SystemTime should not overflow"), + ) + .expect("MilliSecondsSinceUnixEpoch is not too large"); + + Self { + info, + provider_id, + unique_claim, + redirect_uri, + expires_at, + } + } +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct RegistrationInfo { + pub username: String, + pub displayname: Option, + pub avatar_url: Option, + pub email: Option, +} + +impl RegistrationInfo { + pub fn new( + claims: &HashMap, + username: &str, + displayname: &str, + avatar_url: &str, + email: &str, + ) -> Self { + Self { + username: claims + .get(username) + .and_then(|v| v.as_str()) + .map(ToOwned::to_owned) + .unwrap_or_default(), + displayname: claims + .get(displayname) + .and_then(|v| v.as_str()) + .map(ToOwned::to_owned), + avatar_url: claims + .get(avatar_url) + .and_then(|v| v.as_str()) + .map(Url::parse) + .and_then(Result::ok), + email: claims + .get(email) + .and_then(|v| v.as_str()) + .map(EmailAddress::from_str) + .and_then(Result::ok), + } + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct LoginToken { + pub inner: String, + pub provider_id: String, + pub user_id: OwnedUserId, + + #[serde(rename = "exp")] + expires_at: u64, +} + +impl LoginToken { + pub fn new(provider_id: String, user_id: OwnedUserId) -> Self { + let expires_at = SystemTime::now() + .checked_add(Duration::from_secs(LOGIN_TOKEN_EXPIRATION_SECS)) + .expect("SystemTime should not overflow") + .duration_since(UNIX_EPOCH) + .expect("SystemTime went backwards") + .as_secs(); + + Self { + inner: utils::random_string(TOKEN_LENGTH), + provider_id, + user_id, + expires_at, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ValidationData { + pub provider: String, + #[serde(flatten, with = "AuthorizationValidationDataDef")] + pub inner: AuthorizationValidationData, +} + +impl ValidationData { + pub fn new(provider: String, inner: AuthorizationValidationData) -> Self { + Self { provider, inner } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(remote = "AuthorizationValidationData")] +pub struct AuthorizationValidationDataDef { + pub state: String, + pub nonce: String, + pub redirect_uri: Url, + pub code_challenge_verifier: Option, +} + +impl From for AuthorizationValidationDataDef { + fn from( + AuthorizationValidationData { + state, + nonce, + redirect_uri, + code_challenge_verifier, + }: AuthorizationValidationData, + ) -> Self { + Self { + state, + nonce, + redirect_uri, + code_challenge_verifier, + } + } +} diff --git a/src/service/sso/templates.rs b/src/service/sso/templates.rs new file mode 100644 index 00000000..01512cc4 --- /dev/null +++ b/src/service/sso/templates.rs @@ -0,0 +1,34 @@ +pub fn base(title: &str, body: maud::Markup) -> maud::Markup { + maud::html! { + (maud::DOCTYPE) + html lang="en" { + head { + meta charset="utf-8"; + meta name="viewport" content="width=device-width, initial-scale=1.0"; + link rel="icon" type="image/png" sizes="32x32" href="https://conduit.rs/conduit.svg"; + style { (FONT_FACE) } + title { (title) } + } + body { (body) } + } + } +} + +pub fn footer() -> maud::Markup { + let info = "An open network for secure, decentralized communication."; + + maud::html! { + footer { p { (info) } } + } +} + +const FONT_FACE: &str = r#" + @font-face { + font-family: 'Source Sans 3 Variable'; + font-style: normal; + font-display: swap; + font-weight: 200 900; + src: url(https://cdn.jsdelivr.net/fontsource/fonts/source-sans-3:vf@latest/latin-wght-normal.woff2) format('woff2-variations'); + unicode-range: U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0304,U+0308,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD; + } +"#; From b80141b33b6b0514edf4e4f8e60d8dc6055a74f3 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Thu, 11 Jul 2024 22:24:22 +0200 Subject: [PATCH 4/7] fix --- Cargo.toml | 28 +- src/api/client_server/session.rs | 32 +- src/api/client_server/sso.rs | 585 +++++++++---------------------- src/database/key_value/sso.rs | 29 ++ src/database/mod.rs | 3 + src/main.rs | 4 +- src/service/mod.rs | 2 +- src/service/sso/mod.rs | 171 ++------- 8 files changed, 255 insertions(+), 599 deletions(-) create mode 100644 src/database/key_value/sso.rs diff --git a/Cargo.toml b/Cargo.toml index bc079115..044eefea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,18 +35,15 @@ axum = { version = "0.7", default-features = false, features = [ "json", "matched-path", ], optional = true } -axum-extra = { version = "0.9", features = ["cookie", "typed-header"] } +axum-extra = { version = "0.9", features = ["typed-header", "cookie"] } axum-server = { version = "0.6", features = ["tls-rustls"] } tower = { version = "0.4.13", features = ["util"] } -# tower-http = { version = "0.5", features = [ -# "add-extension", -# "cors", -# "sensitive-headers", -# "trace", -# "util", -# ] } tower-http = { version = "0.5", features = [ - "full", + "add-extension", + "cors", + "sensitive-headers", + "trace", + "util", ] } tower-service = "0.3" @@ -143,6 +140,7 @@ figment = { version = "0.10.8", features = ["env", "toml"] } # Validating urls in config url = { version = "2", features = ["serde"] } +mas-oidc-client = { version = "0.9", default-features = false, features = ["hyper"] } # HTML maud = { version = "0.26.0", default-features = false, features = ["axum"] } @@ -152,19 +150,14 @@ tikv-jemallocator = { version = "0.5.0", features = [ ], optional = true } sd-notify = { version = "0.4.1", optional = true } -http-body-util = "0.1.2" -hyper-rustls = { version = "0.27.2", default-features = false, features = ["http1", "http2", "ring", "rustls-native-certs", "rustls-platform-verifier"] } -mas-http = "0.9.0" # Used for matrix spec type definitions and helpers [dependencies.ruma] features = [ "appservice-api-c", - "client", "client-api", "compat", "federation-api", - "client-hyper", "push-gateway-api-c", "rand", "ring-compat", @@ -183,16 +176,11 @@ optional = true package = "rust-rocksdb" version = "0.25" -# Used for Single Sign-On -[dependencies.mas-oidc-client] -git = "https://github.com/matrix-org/matrix-authentication-service.git" -default-features = false - [target.'cfg(unix)'.dependencies] nix = { version = "0.28", features = ["resource"] } [features] -default = ["backend_rocksdb", "backend_sqlite", "conduit_bin", "systemd"] +default = ["backend_sqlite", "conduit_bin"] #backend_sled = ["sled"] backend_persy = ["parking_lot", "persy"] backend_sqlite = ["sqlite"] diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index aea8d3ac..148c67f5 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,5 +1,6 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma}; +use jsonwebtoken::{Algorithm, Validation}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,17 +25,16 @@ struct Claims { pub async fn get_login_types_route( _body: Ruma, ) -> Result { + let identity_providers: Vec<_> = services().sso.login_type().collect(); let mut flows = vec![ get_login_types::v3::LoginType::Password(Default::default()), get_login_types::v3::LoginType::ApplicationService(Default::default()), ]; - if let v @ [_, ..] = &*services().sso.flows() { - let flow = get_login_types::v3::SsoLoginType { - identity_providers: v.to_owned(), - }; - - flows.push(get_login_types::v3::LoginType::Sso(flow)); + if !identity_providers.is_empty() { + flows.push(get_login_types::v3::LoginType::Sso( + get_login_types::v3::SsoLoginType { identity_providers }, + )); } Ok(get_login_types::v3::Response::new(flows)) @@ -113,30 +113,26 @@ pub async fn login_route(body: Ruma) -> Result { match ( services().globals.jwt_decoding_key(), - &services().sso.providers().is_empty(), + services().sso.login_type().next().is_some(), ) { (_, false) => { - let mut validation = - jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); + let mut validation = Validation::new(Algorithm::HS256); validation.validate_nbf = false; validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); - let login_token = services() + services() .globals .validate_claims::(token, Some(validation)) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.") - })?; - - login_token.audience().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid token audience.") - })? + .as_ref() + .map(LoginToken::audience) + .map(ToOwned::to_owned) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid token."))? } (Some(jwt_decoding_key), _) => { let token = jsonwebtoken::decode::( token, jwt_decoding_key, - &jsonwebtoken::Validation::default(), + &Validation::default(), ) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs index c6d26d27..c5e3b0e3 100644 --- a/src/api/client_server/sso.rs +++ b/src/api/client_server/sso.rs @@ -1,30 +1,24 @@ use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime}; use crate::{ - config::{ - sso::{Registration, Template}, - IdpConfig, - }, + config::IdpConfig, service::sso::{ - templates, LoginToken, RegistrationInfo, RegistrationToken, ValidationData, - REGISTRATION_EXPIRATION_SECS, SESSION_EXPIRATION_SECS, SSO_AUTH_EXPIRATION_SECS, - SSO_SESSION_COOKIE, + LoginToken, ValidationData, SSO_AUTH_EXPIRATION_SECS, SSO_SESSION_COOKIE, SUBJECT_CLAIM_KEY, }, services, utils, Error, Result, Ruma, }; use axum::{ - extract::RawQuery, response::{AppendHeaders, IntoResponse, Redirect}, RequestExt, }; use axum_extra::{ - headers::{self, HeaderMapExt}, + headers::{self}, TypedHeader, }; use http::header; use mas_oidc_client::{ requests::{ - authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData}, + authorization_code::{self, AuthorizationRequestData}, jose::{self, JwtVerificationData}, userinfo, }, @@ -35,17 +29,17 @@ use mas_oidc_client::{ requests::{AccessTokenResponse, AuthorizationResponse}, }, }; -use rand::{rngs::StdRng, SeedableRng}; +use rand::{rngs::StdRng, Rng, SeedableRng}; use ruma::{ api::client::{ error::ErrorKind, - session::{self, sso_login, sso_login_with_provider}, + session::{sso_login, sso_login_with_provider}, }, - events::GlobalAccountDataEventType, - push, OwnedMxcUri, UserId, + events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, + push, UserId, }; -use serde_json::Number; -use tracing::error; +use serde_json::Value; +use tracing::{error, info, warn}; use url::Url; pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; @@ -54,17 +48,28 @@ pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; /// /// Redirect the user to the SSO interface. /// TODO: this should be removed once Ruma supports trailing slashes. -pub async fn get_sso_redirect( - body: Ruma, +pub async fn get_sso_redirect_route( + Ruma { + body, + sender_user, + sender_device, + sender_servername, + json_body, + .. + }: Ruma, ) -> Result { let sso_login_with_provider::v3::Response { location, cookie } = - get_sso_redirect_with_provider( + get_sso_redirect_with_provider_route( Ruma { body: sso_login_with_provider::v3::Request::new( Default::default(), - body.redirect_url.clone(), + body.redirect_url, ), - ..body + sender_user, + sender_device, + sender_servername, + json_body, + appservice_info: None, } .into(), ) @@ -76,7 +81,7 @@ pub async fn get_sso_redirect( /// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}` /// /// Redirects the user to the SSO interface. -pub async fn get_sso_redirect_with_provider( +pub async fn get_sso_redirect_with_provider_route( body: Ruma, ) -> Result { let idp_ids: Vec<&str> = services() @@ -124,7 +129,7 @@ pub async fn get_sso_redirect_with_provider( .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?; let signed = services().globals.sign_claims(&ValidationData::new( - provider.borrow().to_string(), + Borrow::::borrow(provider).to_owned(), validation_data, )); @@ -142,12 +147,7 @@ pub async fn get_sso_redirect_with_provider( }) } -/// # `GET /_conduit/client/sso/callback` -/// -/// Validate the authorization response received from the identity provider. -/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. -/// If this is the first login, register the user, possibly interactively through a fallback page. -pub async fn get_sso_callback(req: axum::extract::Request) -> Result { +async fn handle_callback_helper(req: axum::extract::Request) -> Result { let query = req.uri().query().ok_or_else(|| { Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.") })?; @@ -158,22 +158,26 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result(query).map_err(|_| { - serde_html_form::from_str::(query).unwrap_or_else(|_| { - error!("Failed to deserialize authorization callback: {}", callback); + } = serde_html_form::from_str(query).map_err(|_| { + serde_html_form::from_str(query) + .map(ClientError::into) + .unwrap_or_else(|_| { + error!("Failed to deserialize authorization callback: {}", query); - Error::BadRequest( - ErrorKind::Unknown, - "Failed to deserialize authorization callback.", - ) - }) + Error::BadRequest( + ErrorKind::Unknown, + "Failed to deserialize authorization callback.", + ) + }) })?; - let cookie = req - .extract::>>() - .await - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid session cookie."))? - .ok_or_else(|_| Error::BadRequest(ErrorKind::MissingParam, "Missing session cookie."))?; + let Ok(Some(cookie)): Result>, _> = req.extract().await + else { + return Err(Error::BadRequest( + ErrorKind::MissingParam, + "Missing session cookie.", + )); + }; let ValidationData { provider, @@ -186,11 +190,11 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result Result ClientCredentials::ClientSecretBasic { client_id, client_secret, @@ -215,6 +219,16 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result todo!(), }; + let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri()) + .await + .map_err(|_| Error::bad_config("Failed to fetch signing keys for token endpoint."))?; + let jwt_verification_data = Some(JwtVerificationData { + jwks, + issuer: &provider.config.issuer, + client_id: &provider.config.client_id, + signing_algorithm: &JsonWebSignatureAlg::Rs256, + }); + let ( AccessTokenResponse { access_token, @@ -227,34 +241,22 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result Result { + let subject = match id_token.get(SUBJECT_CLAIM_KEY) { + Some(Value::String(s)) => s.to_owned(), + Some(Value::Number(n)) => n.to_string(), + value => { return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Single Sign-On registration is disabled.", - )) + ErrorKind::Unknown, + value + .map(|_| { + error!("Subject claim is missing from ID token: {id_token:?}"); + + "Subject claim is missing from ID token." + }) + .unwrap_or("Subject claim should be a string or number."), + )); } - Registration::Automated => todo!(), - Registration::Interactive => {} }; - let Template { - username, - displayname, - avatar_url, - email, - } = &provider.config.template; - let registration_info = - RegistrationInfo::new(&claims, username, displayname, avatar_url, email); - - let signed = services() - .globals - .sign_macaroon(&RegistrationToken::new( - validation_data.provider_id.clone(), - subject.to_owned(), - redirect_uri.to_owned(), - registration_info, - )) - .expect("signing macaroons always works"); - - let cookie = utils::build_cookie( - "sso-registration", - &signed, - "/_conduit/client/sso/register", - REGISTRATION_EXPIRATION_SECS, - ); - - Ok(( - AppendHeaders(vec![ - (header::SET_COOKIE, cookie.to_string()), - ( - header::SET_COOKIE, - utils::reset_cookie("sso-session").to_string(), - ), - ]), - Redirect::temporary("/_conduit/client/sso/register"), - ) - .into_response()) -} - -/// # `GET /_conduit/client/sso/pick_idp` -pub async fn pick_idp(RawQuery(query): RawQuery) -> impl IntoResponse { - let providers: Vec<_> = services() - .globals - .config + let user_id = match services() .sso - .iter() - .map(|p| p.inner.to_owned()) - .collect(); + .user_from_subject(Borrow::::borrow(provider), &subject)? + { + Some(user_id) => user_id, + None => { + let mut localpart = subject.clone(); - let body = maud::html! { - header { - h1 { "Log in to " (services().globals.server_name()) } - p { "Choose an identity provider to continue" } - } - main { - ul .providers { - @for provider in providers { - li { - a href={ "/_matrix/client/v3/login/sso/redirect/" (provider.id) "?" (query.as_deref().unwrap_or_default()) } { - @if let Some(url) = provider.icon.as_deref().and_then(utils::mxc_to_http) { - img src=(url); - } - } - span { - (provider.name) - } + let user_id = loop { + match UserId::parse_with_server_name(&*localpart, services().globals.server_name()) + { + Ok(user_id) if services().users.exists(&user_id)? => break user_id, + _ => { + let n: u8 = rand::thread_rng().gen(); + + localpart = format!("{}{}", localpart, n % 10); } } + }; + + services().users.set_placeholder_password(&user_id)?; + let mut displayname = id_token + .get("preferred_username") + .or(id_token.get("nickname")) + .as_deref() + .map(Value::to_string) + .unwrap_or(user_id.localpart().to_owned()); + + // If enabled append lightning bolt to display name (default true) + if services().globals.enable_lightning_bolt() { + displayname.push_str(" ⚡️"); } + + services() + .users + .set_displayname(&user_id, Some(displayname.clone()))?; + + // Initial account data + services().account_data.update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json always works"), + )?; + + info!("New user {} registered on this server.", user_id); + services() + .admin + .send_message(RoomMessageEventContent::notice_plain(format!( + "New user {user_id} registered on this server." + ))); + + if let Some(admin_room) = services().admin.get_admin_room()? { + if services() + .rooms + .state_cache + .room_joined_count(&admin_room)? + == Some(1) + { + services() + .admin + .make_user_admin(&user_id, displayname) + .await?; + + warn!("Granting {} admin privileges as the first user", user_id); + } + } + + user_id } }; - ( - [(header::CONTENT_TYPE, "text/html; charset=utf-8")], - maud::html! { - (templates::base("Pick Identity Provider", body)) + let signed = services().globals.sign_claims(&LoginToken::new( + Borrow::::borrow(provider).to_owned(), + user_id, + )); - (templates::footer()) - }, - ) -} - -/// # `GET /_conduit/client/sso/register` -/// -/// Serve a registration form with defaults based on the retrieved claims. -/// This endpoint is only available when interactive registration is enabled. -pub async fn get_sso_registration( - cookie: TypedHeader, -) -> Result { - let token = cookie.get("sso-registration").ok_or_else(|| { - Error::BadRequest( - ErrorKind::MissingParam, - "Missing registration token cookie.", - ) - })?; - - let registration_token: RegistrationToken = services() - .globals - .validate_macaroon(token, None) - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid registration token cookie.", - ) - })?; - - let provider = services() - .sso - .get(®istration_token.provider_id) - .map(|p| p.config.inner.to_owned())?; - let server_name = services().globals.server_name(); - - let RegistrationInfo { - username, - displayname, - avatar_url, - email, - } = registration_token.info; - - let additional_info = (&displayname, &avatar_url, &email) != (&None, &None, &None); - - fn detail(title: &str, body: maud::Markup) -> maud::Markup { - maud::html! { - label .detail for=(title) { - div .check-row { - span .name { (title) } " " - span .use { "use" } - input #(title) type="checkbox" name={(title)"-checkbox"} value=(true) checked; - } - (body) - } - } - } - - let body = maud::html! { - header { - h1 { "Complete your registration at " (server_name) } - p { "Confirm your details to finish creating your account." } - } - main { - form .form #form method="post" { - div .username-div #username-div { - label for="username-input" { "Username (required)" } - div .prefix { "@" } - input .username-input type="text" name="username" - value=(username) autofocus autocorrect="off" autocapitalize="none"; - div .postfix { ":" (server_name) } - } - output .username-output for="username-input" { } - - @if additional_info { - section .additional-info { - h2 { - @if let Some(icon) = provider.icon.as_deref().and_then(utils::mxc_to_http) { - img src=(icon.to_string()); - } - "Optional data from " (provider.name) - } - @if let Some(avatar_url) = avatar_url.as_ref() { - (detail("avatar", maud::html!{ - img .avatar src=(avatar_url); - })) - } - @if let Some(displayname) = displayname.as_ref() { - (detail("displayname", maud::html!{ - p .value { (displayname) }; - })) - } - @if let Some(email) = email.as_ref() { - (detail("email", maud::html!{ - p .value { (email) }; - })) - } - } - } - - input type="submit" value="Submit" .primary-button {} - } - } - }; + let mut redirect_uri = validation_data.redirect_uri; + redirect_uri + .query_pairs_mut() + .append_pair("loginToken", &signed); Ok(( - [(header::CONTENT_TYPE, "text/html; charset=utf-8")], - maud::html! { - (templates::base("Register Account", body)) - - (templates::footer()) - }, - ) - .into_response()) -} - -/// # `POST /_conduit/client/sso/register` -/// -/// Submit the registration form. -pub async fn submit_sso_registration( - cookie: TypedHeader, - axum::extract::Form(registration_info): axum::extract::Form, -) -> Result { - let token = cookie.get("sso-registration").ok_or_else(|| { - Error::BadRequest( - ErrorKind::MissingParam, - "Missing registration token cookie.", - ) - })?; - - let registration_token: RegistrationToken = services() - .globals - .validate_macaroon(token, None) - .map_err(|_| { - Error::BadRequest( - ErrorKind::MissingParam, - "Invalid registration token cookie.", - ) - })?; - - let RegistrationInfo { - username, - mut displayname, - avatar_url, - email: _, - } = registration_info; - - let user_id = - UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Invalid username."))?; - - if services().users.exists(&user_id)? { - return Err(Error::BadRequest( - ErrorKind::UserInUse, - "Desired UserId is already taken.", - )); - } - - if services().appservice.is_exclusive_user_id(&user_id).await { - return Err(Error::BadRequest( - ErrorKind::Exclusive, - "Desired UserId reserved by appservice.", - )); - } - - services().users.create(&user_id, None)?; - services().users.set_password_placeholder(&user_id)?; - - if let Some(avatar_url) = avatar_url { - let request = services().globals.default_client().get(avatar_url.as_ref()); - - let res = request.send().await.map_err(|_| { - Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.") - })?; - - let filename = avatar_url.path_segments().and_then(Iterator::last); - - let (content_type, body): (Option, Vec) = ( - res.headers().typed_get(), - res.bytes().await.map(Into::into).map_err(|_| { - Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.") - })?, - ); - - let mxc = format!( - "mxc://{}/{}", - services().globals.server_name(), - utils::random_string(crate::api::client_server::MXC_LENGTH) - ); - - services() - .media - .create( - mxc.clone(), - filename - .map(|filename| "inline; filename=".to_owned() + filename) - .as_deref(), - content_type.map(|header| header.to_string()).as_deref(), - &body, - ) - .await?; - - services() - .users - .set_avatar_url(&user_id, Some(OwnedMxcUri::from(mxc)))?; - }; - - if let (Some(displayname), true) = ( - displayname.as_mut(), - services().globals.config.enable_lightning_bolt, - ) { - displayname.push_str(" ⚡️"); - } - - services().users.set_displayname(&user_id, displayname)?; - - services().sso.save_claim( - ®istration_token.provider_id, - &user_id, - ®istration_token.unique_claim, - )?; - - services().account_data.update( - None, - &user_id, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: push::Ruleset::server_default(&user_id), - }, - }) - .expect("PushRulesEvent should always serialize"), - )?; - - let login_token = LoginToken::new(registration_token.provider_id, user_id); - let redirect_uri = redirect_with_login_token(registration_token.redirect_uri, &login_token); - - Ok(( - AppendHeaders([( + AppendHeaders(vec![( header::SET_COOKIE, - utils::reset_cookie("sso-registration").to_string(), + utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string(), )]), Redirect::temporary(redirect_uri.as_str()), ) .into_response()) } -fn redirect_with_login_token(mut redirect_uri: Url, login_token: &LoginToken) -> Url { - let signed = services() - .globals - .sign_macaroon(login_token) - .expect("signing macaroons should always works"); - - redirect_uri - .query_pairs_mut() - .append_pair("loginToken", &signed); - - redirect_uri +/// # `GET /_conduit/client/sso/callback` +/// +/// Validate the authorization response received from the identity provider. +/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. +/// If this is the first login, register the user, possibly interactively through a fallback page. +pub async fn handle_callback_route(req: axum::extract::Request) -> axum::response::Response { + match handle_callback_helper(req).await { + Ok(res) => res, + Err(e) => e.into_response(), + } } diff --git a/src/database/key_value/sso.rs b/src/database/key_value/sso.rs new file mode 100644 index 00000000..1f6eab28 --- /dev/null +++ b/src/database/key_value/sso.rs @@ -0,0 +1,29 @@ +use ruma::{OwnedUserId, UserId}; + +use crate::{service, utils, Error, KeyValueDatabase, Result}; + +impl service::sso::Data for KeyValueDatabase { + fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()> { + let mut key = provider.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(subject.as_bytes()); + + self.subject_userid.insert(&key, user_id.as_bytes()) + } + + fn user_from_subject(&self, provider: &str, subject: &str) -> Result> { + let mut key = provider.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(subject.as_bytes()); + + self.subject_userid.get(&key)?.map_or(Ok(None), |bytes| { + Some( + UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in claim_userid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("User ID in claim_userid is invalid.")), + ) + .transpose() + }) + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 16a5e60a..39550a93 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -49,6 +49,7 @@ pub struct KeyValueDatabase { pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 pub(super) token_userdeviceid: Arc, + pub(super) subject_userid: Arc, pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count @@ -289,6 +290,8 @@ impl KeyValueDatabase { userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, token_userdeviceid: builder.open_tree("token_userdeviceid")?, + subject_userid: builder.open_tree("subject_userid")?, + onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, keychangeid_userid: builder.open_tree("keychangeid_userid")?, diff --git a/src/main.rs b/src/main.rs index 0b07fe2b..c3ad4c1e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use axum::{ }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use conduit::api::{ - client_server::{self, SSO_CALLBACK_PATH}, + client_server::{self, CALLBACK_PATH}, server_server, }; use figment::{ @@ -283,7 +283,7 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::get_sso_redirect_with_provider_route) // The specification will likely never introduce any endpoint for handling authorization callbacks. // As a workaround, we use custom path that redirects the user to the default login handler. - .route(SSO_CALLBACK_PATH, get(client_server::sso_login_route)) + .route(CALLBACK_PATH, get(client_server::handle_callback_route)) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) diff --git a/src/service/mod.rs b/src/service/mod.rs index fae5a726..6d8c34d2 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -123,7 +123,7 @@ impl Services { key_backups: key_backups::Service { db }, media: media::Service { db }, sending: sending::Service::build(db, &config), - sso: sso::Service::build(db), + sso: sso::Service::build(db)?, globals: globals::Service::load(db, config)?, }) diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs index 1206ed34..31c9c3ab 100644 --- a/src/service/sso/mod.rs +++ b/src/service/sso/mod.rs @@ -1,50 +1,36 @@ -mod data; use std::{ borrow::Borrow, - collections::{HashMap, HashSet}, + collections::HashSet, hash::{Hash, Hasher}, - str::FromStr, - sync::{Arc, RwLock}, - time::{Duration, SystemTime, UNIX_EPOCH}, + sync::Arc, }; use crate::{ - api::client_server::TOKEN_LENGTH, - config::{sso::ProviderConfig as Config, IdpConfig}, + api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH}, + config::IdpConfig, utils, Error, Result, }; -pub use data::Data; -use email_address::EmailAddress; use futures_util::future::{self}; use mas_oidc_client::{ http_service::{hyper, HttpService}, - jose::jwk::PublicJsonWebKeySet, - requests::{ - authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData}, - discovery, - jose::{self, JwtVerificationData}, - userinfo, - }, - types::{ - iana::jose::JsonWebSignatureAlg, oidc::VerifiedProviderMetadata, - requests::AccessTokenResponse, IdToken, - }, + requests::{authorization_code::AuthorizationValidationData, discovery}, + types::oidc::VerifiedProviderMetadata, }; -use rand::SeedableRng; -use ruma::{api::client::error::ErrorKind, MilliSecondsSinceUnixEpoch, OwnedUserId, UserId}; +use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId}; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tokio::sync::{oneshot, OnceCell}; +use tokio::sync::OnceCell; use tracing::error; use url::Url; use crate::services; +mod data; pub use data::Data; pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60; pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2; pub const SSO_SESSION_COOKIE: &str = "sso-auth"; +pub const SUBJECT_CLAIM_KEY: &str = "sub"; pub struct Service { db: &'static dyn Data, @@ -69,7 +55,7 @@ impl Service { let providers = services().globals.config.idps.iter(); self.providers - .get_or_try_init(|| { + .get_or_try_init(|| async move { future::try_join_all(providers.map(Provider::fetch_metadata)) .await .map(Vec::into_iter) @@ -86,6 +72,12 @@ impl Service { providers.get(provider) } + pub fn login_type(&self) -> impl Iterator + '_ { + let providers = self.providers.get().expect(""); + + providers.iter().map(|p| p.config.inner.clone()) + } + pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result> { self.db.user_from_subject(provider, subject) } @@ -111,30 +103,6 @@ impl Provider { Error::bad_config("Failed to fetch identity provider metadata.") }) } - - async fn fetch_signing_keys(&self) -> Result { - jose::fetch_jwks(&services().sso.service, self.metadata.jwks_uri()) - .await - .map_err(|e| { - error!("Failed to fetch signing keys for token endpoint: {}", e); - - Error::bad_config("Failed to fetch signing keys for token endpoint.") - }) - } - - pub async fn fetch_access_token( - &self, - auth_code: String, - validation_data: AuthorizationValidationData, - ) -> Result<(AccessTokenResponse, Option>)> { - } - - pub async fn fetch_userinfo( - &self, - access_token: &str, - id_token: &IdToken<'_>, - ) -> Result>> { - } } impl Borrow for Provider { @@ -157,105 +125,28 @@ impl Hash for Provider { } } -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct RegistrationToken { - pub info: RegistrationInfo, - pub provider_id: String, - pub unique_claim: String, - pub redirect_uri: Url, - pub expires_at: MilliSecondsSinceUnixEpoch, -} - -impl RegistrationToken { - pub fn new( - provider_id: String, - unique_claim: String, - redirect_uri: Url, - info: RegistrationInfo, - ) -> Self { - let expires_at = MilliSecondsSinceUnixEpoch::from_system_time( - UNIX_EPOCH - .checked_add(Duration::from_secs(REGISTRATION_EXPIRATION_SECS)) - .expect("SystemTime should not overflow"), - ) - .expect("MilliSecondsSinceUnixEpoch is not too large"); - - Self { - info, - provider_id, - unique_claim, - redirect_uri, - expires_at, - } - } -} - -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -pub struct RegistrationInfo { - pub username: String, - pub displayname: Option, - pub avatar_url: Option, - pub email: Option, -} - -impl RegistrationInfo { - pub fn new( - claims: &HashMap, - username: &str, - displayname: &str, - avatar_url: &str, - email: &str, - ) -> Self { - Self { - username: claims - .get(username) - .and_then(|v| v.as_str()) - .map(ToOwned::to_owned) - .unwrap_or_default(), - displayname: claims - .get(displayname) - .and_then(|v| v.as_str()) - .map(ToOwned::to_owned), - avatar_url: claims - .get(avatar_url) - .and_then(|v| v.as_str()) - .map(Url::parse) - .and_then(Result::ok), - email: claims - .get(email) - .and_then(|v| v.as_str()) - .map(EmailAddress::from_str) - .and_then(Result::ok), - } - } -} - #[derive(Clone, Deserialize, Serialize)] pub struct LoginToken { - pub inner: String, - pub provider_id: String, - pub user_id: OwnedUserId, - - #[serde(rename = "exp")] - expires_at: u64, + pub iss: String, + pub aud: OwnedUserId, + pub sub: String, + pub exp: u64, } impl LoginToken { - pub fn new(provider_id: String, user_id: OwnedUserId) -> Self { - let expires_at = SystemTime::now() - .checked_add(Duration::from_secs(LOGIN_TOKEN_EXPIRATION_SECS)) - .expect("SystemTime should not overflow") - .duration_since(UNIX_EPOCH) - .expect("SystemTime went backwards") - .as_secs(); - + pub fn new(provider: String, user_id: OwnedUserId) -> Self { Self { - inner: utils::random_string(TOKEN_LENGTH), - provider_id, - user_id, - expires_at, + iss: provider, + aud: user_id, + sub: utils::random_string(TOKEN_LENGTH), + exp: utils::millis_since_unix_epoch() + .checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000) + .expect("time overflow"), } } + pub fn audience(&self) -> &UserId { + &self.aud + } } #[derive(Clone, Debug, Deserialize, Serialize)] From 5af171e7ee9dcc730f87ac4ca7620d9ff2a7eac9 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Thu, 11 Jul 2024 22:44:47 +0200 Subject: [PATCH 5/7] ok --- src/api/client_server/sso.rs | 14 ++++++++------ src/database/key_value/sso.rs | 4 ++-- src/database/mod.rs | 4 ---- src/main.rs | 5 ++++- src/service/sso/mod.rs | 5 +++-- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs index c5e3b0e3..bb7719bc 100644 --- a/src/api/client_server/sso.rs +++ b/src/api/client_server/sso.rs @@ -122,7 +122,7 @@ pub async fn get_sso_redirect_with_provider_route( AuthorizationRequestData::new( provider.config.client_id.clone(), provider.config.scopes.clone(), - redirect_url, + callback, ), &mut StdRng::from_entropy(), ) @@ -130,6 +130,7 @@ pub async fn get_sso_redirect_with_provider_route( let signed = services().globals.sign_claims(&ValidationData::new( Borrow::::borrow(provider).to_owned(), + redirect_url.to_string(), validation_data, )); @@ -139,7 +140,7 @@ pub async fn get_sso_redirect_with_provider_route( utils::build_cookie( SSO_SESSION_COOKIE, &signed, - "/_conduit/client/sso/callback", + CALLBACK_PATH, Some(SSO_AUTH_EXPIRATION_SECS), ) .to_string(), @@ -181,6 +182,7 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result Result Result Result Result> { @@ -16,7 +16,7 @@ impl service::sso::Data for KeyValueDatabase { key.push(0xff); key.extend_from_slice(subject.as_bytes()); - self.subject_userid.get(&key)?.map_or(Ok(None), |bytes| { + self.providersubjectid_userid.get(&key)?.map_or(Ok(None), |bytes| { Some( UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("User ID in claim_userid is invalid unicode.") diff --git a/src/database/mod.rs b/src/database/mod.rs index 39550a93..a52fb637 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -49,8 +49,6 @@ pub struct KeyValueDatabase { pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 pub(super) token_userdeviceid: Arc, - pub(super) subject_userid: Arc, - pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count @@ -290,8 +288,6 @@ impl KeyValueDatabase { userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, token_userdeviceid: builder.open_tree("token_userdeviceid")?, - subject_userid: builder.open_tree("subject_userid")?, - onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, keychangeid_userid: builder.open_tree("keychangeid_userid")?, diff --git a/src/main.rs b/src/main.rs index c3ad4c1e..15b59be4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -283,7 +283,10 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::get_sso_redirect_with_provider_route) // The specification will likely never introduce any endpoint for handling authorization callbacks. // As a workaround, we use custom path that redirects the user to the default login handler. - .route(CALLBACK_PATH, get(client_server::handle_callback_route)) + .route( + &format!("/{CALLBACK_PATH}"), + get(client_server::handle_callback_route), + ) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs index 31c9c3ab..242d92cc 100644 --- a/src/service/sso/mod.rs +++ b/src/service/sso/mod.rs @@ -152,13 +152,14 @@ impl LoginToken { #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ValidationData { pub provider: String, + pub redirect_url: String, #[serde(flatten, with = "AuthorizationValidationDataDef")] pub inner: AuthorizationValidationData, } impl ValidationData { - pub fn new(provider: String, inner: AuthorizationValidationData) -> Self { - Self { provider, inner } + pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self { + Self { provider, redirect_url, inner } } } From 139588b64c137efee6b398f39551cc74c637a651 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 15 Jul 2024 06:08:25 +0200 Subject: [PATCH 6/7] nice --- Cargo.toml | 16 ++++- src/api/client_server/sso.rs | 110 ++++++++++++++++++++++++++++++----- src/service/sso/mod.rs | 23 +++++++- src/service/uiaa/mod.rs | 3 - 4 files changed, 131 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 044eefea..c52ed191 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,10 +41,23 @@ tower = { version = "0.4.13", features = ["util"] } tower-http = { version = "0.5", features = [ "add-extension", "cors", + "follow-redirect", + "map-request-body", "sensitive-headers", + "set-header", + "timeout", "trace", "util", ] } +# tower-http = { version = "0.5", features = [ +# "add-extension", +# "cors", +# "decompression-full", +# "sensitive-headers", +# "set-header", +# "trace", +# "util", +# ] } tower-service = "0.3" # Async runtime and utilities @@ -140,8 +153,9 @@ figment = { version = "0.10.8", features = ["env", "toml"] } # Validating urls in config url = { version = "2", features = ["serde"] } -mas-oidc-client = { version = "0.9", default-features = false, features = ["hyper"] } # HTML +mas-oidc-client = { git = "https://github.com/matrix-org/matrix-authentication-service", default-features = false } +mas-http = { git = "https://github.com/matrix-org/matrix-authentication-service", features = ["client"] } maud = { version = "0.26.0", default-features = false, features = ["axum"] } async-trait = "0.1.68" diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs index bb7719bc..e0b7f9a2 100644 --- a/src/api/client_server/sso.rs +++ b/src/api/client_server/sso.rs @@ -15,7 +15,7 @@ use axum_extra::{ headers::{self}, TypedHeader, }; -use http::header; +use http::header::{self}; use mas_oidc_client::{ requests::{ authorization_code::{self, AuthorizationRequestData}, @@ -42,7 +42,7 @@ use serde_json::Value; use tracing::{error, info, warn}; use url::Url; -pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; +pub const CALLBACK_PATH: &str = "/_matrix/client/unstable/conduit/callback"; /// # `GET /_matrix/client/v3/login/sso/redirect` /// @@ -155,10 +155,10 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result Result Result Result axum::respons Err(e) => e.into_response(), } } + +mod sso_callback { + use axum_extra::headers::{self, HeaderMapExt}; + use http::Method; + use mas_oidc_client::types::requests::AuthorizationResponse; + use ruma::{ + api::{ + client::Error, + error::{FromHttpRequestError, HeaderDeserializationError}, + IncomingRequest, Metadata, OutgoingResponse, + }, + metadata, + }; + + use crate::service::sso::SSO_SESSION_COOKIE; + + pub const METADATA: Metadata = metadata! { + method: GET, + rate_limited: false, + authentication: None, + history: { + 1.0 => "/_matrix/client/unstable/conduit/callback", + } + }; + + pub struct Request { + response: AuthorizationResponse, + cookie: String, + } + + pub struct Response {} + + impl IncomingRequest for Request { + type EndpointError = Error; + type OutgoingResponse = Response; + + const METADATA: Metadata = METADATA; + + fn try_from_http_request( + req: http::Request, + _path_args: &[S], + ) -> Result + where + B: AsRef<[u8]>, + S: AsRef, + { + if !(req.method() == METADATA.method + || req.method() == Method::HEAD && METADATA.method == Method::GET) + { + return Err(FromHttpRequestError::MethodMismatch { + expected: METADATA.method, + received: req.method().clone(), + }); + } + + let response: AuthorizationResponse = + serde_html_form::from_str(req.uri().query().unwrap_or(""))?; + + let Some(cookie) = req + .headers() + .typed_get() + .and_then(|cookie: headers::Cookie| { + cookie.get(SSO_SESSION_COOKIE).map(str::to_owned) + }) + else { + return Err(HeaderDeserializationError::MissingHeader( + "Cookie".to_owned(), + ))?; + }; + + Ok(Self { response, cookie }) + } + } + + impl OutgoingResponse for Response { + fn try_into_http_response( + self, + ) -> Result, ruma::api::error::IntoHttpError> { + todo!() + } + } +} diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs index 242d92cc..6ffef5ab 100644 --- a/src/service/sso/mod.rs +++ b/src/service/sso/mod.rs @@ -12,13 +12,14 @@ use crate::{ }; use futures_util::future::{self}; use mas_oidc_client::{ - http_service::{hyper, HttpService}, + http_service::HttpService, requests::{authorization_code::AuthorizationValidationData, discovery}, types::oidc::VerifiedProviderMetadata, }; use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId}; use serde::{Deserialize, Serialize}; use tokio::sync::OnceCell; +use tower::BoxError; use tracing::error; use url::Url; @@ -40,9 +41,21 @@ pub struct Service { impl Service { pub fn build(db: &'static dyn Data) -> Result> { + let client = tower::ServiceBuilder::new() + .map_err(BoxError::from) + .layer(mas_http::BytesToBodyRequestLayer) + .layer(mas_http::BodyToBytesResponseLayer) + // .override_request_header(http::header::USER_AGENT, "conduit".to_owned()) + // .concurrency_limit(10) + // .follow_redirects() + // .layer(tower_http::timeout::TimeoutLayer::new( + // std::time::Duration::from_secs(10), + // )) + .service(mas_http::make_untraced_client()); + Ok(Arc::new(Self { db, - service: HttpService::new(hyper::hyper_service()), + service: HttpService::new(client), providers: OnceCell::new(), })) } @@ -159,7 +172,11 @@ pub struct ValidationData { impl ValidationData { pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self { - Self { provider, redirect_url, inner } + Self { + provider, + redirect_url, + inner, + } } } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 677d49f0..696be958 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -110,9 +110,6 @@ impl Service { AuthData::Dummy(_) => { uiaainfo.completed.push(AuthType::Dummy); } - AuthData::FallbackAcknowledgement(fallback) => { - todo!() - } k => error!("type not supported: {:?}", k), } From 67c23d6dd4d57ee74a12fa2a4ab3cc140e277487 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 15 Jul 2024 12:24:06 +0200 Subject: [PATCH 7/7] feat: base support --- Cargo.toml | 28 ++--- docs/configuration.md | 21 +--- src/api/client_server/account.rs | 2 - src/api/client_server/keys.rs | 32 +++--- src/api/client_server/session.rs | 20 ++-- src/api/client_server/sso.rs | 184 +++++++++++++++---------------- src/main.rs | 10 +- src/service/globals/mod.rs | 8 +- src/service/sso/mod.rs | 21 ++-- 9 files changed, 148 insertions(+), 178 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c52ed191..93abd753 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ axum = { version = "0.7", default-features = false, features = [ "json", "matched-path", ], optional = true } -axum-extra = { version = "0.9", features = ["typed-header", "cookie"] } +axum-extra = { version = "0.9", features = ["cookie", "typed-header"] } axum-server = { version = "0.6", features = ["tls-rustls"] } tower = { version = "0.4.13", features = ["util"] } tower-http = { version = "0.5", features = [ @@ -49,15 +49,6 @@ tower-http = { version = "0.5", features = [ "trace", "util", ] } -# tower-http = { version = "0.5", features = [ -# "add-extension", -# "cors", -# "decompression-full", -# "sensitive-headers", -# "set-header", -# "trace", -# "util", -# ] } tower-service = "0.3" # Async runtime and utilities @@ -153,11 +144,6 @@ figment = { version = "0.10.8", features = ["env", "toml"] } # Validating urls in config url = { version = "2", features = ["serde"] } -# HTML -mas-oidc-client = { git = "https://github.com/matrix-org/matrix-authentication-service", default-features = false } -mas-http = { git = "https://github.com/matrix-org/matrix-authentication-service", features = ["client"] } -maud = { version = "0.26.0", default-features = false, features = ["axum"] } - async-trait = "0.1.68" tikv-jemallocator = { version = "0.5.0", features = [ "unprefixed_malloc_on_supported_platforms", @@ -190,11 +176,21 @@ optional = true package = "rust-rocksdb" version = "0.25" +[dependencies.mas-http] +features = ["client"] +git = "https://github.com/matrix-org/matrix-authentication-service" +rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8" + +[dependencies.mas-oidc-client] +features = [] +git = "https://github.com/matrix-org/matrix-authentication-service" +rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8" + [target.'cfg(unix)'.dependencies] nix = { version = "0.28", features = ["resource"] } [features] -default = ["backend_sqlite", "conduit_bin"] +default = ["backend_rocksdb", "backend_sqlite", "conduit_bin", "systemd"] #backend_sled = ["sled"] backend_persy = ["parking_lot", "persy"] backend_sqlite = ["sqlite"] diff --git a/docs/configuration.md b/docs/configuration.md index a8fa07de..21375004 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -124,25 +124,6 @@ Identity providers using OAuth such as Github are not supported yet. | `name` | `string` | The name displayed on fallback pages. | `issuer` | | `icon` | `Url` OR `MxcUri` | The icon displayed on fallback pages. | N/A | | `scopes` | `array` | The scopes used to obtain extra claims which can be used for templates. | `["openid"]` | - - - - | `client_id`* | `string` | The provider-supplied, unique ID for the client. | N/A | | `client_secret`* | `string` | The provider-supplied, unique ID for the client. | N/A | -| `authentication_method`* | `"basic" | "post"` | The method used for client authentication. | N/A | - - - - - - - - - - - - - - - +| `authentication_method`* | `"basic" OR "post"` | The method used for client authentication. | N/A | diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index f688ff68..47ccdc83 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -322,8 +322,6 @@ pub async fn change_password_route( .ok_or_else(|| Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))?; let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - // if services().users.password_hash(sender_user)? == Some(""); - let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { stages: vec![AuthType::Password], diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 5dcea4fa..05110248 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -100,6 +100,12 @@ pub async fn upload_signing_keys_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let master_key = services() + .users + .get_master_key(Some(sender_user), sender_user, &|other| { + sender_user == other + })?; + // UIAA let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -111,11 +117,15 @@ pub async fn upload_signing_keys_route( auth_error: None, }; - let master_key = services() - .users - .get_master_key(None, sender_user, &|user_id| user_id == sender_user)?; - - if let Some(auth) = &body.auth { + if let (Some(master_key), None) = (&body.master_key, master_key) { + services().users.add_cross_signing_keys( + sender_user, + master_key, + &body.self_signing_key, + &body.user_signing_key, + true, + )?; + } else if let Some(auth) = &body.auth { let (worked, uiaainfo) = services() .uiaa @@ -130,20 +140,10 @@ pub async fn upload_signing_keys_route( .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); - } else if master_key.is_some() { + } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - if let Some(master_key) = &body.master_key { - services().users.add_cross_signing_keys( - sender_user, - master_key, - &body.self_signing_key, - &body.user_signing_key, - true, // notify so that other users see the new keys - )?; - } - Ok(upload_signing_keys::v3::Response {}) } diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 148c67f5..0c1189ae 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -113,20 +113,24 @@ pub async fn login_route(body: Ruma) -> Result { match ( services().globals.jwt_decoding_key(), - services().sso.login_type().next().is_some(), + services().globals.config.idps.is_empty(), ) { (_, false) => { - let mut validation = Validation::new(Algorithm::HS256); - validation.validate_nbf = false; - validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); + let mut v = Validation::new(Algorithm::HS256); + + v.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); + v.validate_aud = false; + v.validate_nbf = false; services() .globals - .validate_claims::(token, Some(validation)) - .as_ref() + .validate_claims::(token, Some(&v)) .map(LoginToken::audience) - .map(ToOwned::to_owned) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid token."))? + .map_err(|e| { + tracing::warn!("Invalid token: {}", e); + + Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.") + })? } (Some(jwt_decoding_key), _) => { let token = jsonwebtoken::decode::( diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs index e0b7f9a2..a35439c9 100644 --- a/src/api/client_server/sso.rs +++ b/src/api/client_server/sso.rs @@ -7,15 +7,7 @@ use crate::{ }, services, utils, Error, Result, Ruma, }; -use axum::{ - response::{AppendHeaders, IntoResponse, Redirect}, - RequestExt, -}; -use axum_extra::{ - headers::{self}, - TypedHeader, -}; -use http::header::{self}; +use futures_util::TryFutureExt; use mas_oidc_client::{ requests::{ authorization_code::{self, AuthorizationRequestData}, @@ -24,7 +16,6 @@ use mas_oidc_client::{ }, types::{ client_credentials::ClientCredentials, - errors::ClientError, iana::jose::JsonWebSignatureAlg, requests::{AccessTokenResponse, AuthorizationResponse}, }, @@ -33,6 +24,7 @@ use rand::{rngs::StdRng, Rng, SeedableRng}; use ruma::{ api::client::{ error::ErrorKind, + media::create_content, session::{sso_login, sso_login_with_provider}, }, events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, @@ -46,7 +38,7 @@ pub const CALLBACK_PATH: &str = "/_matrix/client/unstable/conduit/callback"; /// # `GET /_matrix/client/v3/login/sso/redirect` /// -/// Redirect the user to the SSO interface. +/// Redirect the user to the SSO interfa. /// TODO: this should be removed once Ruma supports trailing slashes. pub async fn get_sso_redirect_route( Ruma { @@ -148,37 +140,25 @@ pub async fn get_sso_redirect_with_provider_route( }) } -async fn handle_callback_helper(req: axum::extract::Request) -> Result { - let query = req.uri().query().ok_or_else(|| { - Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.") - })?; - - let AuthorizationResponse { - code, - access_token: _, - token_type: _, - id_token: _, - expires_in: _, - } = serde_html_form::from_str(query).map_err(|_| { - serde_html_form::from_str(query) - .map(ClientError::into) - .unwrap_or_else(|_| { - error!("Failed to deserialize authorization callback: {}", query); - - Error::BadRequest( - ErrorKind::Unknown, - "Failed to deserialize authorization callback.", - ) - }) - })?; - - let Ok(Some(cookie)): Result>, _> = req.extract().await - else { - return Err(Error::BadRequest( - ErrorKind::MissingParam, - "Missing session cookie.", - )); - }; +/// # `GET /_conduit/client/sso/callback` +/// +/// Validate the authorization response received from the identity provider. +/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. +/// If this is the first login, register the user, possibly interactively through a fallback page. +pub async fn handle_callback_route( + body: Ruma, +) -> Result { + let sso_callback::Request { + response: + AuthorizationResponse { + code, + access_token: _, + token_type: _, + id_token: _, + expires_in: _, + }, + cookie, + } = body.body; let ValidationData { provider, @@ -186,12 +166,7 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result Result Result Result s.to_owned(), Some(Value::Number(n)) => n.to_string(), @@ -299,8 +281,13 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result break user_id, + .map(|user_id| { + ( + user_id.clone(), + services().users.exists(&user_id).unwrap_or(true), + ) + }) { + Ok((user_id, false)) => break user_id, _ => { let n: u8 = rand::thread_rng().gen(); @@ -310,12 +297,15 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result Result Result Result axum::response::Response { - match handle_callback_helper(req).await { - Ok(res) => res, - Err(e) => e.into_response(), - } + Ok(sso_login_with_provider::v3::Response { + location: redirect_url.to_string(), + cookie: Some(utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string()), + }) } mod sso_callback { @@ -404,9 +406,9 @@ mod sso_callback { use mas_oidc_client::types::requests::AuthorizationResponse; use ruma::{ api::{ - client::Error, + client::{session::sso_login_with_provider, Error}, error::{FromHttpRequestError, HeaderDeserializationError}, - IncomingRequest, Metadata, OutgoingResponse, + IncomingRequest, Metadata, }, metadata, }; @@ -423,15 +425,13 @@ mod sso_callback { }; pub struct Request { - response: AuthorizationResponse, - cookie: String, + pub response: AuthorizationResponse, + pub cookie: String, } - pub struct Response {} - impl IncomingRequest for Request { type EndpointError = Error; - type OutgoingResponse = Response; + type OutgoingResponse = sso_login_with_provider::v3::Response; const METADATA: Metadata = METADATA; @@ -470,12 +470,4 @@ mod sso_callback { Ok(Self { response, cookie }) } } - - impl OutgoingResponse for Response { - fn try_into_http_response( - self, - ) -> Result, ruma::api::error::IntoHttpError> { - todo!() - } - } } diff --git a/src/main.rs b/src/main.rs index 15b59be4..34887460 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,10 +9,7 @@ use axum::{ Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; -use conduit::api::{ - client_server::{self, CALLBACK_PATH}, - server_server, -}; +use conduit::api::{client_server, server_server}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -283,10 +280,7 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::get_sso_redirect_with_provider_route) // The specification will likely never introduce any endpoint for handling authorization callbacks. // As a workaround, we use custom path that redirects the user to the default login handler. - .route( - &format!("/{CALLBACK_PATH}"), - get(client_server::handle_callback_route), - ) + .ruma_route(client_server::handle_callback_route) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 9a3c7d6a..2a4b76ff 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -18,7 +18,7 @@ use ruma::{ DeviceId, RoomVersionId, ServerName, UserId, }; use std::{ - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, HashMap, HashSet}, error::Error as StdError, fs, future::{self, Future}, @@ -522,7 +522,7 @@ impl Service { pub fn validate_claims( &self, token: &str, - validation_data: Option, + validation_data: Option<&jsonwebtoken::Validation>, ) -> jsonwebtoken::errors::Result { let key = jsonwebtoken::DecodingKey::from_secret( self.keypair().sign(PROBLEMATIC_CONST).as_bytes(), @@ -533,9 +533,9 @@ impl Service { // these validations are redundant as all JWTs are stored in cookies v.validate_exp = false; v.validate_nbf = false; - v.required_spec_claims = Default::default(); + v.required_spec_claims = HashSet::new(); - jsonwebtoken::decode::(token, &key, &validation_data.unwrap_or(v)) + jsonwebtoken::decode::(token, &key, validation_data.unwrap_or(&v)) .map(|data| data.claims) } diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs index 6ffef5ab..ac14edbf 100644 --- a/src/service/sso/mod.rs +++ b/src/service/sso/mod.rs @@ -11,6 +11,7 @@ use crate::{ utils, Error, Result, }; use futures_util::future::{self}; +use http::HeaderValue; use mas_oidc_client::{ http_service::HttpService, requests::{authorization_code::AuthorizationValidationData, discovery}, @@ -20,6 +21,7 @@ use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUse use serde::{Deserialize, Serialize}; use tokio::sync::OnceCell; use tower::BoxError; +use tower_http::{set_header::SetRequestHeaderLayer, ServiceBuilderExt}; use tracing::error; use url::Url; @@ -43,14 +45,17 @@ impl Service { pub fn build(db: &'static dyn Data) -> Result> { let client = tower::ServiceBuilder::new() .map_err(BoxError::from) + .layer(tower_http::timeout::TimeoutLayer::new( + std::time::Duration::from_secs(10), + )) .layer(mas_http::BytesToBodyRequestLayer) .layer(mas_http::BodyToBytesResponseLayer) - // .override_request_header(http::header::USER_AGENT, "conduit".to_owned()) - // .concurrency_limit(10) - // .follow_redirects() - // .layer(tower_http::timeout::TimeoutLayer::new( - // std::time::Duration::from_secs(10), - // )) + .layer(SetRequestHeaderLayer::overriding( + http::header::USER_AGENT, + HeaderValue::from_static("conduit/0.9-alpha"), + )) + .concurrency_limit(10) + .follow_redirects() .service(mas_http::make_untraced_client()); Ok(Arc::new(Self { @@ -157,8 +162,8 @@ impl LoginToken { .expect("time overflow"), } } - pub fn audience(&self) -> &UserId { - &self.aud + pub fn audience(self) -> OwnedUserId { + self.aud } }