use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime}; use crate::{ config::IdpConfig, service::sso::{ LoginToken, ValidationData, SSO_AUTH_EXPIRATION_SECS, SSO_SESSION_COOKIE, SUBJECT_CLAIM_KEY, }, services, utils, Error, Result, Ruma, }; use axum::{ response::{AppendHeaders, IntoResponse, Redirect}, RequestExt, }; use axum_extra::{ headers::{self}, TypedHeader, }; use http::header; use mas_oidc_client::{ requests::{ authorization_code::{self, AuthorizationRequestData}, jose::{self, JwtVerificationData}, userinfo, }, types::{ client_credentials::ClientCredentials, errors::ClientError, iana::jose::JsonWebSignatureAlg, requests::{AccessTokenResponse, AuthorizationResponse}, }, }; use rand::{rngs::StdRng, Rng, SeedableRng}; use ruma::{ api::client::{ error::ErrorKind, session::{sso_login, sso_login_with_provider}, }, events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, push, UserId, }; use serde_json::Value; use tracing::{error, info, warn}; use url::Url; pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; /// # `GET /_matrix/client/v3/login/sso/redirect` /// /// Redirect the user to the SSO interface. /// TODO: this should be removed once Ruma supports trailing slashes. pub async fn get_sso_redirect_route( Ruma { body, sender_user, sender_device, sender_servername, json_body, .. }: Ruma, ) -> Result { let sso_login_with_provider::v3::Response { location, cookie } = get_sso_redirect_with_provider_route( Ruma { body: sso_login_with_provider::v3::Request::new( Default::default(), body.redirect_url, ), sender_user, sender_device, sender_servername, json_body, appservice_info: None, } .into(), ) .await?; Ok(sso_login::v3::Response { location, cookie }) } /// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}` /// /// Redirects the user to the SSO interface. pub async fn get_sso_redirect_with_provider_route( body: Ruma, ) -> Result { let idp_ids: Vec<&str> = services() .globals .config .idps .iter() .map(Borrow::borrow) .collect(); let provider = match &*idp_ids { [] => { return Err(Error::BadRequest( ErrorKind::forbidden(), "Single Sign-On is disabled.", )); } [idp_id] => services().sso.get(idp_id).expect("we know it exists"), [_, ..] => services().sso.get(&body.idp_id).ok_or_else(|| { Error::BadRequest(ErrorKind::InvalidParam, "Unknown identity provider.") })?, }; let redirect_url = body .redirect_url .parse::() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid redirect_url."))?; let mut callback = services() .globals .well_known_client() .parse::() .map_err(|_| Error::bad_config("Invalid well_known_client url."))?; callback.set_path(CALLBACK_PATH); let (auth_url, validation_data) = authorization_code::build_authorization_url( provider.metadata.authorization_endpoint().clone(), AuthorizationRequestData::new( provider.config.client_id.clone(), provider.config.scopes.clone(), callback, ), &mut StdRng::from_entropy(), ) .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?; let signed = services().globals.sign_claims(&ValidationData::new( Borrow::::borrow(provider).to_owned(), redirect_url.to_string(), validation_data, )); Ok(sso_login_with_provider::v3::Response { location: auth_url.to_string(), cookie: Some( utils::build_cookie( SSO_SESSION_COOKIE, &signed, CALLBACK_PATH, Some(SSO_AUTH_EXPIRATION_SECS), ) .to_string(), ), }) } async fn handle_callback_helper(req: axum::extract::Request) -> Result { let query = req.uri().query().ok_or_else(|| { Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.") })?; let AuthorizationResponse { code, access_token, token_type, id_token, expires_in, } = serde_html_form::from_str(query).map_err(|_| { serde_html_form::from_str(query) .map(ClientError::into) .unwrap_or_else(|_| { error!("Failed to deserialize authorization callback: {}", query); Error::BadRequest( ErrorKind::Unknown, "Failed to deserialize authorization callback.", ) }) })?; let Ok(Some(cookie)): Result>, _> = req.extract().await else { return Err(Error::BadRequest( ErrorKind::MissingParam, "Missing session cookie.", )); }; let ValidationData { provider, redirect_url, inner: validation_data, } = services() .globals .validate_claims( cookie.get(SSO_SESSION_COOKIE).ok_or_else(|| { Error::BadRequest(ErrorKind::MissingParam, "Missing value for session cookie.") })?, None, ) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.") })?; let provider = services().sso.get(&provider).ok_or_else(|| { Error::BadRequest( ErrorKind::InvalidParam, "Unknown provider for session cookie.", ) })?; let IdpConfig { client_id, client_secret, auth_method, .. } = provider.config.clone(); let credentials = match &*auth_method { "basic" => ClientCredentials::ClientSecretBasic { client_id, client_secret, }, "post" => ClientCredentials::ClientSecretPost { client_id, client_secret, }, _ => todo!(), }; let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri()) .await .map_err(|_| Error::bad_config("Failed to fetch signing keys for token endpoint."))?; let jwt_verification_data = Some(JwtVerificationData { jwks, issuer: &provider.config.issuer, client_id: &provider.config.client_id, signing_algorithm: &JsonWebSignatureAlg::Rs256, }); let ( AccessTokenResponse { access_token, refresh_token, token_type, expires_in, scope, .. }, Some(id_token), ) = authorization_code::access_token_with_authorization_code( services().sso.service(), credentials, provider.metadata.token_endpoint(), code.unwrap_or_default(), validation_data, jwt_verification_data, SystemTime::now().into(), &mut StdRng::from_entropy(), ) .await .map_err(|_| Error::bad_config("Failed to fetch access token."))? else { unreachable!("ID token should never be empty") }; let mut userinfo = HashMap::default(); if let Some(endpoint) = &provider.metadata.userinfo_endpoint { userinfo = userinfo::fetch_userinfo( services().sso.service(), endpoint, &access_token, jwt_verification_data, &id_token, ) .await .map_err(|_| Error::bad_config("Failed to fetch claims for userinfo endpoint."))?; }; let (_, id_token) = id_token.into_parts(); let subject = match id_token.get(SUBJECT_CLAIM_KEY) { Some(Value::String(s)) => s.to_owned(), Some(Value::Number(n)) => n.to_string(), value => { return Err(Error::BadRequest( ErrorKind::Unknown, value .map(|_| { error!("Subject claim is missing from ID token: {id_token:?}"); "Subject claim is missing from ID token." }) .unwrap_or("Subject claim should be a string or number."), )); } }; let user_id = match services() .sso .user_from_subject(Borrow::::borrow(provider), &subject)? { Some(user_id) => user_id, None => { let mut localpart = subject.clone(); let user_id = loop { match UserId::parse_with_server_name(&*localpart, services().globals.server_name()) { Ok(user_id) if services().users.exists(&user_id)? => break user_id, _ => { let n: u8 = rand::thread_rng().gen(); localpart = format!("{}{}", localpart, n % 10); } } }; services().users.set_placeholder_password(&user_id)?; let mut displayname = id_token .get("preferred_username") .or(id_token.get("nickname")) .as_deref() .map(Value::to_string) .unwrap_or(user_id.localpart().to_owned()); // If enabled append lightning bolt to display name (default true) if services().globals.enable_lightning_bolt() { displayname.push_str(" ⚡️"); } services() .users .set_displayname(&user_id, Some(displayname.clone()))?; // Initial account data services().account_data.update( None, &user_id, GlobalAccountDataEventType::PushRules.to_string().into(), &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { content: ruma::events::push_rules::PushRulesEventContent { global: push::Ruleset::server_default(&user_id), }, }) .expect("to json always works"), )?; info!("New user {} registered on this server.", user_id); services() .admin .send_message(RoomMessageEventContent::notice_plain(format!( "New user {user_id} registered on this server." ))); if let Some(admin_room) = services().admin.get_admin_room()? { if services() .rooms .state_cache .room_joined_count(&admin_room)? == Some(1) { services() .admin .make_user_admin(&user_id, displayname) .await?; warn!("Granting {} admin privileges as the first user", user_id); } } user_id } }; let signed = services().globals.sign_claims(&LoginToken::new( Borrow::::borrow(provider).to_owned(), user_id, )); let mut redirect_url: Url = redirect_url.parse().expect(""); redirect_url .query_pairs_mut() .append_pair("loginToken", &signed); Ok(( AppendHeaders(vec![( header::SET_COOKIE, utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string(), )]), Redirect::temporary(redirect_url.as_str()), ) .into_response()) } /// # `GET /_conduit/client/sso/callback` /// /// Validate the authorization response received from the identity provider. /// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect. /// If this is the first login, register the user, possibly interactively through a fallback page. pub async fn handle_callback_route(req: axum::extract::Request) -> axum::response::Response { match handle_callback_helper(req).await { Ok(res) => res, Err(e) => e.into_response(), } }