mod data; use std::{ borrow::Borrow, collections::{HashMap, HashSet}, hash::{Hash, Hasher}, str::FromStr, sync::{Arc, RwLock}, time::{Duration, SystemTime, UNIX_EPOCH}, }; use crate::{ api::client_server::TOKEN_LENGTH, config::{sso::ProviderConfig as Config, IdpConfig}, utils, Error, Result, }; pub use data::Data; use email_address::EmailAddress; use futures_util::future::{self}; use mas_oidc_client::{ http_service::{hyper, HttpService}, jose::jwk::PublicJsonWebKeySet, requests::{ authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData}, discovery, jose::{self, JwtVerificationData}, userinfo, }, types::{ iana::jose::JsonWebSignatureAlg, oidc::VerifiedProviderMetadata, requests::AccessTokenResponse, IdToken, }, }; use rand::SeedableRng; use ruma::{api::client::error::ErrorKind, MilliSecondsSinceUnixEpoch, OwnedUserId, UserId}; use serde::{Deserialize, Serialize}; use serde_json::Value; use tokio::sync::{oneshot, OnceCell}; use tracing::error; use url::Url; use crate::services; pub use data::Data; pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60; pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2; pub const SSO_SESSION_COOKIE: &str = "sso-auth"; pub struct Service { db: &'static dyn Data, service: HttpService, providers: OnceCell>, } impl Service { pub fn build(db: &'static dyn Data) -> Result> { Ok(Arc::new(Self { db, service: HttpService::new(hyper::hyper_service()), providers: OnceCell::new(), })) } pub fn service(&self) -> &HttpService { &self.service } pub async fn start_handler(&self) -> Result<()> { let providers = services().globals.config.idps.iter(); self.providers .get_or_try_init(|| { future::try_join_all(providers.map(Provider::fetch_metadata)) .await .map(Vec::into_iter) .map(HashSet::from_iter) }) .await?; Ok(()) } pub fn get(&self, provider: &str) -> Option<&Provider> { let providers = self.providers.get().expect(""); providers.get(provider) } pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result> { self.db.user_from_subject(provider, subject) } } #[derive(Clone, Debug)] pub struct Provider { pub config: &'static IdpConfig, pub metadata: VerifiedProviderMetadata, } impl Provider { pub async fn fetch_metadata(config: &'static IdpConfig) -> Result { discovery::discover(services().sso.service(), &config.issuer) .await .map(|metadata| Provider { config, metadata }) .map_err(|e| { error!( "Failed to fetch identity provider metadata ({}): {}", &config.inner.id, e ); Error::bad_config("Failed to fetch identity provider metadata.") }) } async fn fetch_signing_keys(&self) -> Result { jose::fetch_jwks(&services().sso.service, self.metadata.jwks_uri()) .await .map_err(|e| { error!("Failed to fetch signing keys for token endpoint: {}", e); Error::bad_config("Failed to fetch signing keys for token endpoint.") }) } pub async fn fetch_access_token( &self, auth_code: String, validation_data: AuthorizationValidationData, ) -> Result<(AccessTokenResponse, Option>)> { } pub async fn fetch_userinfo( &self, access_token: &str, id_token: &IdToken<'_>, ) -> Result>> { } } impl Borrow for Provider { fn borrow(&self) -> &str { self.config.borrow() } } impl PartialEq for Provider { fn eq(&self, other: &Self) -> bool { self.config == other.config } } impl Eq for Provider {} impl Hash for Provider { fn hash(&self, hasher: &mut H) { self.config.hash(hasher) } } #[derive(Clone, Debug, Deserialize, Serialize)] pub struct RegistrationToken { pub info: RegistrationInfo, pub provider_id: String, pub unique_claim: String, pub redirect_uri: Url, pub expires_at: MilliSecondsSinceUnixEpoch, } impl RegistrationToken { pub fn new( provider_id: String, unique_claim: String, redirect_uri: Url, info: RegistrationInfo, ) -> Self { let expires_at = MilliSecondsSinceUnixEpoch::from_system_time( UNIX_EPOCH .checked_add(Duration::from_secs(REGISTRATION_EXPIRATION_SECS)) .expect("SystemTime should not overflow"), ) .expect("MilliSecondsSinceUnixEpoch is not too large"); Self { info, provider_id, unique_claim, redirect_uri, expires_at, } } } #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct RegistrationInfo { pub username: String, pub displayname: Option, pub avatar_url: Option, pub email: Option, } impl RegistrationInfo { pub fn new( claims: &HashMap, username: &str, displayname: &str, avatar_url: &str, email: &str, ) -> Self { Self { username: claims .get(username) .and_then(|v| v.as_str()) .map(ToOwned::to_owned) .unwrap_or_default(), displayname: claims .get(displayname) .and_then(|v| v.as_str()) .map(ToOwned::to_owned), avatar_url: claims .get(avatar_url) .and_then(|v| v.as_str()) .map(Url::parse) .and_then(Result::ok), email: claims .get(email) .and_then(|v| v.as_str()) .map(EmailAddress::from_str) .and_then(Result::ok), } } } #[derive(Clone, Deserialize, Serialize)] pub struct LoginToken { pub inner: String, pub provider_id: String, pub user_id: OwnedUserId, #[serde(rename = "exp")] expires_at: u64, } impl LoginToken { pub fn new(provider_id: String, user_id: OwnedUserId) -> Self { let expires_at = SystemTime::now() .checked_add(Duration::from_secs(LOGIN_TOKEN_EXPIRATION_SECS)) .expect("SystemTime should not overflow") .duration_since(UNIX_EPOCH) .expect("SystemTime went backwards") .as_secs(); Self { inner: utils::random_string(TOKEN_LENGTH), provider_id, user_id, expires_at, } } } #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ValidationData { pub provider: String, #[serde(flatten, with = "AuthorizationValidationDataDef")] pub inner: AuthorizationValidationData, } impl ValidationData { pub fn new(provider: String, inner: AuthorizationValidationData) -> Self { Self { provider, inner } } } #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(remote = "AuthorizationValidationData")] pub struct AuthorizationValidationDataDef { pub state: String, pub nonce: String, pub redirect_uri: Url, pub code_challenge_verifier: Option, } impl From for AuthorizationValidationDataDef { fn from( AuthorizationValidationData { state, nonce, redirect_uri, code_challenge_verifier, }: AuthorizationValidationData, ) -> Self { Self { state, nonce, redirect_uri, code_challenge_verifier, } } }