1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00
conduit/src/service/sso/mod.rs

214 lines
5.6 KiB
Rust
Raw Normal View History

2024-07-11 21:55:52 +02:00
use std::{
borrow::Borrow,
2024-07-11 22:24:22 +02:00
collections::HashSet,
2024-07-11 21:55:52 +02:00
hash::{Hash, Hasher},
2024-07-11 22:24:22 +02:00
sync::Arc,
2024-07-11 21:55:52 +02:00
};
use crate::{
2024-07-11 22:24:22 +02:00
api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH},
config::IdpConfig,
2024-07-11 21:55:52 +02:00
utils, Error, Result,
};
use futures_util::future::{self};
2024-07-15 12:24:06 +02:00
use http::HeaderValue;
2024-07-11 21:55:52 +02:00
use mas_oidc_client::{
2024-07-15 06:08:25 +02:00
http_service::HttpService,
2024-07-11 22:24:22 +02:00
requests::{authorization_code::AuthorizationValidationData, discovery},
types::oidc::VerifiedProviderMetadata,
2024-07-11 21:55:52 +02:00
};
2024-07-11 22:24:22 +02:00
use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId};
2024-07-11 21:55:52 +02:00
use serde::{Deserialize, Serialize};
2024-07-11 22:24:22 +02:00
use tokio::sync::OnceCell;
2024-07-15 06:08:25 +02:00
use tower::BoxError;
2024-07-15 12:24:06 +02:00
use tower_http::{set_header::SetRequestHeaderLayer, ServiceBuilderExt};
2024-07-11 21:55:52 +02:00
use tracing::error;
use url::Url;
use crate::services;
2024-07-11 22:24:22 +02:00
mod data;
2024-07-11 21:55:52 +02:00
pub use data::Data;
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
pub const SSO_SESSION_COOKIE: &str = "sso-auth";
2024-07-11 22:24:22 +02:00
pub const SUBJECT_CLAIM_KEY: &str = "sub";
2024-07-11 21:55:52 +02:00
pub struct Service {
db: &'static dyn Data,
service: HttpService,
providers: OnceCell<HashSet<Provider>>,
}
impl Service {
pub fn build(db: &'static dyn Data) -> Result<Arc<Self>> {
2024-07-15 06:08:25 +02:00
let client = tower::ServiceBuilder::new()
.map_err(BoxError::from)
2024-07-15 12:24:06 +02:00
.layer(tower_http::timeout::TimeoutLayer::new(
std::time::Duration::from_secs(10),
))
2024-07-15 06:08:25 +02:00
.layer(mas_http::BytesToBodyRequestLayer)
.layer(mas_http::BodyToBytesResponseLayer)
2024-07-15 12:24:06 +02:00
.layer(SetRequestHeaderLayer::overriding(
http::header::USER_AGENT,
HeaderValue::from_static("conduit/0.9-alpha"),
))
.concurrency_limit(10)
.follow_redirects()
2024-07-15 06:08:25 +02:00
.service(mas_http::make_untraced_client());
2024-07-11 21:55:52 +02:00
Ok(Arc::new(Self {
db,
2024-07-15 06:08:25 +02:00
service: HttpService::new(client),
2024-07-11 21:55:52 +02:00
providers: OnceCell::new(),
}))
}
pub fn service(&self) -> &HttpService {
&self.service
}
pub async fn start_handler(&self) -> Result<()> {
let providers = services().globals.config.idps.iter();
self.providers
2024-07-11 22:24:22 +02:00
.get_or_try_init(|| async move {
2024-07-11 21:55:52 +02:00
future::try_join_all(providers.map(Provider::fetch_metadata))
.await
.map(Vec::into_iter)
.map(HashSet::from_iter)
})
.await?;
Ok(())
}
pub fn get(&self, provider: &str) -> Option<&Provider> {
let providers = self.providers.get().expect("");
providers.get(provider)
}
2024-07-11 22:24:22 +02:00
pub fn login_type(&self) -> impl Iterator<Item = IdentityProvider> + '_ {
let providers = self.providers.get().expect("");
providers.iter().map(|p| p.config.inner.clone())
}
2024-07-11 21:55:52 +02:00
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
self.db.user_from_subject(provider, subject)
}
}
#[derive(Clone, Debug)]
pub struct Provider {
pub config: &'static IdpConfig,
pub metadata: VerifiedProviderMetadata,
}
impl Provider {
pub async fn fetch_metadata(config: &'static IdpConfig) -> Result<Self> {
discovery::discover(services().sso.service(), &config.issuer)
.await
.map(|metadata| Provider { config, metadata })
.map_err(|e| {
error!(
"Failed to fetch identity provider metadata ({}): {}",
&config.inner.id, e
);
Error::bad_config("Failed to fetch identity provider metadata.")
})
}
}
impl Borrow<str> for Provider {
fn borrow(&self) -> &str {
self.config.borrow()
}
}
impl PartialEq for Provider {
fn eq(&self, other: &Self) -> bool {
self.config == other.config
}
}
impl Eq for Provider {}
impl Hash for Provider {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.config.hash(hasher)
}
}
#[derive(Clone, Deserialize, Serialize)]
pub struct LoginToken {
2024-07-11 22:24:22 +02:00
pub iss: String,
pub aud: OwnedUserId,
pub sub: String,
pub exp: u64,
2024-07-11 21:55:52 +02:00
}
impl LoginToken {
2024-07-11 22:24:22 +02:00
pub fn new(provider: String, user_id: OwnedUserId) -> Self {
2024-07-11 21:55:52 +02:00
Self {
2024-07-11 22:24:22 +02:00
iss: provider,
aud: user_id,
sub: utils::random_string(TOKEN_LENGTH),
exp: utils::millis_since_unix_epoch()
.checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000)
.expect("time overflow"),
2024-07-11 21:55:52 +02:00
}
}
2024-07-15 12:24:06 +02:00
pub fn audience(self) -> OwnedUserId {
self.aud
2024-07-11 22:24:22 +02:00
}
2024-07-11 21:55:52 +02:00
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ValidationData {
pub provider: String,
2024-07-11 22:44:47 +02:00
pub redirect_url: String,
2024-07-11 21:55:52 +02:00
#[serde(flatten, with = "AuthorizationValidationDataDef")]
pub inner: AuthorizationValidationData,
}
impl ValidationData {
2024-07-11 22:44:47 +02:00
pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self {
2024-07-15 06:08:25 +02:00
Self {
provider,
redirect_url,
inner,
}
2024-07-11 21:55:52 +02:00
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(remote = "AuthorizationValidationData")]
pub struct AuthorizationValidationDataDef {
pub state: String,
pub nonce: String,
pub redirect_uri: Url,
pub code_challenge_verifier: Option<String>,
}
impl From<AuthorizationValidationData> for AuthorizationValidationDataDef {
fn from(
AuthorizationValidationData {
state,
nonce,
redirect_uri,
code_challenge_verifier,
}: AuthorizationValidationData,
) -> Self {
Self {
state,
nonce,
redirect_uri,
code_challenge_verifier,
}
}
}