diff --git a/Cargo.toml b/Cargo.toml index 67128f07..bc079115 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,15 +35,18 @@ axum = { version = "0.7", default-features = false, features = [ "json", "matched-path", ], optional = true } -axum-extra = { version = "0.9", features = ["typed-header"] } +axum-extra = { version = "0.9", features = ["cookie", "typed-header"] } axum-server = { version = "0.6", features = ["tls-rustls"] } tower = { version = "0.4.13", features = ["util"] } +# tower-http = { version = "0.5", features = [ +# "add-extension", +# "cors", +# "sensitive-headers", +# "trace", +# "util", +# ] } tower-http = { version = "0.5", features = [ - "add-extension", - "cors", - "sensitive-headers", - "trace", - "util", + "full", ] } tower-service = "0.3" @@ -140,20 +143,28 @@ figment = { version = "0.10.8", features = ["env", "toml"] } # Validating urls in config url = { version = "2", features = ["serde"] } +# HTML +maud = { version = "0.26.0", default-features = false, features = ["axum"] } + async-trait = "0.1.68" tikv-jemallocator = { version = "0.5.0", features = [ "unprefixed_malloc_on_supported_platforms", ], optional = true } sd-notify = { version = "0.4.1", optional = true } +http-body-util = "0.1.2" +hyper-rustls = { version = "0.27.2", default-features = false, features = ["http1", "http2", "ring", "rustls-native-certs", "rustls-platform-verifier"] } +mas-http = "0.9.0" # Used for matrix spec type definitions and helpers [dependencies.ruma] features = [ "appservice-api-c", + "client", "client-api", "compat", "federation-api", + "client-hyper", "push-gateway-api-c", "rand", "ring-compat", @@ -172,6 +183,11 @@ optional = true package = "rust-rocksdb" version = "0.25" +# Used for Single Sign-On +[dependencies.mas-oidc-client] +git = "https://github.com/matrix-org/matrix-authentication-service.git" +default-features = false + [target.'cfg(unix)'.dependencies] nix = { version = "0.28", features = ["resource"] } diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 47ccdc83..f688ff68 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -322,6 +322,8 @@ pub async fn change_password_route( .ok_or_else(|| Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))?; let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + // if services().users.password_hash(sender_user)? == Some(""); + let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { stages: vec![AuthType::Password], diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 4af8890d..5dcea4fa 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -111,6 +111,10 @@ pub async fn upload_signing_keys_route( auth_error: None, }; + let master_key = services() + .users + .get_master_key(None, sender_user, &|user_id| user_id == sender_user)?; + if let Some(auth) = &body.auth { let (worked, uiaainfo) = services() @@ -126,7 +130,7 @@ pub async fn upload_signing_keys_route( .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); - } else { + } else if master_key.is_some() { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index a35d7a98..07ee7a17 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -23,6 +23,7 @@ mod room; mod search; mod session; mod space; +mod sso; mod state; mod sync; mod tag; @@ -60,6 +61,7 @@ pub use room::*; pub use search::*; pub use session::*; pub use space::*; +pub use sso::*; pub use state::*; pub use sync::*; pub use tag::*; @@ -76,3 +78,5 @@ pub const DEVICE_ID_LENGTH: usize = 10; pub const TOKEN_LENGTH: usize = 32; pub const SESSION_ID_LENGTH: usize = 32; pub const AUTO_GEN_PASSWORD_LENGTH: usize = 15; +pub const AUTH_SESSION_EXPIRATION_SECS: u64 = 60 * 5; +pub const LOGIN_TOKEN_EXPIRATION_SECS: u64 = 15; diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 07078328..aea8d3ac 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,5 +1,5 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{services, utils, Error, Result, Ruma}; +use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,10 +24,20 @@ struct Claims { pub async fn get_login_types_route( _body: Ruma, ) -> Result { - Ok(get_login_types::v3::Response::new(vec![ + let mut flows = vec![ get_login_types::v3::LoginType::Password(Default::default()), get_login_types::v3::LoginType::ApplicationService(Default::default()), - ])) + ]; + + if let v @ [_, ..] = &*services().sso.flows() { + let flow = get_login_types::v3::SsoLoginType { + identity_providers: v.to_owned(), + }; + + flows.push(get_login_types::v3::LoginType::Sso(flow)); + } + + Ok(get_login_types::v3::Response::new(flows)) } /// # `POST /_matrix/client/r0/login` @@ -101,35 +111,64 @@ pub async fn login_route(body: Ruma) -> Result { - if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { - let token = jsonwebtoken::decode::( - token, - jwt_decoding_key, - &jsonwebtoken::Validation::default(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; - let username = token.claims.sub.to_lowercase(); - let user_id = - UserId::parse_with_server_name(username, services().globals.server_name()) + match ( + services().globals.jwt_decoding_key(), + &services().sso.providers().is_empty(), + ) { + (_, false) => { + let mut validation = + jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); + validation.validate_nbf = false; + validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); + + let login_token = services() + .globals + .validate_claims::(token, Some(validation)) .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") + Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.") })?; - if services().appservice.is_exclusive_user_id(&user_id).await { + login_token.audience().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid token audience.") + })? + } + (Some(jwt_decoding_key), _) => { + let token = jsonwebtoken::decode::( + token, + jwt_decoding_key, + &jsonwebtoken::Validation::default(), + ) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") + })?; + let username = token.claims.sub.to_lowercase(); + let user_id = + UserId::parse_with_server_name(username, services().globals.server_name()) + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ) + })?; + + if services().appservice.is_exclusive_user_id(&user_id).await { + return Err(Error::BadRequest( + ErrorKind::Exclusive, + "User id reserved by appservice.", + )); + } + + user_id + } + (None, _) => { return Err(Error::BadRequest( - ErrorKind::Exclusive, - "User id reserved by appservice.", + ErrorKind::Unknown, + "Token login is not supported (server has no jwt decoding key).", )); } - - user_id - } else { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Token login is not supported (server has no jwt decoding key).", - )); } } + login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService { identifier, user, diff --git a/src/config/mod.rs b/src/config/mod.rs index 378ab929..93198cf3 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,16 +1,27 @@ use std::{ - collections::BTreeMap, + borrow::Borrow, + collections::{BTreeMap, HashSet}, fmt, + hash::{Hash, Hasher}, net::{IpAddr, Ipv4Addr}, }; -use ruma::{OwnedServerName, RoomVersionId}; -use serde::{de::IgnoredAny, Deserialize}; +use figment::value::{Dict, Value}; +use mas_oidc_client::types::{client_credentials::ClientCredentials, scope::Scope}; +use ruma::{ + api::client::session::get_login_types::v3::IdentityProvider, OwnedServerName, RoomVersionId, +}; +use serde::{ + de::{self, IgnoredAny}, + Deserialize, Deserializer, Serialize, +}; use tracing::warn; use url::Url; mod proxy; +use crate::{Error, Result}; + use self::proxy::ProxyConfig; #[derive(Clone, Debug, Deserialize)] @@ -67,6 +78,8 @@ pub struct Config { pub tracing_flame: bool, #[serde(default)] pub proxy: ProxyConfig, + #[serde(default, deserialize_with = "deserialize_providers")] + pub idps: HashSet, pub jwt_secret: Option, #[serde(default = "default_trusted_servers")] pub trusted_servers: Vec, @@ -101,6 +114,27 @@ pub struct WellKnownConfig { pub server: Option, } +#[derive(Clone, Debug, Deserialize)] +pub struct IdpConfig { + pub issuer: String, + #[serde(flatten)] + pub inner: IdentityProvider, + #[serde(deserialize_with = "deserialize_scopes")] + pub scopes: Scope, + + pub client_id: String, + pub client_secret: String, + pub auth_method: String, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct Template { + pub localpart: Option, + pub displayname: Option, + pub avatar_url: Option, + pub email: Option, +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { @@ -244,6 +278,49 @@ impl fmt::Display for Config { } } +impl Borrow for IdpConfig { + fn borrow(&self) -> &str { + &self.inner.id + } +} + +impl PartialEq for IdpConfig { + fn eq(&self, other: &Self) -> bool { + self.inner.id == other.inner.id + } +} + +impl Eq for IdpConfig {} + +impl Hash for IdpConfig { + fn hash(&self, hasher: &mut H) { + self.inner.id.hash(hasher) + } +} + +impl Into for IdpConfig { + fn into(self) -> ClientCredentials { + let IdpConfig { + client_id, + client_secret, + auth_method, + .. + } = self; + + match &*auth_method { + "basic" => ClientCredentials::ClientSecretBasic { + client_id, + client_secret, + }, + "post" => ClientCredentials::ClientSecretPost { + client_id, + client_secret, + }, + _ => unimplemented!(), + } + } +} + fn false_fn() -> bool { false } @@ -312,3 +389,46 @@ fn default_openid_token_ttl() -> u64 { pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } + +fn deserialize_scopes<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let scopes = >::deserialize(deserializer)?; + + scopes.join(" ").parse().map_err(de::Error::custom) +} + +fn deserialize_providers<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let mut result = HashSet::new(); + let dict = Dict::deserialize(deserializer) + .map(Dict::into_iter) + .map_err(de::Error::custom)?; + warn!(?dict); + + for (name, value) in dict { + let tag = value.tag(); + + let Some(dict) = value.into_dict() else { + return Err(de::Error::custom(Error::bad_config( + "Invalid SSO configuration. ", + ))); + }; + + let id = String::from("id"); + let name = name.parse().map_err(de::Error::custom)?; + + let dict = Some((id, name)).into_iter().chain(dict).collect(); + + result.insert( + Value::Dict(tag, dict) + .deserialize() + .map_err(de::Error::custom)?, + ); + } + + Ok(result) +} diff --git a/src/database/key_value/mod.rs b/src/database/key_value/mod.rs index c4496af8..5027c367 100644 --- a/src/database/key_value/mod.rs +++ b/src/database/key_value/mod.rs @@ -8,6 +8,7 @@ mod media; mod pusher; mod rooms; mod sending; +mod sso; mod transaction_ids; mod uiaa; mod users; diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 63321a40..fca0328c 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -119,6 +119,10 @@ impl service::users::Data for KeyValueDatabase { } } + fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> { + self.userid_password.insert(user_id.as_bytes(), b"0xff") + } + /// Returns the displayname of a user on this homeserver. fn displayname(&self, user_id: &UserId) -> Result> { self.userid_displayname diff --git a/src/database/mod.rs b/src/database/mod.rs index 5171d4bb..16a5e60a 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -63,6 +63,9 @@ pub struct KeyValueDatabase { pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count + pub(super) userid_providersubjectid: Arc, + pub(super) providersubjectid_userid: Arc, + //pub uiaa: uiaa::Uiaa, pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication pub(super) userdevicesessionid_uiaarequest: @@ -297,6 +300,9 @@ impl KeyValueDatabase { userfilterid_filter: builder.open_tree("userfilterid_filter")?, todeviceid_events: builder.open_tree("todeviceid_events")?, + userid_providersubjectid: builder.open_tree("userid_providersubjectid")?, + providersubjectid_userid: builder.open_tree("providersubjectid_userid")?, + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, @@ -971,6 +977,8 @@ impl KeyValueDatabase { services().admin.start_handler(); + services().sso.start_handler().await?; + // Set emergency access for the conduit user match set_emergency_access() { Ok(pwd_set) => { diff --git a/src/main.rs b/src/main.rs index 8d242c53..0b07fe2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,10 @@ use axum::{ Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; -use conduit::api::{client_server, server_server}; +use conduit::api::{ + client_server::{self, SSO_CALLBACK_PATH}, + server_server, +}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -276,6 +279,11 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::third_party_route) .ruma_route(client_server::request_3pid_management_token_via_email_route) .ruma_route(client_server::request_3pid_management_token_via_msisdn_route) + .ruma_route(client_server::get_sso_redirect_route) + .ruma_route(client_server::get_sso_redirect_with_provider_route) + // The specification will likely never introduce any endpoint for handling authorization callbacks. + // As a workaround, we use custom path that redirects the user to the default login handler. + .route(SSO_CALLBACK_PATH, get(client_server::sso_login_route)) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 3325e518..9a3c7d6a 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,9 +1,10 @@ mod data; pub use data::{Data, SigningKeys}; use ruma::{ - serde::Base64, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, OwnedRoomAliasId, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, + serde::Base64, signatures::KeyPair, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, + OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, }; +use serde::{de::DeserializeOwned, Serialize}; use crate::api::server_server::DestinationResponse; @@ -37,6 +38,9 @@ use tracing::{error, info}; use base64::{engine::general_purpose, Engine as _}; +// https://github.com/rust-lang/rust/issues/104699 +const PROBLEMATIC_CONST: &[u8] = b"0xCAFEBABE"; + type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries @@ -505,6 +509,36 @@ impl Service { self.config.well_known_client() } + pub fn sign_claims(&self, claims: &S) -> String { + let key = jsonwebtoken::EncodingKey::from_secret( + self.keypair().sign(PROBLEMATIC_CONST).as_bytes(), + ); + + jsonwebtoken::encode(&jsonwebtoken::Header::default(), claims, &key) + .expect("signing JWTs always works") + } + + /// Decode and validate a macaroon with this server's macaroon key. + pub fn validate_claims( + &self, + token: &str, + validation_data: Option, + ) -> jsonwebtoken::errors::Result { + let key = jsonwebtoken::DecodingKey::from_secret( + self.keypair().sign(PROBLEMATIC_CONST).as_bytes(), + ); + + let mut v = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); + + // these validations are redundant as all JWTs are stored in cookies + v.validate_exp = false; + v.validate_nbf = false; + v.required_spec_claims = Default::default(); + + jsonwebtoken::decode::(token, &key, &validation_data.unwrap_or(v)) + .map(|data| data.claims) + } + pub fn shutdown(&self) { self.shutdown.store(true, atomic::Ordering::Relaxed); // On shutdown diff --git a/src/service/mod.rs b/src/service/mod.rs index 4c11bc18..fae5a726 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -19,6 +19,7 @@ pub mod pdu; pub mod pusher; pub mod rooms; pub mod sending; +pub mod sso; pub mod transaction_ids; pub mod uiaa; pub mod users; @@ -35,6 +36,7 @@ pub struct Services { pub globals: globals::Service, pub key_backups: key_backups::Service, pub media: media::Service, + pub sso: Arc, pub sending: Arc, } @@ -51,6 +53,7 @@ impl Services { + key_backups::Data + media::Data + sending::Data + + sso::Data + 'static, >( db: &'static D, @@ -120,6 +123,7 @@ impl Services { key_backups: key_backups::Service { db }, media: media::Service { db }, sending: sending::Service::build(db, &config), + sso: sso::Service::build(db), globals: globals::Service::load(db, config)?, }) diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 696be958..677d49f0 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -110,6 +110,9 @@ impl Service { AuthData::Dummy(_) => { uiaainfo.completed.push(AuthType::Dummy); } + AuthData::FallbackAcknowledgement(fallback) => { + todo!() + } k => error!("type not supported: {:?}", k), } diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 4566c36d..eff94b6f 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -217,4 +217,6 @@ pub trait Data: Send + Sync { /// Find out which user an OpenID access token belongs to. fn find_from_openid_token(&self, token: &str) -> Result>; + + fn set_placeholder_password(&self, user_id: &UserId) -> Result<()>; } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index a5694a10..15756ff4 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -602,6 +602,10 @@ impl Service { pub fn find_from_openid_token(&self, token: &str) -> Result> { self.db.find_from_openid_token(token) } + + pub fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> { + self.db.set_placeholder_password(user_id) + } } /// Ensure that a user only sees signatures from themselves and the target user diff --git a/src/utils/error.rs b/src/utils/error.rs index 1d811106..30568001 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -175,6 +175,22 @@ impl Error { } } +impl From for Error { + fn from(e: mas_oidc_client::types::errors::ClientError) -> Self { + error!( + "Failed to complete authorization callback: {} {}", + e.error, + e.error_description.as_deref().unwrap_or_default() + ); + + // TODO: error conversion + Self::BadRequest( + ErrorKind::Unknown, + "Failed to complete authorization callback.", + ) + } +} + #[cfg(feature = "persy")] impl> From> for Error { fn from(err: persy::PE) -> Self { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d09a1033..4ff6fd6f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,7 @@ pub mod error; use argon2::{Config, Variant}; +use axum_extra::extract::cookie::{Cookie, SameSite}; use cmp::Ordering; use rand::prelude::*; use ring::digest; @@ -8,7 +9,7 @@ use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonO use std::{ cmp, fmt, str::FromStr, - time::{SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime, UNIX_EPOCH}, }; pub fn millis_since_unix_epoch() -> u64 { @@ -142,6 +143,29 @@ pub fn deserialize_from_str< deserializer.deserialize_str(Visitor(std::marker::PhantomData)) } +pub fn build_cookie<'c>( + name: &'c str, + value: &'c str, + path: &'c str, + max_age: Option, +) -> Cookie<'c> { + let mut cookie = Cookie::new(name, value); + + cookie.set_path(path); + cookie.set_secure(true); + cookie.set_http_only(true); + cookie.set_same_site(SameSite::None); + cookie.set_max_age( + max_age + .map(Duration::from_secs) + .map(TryInto::try_into) + .transpose() + .expect("time overflow"), + ); + + cookie +} + // Copied from librustdoc: // https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs