2024-07-11 21:55:52 +02:00
|
|
|
use std::{
|
|
|
|
borrow::Borrow,
|
2024-07-11 22:24:22 +02:00
|
|
|
collections::HashSet,
|
2024-07-11 21:55:52 +02:00
|
|
|
hash::{Hash, Hasher},
|
2024-07-11 22:24:22 +02:00
|
|
|
sync::Arc,
|
2024-07-11 21:55:52 +02:00
|
|
|
};
|
|
|
|
|
|
|
|
use crate::{
|
2024-07-11 22:24:22 +02:00
|
|
|
api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH},
|
|
|
|
config::IdpConfig,
|
2024-07-11 21:55:52 +02:00
|
|
|
utils, Error, Result,
|
|
|
|
};
|
|
|
|
use futures_util::future::{self};
|
2024-07-15 12:24:06 +02:00
|
|
|
use http::HeaderValue;
|
2024-07-11 21:55:52 +02:00
|
|
|
use mas_oidc_client::{
|
2024-07-15 06:08:25 +02:00
|
|
|
http_service::HttpService,
|
2024-07-11 22:24:22 +02:00
|
|
|
requests::{authorization_code::AuthorizationValidationData, discovery},
|
|
|
|
types::oidc::VerifiedProviderMetadata,
|
2024-07-11 21:55:52 +02:00
|
|
|
};
|
2024-07-11 22:24:22 +02:00
|
|
|
use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId};
|
2024-07-11 21:55:52 +02:00
|
|
|
use serde::{Deserialize, Serialize};
|
2024-07-11 22:24:22 +02:00
|
|
|
use tokio::sync::OnceCell;
|
2024-07-15 06:08:25 +02:00
|
|
|
use tower::BoxError;
|
2024-07-15 12:24:06 +02:00
|
|
|
use tower_http::{set_header::SetRequestHeaderLayer, ServiceBuilderExt};
|
2024-07-11 21:55:52 +02:00
|
|
|
use tracing::error;
|
|
|
|
use url::Url;
|
|
|
|
|
|
|
|
use crate::services;
|
|
|
|
|
2024-07-11 22:24:22 +02:00
|
|
|
mod data;
|
2024-07-11 21:55:52 +02:00
|
|
|
pub use data::Data;
|
|
|
|
|
|
|
|
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
|
|
|
|
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
|
|
|
|
pub const SSO_SESSION_COOKIE: &str = "sso-auth";
|
2024-07-11 22:24:22 +02:00
|
|
|
pub const SUBJECT_CLAIM_KEY: &str = "sub";
|
2024-07-11 21:55:52 +02:00
|
|
|
|
|
|
|
pub struct Service {
|
|
|
|
db: &'static dyn Data,
|
|
|
|
service: HttpService,
|
|
|
|
providers: OnceCell<HashSet<Provider>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Service {
|
|
|
|
pub fn build(db: &'static dyn Data) -> Result<Arc<Self>> {
|
2024-07-15 06:08:25 +02:00
|
|
|
let client = tower::ServiceBuilder::new()
|
|
|
|
.map_err(BoxError::from)
|
2024-07-15 12:24:06 +02:00
|
|
|
.layer(tower_http::timeout::TimeoutLayer::new(
|
|
|
|
std::time::Duration::from_secs(10),
|
|
|
|
))
|
2024-07-15 06:08:25 +02:00
|
|
|
.layer(mas_http::BytesToBodyRequestLayer)
|
|
|
|
.layer(mas_http::BodyToBytesResponseLayer)
|
2024-07-15 12:24:06 +02:00
|
|
|
.layer(SetRequestHeaderLayer::overriding(
|
|
|
|
http::header::USER_AGENT,
|
|
|
|
HeaderValue::from_static("conduit/0.9-alpha"),
|
|
|
|
))
|
|
|
|
.concurrency_limit(10)
|
|
|
|
.follow_redirects()
|
2024-07-15 06:08:25 +02:00
|
|
|
.service(mas_http::make_untraced_client());
|
|
|
|
|
2024-07-11 21:55:52 +02:00
|
|
|
Ok(Arc::new(Self {
|
|
|
|
db,
|
2024-07-15 06:08:25 +02:00
|
|
|
service: HttpService::new(client),
|
2024-07-11 21:55:52 +02:00
|
|
|
providers: OnceCell::new(),
|
|
|
|
}))
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn service(&self) -> &HttpService {
|
|
|
|
&self.service
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn start_handler(&self) -> Result<()> {
|
|
|
|
let providers = services().globals.config.idps.iter();
|
|
|
|
|
|
|
|
self.providers
|
2024-07-11 22:24:22 +02:00
|
|
|
.get_or_try_init(|| async move {
|
2024-07-11 21:55:52 +02:00
|
|
|
future::try_join_all(providers.map(Provider::fetch_metadata))
|
|
|
|
.await
|
|
|
|
.map(Vec::into_iter)
|
|
|
|
.map(HashSet::from_iter)
|
|
|
|
})
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn get(&self, provider: &str) -> Option<&Provider> {
|
|
|
|
let providers = self.providers.get().expect("");
|
|
|
|
|
|
|
|
providers.get(provider)
|
|
|
|
}
|
|
|
|
|
2024-07-11 22:24:22 +02:00
|
|
|
pub fn login_type(&self) -> impl Iterator<Item = IdentityProvider> + '_ {
|
|
|
|
let providers = self.providers.get().expect("");
|
|
|
|
|
|
|
|
providers.iter().map(|p| p.config.inner.clone())
|
|
|
|
}
|
|
|
|
|
2024-07-11 21:55:52 +02:00
|
|
|
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
|
|
|
|
self.db.user_from_subject(provider, subject)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
|
|
pub struct Provider {
|
|
|
|
pub config: &'static IdpConfig,
|
|
|
|
pub metadata: VerifiedProviderMetadata,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Provider {
|
|
|
|
pub async fn fetch_metadata(config: &'static IdpConfig) -> Result<Self> {
|
|
|
|
discovery::discover(services().sso.service(), &config.issuer)
|
|
|
|
.await
|
|
|
|
.map(|metadata| Provider { config, metadata })
|
|
|
|
.map_err(|e| {
|
|
|
|
error!(
|
|
|
|
"Failed to fetch identity provider metadata ({}): {}",
|
|
|
|
&config.inner.id, e
|
|
|
|
);
|
|
|
|
|
|
|
|
Error::bad_config("Failed to fetch identity provider metadata.")
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Borrow<str> for Provider {
|
|
|
|
fn borrow(&self) -> &str {
|
|
|
|
self.config.borrow()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl PartialEq for Provider {
|
|
|
|
fn eq(&self, other: &Self) -> bool {
|
|
|
|
self.config == other.config
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Eq for Provider {}
|
|
|
|
|
|
|
|
impl Hash for Provider {
|
|
|
|
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
|
|
|
self.config.hash(hasher)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Deserialize, Serialize)]
|
|
|
|
pub struct LoginToken {
|
2024-07-11 22:24:22 +02:00
|
|
|
pub iss: String,
|
|
|
|
pub aud: OwnedUserId,
|
|
|
|
pub sub: String,
|
|
|
|
pub exp: u64,
|
2024-07-11 21:55:52 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
impl LoginToken {
|
2024-07-11 22:24:22 +02:00
|
|
|
pub fn new(provider: String, user_id: OwnedUserId) -> Self {
|
2024-07-11 21:55:52 +02:00
|
|
|
Self {
|
2024-07-11 22:24:22 +02:00
|
|
|
iss: provider,
|
|
|
|
aud: user_id,
|
|
|
|
sub: utils::random_string(TOKEN_LENGTH),
|
|
|
|
exp: utils::millis_since_unix_epoch()
|
|
|
|
.checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000)
|
|
|
|
.expect("time overflow"),
|
2024-07-11 21:55:52 +02:00
|
|
|
}
|
|
|
|
}
|
2024-07-15 12:24:06 +02:00
|
|
|
pub fn audience(self) -> OwnedUserId {
|
|
|
|
self.aud
|
2024-07-11 22:24:22 +02:00
|
|
|
}
|
2024-07-11 21:55:52 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
|
|
pub struct ValidationData {
|
|
|
|
pub provider: String,
|
2024-07-11 22:44:47 +02:00
|
|
|
pub redirect_url: String,
|
2024-07-11 21:55:52 +02:00
|
|
|
#[serde(flatten, with = "AuthorizationValidationDataDef")]
|
|
|
|
pub inner: AuthorizationValidationData,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl ValidationData {
|
2024-07-11 22:44:47 +02:00
|
|
|
pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self {
|
2024-07-15 06:08:25 +02:00
|
|
|
Self {
|
|
|
|
provider,
|
|
|
|
redirect_url,
|
|
|
|
inner,
|
|
|
|
}
|
2024-07-11 21:55:52 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
|
|
#[serde(remote = "AuthorizationValidationData")]
|
|
|
|
pub struct AuthorizationValidationDataDef {
|
|
|
|
pub state: String,
|
|
|
|
pub nonce: String,
|
|
|
|
pub redirect_uri: Url,
|
|
|
|
pub code_challenge_verifier: Option<String>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl From<AuthorizationValidationData> for AuthorizationValidationDataDef {
|
|
|
|
fn from(
|
|
|
|
AuthorizationValidationData {
|
|
|
|
state,
|
|
|
|
nonce,
|
|
|
|
redirect_uri,
|
|
|
|
code_challenge_verifier,
|
|
|
|
}: AuthorizationValidationData,
|
|
|
|
) -> Self {
|
|
|
|
Self {
|
|
|
|
state,
|
|
|
|
nonce,
|
|
|
|
redirect_uri,
|
|
|
|
code_challenge_verifier,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|