1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00
This commit is contained in:
avdb13 2024-07-11 22:24:22 +02:00
parent 10ce7ea3a9
commit b80141b33b
8 changed files with 255 additions and 599 deletions

View file

@ -35,18 +35,15 @@ axum = { version = "0.7", default-features = false, features = [
"json", "json",
"matched-path", "matched-path",
], optional = true } ], optional = true }
axum-extra = { version = "0.9", features = ["cookie", "typed-header"] } axum-extra = { version = "0.9", features = ["typed-header", "cookie"] }
axum-server = { version = "0.6", features = ["tls-rustls"] } axum-server = { version = "0.6", features = ["tls-rustls"] }
tower = { version = "0.4.13", features = ["util"] } tower = { version = "0.4.13", features = ["util"] }
# tower-http = { version = "0.5", features = [
# "add-extension",
# "cors",
# "sensitive-headers",
# "trace",
# "util",
# ] }
tower-http = { version = "0.5", features = [ tower-http = { version = "0.5", features = [
"full", "add-extension",
"cors",
"sensitive-headers",
"trace",
"util",
] } ] }
tower-service = "0.3" tower-service = "0.3"
@ -143,6 +140,7 @@ figment = { version = "0.10.8", features = ["env", "toml"] }
# Validating urls in config # Validating urls in config
url = { version = "2", features = ["serde"] } url = { version = "2", features = ["serde"] }
mas-oidc-client = { version = "0.9", default-features = false, features = ["hyper"] }
# HTML # HTML
maud = { version = "0.26.0", default-features = false, features = ["axum"] } maud = { version = "0.26.0", default-features = false, features = ["axum"] }
@ -152,19 +150,14 @@ tikv-jemallocator = { version = "0.5.0", features = [
], optional = true } ], optional = true }
sd-notify = { version = "0.4.1", optional = true } sd-notify = { version = "0.4.1", optional = true }
http-body-util = "0.1.2"
hyper-rustls = { version = "0.27.2", default-features = false, features = ["http1", "http2", "ring", "rustls-native-certs", "rustls-platform-verifier"] }
mas-http = "0.9.0"
# Used for matrix spec type definitions and helpers # Used for matrix spec type definitions and helpers
[dependencies.ruma] [dependencies.ruma]
features = [ features = [
"appservice-api-c", "appservice-api-c",
"client",
"client-api", "client-api",
"compat", "compat",
"federation-api", "federation-api",
"client-hyper",
"push-gateway-api-c", "push-gateway-api-c",
"rand", "rand",
"ring-compat", "ring-compat",
@ -183,16 +176,11 @@ optional = true
package = "rust-rocksdb" package = "rust-rocksdb"
version = "0.25" version = "0.25"
# Used for Single Sign-On
[dependencies.mas-oidc-client]
git = "https://github.com/matrix-org/matrix-authentication-service.git"
default-features = false
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
nix = { version = "0.28", features = ["resource"] } nix = { version = "0.28", features = ["resource"] }
[features] [features]
default = ["backend_rocksdb", "backend_sqlite", "conduit_bin", "systemd"] default = ["backend_sqlite", "conduit_bin"]
#backend_sled = ["sled"] #backend_sled = ["sled"]
backend_persy = ["parking_lot", "persy"] backend_persy = ["parking_lot", "persy"]
backend_sqlite = ["sqlite"] backend_sqlite = ["sqlite"]

View file

@ -1,5 +1,6 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma}; use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma};
use jsonwebtoken::{Algorithm, Validation};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -24,17 +25,16 @@ struct Claims {
pub async fn get_login_types_route( pub async fn get_login_types_route(
_body: Ruma<get_login_types::v3::Request>, _body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> { ) -> Result<get_login_types::v3::Response> {
let identity_providers: Vec<_> = services().sso.login_type().collect();
let mut flows = vec![ let mut flows = vec![
get_login_types::v3::LoginType::Password(Default::default()), get_login_types::v3::LoginType::Password(Default::default()),
get_login_types::v3::LoginType::ApplicationService(Default::default()), get_login_types::v3::LoginType::ApplicationService(Default::default()),
]; ];
if let v @ [_, ..] = &*services().sso.flows() { if !identity_providers.is_empty() {
let flow = get_login_types::v3::SsoLoginType { flows.push(get_login_types::v3::LoginType::Sso(
identity_providers: v.to_owned(), get_login_types::v3::SsoLoginType { identity_providers },
}; ));
flows.push(get_login_types::v3::LoginType::Sso(flow));
} }
Ok(get_login_types::v3::Response::new(flows)) Ok(get_login_types::v3::Response::new(flows))
@ -113,30 +113,26 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
login::v3::LoginInfo::Token(login::v3::Token { token }) => { login::v3::LoginInfo::Token(login::v3::Token { token }) => {
match ( match (
services().globals.jwt_decoding_key(), services().globals.jwt_decoding_key(),
&services().sso.providers().is_empty(), services().sso.login_type().next().is_some(),
) { ) {
(_, false) => { (_, false) => {
let mut validation = let mut validation = Validation::new(Algorithm::HS256);
jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256);
validation.validate_nbf = false; validation.validate_nbf = false;
validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]); validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]);
let login_token = services() services()
.globals .globals
.validate_claims::<LoginToken>(token, Some(validation)) .validate_claims::<LoginToken>(token, Some(validation))
.map_err(|_| { .as_ref()
Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.") .map(LoginToken::audience)
})?; .map(ToOwned::to_owned)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid token."))?
login_token.audience().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid token audience.")
})?
} }
(Some(jwt_decoding_key), _) => { (Some(jwt_decoding_key), _) => {
let token = jsonwebtoken::decode::<Claims>( let token = jsonwebtoken::decode::<Claims>(
token, token,
jwt_decoding_key, jwt_decoding_key,
&jsonwebtoken::Validation::default(), &Validation::default(),
) )
.map_err(|_| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")

View file

@ -1,30 +1,24 @@
use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime}; use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime};
use crate::{ use crate::{
config::{ config::IdpConfig,
sso::{Registration, Template},
IdpConfig,
},
service::sso::{ service::sso::{
templates, LoginToken, RegistrationInfo, RegistrationToken, ValidationData, LoginToken, ValidationData, SSO_AUTH_EXPIRATION_SECS, SSO_SESSION_COOKIE, SUBJECT_CLAIM_KEY,
REGISTRATION_EXPIRATION_SECS, SESSION_EXPIRATION_SECS, SSO_AUTH_EXPIRATION_SECS,
SSO_SESSION_COOKIE,
}, },
services, utils, Error, Result, Ruma, services, utils, Error, Result, Ruma,
}; };
use axum::{ use axum::{
extract::RawQuery,
response::{AppendHeaders, IntoResponse, Redirect}, response::{AppendHeaders, IntoResponse, Redirect},
RequestExt, RequestExt,
}; };
use axum_extra::{ use axum_extra::{
headers::{self, HeaderMapExt}, headers::{self},
TypedHeader, TypedHeader,
}; };
use http::header; use http::header;
use mas_oidc_client::{ use mas_oidc_client::{
requests::{ requests::{
authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData}, authorization_code::{self, AuthorizationRequestData},
jose::{self, JwtVerificationData}, jose::{self, JwtVerificationData},
userinfo, userinfo,
}, },
@ -35,17 +29,17 @@ use mas_oidc_client::{
requests::{AccessTokenResponse, AuthorizationResponse}, requests::{AccessTokenResponse, AuthorizationResponse},
}, },
}; };
use rand::{rngs::StdRng, SeedableRng}; use rand::{rngs::StdRng, Rng, SeedableRng};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
session::{self, sso_login, sso_login_with_provider}, session::{sso_login, sso_login_with_provider},
}, },
events::GlobalAccountDataEventType, events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
push, OwnedMxcUri, UserId, push, UserId,
}; };
use serde_json::Number; use serde_json::Value;
use tracing::error; use tracing::{error, info, warn};
use url::Url; use url::Url;
pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback";
@ -54,17 +48,28 @@ pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback";
/// ///
/// Redirect the user to the SSO interface. /// Redirect the user to the SSO interface.
/// TODO: this should be removed once Ruma supports trailing slashes. /// TODO: this should be removed once Ruma supports trailing slashes.
pub async fn get_sso_redirect( pub async fn get_sso_redirect_route(
body: Ruma<sso_login::v3::Request>, Ruma {
body,
sender_user,
sender_device,
sender_servername,
json_body,
..
}: Ruma<sso_login::v3::Request>,
) -> Result<sso_login::v3::Response> { ) -> Result<sso_login::v3::Response> {
let sso_login_with_provider::v3::Response { location, cookie } = let sso_login_with_provider::v3::Response { location, cookie } =
get_sso_redirect_with_provider( get_sso_redirect_with_provider_route(
Ruma { Ruma {
body: sso_login_with_provider::v3::Request::new( body: sso_login_with_provider::v3::Request::new(
Default::default(), Default::default(),
body.redirect_url.clone(), body.redirect_url,
), ),
..body sender_user,
sender_device,
sender_servername,
json_body,
appservice_info: None,
} }
.into(), .into(),
) )
@ -76,7 +81,7 @@ pub async fn get_sso_redirect(
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}` /// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
/// ///
/// Redirects the user to the SSO interface. /// Redirects the user to the SSO interface.
pub async fn get_sso_redirect_with_provider( pub async fn get_sso_redirect_with_provider_route(
body: Ruma<sso_login_with_provider::v3::Request>, body: Ruma<sso_login_with_provider::v3::Request>,
) -> Result<sso_login_with_provider::v3::Response> { ) -> Result<sso_login_with_provider::v3::Response> {
let idp_ids: Vec<&str> = services() let idp_ids: Vec<&str> = services()
@ -124,7 +129,7 @@ pub async fn get_sso_redirect_with_provider(
.map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?; .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?;
let signed = services().globals.sign_claims(&ValidationData::new( let signed = services().globals.sign_claims(&ValidationData::new(
provider.borrow().to_string(), Borrow::<str>::borrow(provider).to_owned(),
validation_data, validation_data,
)); ));
@ -142,12 +147,7 @@ pub async fn get_sso_redirect_with_provider(
}) })
} }
/// # `GET /_conduit/client/sso/callback` async fn handle_callback_helper(req: axum::extract::Request) -> Result<axum::response::Response> {
///
/// Validate the authorization response received from the identity provider.
/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect.
/// If this is the first login, register the user, possibly interactively through a fallback page.
pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::response::Response> {
let query = req.uri().query().ok_or_else(|| { let query = req.uri().query().ok_or_else(|| {
Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.") Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.")
})?; })?;
@ -158,9 +158,11 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
token_type, token_type,
id_token, id_token,
expires_in, expires_in,
} = serde_html_form::from_str::<AuthorizationResponse>(query).map_err(|_| { } = serde_html_form::from_str(query).map_err(|_| {
serde_html_form::from_str::<ClientError>(query).unwrap_or_else(|_| { serde_html_form::from_str(query)
error!("Failed to deserialize authorization callback: {}", callback); .map(ClientError::into)
.unwrap_or_else(|_| {
error!("Failed to deserialize authorization callback: {}", query);
Error::BadRequest( Error::BadRequest(
ErrorKind::Unknown, ErrorKind::Unknown,
@ -169,11 +171,13 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
}) })
})?; })?;
let cookie = req let Ok(Some(cookie)): Result<Option<TypedHeader<headers::Cookie>>, _> = req.extract().await
.extract::<Option<TypedHeader<headers::Cookie>>>() else {
.await return Err(Error::BadRequest(
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid session cookie."))? ErrorKind::MissingParam,
.ok_or_else(|_| Error::BadRequest(ErrorKind::MissingParam, "Missing session cookie."))?; "Missing session cookie.",
));
};
let ValidationData { let ValidationData {
provider, provider,
@ -186,11 +190,11 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
})?, })?,
None, None,
) )
.map_err(|e| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.") Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.")
})?; })?;
let provider = services().sso.get(&provider).ok_or_else(|e| { let provider = services().sso.get(&provider).ok_or_else(|| {
Error::BadRequest( Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Unknown provider for session cookie.", "Unknown provider for session cookie.",
@ -204,7 +208,7 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
.. ..
} = provider.config.clone(); } = provider.config.clone();
let credentials = match &auth_method { let credentials = match &*auth_method {
"basic" => ClientCredentials::ClientSecretBasic { "basic" => ClientCredentials::ClientSecretBasic {
client_id, client_id,
client_secret, client_secret,
@ -215,6 +219,16 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
}, },
_ => todo!(), _ => todo!(),
}; };
let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri())
.await
.map_err(|_| Error::bad_config("Failed to fetch signing keys for token endpoint."))?;
let jwt_verification_data = Some(JwtVerificationData {
jwks,
issuer: &provider.config.issuer,
client_id: &provider.config.client_id,
signing_algorithm: &JsonWebSignatureAlg::Rs256,
});
let ( let (
AccessTokenResponse { AccessTokenResponse {
access_token, access_token,
@ -227,34 +241,22 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
Some(id_token), Some(id_token),
) = authorization_code::access_token_with_authorization_code( ) = authorization_code::access_token_with_authorization_code(
services().sso.service(), services().sso.service(),
method, credentials,
provider.metadata.token_endpoint(), provider.metadata.token_endpoint(),
code, code.unwrap_or_default(),
validation_data, validation_data.clone(),
jwt_verification_data, jwt_verification_data,
SystemTime::now().into(), SystemTime::now().into(),
&mut StdRng::from_entropy(), &mut StdRng::from_entropy(),
) )
.await .await
.map_err(|e| Error::bad_config("Failed to fetch access token."))? .map_err(|_| Error::bad_config("Failed to fetch access token."))?
else { else {
unreachable!("ID token should never be empty") unreachable!("ID token should never be empty")
}; };
// let userinfo = provider.fetch_userinfo(&access_token, &id_token).await?;
let mut userinfo = HashMap::default(); let mut userinfo = HashMap::default();
if let Some(endpoint) = &provider.metadata.userinfo_endpoint { if let Some(endpoint) = &provider.metadata.userinfo_endpoint {
let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri())
.await
.map_err(|e| Error::bad_config("Failed to fetch signing keys for token endpoint."))?;
let jwt_verification_data = Some(JwtVerificationData {
jwks,
issuer: &provider.config.issuer,
client_id: credentials.client_id(),
signing_algorithm: &JsonWebSignatureAlg::Rs256,
});
userinfo = userinfo::fetch_userinfo( userinfo = userinfo::fetch_userinfo(
services().sso.service(), services().sso.service(),
endpoint, endpoint,
@ -263,352 +265,66 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
&id_token, &id_token,
) )
.await .await
.map_err(|e| Error::bad_config("Failed to fetch claims for userinfo endpoint."))?; .map_err(|_| Error::bad_config("Failed to fetch claims for userinfo endpoint."))?;
}; };
let (_, mut claims) = id_token.into_parts(); let (_, id_token) = id_token.into_parts();
let subject = claims.get("sub").ok_or_else(|| { let subject = match id_token.get(SUBJECT_CLAIM_KEY) {
error!("Unique \"sub\" claim is missing from ID token: {claims:?}"); Some(Value::String(s)) => s.to_owned(),
Some(Value::Number(n)) => n.to_string(),
value => {
return Err(Error::BadRequest(
ErrorKind::Unknown,
value
.map(|_| {
error!("Subject claim is missing from ID token: {id_token:?}");
Error::bad_config("Unique \"sub\" claim is missing from ID token.") "Subject claim is missing from ID token."
})?; })
.unwrap_or("Subject claim should be a string or number."),
));
}
};
let subject = &subject let user_id = match services()
.as_str()
.map(str::to_owned)
.or_else(|| subject.as_number().map(Number::to_string))
.expect("unique claim should be a string or number");
let redirect_uri = &validation_data.redirect_uri;
if let Some(user_id) = services()
.sso .sso
.user_from_claim(&validation_data.provider_id, subject)? .user_from_subject(Borrow::<str>::borrow(provider), &subject)?
{ {
let login_token = LoginToken::new(validation_data.provider_id.to_owned(), user_id); Some(user_id) => user_id,
None => {
let mut localpart = subject.clone();
let redirect_uri = redirect_with_login_token(redirect_uri.to_owned(), &login_token); let user_id = loop {
match UserId::parse_with_server_name(&*localpart, services().globals.server_name())
{
Ok(user_id) if services().users.exists(&user_id)? => break user_id,
_ => {
let n: u8 = rand::thread_rng().gen();
return Ok(( localpart = format!("{}{}", localpart, n % 10);
AppendHeaders(vec![(
header::SET_COOKIE,
utils::reset_cookie("sso-session").to_string(),
)]),
Redirect::temporary(redirect_uri.as_str()),
)
.into_response());
}
match provider.config.registration {
Registration::Disabled => {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Single Sign-On registration is disabled.",
))
}
Registration::Automated => todo!(),
Registration::Interactive => {}
};
let Template {
username,
displayname,
avatar_url,
email,
} = &provider.config.template;
let registration_info =
RegistrationInfo::new(&claims, username, displayname, avatar_url, email);
let signed = services()
.globals
.sign_macaroon(&RegistrationToken::new(
validation_data.provider_id.clone(),
subject.to_owned(),
redirect_uri.to_owned(),
registration_info,
))
.expect("signing macaroons always works");
let cookie = utils::build_cookie(
"sso-registration",
&signed,
"/_conduit/client/sso/register",
REGISTRATION_EXPIRATION_SECS,
);
Ok((
AppendHeaders(vec![
(header::SET_COOKIE, cookie.to_string()),
(
header::SET_COOKIE,
utils::reset_cookie("sso-session").to_string(),
),
]),
Redirect::temporary("/_conduit/client/sso/register"),
)
.into_response())
}
/// # `GET /_conduit/client/sso/pick_idp`
pub async fn pick_idp(RawQuery(query): RawQuery) -> impl IntoResponse {
let providers: Vec<_> = services()
.globals
.config
.sso
.iter()
.map(|p| p.inner.to_owned())
.collect();
let body = maud::html! {
header {
h1 { "Log in to " (services().globals.server_name()) }
p { "Choose an identity provider to continue" }
}
main {
ul .providers {
@for provider in providers {
li {
a href={ "/_matrix/client/v3/login/sso/redirect/" (provider.id) "?" (query.as_deref().unwrap_or_default()) } {
@if let Some(url) = provider.icon.as_deref().and_then(utils::mxc_to_http) {
img src=(url);
}
}
span {
(provider.name)
}
}
}
}
}
};
(
[(header::CONTENT_TYPE, "text/html; charset=utf-8")],
maud::html! {
(templates::base("Pick Identity Provider", body))
(templates::footer())
},
)
}
/// # `GET /_conduit/client/sso/register`
///
/// Serve a registration form with defaults based on the retrieved claims.
/// This endpoint is only available when interactive registration is enabled.
pub async fn get_sso_registration(
cookie: TypedHeader<headers::Cookie>,
) -> Result<axum::response::Response> {
let token = cookie.get("sso-registration").ok_or_else(|| {
Error::BadRequest(
ErrorKind::MissingParam,
"Missing registration token cookie.",
)
})?;
let registration_token: RegistrationToken = services()
.globals
.validate_macaroon(token, None)
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid registration token cookie.",
)
})?;
let provider = services()
.sso
.get(&registration_token.provider_id)
.map(|p| p.config.inner.to_owned())?;
let server_name = services().globals.server_name();
let RegistrationInfo {
username,
displayname,
avatar_url,
email,
} = registration_token.info;
let additional_info = (&displayname, &avatar_url, &email) != (&None, &None, &None);
fn detail(title: &str, body: maud::Markup) -> maud::Markup {
maud::html! {
label .detail for=(title) {
div .check-row {
span .name { (title) } " "
span .use { "use" }
input #(title) type="checkbox" name={(title)"-checkbox"} value=(true) checked;
}
(body)
}
}
}
let body = maud::html! {
header {
h1 { "Complete your registration at " (server_name) }
p { "Confirm your details to finish creating your account." }
}
main {
form .form #form method="post" {
div .username-div #username-div {
label for="username-input" { "Username (required)" }
div .prefix { "@" }
input .username-input type="text" name="username"
value=(username) autofocus autocorrect="off" autocapitalize="none";
div .postfix { ":" (server_name) }
}
output .username-output for="username-input" { }
@if additional_info {
section .additional-info {
h2 {
@if let Some(icon) = provider.icon.as_deref().and_then(utils::mxc_to_http) {
img src=(icon.to_string());
}
"Optional data from " (provider.name)
}
@if let Some(avatar_url) = avatar_url.as_ref() {
(detail("avatar", maud::html!{
img .avatar src=(avatar_url);
}))
}
@if let Some(displayname) = displayname.as_ref() {
(detail("displayname", maud::html!{
p .value { (displayname) };
}))
}
@if let Some(email) = email.as_ref() {
(detail("email", maud::html!{
p .value { (email) };
}))
}
}
}
input type="submit" value="Submit" .primary-button {}
} }
} }
}; };
Ok(( services().users.set_placeholder_password(&user_id)?;
[(header::CONTENT_TYPE, "text/html; charset=utf-8")], let mut displayname = id_token
maud::html! { .get("preferred_username")
(templates::base("Register Account", body)) .or(id_token.get("nickname"))
.as_deref()
.map(Value::to_string)
.unwrap_or(user_id.localpart().to_owned());
(templates::footer()) // If enabled append lightning bolt to display name (default true)
}, if services().globals.enable_lightning_bolt() {
)
.into_response())
}
/// # `POST /_conduit/client/sso/register`
///
/// Submit the registration form.
pub async fn submit_sso_registration(
cookie: TypedHeader<headers::Cookie>,
axum::extract::Form(registration_info): axum::extract::Form<RegistrationInfo>,
) -> Result<axum::response::Response> {
let token = cookie.get("sso-registration").ok_or_else(|| {
Error::BadRequest(
ErrorKind::MissingParam,
"Missing registration token cookie.",
)
})?;
let registration_token: RegistrationToken = services()
.globals
.validate_macaroon(token, None)
.map_err(|_| {
Error::BadRequest(
ErrorKind::MissingParam,
"Invalid registration token cookie.",
)
})?;
let RegistrationInfo {
username,
mut displayname,
avatar_url,
email: _,
} = registration_info;
let user_id =
UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Invalid username."))?;
if services().users.exists(&user_id)? {
return Err(Error::BadRequest(
ErrorKind::UserInUse,
"Desired UserId is already taken.",
));
}
if services().appservice.is_exclusive_user_id(&user_id).await {
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"Desired UserId reserved by appservice.",
));
}
services().users.create(&user_id, None)?;
services().users.set_password_placeholder(&user_id)?;
if let Some(avatar_url) = avatar_url {
let request = services().globals.default_client().get(avatar_url.as_ref());
let res = request.send().await.map_err(|_| {
Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.")
})?;
let filename = avatar_url.path_segments().and_then(Iterator::last);
let (content_type, body): (Option<headers::ContentType>, Vec<u8>) = (
res.headers().typed_get(),
res.bytes().await.map(Into::into).map_err(|_| {
Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.")
})?,
);
let mxc = format!(
"mxc://{}/{}",
services().globals.server_name(),
utils::random_string(crate::api::client_server::MXC_LENGTH)
);
services()
.media
.create(
mxc.clone(),
filename
.map(|filename| "inline; filename=".to_owned() + filename)
.as_deref(),
content_type.map(|header| header.to_string()).as_deref(),
&body,
)
.await?;
services()
.users
.set_avatar_url(&user_id, Some(OwnedMxcUri::from(mxc)))?;
};
if let (Some(displayname), true) = (
displayname.as_mut(),
services().globals.config.enable_lightning_bolt,
) {
displayname.push_str(" ⚡️"); displayname.push_str(" ⚡️");
} }
services().users.set_displayname(&user_id, displayname)?; services()
.users
services().sso.save_claim( .set_displayname(&user_id, Some(displayname.clone()))?;
&registration_token.provider_id,
&user_id,
&registration_token.unique_claim,
)?;
// Initial account data
services().account_data.update( services().account_data.update(
None, None,
&user_id, &user_id,
@ -618,31 +334,64 @@ pub async fn submit_sso_registration(
global: push::Ruleset::server_default(&user_id), global: push::Ruleset::server_default(&user_id),
}, },
}) })
.expect("PushRulesEvent should always serialize"), .expect("to json always works"),
)?; )?;
let login_token = LoginToken::new(registration_token.provider_id, user_id); info!("New user {} registered on this server.", user_id);
let redirect_uri = redirect_with_login_token(registration_token.redirect_uri, &login_token); services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"New user {user_id} registered on this server."
)));
if let Some(admin_room) = services().admin.get_admin_room()? {
if services()
.rooms
.state_cache
.room_joined_count(&admin_room)?
== Some(1)
{
services()
.admin
.make_user_admin(&user_id, displayname)
.await?;
warn!("Granting {} admin privileges as the first user", user_id);
}
}
user_id
}
};
let signed = services().globals.sign_claims(&LoginToken::new(
Borrow::<str>::borrow(provider).to_owned(),
user_id,
));
let mut redirect_uri = validation_data.redirect_uri;
redirect_uri
.query_pairs_mut()
.append_pair("loginToken", &signed);
Ok(( Ok((
AppendHeaders([( AppendHeaders(vec![(
header::SET_COOKIE, header::SET_COOKIE,
utils::reset_cookie("sso-registration").to_string(), utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string(),
)]), )]),
Redirect::temporary(redirect_uri.as_str()), Redirect::temporary(redirect_uri.as_str()),
) )
.into_response()) .into_response())
} }
fn redirect_with_login_token(mut redirect_uri: Url, login_token: &LoginToken) -> Url { /// # `GET /_conduit/client/sso/callback`
let signed = services() ///
.globals /// Validate the authorization response received from the identity provider.
.sign_macaroon(login_token) /// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect.
.expect("signing macaroons should always works"); /// If this is the first login, register the user, possibly interactively through a fallback page.
pub async fn handle_callback_route(req: axum::extract::Request) -> axum::response::Response {
redirect_uri match handle_callback_helper(req).await {
.query_pairs_mut() Ok(res) => res,
.append_pair("loginToken", &signed); Err(e) => e.into_response(),
}
redirect_uri
} }

View file

@ -0,0 +1,29 @@
use ruma::{OwnedUserId, UserId};
use crate::{service, utils, Error, KeyValueDatabase, Result};
impl service::sso::Data for KeyValueDatabase {
fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()> {
let mut key = provider.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(subject.as_bytes());
self.subject_userid.insert(&key, user_id.as_bytes())
}
fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
let mut key = provider.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(subject.as_bytes());
self.subject_userid.get(&key)?.map_or(Ok(None), |bytes| {
Some(
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in claim_userid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("User ID in claim_userid is invalid.")),
)
.transpose()
})
}
}

View file

@ -49,6 +49,7 @@ pub struct KeyValueDatabase {
pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64 pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
pub(super) token_userdeviceid: Arc<dyn KvTree>, pub(super) token_userdeviceid: Arc<dyn KvTree>,
pub(super) subject_userid: Arc<dyn KvTree>,
pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
@ -289,6 +290,8 @@ impl KeyValueDatabase {
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
token_userdeviceid: builder.open_tree("token_userdeviceid")?, token_userdeviceid: builder.open_tree("token_userdeviceid")?,
subject_userid: builder.open_tree("subject_userid")?,
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: builder.open_tree("keychangeid_userid")?, keychangeid_userid: builder.open_tree("keychangeid_userid")?,

View file

@ -10,7 +10,7 @@ use axum::{
}; };
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
use conduit::api::{ use conduit::api::{
client_server::{self, SSO_CALLBACK_PATH}, client_server::{self, CALLBACK_PATH},
server_server, server_server,
}; };
use figment::{ use figment::{
@ -283,7 +283,7 @@ fn routes(config: &Config) -> Router {
.ruma_route(client_server::get_sso_redirect_with_provider_route) .ruma_route(client_server::get_sso_redirect_with_provider_route)
// The specification will likely never introduce any endpoint for handling authorization callbacks. // The specification will likely never introduce any endpoint for handling authorization callbacks.
// As a workaround, we use custom path that redirects the user to the default login handler. // As a workaround, we use custom path that redirects the user to the default login handler.
.route(SSO_CALLBACK_PATH, get(client_server::sso_login_route)) .route(CALLBACK_PATH, get(client_server::handle_callback_route))
.ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_capabilities_route)
.ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::get_pushrules_all_route)
.ruma_route(client_server::set_pushrule_route) .ruma_route(client_server::set_pushrule_route)

View file

@ -123,7 +123,7 @@ impl Services {
key_backups: key_backups::Service { db }, key_backups: key_backups::Service { db },
media: media::Service { db }, media: media::Service { db },
sending: sending::Service::build(db, &config), sending: sending::Service::build(db, &config),
sso: sso::Service::build(db), sso: sso::Service::build(db)?,
globals: globals::Service::load(db, config)?, globals: globals::Service::load(db, config)?,
}) })

View file

@ -1,50 +1,36 @@
mod data;
use std::{ use std::{
borrow::Borrow, borrow::Borrow,
collections::{HashMap, HashSet}, collections::HashSet,
hash::{Hash, Hasher}, hash::{Hash, Hasher},
str::FromStr, sync::Arc,
sync::{Arc, RwLock},
time::{Duration, SystemTime, UNIX_EPOCH},
}; };
use crate::{ use crate::{
api::client_server::TOKEN_LENGTH, api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH},
config::{sso::ProviderConfig as Config, IdpConfig}, config::IdpConfig,
utils, Error, Result, utils, Error, Result,
}; };
pub use data::Data;
use email_address::EmailAddress;
use futures_util::future::{self}; use futures_util::future::{self};
use mas_oidc_client::{ use mas_oidc_client::{
http_service::{hyper, HttpService}, http_service::{hyper, HttpService},
jose::jwk::PublicJsonWebKeySet, requests::{authorization_code::AuthorizationValidationData, discovery},
requests::{ types::oidc::VerifiedProviderMetadata,
authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData},
discovery,
jose::{self, JwtVerificationData},
userinfo,
},
types::{
iana::jose::JsonWebSignatureAlg, oidc::VerifiedProviderMetadata,
requests::AccessTokenResponse, IdToken,
},
}; };
use rand::SeedableRng; use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId};
use ruma::{api::client::error::ErrorKind, MilliSecondsSinceUnixEpoch, OwnedUserId, UserId};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use tokio::sync::OnceCell;
use tokio::sync::{oneshot, OnceCell};
use tracing::error; use tracing::error;
use url::Url; use url::Url;
use crate::services; use crate::services;
mod data;
pub use data::Data; pub use data::Data;
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60; pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2; pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
pub const SSO_SESSION_COOKIE: &str = "sso-auth"; pub const SSO_SESSION_COOKIE: &str = "sso-auth";
pub const SUBJECT_CLAIM_KEY: &str = "sub";
pub struct Service { pub struct Service {
db: &'static dyn Data, db: &'static dyn Data,
@ -69,7 +55,7 @@ impl Service {
let providers = services().globals.config.idps.iter(); let providers = services().globals.config.idps.iter();
self.providers self.providers
.get_or_try_init(|| { .get_or_try_init(|| async move {
future::try_join_all(providers.map(Provider::fetch_metadata)) future::try_join_all(providers.map(Provider::fetch_metadata))
.await .await
.map(Vec::into_iter) .map(Vec::into_iter)
@ -86,6 +72,12 @@ impl Service {
providers.get(provider) providers.get(provider)
} }
pub fn login_type(&self) -> impl Iterator<Item = IdentityProvider> + '_ {
let providers = self.providers.get().expect("");
providers.iter().map(|p| p.config.inner.clone())
}
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> { pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
self.db.user_from_subject(provider, subject) self.db.user_from_subject(provider, subject)
} }
@ -111,30 +103,6 @@ impl Provider {
Error::bad_config("Failed to fetch identity provider metadata.") Error::bad_config("Failed to fetch identity provider metadata.")
}) })
} }
async fn fetch_signing_keys(&self) -> Result<PublicJsonWebKeySet> {
jose::fetch_jwks(&services().sso.service, self.metadata.jwks_uri())
.await
.map_err(|e| {
error!("Failed to fetch signing keys for token endpoint: {}", e);
Error::bad_config("Failed to fetch signing keys for token endpoint.")
})
}
pub async fn fetch_access_token(
&self,
auth_code: String,
validation_data: AuthorizationValidationData,
) -> Result<(AccessTokenResponse, Option<IdToken<'_>>)> {
}
pub async fn fetch_userinfo(
&self,
access_token: &str,
id_token: &IdToken<'_>,
) -> Result<Option<HashMap<String, Value>>> {
}
} }
impl Borrow<str> for Provider { impl Borrow<str> for Provider {
@ -157,105 +125,28 @@ impl Hash for Provider {
} }
} }
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct RegistrationToken {
pub info: RegistrationInfo,
pub provider_id: String,
pub unique_claim: String,
pub redirect_uri: Url,
pub expires_at: MilliSecondsSinceUnixEpoch,
}
impl RegistrationToken {
pub fn new(
provider_id: String,
unique_claim: String,
redirect_uri: Url,
info: RegistrationInfo,
) -> Self {
let expires_at = MilliSecondsSinceUnixEpoch::from_system_time(
UNIX_EPOCH
.checked_add(Duration::from_secs(REGISTRATION_EXPIRATION_SECS))
.expect("SystemTime should not overflow"),
)
.expect("MilliSecondsSinceUnixEpoch is not too large");
Self {
info,
provider_id,
unique_claim,
redirect_uri,
expires_at,
}
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct RegistrationInfo {
pub username: String,
pub displayname: Option<String>,
pub avatar_url: Option<Url>,
pub email: Option<EmailAddress>,
}
impl RegistrationInfo {
pub fn new(
claims: &HashMap<String, Value>,
username: &str,
displayname: &str,
avatar_url: &str,
email: &str,
) -> Self {
Self {
username: claims
.get(username)
.and_then(|v| v.as_str())
.map(ToOwned::to_owned)
.unwrap_or_default(),
displayname: claims
.get(displayname)
.and_then(|v| v.as_str())
.map(ToOwned::to_owned),
avatar_url: claims
.get(avatar_url)
.and_then(|v| v.as_str())
.map(Url::parse)
.and_then(Result::ok),
email: claims
.get(email)
.and_then(|v| v.as_str())
.map(EmailAddress::from_str)
.and_then(Result::ok),
}
}
}
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
pub struct LoginToken { pub struct LoginToken {
pub inner: String, pub iss: String,
pub provider_id: String, pub aud: OwnedUserId,
pub user_id: OwnedUserId, pub sub: String,
pub exp: u64,
#[serde(rename = "exp")]
expires_at: u64,
} }
impl LoginToken { impl LoginToken {
pub fn new(provider_id: String, user_id: OwnedUserId) -> Self { pub fn new(provider: String, user_id: OwnedUserId) -> Self {
let expires_at = SystemTime::now()
.checked_add(Duration::from_secs(LOGIN_TOKEN_EXPIRATION_SECS))
.expect("SystemTime should not overflow")
.duration_since(UNIX_EPOCH)
.expect("SystemTime went backwards")
.as_secs();
Self { Self {
inner: utils::random_string(TOKEN_LENGTH), iss: provider,
provider_id, aud: user_id,
user_id, sub: utils::random_string(TOKEN_LENGTH),
expires_at, exp: utils::millis_since_unix_epoch()
.checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000)
.expect("time overflow"),
} }
} }
pub fn audience(&self) -> &UserId {
&self.aud
}
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]