mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-06-27 16:35:59 +00:00
300 lines
7.9 KiB
Rust
300 lines
7.9 KiB
Rust
|
mod data;
|
||
|
use std::{
|
||
|
borrow::Borrow,
|
||
|
collections::{HashMap, HashSet},
|
||
|
hash::{Hash, Hasher},
|
||
|
str::FromStr,
|
||
|
sync::{Arc, RwLock},
|
||
|
time::{Duration, SystemTime, UNIX_EPOCH},
|
||
|
};
|
||
|
|
||
|
use crate::{
|
||
|
api::client_server::TOKEN_LENGTH,
|
||
|
config::{sso::ProviderConfig as Config, IdpConfig},
|
||
|
utils, Error, Result,
|
||
|
};
|
||
|
pub use data::Data;
|
||
|
use email_address::EmailAddress;
|
||
|
use futures_util::future::{self};
|
||
|
use mas_oidc_client::{
|
||
|
http_service::{hyper, HttpService},
|
||
|
jose::jwk::PublicJsonWebKeySet,
|
||
|
requests::{
|
||
|
authorization_code::{self, AuthorizationRequestData, AuthorizationValidationData},
|
||
|
discovery,
|
||
|
jose::{self, JwtVerificationData},
|
||
|
userinfo,
|
||
|
},
|
||
|
types::{
|
||
|
iana::jose::JsonWebSignatureAlg, oidc::VerifiedProviderMetadata,
|
||
|
requests::AccessTokenResponse, IdToken,
|
||
|
},
|
||
|
};
|
||
|
use rand::SeedableRng;
|
||
|
use ruma::{api::client::error::ErrorKind, MilliSecondsSinceUnixEpoch, OwnedUserId, UserId};
|
||
|
use serde::{Deserialize, Serialize};
|
||
|
use serde_json::Value;
|
||
|
use tokio::sync::{oneshot, OnceCell};
|
||
|
use tracing::error;
|
||
|
use url::Url;
|
||
|
|
||
|
use crate::services;
|
||
|
|
||
|
pub use data::Data;
|
||
|
|
||
|
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
|
||
|
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
|
||
|
pub const SSO_SESSION_COOKIE: &str = "sso-auth";
|
||
|
|
||
|
pub struct Service {
|
||
|
db: &'static dyn Data,
|
||
|
service: HttpService,
|
||
|
providers: OnceCell<HashSet<Provider>>,
|
||
|
}
|
||
|
|
||
|
impl Service {
|
||
|
pub fn build(db: &'static dyn Data) -> Result<Arc<Self>> {
|
||
|
Ok(Arc::new(Self {
|
||
|
db,
|
||
|
service: HttpService::new(hyper::hyper_service()),
|
||
|
providers: OnceCell::new(),
|
||
|
}))
|
||
|
}
|
||
|
|
||
|
pub fn service(&self) -> &HttpService {
|
||
|
&self.service
|
||
|
}
|
||
|
|
||
|
pub async fn start_handler(&self) -> Result<()> {
|
||
|
let providers = services().globals.config.idps.iter();
|
||
|
|
||
|
self.providers
|
||
|
.get_or_try_init(|| {
|
||
|
future::try_join_all(providers.map(Provider::fetch_metadata))
|
||
|
.await
|
||
|
.map(Vec::into_iter)
|
||
|
.map(HashSet::from_iter)
|
||
|
})
|
||
|
.await?;
|
||
|
|
||
|
Ok(())
|
||
|
}
|
||
|
|
||
|
pub fn get(&self, provider: &str) -> Option<&Provider> {
|
||
|
let providers = self.providers.get().expect("");
|
||
|
|
||
|
providers.get(provider)
|
||
|
}
|
||
|
|
||
|
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
|
||
|
self.db.user_from_subject(provider, subject)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Debug)]
|
||
|
pub struct Provider {
|
||
|
pub config: &'static IdpConfig,
|
||
|
pub metadata: VerifiedProviderMetadata,
|
||
|
}
|
||
|
|
||
|
impl Provider {
|
||
|
pub async fn fetch_metadata(config: &'static IdpConfig) -> Result<Self> {
|
||
|
discovery::discover(services().sso.service(), &config.issuer)
|
||
|
.await
|
||
|
.map(|metadata| Provider { config, metadata })
|
||
|
.map_err(|e| {
|
||
|
error!(
|
||
|
"Failed to fetch identity provider metadata ({}): {}",
|
||
|
&config.inner.id, e
|
||
|
);
|
||
|
|
||
|
Error::bad_config("Failed to fetch identity provider metadata.")
|
||
|
})
|
||
|
}
|
||
|
|
||
|
async fn fetch_signing_keys(&self) -> Result<PublicJsonWebKeySet> {
|
||
|
jose::fetch_jwks(&services().sso.service, self.metadata.jwks_uri())
|
||
|
.await
|
||
|
.map_err(|e| {
|
||
|
error!("Failed to fetch signing keys for token endpoint: {}", e);
|
||
|
|
||
|
Error::bad_config("Failed to fetch signing keys for token endpoint.")
|
||
|
})
|
||
|
}
|
||
|
|
||
|
pub async fn fetch_access_token(
|
||
|
&self,
|
||
|
auth_code: String,
|
||
|
validation_data: AuthorizationValidationData,
|
||
|
) -> Result<(AccessTokenResponse, Option<IdToken<'_>>)> {
|
||
|
}
|
||
|
|
||
|
pub async fn fetch_userinfo(
|
||
|
&self,
|
||
|
access_token: &str,
|
||
|
id_token: &IdToken<'_>,
|
||
|
) -> Result<Option<HashMap<String, Value>>> {
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Borrow<str> for Provider {
|
||
|
fn borrow(&self) -> &str {
|
||
|
self.config.borrow()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl PartialEq for Provider {
|
||
|
fn eq(&self, other: &Self) -> bool {
|
||
|
self.config == other.config
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Eq for Provider {}
|
||
|
|
||
|
impl Hash for Provider {
|
||
|
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||
|
self.config.hash(hasher)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||
|
pub struct RegistrationToken {
|
||
|
pub info: RegistrationInfo,
|
||
|
pub provider_id: String,
|
||
|
pub unique_claim: String,
|
||
|
pub redirect_uri: Url,
|
||
|
pub expires_at: MilliSecondsSinceUnixEpoch,
|
||
|
}
|
||
|
|
||
|
impl RegistrationToken {
|
||
|
pub fn new(
|
||
|
provider_id: String,
|
||
|
unique_claim: String,
|
||
|
redirect_uri: Url,
|
||
|
info: RegistrationInfo,
|
||
|
) -> Self {
|
||
|
let expires_at = MilliSecondsSinceUnixEpoch::from_system_time(
|
||
|
UNIX_EPOCH
|
||
|
.checked_add(Duration::from_secs(REGISTRATION_EXPIRATION_SECS))
|
||
|
.expect("SystemTime should not overflow"),
|
||
|
)
|
||
|
.expect("MilliSecondsSinceUnixEpoch is not too large");
|
||
|
|
||
|
Self {
|
||
|
info,
|
||
|
provider_id,
|
||
|
unique_claim,
|
||
|
redirect_uri,
|
||
|
expires_at,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||
|
pub struct RegistrationInfo {
|
||
|
pub username: String,
|
||
|
pub displayname: Option<String>,
|
||
|
pub avatar_url: Option<Url>,
|
||
|
pub email: Option<EmailAddress>,
|
||
|
}
|
||
|
|
||
|
impl RegistrationInfo {
|
||
|
pub fn new(
|
||
|
claims: &HashMap<String, Value>,
|
||
|
username: &str,
|
||
|
displayname: &str,
|
||
|
avatar_url: &str,
|
||
|
email: &str,
|
||
|
) -> Self {
|
||
|
Self {
|
||
|
username: claims
|
||
|
.get(username)
|
||
|
.and_then(|v| v.as_str())
|
||
|
.map(ToOwned::to_owned)
|
||
|
.unwrap_or_default(),
|
||
|
displayname: claims
|
||
|
.get(displayname)
|
||
|
.and_then(|v| v.as_str())
|
||
|
.map(ToOwned::to_owned),
|
||
|
avatar_url: claims
|
||
|
.get(avatar_url)
|
||
|
.and_then(|v| v.as_str())
|
||
|
.map(Url::parse)
|
||
|
.and_then(Result::ok),
|
||
|
email: claims
|
||
|
.get(email)
|
||
|
.and_then(|v| v.as_str())
|
||
|
.map(EmailAddress::from_str)
|
||
|
.and_then(Result::ok),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Deserialize, Serialize)]
|
||
|
pub struct LoginToken {
|
||
|
pub inner: String,
|
||
|
pub provider_id: String,
|
||
|
pub user_id: OwnedUserId,
|
||
|
|
||
|
#[serde(rename = "exp")]
|
||
|
expires_at: u64,
|
||
|
}
|
||
|
|
||
|
impl LoginToken {
|
||
|
pub fn new(provider_id: String, user_id: OwnedUserId) -> Self {
|
||
|
let expires_at = SystemTime::now()
|
||
|
.checked_add(Duration::from_secs(LOGIN_TOKEN_EXPIRATION_SECS))
|
||
|
.expect("SystemTime should not overflow")
|
||
|
.duration_since(UNIX_EPOCH)
|
||
|
.expect("SystemTime went backwards")
|
||
|
.as_secs();
|
||
|
|
||
|
Self {
|
||
|
inner: utils::random_string(TOKEN_LENGTH),
|
||
|
provider_id,
|
||
|
user_id,
|
||
|
expires_at,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||
|
pub struct ValidationData {
|
||
|
pub provider: String,
|
||
|
#[serde(flatten, with = "AuthorizationValidationDataDef")]
|
||
|
pub inner: AuthorizationValidationData,
|
||
|
}
|
||
|
|
||
|
impl ValidationData {
|
||
|
pub fn new(provider: String, inner: AuthorizationValidationData) -> Self {
|
||
|
Self { provider, inner }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||
|
#[serde(remote = "AuthorizationValidationData")]
|
||
|
pub struct AuthorizationValidationDataDef {
|
||
|
pub state: String,
|
||
|
pub nonce: String,
|
||
|
pub redirect_uri: Url,
|
||
|
pub code_challenge_verifier: Option<String>,
|
||
|
}
|
||
|
|
||
|
impl From<AuthorizationValidationData> for AuthorizationValidationDataDef {
|
||
|
fn from(
|
||
|
AuthorizationValidationData {
|
||
|
state,
|
||
|
nonce,
|
||
|
redirect_uri,
|
||
|
code_challenge_verifier,
|
||
|
}: AuthorizationValidationData,
|
||
|
) -> Self {
|
||
|
Self {
|
||
|
state,
|
||
|
nonce,
|
||
|
redirect_uri,
|
||
|
code_challenge_verifier,
|
||
|
}
|
||
|
}
|
||
|
}
|