mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-06-27 16:35:59 +00:00
fix
This commit is contained in:
parent
10ce7ea3a9
commit
b80141b33b
8 changed files with 255 additions and 599 deletions
28
Cargo.toml
28
Cargo.toml
|
@ -35,18 +35,15 @@ axum = { version = "0.7", default-features = false, features = [
|
||||||
"json",
|
"json",
|
||||||
"matched-path",
|
"matched-path",
|
||||||
], optional = true }
|
], optional = true }
|
||||||
axum-extra = { version = "0.9", features = ["cookie", "typed-header"] }
|
axum-extra = { version = "0.9", features = ["typed-header", "cookie"] }
|
||||||
axum-server = { version = "0.6", features = ["tls-rustls"] }
|
axum-server = { version = "0.6", features = ["tls-rustls"] }
|
||||||
tower = { version = "0.4.13", features = ["util"] }
|
tower = { version = "0.4.13", features = ["util"] }
|
||||||
# tower-http = { version = "0.5", features = [
|
|
||||||
# "add-extension",
|
|
||||||
# "cors",
|
|
||||||
# "sensitive-headers",
|
|
||||||
# "trace",
|
|
||||||
# "util",
|
|
||||||
# ] }
|
|
||||||
tower-http = { version = "0.5", features = [
|
tower-http = { version = "0.5", features = [
|
||||||
"full",
|
"add-extension",
|
||||||
|
"cors",
|
||||||
|
"sensitive-headers",
|
||||||
|
"trace",
|
||||||
|
"util",
|
||||||
] }
|
] }
|
||||||
tower-service = "0.3"
|
tower-service = "0.3"
|
||||||
|
|
||||||
|
@ -143,6 +140,7 @@ figment = { version = "0.10.8", features = ["env", "toml"] }
|
||||||
# Validating urls in config
|
# Validating urls in config
|
||||||
url = { version = "2", features = ["serde"] }
|
url = { version = "2", features = ["serde"] }
|
||||||
|
|
||||||
|
mas-oidc-client = { version = "0.9", default-features = false, features = ["hyper"] }
|
||||||
# HTML
|
# HTML
|
||||||
maud = { version = "0.26.0", default-features = false, features = ["axum"] }
|
maud = { version = "0.26.0", default-features = false, features = ["axum"] }
|
||||||
|
|
||||||
|
@ -152,19 +150,14 @@ tikv-jemallocator = { version = "0.5.0", features = [
|
||||||
], optional = true }
|
], optional = true }
|
||||||
|
|
||||||
sd-notify = { version = "0.4.1", optional = true }
|
sd-notify = { version = "0.4.1", optional = true }
|
||||||
http-body-util = "0.1.2"
|
|
||||||
hyper-rustls = { version = "0.27.2", default-features = false, features = ["http1", "http2", "ring", "rustls-native-certs", "rustls-platform-verifier"] }
|
|
||||||
mas-http = "0.9.0"
|
|
||||||
|
|
||||||
# Used for matrix spec type definitions and helpers
|
# Used for matrix spec type definitions and helpers
|
||||||
[dependencies.ruma]
|
[dependencies.ruma]
|
||||||
features = [
|
features = [
|
||||||
"appservice-api-c",
|
"appservice-api-c",
|
||||||
"client",
|
|
||||||
"client-api",
|
"client-api",
|
||||||
"compat",
|
"compat",
|
||||||
"federation-api",
|
"federation-api",
|
||||||
"client-hyper",
|
|
||||||
"push-gateway-api-c",
|
"push-gateway-api-c",
|
||||||
"rand",
|
"rand",
|
||||||
"ring-compat",
|
"ring-compat",
|
||||||
|
@ -183,16 +176,11 @@ optional = true
|
||||||
package = "rust-rocksdb"
|
package = "rust-rocksdb"
|
||||||
version = "0.25"
|
version = "0.25"
|
||||||
|
|
||||||
# Used for Single Sign-On
|
|
||||||
[dependencies.mas-oidc-client]
|
|
||||||
git = "https://github.com/matrix-org/matrix-authentication-service.git"
|
|
||||||
default-features = false
|
|
||||||
|
|
||||||
[target.'cfg(unix)'.dependencies]
|
[target.'cfg(unix)'.dependencies]
|
||||||
nix = { version = "0.28", features = ["resource"] }
|
nix = { version = "0.28", features = ["resource"] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["backend_rocksdb", "backend_sqlite", "conduit_bin", "systemd"]
|
default = ["backend_sqlite", "conduit_bin"]
|
||||||
#backend_sled = ["sled"]
|
#backend_sled = ["sled"]
|
||||||
backend_persy = ["parking_lot", "persy"]
|
backend_persy = ["parking_lot", "persy"]
|
||||||
backend_sqlite = ["sqlite"]
|
backend_sqlite = ["sqlite"]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||||
use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma};
|
use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma};
|
||||||
|
use jsonwebtoken::{Algorithm, Validation};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
|
@ -24,17 +25,16 @@ struct Claims {
|
||||||
pub async fn get_login_types_route(
|
pub async fn get_login_types_route(
|
||||||
_body: Ruma<get_login_types::v3::Request>,
|
_body: Ruma<get_login_types::v3::Request>,
|
||||||
) -> Result<get_login_types::v3::Response> {
|
) -> Result<get_login_types::v3::Response> {
|
||||||
|
let identity_providers: Vec<_> = services().sso.login_type().collect();
|
||||||
let mut flows = vec![
|
let mut flows = vec![
|
||||||
get_login_types::v3::LoginType::Password(Default::default()),
|
get_login_types::v3::LoginType::Password(Default::default()),
|
||||||
get_login_types::v3::LoginType::ApplicationService(Default::default()),
|
get_login_types::v3::LoginType::ApplicationService(Default::default()),
|
||||||
];
|
];
|
||||||
|
|
||||||
if let v @ [_, ..] = &*services().sso.flows() {
|
if !identity_providers.is_empty() {
|
||||||
let flow = get_login_types::v3::SsoLoginType {
|
flows.push(get_login_types::v3::LoginType::Sso(
|
||||||
identity_providers: v.to_owned(),
|
get_login_types::v3::SsoLoginType { identity_providers },
|
||||||
};
|
));
|
||||||
|
|
||||||
flows.push(get_login_types::v3::LoginType::Sso(flow));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(get_login_types::v3::Response::new(flows))
|
Ok(get_login_types::v3::Response::new(flows))
|
||||||
|
@ -113,30 +113,26 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
||||||
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
|
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
|
||||||
match (
|
match (
|
||||||
services().globals.jwt_decoding_key(),
|
services().globals.jwt_decoding_key(),
|
||||||
&services().sso.providers().is_empty(),
|
services().sso.login_type().next().is_some(),
|
||||||
) {
|
) {
|
||||||
(_, false) => {
|
(_, false) => {
|
||||||
let mut validation =
|
let mut validation = Validation::new(Algorithm::HS256);
|
||||||
jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256);
|
|
||||||
validation.validate_nbf = false;
|
validation.validate_nbf = false;
|
||||||
validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]);
|
validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]);
|
||||||
|
|
||||||
let login_token = services()
|
services()
|
||||||
.globals
|
.globals
|
||||||
.validate_claims::<LoginToken>(token, Some(validation))
|
.validate_claims::<LoginToken>(token, Some(validation))
|
||||||
.map_err(|_| {
|
.as_ref()
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.")
|
.map(LoginToken::audience)
|
||||||
})?;
|
.map(ToOwned::to_owned)
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid token."))?
|
||||||
login_token.audience().map_err(|_| {
|
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid token audience.")
|
|
||||||
})?
|
|
||||||
}
|
}
|
||||||
(Some(jwt_decoding_key), _) => {
|
(Some(jwt_decoding_key), _) => {
|
||||||
let token = jsonwebtoken::decode::<Claims>(
|
let token = jsonwebtoken::decode::<Claims>(
|
||||||
token,
|
token,
|
||||||
jwt_decoding_key,
|
jwt_decoding_key,
|
||||||
&jsonwebtoken::Validation::default(),
|
&Validation::default(),
|
||||||
)
|
)
|
||||||
.map_err(|_| {
|
.map_err(|_| {
|
||||||
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
|
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
|
||||||
|
|
|
@ -1,30 +1,24 @@
|
||||||
use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime};
|
use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::{
|
config::IdpConfig,
|
||||||
sso::{Registration, Template},
|
|
||||||
IdpConfig,
|
|
||||||
},
|
|
||||||
service::sso::{
|
service::sso::{
|
||||||
templates, LoginToken, RegistrationInfo, RegistrationToken, ValidationData,
|
LoginToken, ValidationData, SSO_AUTH_EXPIRATION_SECS, SSO_SESSION_COOKIE, SUBJECT_CLAIM_KEY,
|
||||||
REGISTRATION_EXPIRATION_SECS, SESSION_EXPIRATION_SECS, SSO_AUTH_EXPIRATION_SECS,
|
|
||||||
SSO_SESSION_COOKIE,
|
|
||||||
},
|
},
|
||||||
services, utils, Error, Result, Ruma,
|
services, utils, Error, Result, Ruma,
|
||||||
};
|
};
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::RawQuery,
|
|
||||||
response::{AppendHeaders, IntoResponse, Redirect},
|
response::{AppendHeaders, IntoResponse, Redirect},
|
||||||
RequestExt,
|
RequestExt,
|
||||||
};
|
};
|
||||||
use axum_extra::{
|
use axum_extra::{
|
||||||
headers::{self, HeaderMapExt},
|
headers::{self},
|
||||||
TypedHeader,
|
TypedHeader,
|
||||||
};
|
};
|
||||||
use http::header;
|
use http::header;
|
||||||
use mas_oidc_client::{
|
use mas_oidc_client::{
|
||||||
requests::{
|
requests::{
|
||||||
authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData},
|
authorization_code::{self, AuthorizationRequestData},
|
||||||
jose::{self, JwtVerificationData},
|
jose::{self, JwtVerificationData},
|
||||||
userinfo,
|
userinfo,
|
||||||
},
|
},
|
||||||
|
@ -35,17 +29,17 @@ use mas_oidc_client::{
|
||||||
requests::{AccessTokenResponse, AuthorizationResponse},
|
requests::{AccessTokenResponse, AuthorizationResponse},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use rand::{rngs::StdRng, SeedableRng};
|
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
session::{self, sso_login, sso_login_with_provider},
|
session::{sso_login, sso_login_with_provider},
|
||||||
},
|
},
|
||||||
events::GlobalAccountDataEventType,
|
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
|
||||||
push, OwnedMxcUri, UserId,
|
push, UserId,
|
||||||
};
|
};
|
||||||
use serde_json::Number;
|
use serde_json::Value;
|
||||||
use tracing::error;
|
use tracing::{error, info, warn};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback";
|
pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback";
|
||||||
|
@ -54,17 +48,28 @@ pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback";
|
||||||
///
|
///
|
||||||
/// Redirect the user to the SSO interface.
|
/// Redirect the user to the SSO interface.
|
||||||
/// TODO: this should be removed once Ruma supports trailing slashes.
|
/// TODO: this should be removed once Ruma supports trailing slashes.
|
||||||
pub async fn get_sso_redirect(
|
pub async fn get_sso_redirect_route(
|
||||||
body: Ruma<sso_login::v3::Request>,
|
Ruma {
|
||||||
|
body,
|
||||||
|
sender_user,
|
||||||
|
sender_device,
|
||||||
|
sender_servername,
|
||||||
|
json_body,
|
||||||
|
..
|
||||||
|
}: Ruma<sso_login::v3::Request>,
|
||||||
) -> Result<sso_login::v3::Response> {
|
) -> Result<sso_login::v3::Response> {
|
||||||
let sso_login_with_provider::v3::Response { location, cookie } =
|
let sso_login_with_provider::v3::Response { location, cookie } =
|
||||||
get_sso_redirect_with_provider(
|
get_sso_redirect_with_provider_route(
|
||||||
Ruma {
|
Ruma {
|
||||||
body: sso_login_with_provider::v3::Request::new(
|
body: sso_login_with_provider::v3::Request::new(
|
||||||
Default::default(),
|
Default::default(),
|
||||||
body.redirect_url.clone(),
|
body.redirect_url,
|
||||||
),
|
),
|
||||||
..body
|
sender_user,
|
||||||
|
sender_device,
|
||||||
|
sender_servername,
|
||||||
|
json_body,
|
||||||
|
appservice_info: None,
|
||||||
}
|
}
|
||||||
.into(),
|
.into(),
|
||||||
)
|
)
|
||||||
|
@ -76,7 +81,7 @@ pub async fn get_sso_redirect(
|
||||||
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
|
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
|
||||||
///
|
///
|
||||||
/// Redirects the user to the SSO interface.
|
/// Redirects the user to the SSO interface.
|
||||||
pub async fn get_sso_redirect_with_provider(
|
pub async fn get_sso_redirect_with_provider_route(
|
||||||
body: Ruma<sso_login_with_provider::v3::Request>,
|
body: Ruma<sso_login_with_provider::v3::Request>,
|
||||||
) -> Result<sso_login_with_provider::v3::Response> {
|
) -> Result<sso_login_with_provider::v3::Response> {
|
||||||
let idp_ids: Vec<&str> = services()
|
let idp_ids: Vec<&str> = services()
|
||||||
|
@ -124,7 +129,7 @@ pub async fn get_sso_redirect_with_provider(
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?;
|
.map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?;
|
||||||
|
|
||||||
let signed = services().globals.sign_claims(&ValidationData::new(
|
let signed = services().globals.sign_claims(&ValidationData::new(
|
||||||
provider.borrow().to_string(),
|
Borrow::<str>::borrow(provider).to_owned(),
|
||||||
validation_data,
|
validation_data,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -142,12 +147,7 @@ pub async fn get_sso_redirect_with_provider(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_conduit/client/sso/callback`
|
async fn handle_callback_helper(req: axum::extract::Request) -> Result<axum::response::Response> {
|
||||||
///
|
|
||||||
/// Validate the authorization response received from the identity provider.
|
|
||||||
/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect.
|
|
||||||
/// If this is the first login, register the user, possibly interactively through a fallback page.
|
|
||||||
pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::response::Response> {
|
|
||||||
let query = req.uri().query().ok_or_else(|| {
|
let query = req.uri().query().ok_or_else(|| {
|
||||||
Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.")
|
Error::BadRequest(ErrorKind::MissingParam, "Empty authorization callback.")
|
||||||
})?;
|
})?;
|
||||||
|
@ -158,22 +158,26 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
|
||||||
token_type,
|
token_type,
|
||||||
id_token,
|
id_token,
|
||||||
expires_in,
|
expires_in,
|
||||||
} = serde_html_form::from_str::<AuthorizationResponse>(query).map_err(|_| {
|
} = serde_html_form::from_str(query).map_err(|_| {
|
||||||
serde_html_form::from_str::<ClientError>(query).unwrap_or_else(|_| {
|
serde_html_form::from_str(query)
|
||||||
error!("Failed to deserialize authorization callback: {}", callback);
|
.map(ClientError::into)
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
error!("Failed to deserialize authorization callback: {}", query);
|
||||||
|
|
||||||
Error::BadRequest(
|
Error::BadRequest(
|
||||||
ErrorKind::Unknown,
|
ErrorKind::Unknown,
|
||||||
"Failed to deserialize authorization callback.",
|
"Failed to deserialize authorization callback.",
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let cookie = req
|
let Ok(Some(cookie)): Result<Option<TypedHeader<headers::Cookie>>, _> = req.extract().await
|
||||||
.extract::<Option<TypedHeader<headers::Cookie>>>()
|
else {
|
||||||
.await
|
return Err(Error::BadRequest(
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid session cookie."))?
|
ErrorKind::MissingParam,
|
||||||
.ok_or_else(|_| Error::BadRequest(ErrorKind::MissingParam, "Missing session cookie."))?;
|
"Missing session cookie.",
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
let ValidationData {
|
let ValidationData {
|
||||||
provider,
|
provider,
|
||||||
|
@ -186,11 +190,11 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
|
||||||
})?,
|
})?,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.map_err(|e| {
|
.map_err(|_| {
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.")
|
Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let provider = services().sso.get(&provider).ok_or_else(|e| {
|
let provider = services().sso.get(&provider).ok_or_else(|| {
|
||||||
Error::BadRequest(
|
Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Unknown provider for session cookie.",
|
"Unknown provider for session cookie.",
|
||||||
|
@ -204,7 +208,7 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
|
||||||
..
|
..
|
||||||
} = provider.config.clone();
|
} = provider.config.clone();
|
||||||
|
|
||||||
let credentials = match &auth_method {
|
let credentials = match &*auth_method {
|
||||||
"basic" => ClientCredentials::ClientSecretBasic {
|
"basic" => ClientCredentials::ClientSecretBasic {
|
||||||
client_id,
|
client_id,
|
||||||
client_secret,
|
client_secret,
|
||||||
|
@ -215,6 +219,16 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
|
||||||
},
|
},
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
};
|
};
|
||||||
|
let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri())
|
||||||
|
.await
|
||||||
|
.map_err(|_| Error::bad_config("Failed to fetch signing keys for token endpoint."))?;
|
||||||
|
let jwt_verification_data = Some(JwtVerificationData {
|
||||||
|
jwks,
|
||||||
|
issuer: &provider.config.issuer,
|
||||||
|
client_id: &provider.config.client_id,
|
||||||
|
signing_algorithm: &JsonWebSignatureAlg::Rs256,
|
||||||
|
});
|
||||||
|
|
||||||
let (
|
let (
|
||||||
AccessTokenResponse {
|
AccessTokenResponse {
|
||||||
access_token,
|
access_token,
|
||||||
|
@ -227,34 +241,22 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
|
||||||
Some(id_token),
|
Some(id_token),
|
||||||
) = authorization_code::access_token_with_authorization_code(
|
) = authorization_code::access_token_with_authorization_code(
|
||||||
services().sso.service(),
|
services().sso.service(),
|
||||||
method,
|
credentials,
|
||||||
provider.metadata.token_endpoint(),
|
provider.metadata.token_endpoint(),
|
||||||
code,
|
code.unwrap_or_default(),
|
||||||
validation_data,
|
validation_data.clone(),
|
||||||
jwt_verification_data,
|
jwt_verification_data,
|
||||||
SystemTime::now().into(),
|
SystemTime::now().into(),
|
||||||
&mut StdRng::from_entropy(),
|
&mut StdRng::from_entropy(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::bad_config("Failed to fetch access token."))?
|
.map_err(|_| Error::bad_config("Failed to fetch access token."))?
|
||||||
else {
|
else {
|
||||||
unreachable!("ID token should never be empty")
|
unreachable!("ID token should never be empty")
|
||||||
};
|
};
|
||||||
|
|
||||||
// let userinfo = provider.fetch_userinfo(&access_token, &id_token).await?;
|
|
||||||
|
|
||||||
let mut userinfo = HashMap::default();
|
let mut userinfo = HashMap::default();
|
||||||
if let Some(endpoint) = &provider.metadata.userinfo_endpoint {
|
if let Some(endpoint) = &provider.metadata.userinfo_endpoint {
|
||||||
let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri())
|
|
||||||
.await
|
|
||||||
.map_err(|e| Error::bad_config("Failed to fetch signing keys for token endpoint."))?;
|
|
||||||
let jwt_verification_data = Some(JwtVerificationData {
|
|
||||||
jwks,
|
|
||||||
issuer: &provider.config.issuer,
|
|
||||||
client_id: credentials.client_id(),
|
|
||||||
signing_algorithm: &JsonWebSignatureAlg::Rs256,
|
|
||||||
});
|
|
||||||
|
|
||||||
userinfo = userinfo::fetch_userinfo(
|
userinfo = userinfo::fetch_userinfo(
|
||||||
services().sso.service(),
|
services().sso.service(),
|
||||||
endpoint,
|
endpoint,
|
||||||
|
@ -263,386 +265,133 @@ pub async fn get_sso_callback(req: axum::extract::Request) -> Result<axum::respo
|
||||||
&id_token,
|
&id_token,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::bad_config("Failed to fetch claims for userinfo endpoint."))?;
|
.map_err(|_| Error::bad_config("Failed to fetch claims for userinfo endpoint."))?;
|
||||||
};
|
};
|
||||||
|
|
||||||
let (_, mut claims) = id_token.into_parts();
|
let (_, id_token) = id_token.into_parts();
|
||||||
|
|
||||||
let subject = claims.get("sub").ok_or_else(|| {
|
let subject = match id_token.get(SUBJECT_CLAIM_KEY) {
|
||||||
error!("Unique \"sub\" claim is missing from ID token: {claims:?}");
|
Some(Value::String(s)) => s.to_owned(),
|
||||||
|
Some(Value::Number(n)) => n.to_string(),
|
||||||
Error::bad_config("Unique \"sub\" claim is missing from ID token.")
|
value => {
|
||||||
})?;
|
|
||||||
|
|
||||||
let subject = &subject
|
|
||||||
.as_str()
|
|
||||||
.map(str::to_owned)
|
|
||||||
.or_else(|| subject.as_number().map(Number::to_string))
|
|
||||||
.expect("unique claim should be a string or number");
|
|
||||||
|
|
||||||
let redirect_uri = &validation_data.redirect_uri;
|
|
||||||
|
|
||||||
if let Some(user_id) = services()
|
|
||||||
.sso
|
|
||||||
.user_from_claim(&validation_data.provider_id, subject)?
|
|
||||||
{
|
|
||||||
let login_token = LoginToken::new(validation_data.provider_id.to_owned(), user_id);
|
|
||||||
|
|
||||||
let redirect_uri = redirect_with_login_token(redirect_uri.to_owned(), &login_token);
|
|
||||||
|
|
||||||
return Ok((
|
|
||||||
AppendHeaders(vec![(
|
|
||||||
header::SET_COOKIE,
|
|
||||||
utils::reset_cookie("sso-session").to_string(),
|
|
||||||
)]),
|
|
||||||
Redirect::temporary(redirect_uri.as_str()),
|
|
||||||
)
|
|
||||||
.into_response());
|
|
||||||
}
|
|
||||||
|
|
||||||
match provider.config.registration {
|
|
||||||
Registration::Disabled => {
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::forbidden(),
|
ErrorKind::Unknown,
|
||||||
"Single Sign-On registration is disabled.",
|
value
|
||||||
))
|
.map(|_| {
|
||||||
|
error!("Subject claim is missing from ID token: {id_token:?}");
|
||||||
|
|
||||||
|
"Subject claim is missing from ID token."
|
||||||
|
})
|
||||||
|
.unwrap_or("Subject claim should be a string or number."),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
Registration::Automated => todo!(),
|
|
||||||
Registration::Interactive => {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let Template {
|
let user_id = match services()
|
||||||
username,
|
|
||||||
displayname,
|
|
||||||
avatar_url,
|
|
||||||
email,
|
|
||||||
} = &provider.config.template;
|
|
||||||
let registration_info =
|
|
||||||
RegistrationInfo::new(&claims, username, displayname, avatar_url, email);
|
|
||||||
|
|
||||||
let signed = services()
|
|
||||||
.globals
|
|
||||||
.sign_macaroon(&RegistrationToken::new(
|
|
||||||
validation_data.provider_id.clone(),
|
|
||||||
subject.to_owned(),
|
|
||||||
redirect_uri.to_owned(),
|
|
||||||
registration_info,
|
|
||||||
))
|
|
||||||
.expect("signing macaroons always works");
|
|
||||||
|
|
||||||
let cookie = utils::build_cookie(
|
|
||||||
"sso-registration",
|
|
||||||
&signed,
|
|
||||||
"/_conduit/client/sso/register",
|
|
||||||
REGISTRATION_EXPIRATION_SECS,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok((
|
|
||||||
AppendHeaders(vec![
|
|
||||||
(header::SET_COOKIE, cookie.to_string()),
|
|
||||||
(
|
|
||||||
header::SET_COOKIE,
|
|
||||||
utils::reset_cookie("sso-session").to_string(),
|
|
||||||
),
|
|
||||||
]),
|
|
||||||
Redirect::temporary("/_conduit/client/sso/register"),
|
|
||||||
)
|
|
||||||
.into_response())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # `GET /_conduit/client/sso/pick_idp`
|
|
||||||
pub async fn pick_idp(RawQuery(query): RawQuery) -> impl IntoResponse {
|
|
||||||
let providers: Vec<_> = services()
|
|
||||||
.globals
|
|
||||||
.config
|
|
||||||
.sso
|
.sso
|
||||||
.iter()
|
.user_from_subject(Borrow::<str>::borrow(provider), &subject)?
|
||||||
.map(|p| p.inner.to_owned())
|
{
|
||||||
.collect();
|
Some(user_id) => user_id,
|
||||||
|
None => {
|
||||||
|
let mut localpart = subject.clone();
|
||||||
|
|
||||||
let body = maud::html! {
|
let user_id = loop {
|
||||||
header {
|
match UserId::parse_with_server_name(&*localpart, services().globals.server_name())
|
||||||
h1 { "Log in to " (services().globals.server_name()) }
|
{
|
||||||
p { "Choose an identity provider to continue" }
|
Ok(user_id) if services().users.exists(&user_id)? => break user_id,
|
||||||
}
|
_ => {
|
||||||
main {
|
let n: u8 = rand::thread_rng().gen();
|
||||||
ul .providers {
|
|
||||||
@for provider in providers {
|
localpart = format!("{}{}", localpart, n % 10);
|
||||||
li {
|
|
||||||
a href={ "/_matrix/client/v3/login/sso/redirect/" (provider.id) "?" (query.as_deref().unwrap_or_default()) } {
|
|
||||||
@if let Some(url) = provider.icon.as_deref().and_then(utils::mxc_to_http) {
|
|
||||||
img src=(url);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
span {
|
|
||||||
(provider.name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
services().users.set_placeholder_password(&user_id)?;
|
||||||
|
let mut displayname = id_token
|
||||||
|
.get("preferred_username")
|
||||||
|
.or(id_token.get("nickname"))
|
||||||
|
.as_deref()
|
||||||
|
.map(Value::to_string)
|
||||||
|
.unwrap_or(user_id.localpart().to_owned());
|
||||||
|
|
||||||
|
// If enabled append lightning bolt to display name (default true)
|
||||||
|
if services().globals.enable_lightning_bolt() {
|
||||||
|
displayname.push_str(" ⚡️");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
services()
|
||||||
|
.users
|
||||||
|
.set_displayname(&user_id, Some(displayname.clone()))?;
|
||||||
|
|
||||||
|
// Initial account data
|
||||||
|
services().account_data.update(
|
||||||
|
None,
|
||||||
|
&user_id,
|
||||||
|
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||||
|
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
|
||||||
|
content: ruma::events::push_rules::PushRulesEventContent {
|
||||||
|
global: push::Ruleset::server_default(&user_id),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.expect("to json always works"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
info!("New user {} registered on this server.", user_id);
|
||||||
|
services()
|
||||||
|
.admin
|
||||||
|
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||||
|
"New user {user_id} registered on this server."
|
||||||
|
)));
|
||||||
|
|
||||||
|
if let Some(admin_room) = services().admin.get_admin_room()? {
|
||||||
|
if services()
|
||||||
|
.rooms
|
||||||
|
.state_cache
|
||||||
|
.room_joined_count(&admin_room)?
|
||||||
|
== Some(1)
|
||||||
|
{
|
||||||
|
services()
|
||||||
|
.admin
|
||||||
|
.make_user_admin(&user_id, displayname)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
warn!("Granting {} admin privileges as the first user", user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
user_id
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
let signed = services().globals.sign_claims(&LoginToken::new(
|
||||||
[(header::CONTENT_TYPE, "text/html; charset=utf-8")],
|
Borrow::<str>::borrow(provider).to_owned(),
|
||||||
maud::html! {
|
user_id,
|
||||||
(templates::base("Pick Identity Provider", body))
|
));
|
||||||
|
|
||||||
(templates::footer())
|
let mut redirect_uri = validation_data.redirect_uri;
|
||||||
},
|
redirect_uri
|
||||||
)
|
.query_pairs_mut()
|
||||||
}
|
.append_pair("loginToken", &signed);
|
||||||
|
|
||||||
/// # `GET /_conduit/client/sso/register`
|
|
||||||
///
|
|
||||||
/// Serve a registration form with defaults based on the retrieved claims.
|
|
||||||
/// This endpoint is only available when interactive registration is enabled.
|
|
||||||
pub async fn get_sso_registration(
|
|
||||||
cookie: TypedHeader<headers::Cookie>,
|
|
||||||
) -> Result<axum::response::Response> {
|
|
||||||
let token = cookie.get("sso-registration").ok_or_else(|| {
|
|
||||||
Error::BadRequest(
|
|
||||||
ErrorKind::MissingParam,
|
|
||||||
"Missing registration token cookie.",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let registration_token: RegistrationToken = services()
|
|
||||||
.globals
|
|
||||||
.validate_macaroon(token, None)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::BadRequest(
|
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Invalid registration token cookie.",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let provider = services()
|
|
||||||
.sso
|
|
||||||
.get(®istration_token.provider_id)
|
|
||||||
.map(|p| p.config.inner.to_owned())?;
|
|
||||||
let server_name = services().globals.server_name();
|
|
||||||
|
|
||||||
let RegistrationInfo {
|
|
||||||
username,
|
|
||||||
displayname,
|
|
||||||
avatar_url,
|
|
||||||
email,
|
|
||||||
} = registration_token.info;
|
|
||||||
|
|
||||||
let additional_info = (&displayname, &avatar_url, &email) != (&None, &None, &None);
|
|
||||||
|
|
||||||
fn detail(title: &str, body: maud::Markup) -> maud::Markup {
|
|
||||||
maud::html! {
|
|
||||||
label .detail for=(title) {
|
|
||||||
div .check-row {
|
|
||||||
span .name { (title) } " "
|
|
||||||
span .use { "use" }
|
|
||||||
input #(title) type="checkbox" name={(title)"-checkbox"} value=(true) checked;
|
|
||||||
}
|
|
||||||
(body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let body = maud::html! {
|
|
||||||
header {
|
|
||||||
h1 { "Complete your registration at " (server_name) }
|
|
||||||
p { "Confirm your details to finish creating your account." }
|
|
||||||
}
|
|
||||||
main {
|
|
||||||
form .form #form method="post" {
|
|
||||||
div .username-div #username-div {
|
|
||||||
label for="username-input" { "Username (required)" }
|
|
||||||
div .prefix { "@" }
|
|
||||||
input .username-input type="text" name="username"
|
|
||||||
value=(username) autofocus autocorrect="off" autocapitalize="none";
|
|
||||||
div .postfix { ":" (server_name) }
|
|
||||||
}
|
|
||||||
output .username-output for="username-input" { }
|
|
||||||
|
|
||||||
@if additional_info {
|
|
||||||
section .additional-info {
|
|
||||||
h2 {
|
|
||||||
@if let Some(icon) = provider.icon.as_deref().and_then(utils::mxc_to_http) {
|
|
||||||
img src=(icon.to_string());
|
|
||||||
}
|
|
||||||
"Optional data from " (provider.name)
|
|
||||||
}
|
|
||||||
@if let Some(avatar_url) = avatar_url.as_ref() {
|
|
||||||
(detail("avatar", maud::html!{
|
|
||||||
img .avatar src=(avatar_url);
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
@if let Some(displayname) = displayname.as_ref() {
|
|
||||||
(detail("displayname", maud::html!{
|
|
||||||
p .value { (displayname) };
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
@if let Some(email) = email.as_ref() {
|
|
||||||
(detail("email", maud::html!{
|
|
||||||
p .value { (email) };
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
input type="submit" value="Submit" .primary-button {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
[(header::CONTENT_TYPE, "text/html; charset=utf-8")],
|
AppendHeaders(vec![(
|
||||||
maud::html! {
|
|
||||||
(templates::base("Register Account", body))
|
|
||||||
|
|
||||||
(templates::footer())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.into_response())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # `POST /_conduit/client/sso/register`
|
|
||||||
///
|
|
||||||
/// Submit the registration form.
|
|
||||||
pub async fn submit_sso_registration(
|
|
||||||
cookie: TypedHeader<headers::Cookie>,
|
|
||||||
axum::extract::Form(registration_info): axum::extract::Form<RegistrationInfo>,
|
|
||||||
) -> Result<axum::response::Response> {
|
|
||||||
let token = cookie.get("sso-registration").ok_or_else(|| {
|
|
||||||
Error::BadRequest(
|
|
||||||
ErrorKind::MissingParam,
|
|
||||||
"Missing registration token cookie.",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let registration_token: RegistrationToken = services()
|
|
||||||
.globals
|
|
||||||
.validate_macaroon(token, None)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::BadRequest(
|
|
||||||
ErrorKind::MissingParam,
|
|
||||||
"Invalid registration token cookie.",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let RegistrationInfo {
|
|
||||||
username,
|
|
||||||
mut displayname,
|
|
||||||
avatar_url,
|
|
||||||
email: _,
|
|
||||||
} = registration_info;
|
|
||||||
|
|
||||||
let user_id =
|
|
||||||
UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name())
|
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Invalid username."))?;
|
|
||||||
|
|
||||||
if services().users.exists(&user_id)? {
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::UserInUse,
|
|
||||||
"Desired UserId is already taken.",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if services().appservice.is_exclusive_user_id(&user_id).await {
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Exclusive,
|
|
||||||
"Desired UserId reserved by appservice.",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
services().users.create(&user_id, None)?;
|
|
||||||
services().users.set_password_placeholder(&user_id)?;
|
|
||||||
|
|
||||||
if let Some(avatar_url) = avatar_url {
|
|
||||||
let request = services().globals.default_client().get(avatar_url.as_ref());
|
|
||||||
|
|
||||||
let res = request.send().await.map_err(|_| {
|
|
||||||
Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let filename = avatar_url.path_segments().and_then(Iterator::last);
|
|
||||||
|
|
||||||
let (content_type, body): (Option<headers::ContentType>, Vec<u8>) = (
|
|
||||||
res.headers().typed_get(),
|
|
||||||
res.bytes().await.map(Into::into).map_err(|_| {
|
|
||||||
Error::BadRequest(ErrorKind::UserInUse, "Desired UserId is already taken.")
|
|
||||||
})?,
|
|
||||||
);
|
|
||||||
|
|
||||||
let mxc = format!(
|
|
||||||
"mxc://{}/{}",
|
|
||||||
services().globals.server_name(),
|
|
||||||
utils::random_string(crate::api::client_server::MXC_LENGTH)
|
|
||||||
);
|
|
||||||
|
|
||||||
services()
|
|
||||||
.media
|
|
||||||
.create(
|
|
||||||
mxc.clone(),
|
|
||||||
filename
|
|
||||||
.map(|filename| "inline; filename=".to_owned() + filename)
|
|
||||||
.as_deref(),
|
|
||||||
content_type.map(|header| header.to_string()).as_deref(),
|
|
||||||
&body,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.set_avatar_url(&user_id, Some(OwnedMxcUri::from(mxc)))?;
|
|
||||||
};
|
|
||||||
|
|
||||||
if let (Some(displayname), true) = (
|
|
||||||
displayname.as_mut(),
|
|
||||||
services().globals.config.enable_lightning_bolt,
|
|
||||||
) {
|
|
||||||
displayname.push_str(" ⚡️");
|
|
||||||
}
|
|
||||||
|
|
||||||
services().users.set_displayname(&user_id, displayname)?;
|
|
||||||
|
|
||||||
services().sso.save_claim(
|
|
||||||
®istration_token.provider_id,
|
|
||||||
&user_id,
|
|
||||||
®istration_token.unique_claim,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
services().account_data.update(
|
|
||||||
None,
|
|
||||||
&user_id,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
|
|
||||||
content: ruma::events::push_rules::PushRulesEventContent {
|
|
||||||
global: push::Ruleset::server_default(&user_id),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.expect("PushRulesEvent should always serialize"),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let login_token = LoginToken::new(registration_token.provider_id, user_id);
|
|
||||||
let redirect_uri = redirect_with_login_token(registration_token.redirect_uri, &login_token);
|
|
||||||
|
|
||||||
Ok((
|
|
||||||
AppendHeaders([(
|
|
||||||
header::SET_COOKIE,
|
header::SET_COOKIE,
|
||||||
utils::reset_cookie("sso-registration").to_string(),
|
utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string(),
|
||||||
)]),
|
)]),
|
||||||
Redirect::temporary(redirect_uri.as_str()),
|
Redirect::temporary(redirect_uri.as_str()),
|
||||||
)
|
)
|
||||||
.into_response())
|
.into_response())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn redirect_with_login_token(mut redirect_uri: Url, login_token: &LoginToken) -> Url {
|
/// # `GET /_conduit/client/sso/callback`
|
||||||
let signed = services()
|
///
|
||||||
.globals
|
/// Validate the authorization response received from the identity provider.
|
||||||
.sign_macaroon(login_token)
|
/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect.
|
||||||
.expect("signing macaroons should always works");
|
/// If this is the first login, register the user, possibly interactively through a fallback page.
|
||||||
|
pub async fn handle_callback_route(req: axum::extract::Request) -> axum::response::Response {
|
||||||
redirect_uri
|
match handle_callback_helper(req).await {
|
||||||
.query_pairs_mut()
|
Ok(res) => res,
|
||||||
.append_pair("loginToken", &signed);
|
Err(e) => e.into_response(),
|
||||||
|
}
|
||||||
redirect_uri
|
|
||||||
}
|
}
|
||||||
|
|
29
src/database/key_value/sso.rs
Normal file
29
src/database/key_value/sso.rs
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
use ruma::{OwnedUserId, UserId};
|
||||||
|
|
||||||
|
use crate::{service, utils, Error, KeyValueDatabase, Result};
|
||||||
|
|
||||||
|
impl service::sso::Data for KeyValueDatabase {
|
||||||
|
fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()> {
|
||||||
|
let mut key = provider.as_bytes().to_vec();
|
||||||
|
key.push(0xff);
|
||||||
|
key.extend_from_slice(subject.as_bytes());
|
||||||
|
|
||||||
|
self.subject_userid.insert(&key, user_id.as_bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
|
||||||
|
let mut key = provider.as_bytes().to_vec();
|
||||||
|
key.push(0xff);
|
||||||
|
key.extend_from_slice(subject.as_bytes());
|
||||||
|
|
||||||
|
self.subject_userid.get(&key)?.map_or(Ok(None), |bytes| {
|
||||||
|
Some(
|
||||||
|
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||||
|
Error::bad_database("User ID in claim_userid is invalid unicode.")
|
||||||
|
})?)
|
||||||
|
.map_err(|_| Error::bad_database("User ID in claim_userid is invalid.")),
|
||||||
|
)
|
||||||
|
.transpose()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -49,6 +49,7 @@ pub struct KeyValueDatabase {
|
||||||
pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
|
pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
|
||||||
pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
|
pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
|
||||||
pub(super) token_userdeviceid: Arc<dyn KvTree>,
|
pub(super) token_userdeviceid: Arc<dyn KvTree>,
|
||||||
|
pub(super) subject_userid: Arc<dyn KvTree>,
|
||||||
|
|
||||||
pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
|
pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
|
||||||
pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
|
pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
|
||||||
|
@ -289,6 +290,8 @@ impl KeyValueDatabase {
|
||||||
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
|
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
|
||||||
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
|
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
|
||||||
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
|
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
|
||||||
|
subject_userid: builder.open_tree("subject_userid")?,
|
||||||
|
|
||||||
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
|
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
|
||||||
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
|
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
|
||||||
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
|
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
|
||||||
|
|
|
@ -10,7 +10,7 @@ use axum::{
|
||||||
};
|
};
|
||||||
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
|
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
|
||||||
use conduit::api::{
|
use conduit::api::{
|
||||||
client_server::{self, SSO_CALLBACK_PATH},
|
client_server::{self, CALLBACK_PATH},
|
||||||
server_server,
|
server_server,
|
||||||
};
|
};
|
||||||
use figment::{
|
use figment::{
|
||||||
|
@ -283,7 +283,7 @@ fn routes(config: &Config) -> Router {
|
||||||
.ruma_route(client_server::get_sso_redirect_with_provider_route)
|
.ruma_route(client_server::get_sso_redirect_with_provider_route)
|
||||||
// The specification will likely never introduce any endpoint for handling authorization callbacks.
|
// The specification will likely never introduce any endpoint for handling authorization callbacks.
|
||||||
// As a workaround, we use custom path that redirects the user to the default login handler.
|
// As a workaround, we use custom path that redirects the user to the default login handler.
|
||||||
.route(SSO_CALLBACK_PATH, get(client_server::sso_login_route))
|
.route(CALLBACK_PATH, get(client_server::handle_callback_route))
|
||||||
.ruma_route(client_server::get_capabilities_route)
|
.ruma_route(client_server::get_capabilities_route)
|
||||||
.ruma_route(client_server::get_pushrules_all_route)
|
.ruma_route(client_server::get_pushrules_all_route)
|
||||||
.ruma_route(client_server::set_pushrule_route)
|
.ruma_route(client_server::set_pushrule_route)
|
||||||
|
|
|
@ -123,7 +123,7 @@ impl Services {
|
||||||
key_backups: key_backups::Service { db },
|
key_backups: key_backups::Service { db },
|
||||||
media: media::Service { db },
|
media: media::Service { db },
|
||||||
sending: sending::Service::build(db, &config),
|
sending: sending::Service::build(db, &config),
|
||||||
sso: sso::Service::build(db),
|
sso: sso::Service::build(db)?,
|
||||||
|
|
||||||
globals: globals::Service::load(db, config)?,
|
globals: globals::Service::load(db, config)?,
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,50 +1,36 @@
|
||||||
mod data;
|
|
||||||
use std::{
|
use std::{
|
||||||
borrow::Borrow,
|
borrow::Borrow,
|
||||||
collections::{HashMap, HashSet},
|
collections::HashSet,
|
||||||
hash::{Hash, Hasher},
|
hash::{Hash, Hasher},
|
||||||
str::FromStr,
|
sync::Arc,
|
||||||
sync::{Arc, RwLock},
|
|
||||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::client_server::TOKEN_LENGTH,
|
api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH},
|
||||||
config::{sso::ProviderConfig as Config, IdpConfig},
|
config::IdpConfig,
|
||||||
utils, Error, Result,
|
utils, Error, Result,
|
||||||
};
|
};
|
||||||
pub use data::Data;
|
|
||||||
use email_address::EmailAddress;
|
|
||||||
use futures_util::future::{self};
|
use futures_util::future::{self};
|
||||||
use mas_oidc_client::{
|
use mas_oidc_client::{
|
||||||
http_service::{hyper, HttpService},
|
http_service::{hyper, HttpService},
|
||||||
jose::jwk::PublicJsonWebKeySet,
|
requests::{authorization_code::AuthorizationValidationData, discovery},
|
||||||
requests::{
|
types::oidc::VerifiedProviderMetadata,
|
||||||
authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData},
|
|
||||||
discovery,
|
|
||||||
jose::{self, JwtVerificationData},
|
|
||||||
userinfo,
|
|
||||||
},
|
|
||||||
types::{
|
|
||||||
iana::jose::JsonWebSignatureAlg, oidc::VerifiedProviderMetadata,
|
|
||||||
requests::AccessTokenResponse, IdToken,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use rand::SeedableRng;
|
use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId};
|
||||||
use ruma::{api::client::error::ErrorKind, MilliSecondsSinceUnixEpoch, OwnedUserId, UserId};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use tokio::sync::OnceCell;
|
||||||
use tokio::sync::{oneshot, OnceCell};
|
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use crate::services;
|
use crate::services;
|
||||||
|
|
||||||
|
mod data;
|
||||||
pub use data::Data;
|
pub use data::Data;
|
||||||
|
|
||||||
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
|
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
|
||||||
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
|
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
|
||||||
pub const SSO_SESSION_COOKIE: &str = "sso-auth";
|
pub const SSO_SESSION_COOKIE: &str = "sso-auth";
|
||||||
|
pub const SUBJECT_CLAIM_KEY: &str = "sub";
|
||||||
|
|
||||||
pub struct Service {
|
pub struct Service {
|
||||||
db: &'static dyn Data,
|
db: &'static dyn Data,
|
||||||
|
@ -69,7 +55,7 @@ impl Service {
|
||||||
let providers = services().globals.config.idps.iter();
|
let providers = services().globals.config.idps.iter();
|
||||||
|
|
||||||
self.providers
|
self.providers
|
||||||
.get_or_try_init(|| {
|
.get_or_try_init(|| async move {
|
||||||
future::try_join_all(providers.map(Provider::fetch_metadata))
|
future::try_join_all(providers.map(Provider::fetch_metadata))
|
||||||
.await
|
.await
|
||||||
.map(Vec::into_iter)
|
.map(Vec::into_iter)
|
||||||
|
@ -86,6 +72,12 @@ impl Service {
|
||||||
providers.get(provider)
|
providers.get(provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn login_type(&self) -> impl Iterator<Item = IdentityProvider> + '_ {
|
||||||
|
let providers = self.providers.get().expect("");
|
||||||
|
|
||||||
|
providers.iter().map(|p| p.config.inner.clone())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
|
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
|
||||||
self.db.user_from_subject(provider, subject)
|
self.db.user_from_subject(provider, subject)
|
||||||
}
|
}
|
||||||
|
@ -111,30 +103,6 @@ impl Provider {
|
||||||
Error::bad_config("Failed to fetch identity provider metadata.")
|
Error::bad_config("Failed to fetch identity provider metadata.")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn fetch_signing_keys(&self) -> Result<PublicJsonWebKeySet> {
|
|
||||||
jose::fetch_jwks(&services().sso.service, self.metadata.jwks_uri())
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
error!("Failed to fetch signing keys for token endpoint: {}", e);
|
|
||||||
|
|
||||||
Error::bad_config("Failed to fetch signing keys for token endpoint.")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn fetch_access_token(
|
|
||||||
&self,
|
|
||||||
auth_code: String,
|
|
||||||
validation_data: AuthorizationValidationData,
|
|
||||||
) -> Result<(AccessTokenResponse, Option<IdToken<'_>>)> {
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn fetch_userinfo(
|
|
||||||
&self,
|
|
||||||
access_token: &str,
|
|
||||||
id_token: &IdToken<'_>,
|
|
||||||
) -> Result<Option<HashMap<String, Value>>> {
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Borrow<str> for Provider {
|
impl Borrow<str> for Provider {
|
||||||
|
@ -157,105 +125,28 @@ impl Hash for Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct RegistrationToken {
|
|
||||||
pub info: RegistrationInfo,
|
|
||||||
pub provider_id: String,
|
|
||||||
pub unique_claim: String,
|
|
||||||
pub redirect_uri: Url,
|
|
||||||
pub expires_at: MilliSecondsSinceUnixEpoch,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RegistrationToken {
|
|
||||||
pub fn new(
|
|
||||||
provider_id: String,
|
|
||||||
unique_claim: String,
|
|
||||||
redirect_uri: Url,
|
|
||||||
info: RegistrationInfo,
|
|
||||||
) -> Self {
|
|
||||||
let expires_at = MilliSecondsSinceUnixEpoch::from_system_time(
|
|
||||||
UNIX_EPOCH
|
|
||||||
.checked_add(Duration::from_secs(REGISTRATION_EXPIRATION_SECS))
|
|
||||||
.expect("SystemTime should not overflow"),
|
|
||||||
)
|
|
||||||
.expect("MilliSecondsSinceUnixEpoch is not too large");
|
|
||||||
|
|
||||||
Self {
|
|
||||||
info,
|
|
||||||
provider_id,
|
|
||||||
unique_claim,
|
|
||||||
redirect_uri,
|
|
||||||
expires_at,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
|
||||||
pub struct RegistrationInfo {
|
|
||||||
pub username: String,
|
|
||||||
pub displayname: Option<String>,
|
|
||||||
pub avatar_url: Option<Url>,
|
|
||||||
pub email: Option<EmailAddress>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RegistrationInfo {
|
|
||||||
pub fn new(
|
|
||||||
claims: &HashMap<String, Value>,
|
|
||||||
username: &str,
|
|
||||||
displayname: &str,
|
|
||||||
avatar_url: &str,
|
|
||||||
email: &str,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
username: claims
|
|
||||||
.get(username)
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(ToOwned::to_owned)
|
|
||||||
.unwrap_or_default(),
|
|
||||||
displayname: claims
|
|
||||||
.get(displayname)
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(ToOwned::to_owned),
|
|
||||||
avatar_url: claims
|
|
||||||
.get(avatar_url)
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(Url::parse)
|
|
||||||
.and_then(Result::ok),
|
|
||||||
email: claims
|
|
||||||
.get(email)
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(EmailAddress::from_str)
|
|
||||||
.and_then(Result::ok),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
pub struct LoginToken {
|
pub struct LoginToken {
|
||||||
pub inner: String,
|
pub iss: String,
|
||||||
pub provider_id: String,
|
pub aud: OwnedUserId,
|
||||||
pub user_id: OwnedUserId,
|
pub sub: String,
|
||||||
|
pub exp: u64,
|
||||||
#[serde(rename = "exp")]
|
|
||||||
expires_at: u64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoginToken {
|
impl LoginToken {
|
||||||
pub fn new(provider_id: String, user_id: OwnedUserId) -> Self {
|
pub fn new(provider: String, user_id: OwnedUserId) -> Self {
|
||||||
let expires_at = SystemTime::now()
|
|
||||||
.checked_add(Duration::from_secs(LOGIN_TOKEN_EXPIRATION_SECS))
|
|
||||||
.expect("SystemTime should not overflow")
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.expect("SystemTime went backwards")
|
|
||||||
.as_secs();
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
inner: utils::random_string(TOKEN_LENGTH),
|
iss: provider,
|
||||||
provider_id,
|
aud: user_id,
|
||||||
user_id,
|
sub: utils::random_string(TOKEN_LENGTH),
|
||||||
expires_at,
|
exp: utils::millis_since_unix_epoch()
|
||||||
|
.checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000)
|
||||||
|
.expect("time overflow"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn audience(&self) -> &UserId {
|
||||||
|
&self.aud
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue