1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00
conduit/src/service/sso/mod.rs

300 lines
7.9 KiB
Rust
Raw Normal View History

2024-07-11 21:55:52 +02:00
mod data;
use std::{
borrow::Borrow,
collections::{HashMap, HashSet},
hash::{Hash, Hasher},
str::FromStr,
sync::{Arc, RwLock},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use crate::{
api::client_server::TOKEN_LENGTH,
config::{sso::ProviderConfig as Config, IdpConfig},
utils, Error, Result,
};
pub use data::Data;
use email_address::EmailAddress;
use futures_util::future::{self};
use mas_oidc_client::{
http_service::{hyper, HttpService},
jose::jwk::PublicJsonWebKeySet,
requests::{
authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData},
discovery,
jose::{self, JwtVerificationData},
userinfo,
},
types::{
iana::jose::JsonWebSignatureAlg, oidc::VerifiedProviderMetadata,
requests::AccessTokenResponse, IdToken,
},
};
use rand::SeedableRng;
use ruma::{api::client::error::ErrorKind, MilliSecondsSinceUnixEpoch, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{oneshot, OnceCell};
use tracing::error;
use url::Url;
use crate::services;
pub use data::Data;
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
pub const SSO_SESSION_COOKIE: &str = "sso-auth";
pub struct Service {
db: &'static dyn Data,
service: HttpService,
providers: OnceCell<HashSet<Provider>>,
}
impl Service {
pub fn build(db: &'static dyn Data) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db,
service: HttpService::new(hyper::hyper_service()),
providers: OnceCell::new(),
}))
}
pub fn service(&self) -> &HttpService {
&self.service
}
pub async fn start_handler(&self) -> Result<()> {
let providers = services().globals.config.idps.iter();
self.providers
.get_or_try_init(|| {
future::try_join_all(providers.map(Provider::fetch_metadata))
.await
.map(Vec::into_iter)
.map(HashSet::from_iter)
})
.await?;
Ok(())
}
pub fn get(&self, provider: &str) -> Option<&Provider> {
let providers = self.providers.get().expect("");
providers.get(provider)
}
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
self.db.user_from_subject(provider, subject)
}
}
#[derive(Clone, Debug)]
pub struct Provider {
pub config: &'static IdpConfig,
pub metadata: VerifiedProviderMetadata,
}
impl Provider {
pub async fn fetch_metadata(config: &'static IdpConfig) -> Result<Self> {
discovery::discover(services().sso.service(), &config.issuer)
.await
.map(|metadata| Provider { config, metadata })
.map_err(|e| {
error!(
"Failed to fetch identity provider metadata ({}): {}",
&config.inner.id, e
);
Error::bad_config("Failed to fetch identity provider metadata.")
})
}
async fn fetch_signing_keys(&self) -> Result<PublicJsonWebKeySet> {
jose::fetch_jwks(&services().sso.service, self.metadata.jwks_uri())
.await
.map_err(|e| {
error!("Failed to fetch signing keys for token endpoint: {}", e);
Error::bad_config("Failed to fetch signing keys for token endpoint.")
})
}
pub async fn fetch_access_token(
&self,
auth_code: String,
validation_data: AuthorizationValidationData,
) -> Result<(AccessTokenResponse, Option<IdToken<'_>>)> {
}
pub async fn fetch_userinfo(
&self,
access_token: &str,
id_token: &IdToken<'_>,
) -> Result<Option<HashMap<String, Value>>> {
}
}
impl Borrow<str> for Provider {
fn borrow(&self) -> &str {
self.config.borrow()
}
}
impl PartialEq for Provider {
fn eq(&self, other: &Self) -> bool {
self.config == other.config
}
}
impl Eq for Provider {}
impl Hash for Provider {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.config.hash(hasher)
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct RegistrationToken {
pub info: RegistrationInfo,
pub provider_id: String,
pub unique_claim: String,
pub redirect_uri: Url,
pub expires_at: MilliSecondsSinceUnixEpoch,
}
impl RegistrationToken {
pub fn new(
provider_id: String,
unique_claim: String,
redirect_uri: Url,
info: RegistrationInfo,
) -> Self {
let expires_at = MilliSecondsSinceUnixEpoch::from_system_time(
UNIX_EPOCH
.checked_add(Duration::from_secs(REGISTRATION_EXPIRATION_SECS))
.expect("SystemTime should not overflow"),
)
.expect("MilliSecondsSinceUnixEpoch is not too large");
Self {
info,
provider_id,
unique_claim,
redirect_uri,
expires_at,
}
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct RegistrationInfo {
pub username: String,
pub displayname: Option<String>,
pub avatar_url: Option<Url>,
pub email: Option<EmailAddress>,
}
impl RegistrationInfo {
pub fn new(
claims: &HashMap<String, Value>,
username: &str,
displayname: &str,
avatar_url: &str,
email: &str,
) -> Self {
Self {
username: claims
.get(username)
.and_then(|v| v.as_str())
.map(ToOwned::to_owned)
.unwrap_or_default(),
displayname: claims
.get(displayname)
.and_then(|v| v.as_str())
.map(ToOwned::to_owned),
avatar_url: claims
.get(avatar_url)
.and_then(|v| v.as_str())
.map(Url::parse)
.and_then(Result::ok),
email: claims
.get(email)
.and_then(|v| v.as_str())
.map(EmailAddress::from_str)
.and_then(Result::ok),
}
}
}
#[derive(Clone, Deserialize, Serialize)]
pub struct LoginToken {
pub inner: String,
pub provider_id: String,
pub user_id: OwnedUserId,
#[serde(rename = "exp")]
expires_at: u64,
}
impl LoginToken {
pub fn new(provider_id: String, user_id: OwnedUserId) -> Self {
let expires_at = SystemTime::now()
.checked_add(Duration::from_secs(LOGIN_TOKEN_EXPIRATION_SECS))
.expect("SystemTime should not overflow")
.duration_since(UNIX_EPOCH)
.expect("SystemTime went backwards")
.as_secs();
Self {
inner: utils::random_string(TOKEN_LENGTH),
provider_id,
user_id,
expires_at,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ValidationData {
pub provider: String,
#[serde(flatten, with = "AuthorizationValidationDataDef")]
pub inner: AuthorizationValidationData,
}
impl ValidationData {
pub fn new(provider: String, inner: AuthorizationValidationData) -> Self {
Self { provider, inner }
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(remote = "AuthorizationValidationData")]
pub struct AuthorizationValidationDataDef {
pub state: String,
pub nonce: String,
pub redirect_uri: Url,
pub code_challenge_verifier: Option<String>,
}
impl From<AuthorizationValidationData> for AuthorizationValidationDataDef {
fn from(
AuthorizationValidationData {
state,
nonce,
redirect_uri,
code_challenge_verifier,
}: AuthorizationValidationData,
) -> Self {
Self {
state,
nonce,
redirect_uri,
code_challenge_verifier,
}
}
}