1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00

where did my code go???

This commit is contained in:
avdb13 2024-07-10 08:19:39 +02:00
parent 269455d93a
commit 895b66fa50
17 changed files with 331 additions and 38 deletions

View file

@ -35,15 +35,18 @@ axum = { version = "0.7", default-features = false, features = [
"json", "json",
"matched-path", "matched-path",
], optional = true } ], optional = true }
axum-extra = { version = "0.9", features = ["typed-header"] } axum-extra = { version = "0.9", features = ["cookie", "typed-header"] }
axum-server = { version = "0.6", features = ["tls-rustls"] } axum-server = { version = "0.6", features = ["tls-rustls"] }
tower = { version = "0.4.13", features = ["util"] } tower = { version = "0.4.13", features = ["util"] }
# tower-http = { version = "0.5", features = [
# "add-extension",
# "cors",
# "sensitive-headers",
# "trace",
# "util",
# ] }
tower-http = { version = "0.5", features = [ tower-http = { version = "0.5", features = [
"add-extension", "full",
"cors",
"sensitive-headers",
"trace",
"util",
] } ] }
tower-service = "0.3" tower-service = "0.3"
@ -140,20 +143,28 @@ figment = { version = "0.10.8", features = ["env", "toml"] }
# Validating urls in config # Validating urls in config
url = { version = "2", features = ["serde"] } url = { version = "2", features = ["serde"] }
# HTML
maud = { version = "0.26.0", default-features = false, features = ["axum"] }
async-trait = "0.1.68" async-trait = "0.1.68"
tikv-jemallocator = { version = "0.5.0", features = [ tikv-jemallocator = { version = "0.5.0", features = [
"unprefixed_malloc_on_supported_platforms", "unprefixed_malloc_on_supported_platforms",
], optional = true } ], optional = true }
sd-notify = { version = "0.4.1", optional = true } sd-notify = { version = "0.4.1", optional = true }
http-body-util = "0.1.2"
hyper-rustls = { version = "0.27.2", default-features = false, features = ["http1", "http2", "ring", "rustls-native-certs", "rustls-platform-verifier"] }
mas-http = "0.9.0"
# Used for matrix spec type definitions and helpers # Used for matrix spec type definitions and helpers
[dependencies.ruma] [dependencies.ruma]
features = [ features = [
"appservice-api-c", "appservice-api-c",
"client",
"client-api", "client-api",
"compat", "compat",
"federation-api", "federation-api",
"client-hyper",
"push-gateway-api-c", "push-gateway-api-c",
"rand", "rand",
"ring-compat", "ring-compat",
@ -172,6 +183,11 @@ optional = true
package = "rust-rocksdb" package = "rust-rocksdb"
version = "0.25" version = "0.25"
# Used for Single Sign-On
[dependencies.mas-oidc-client]
git = "https://github.com/matrix-org/matrix-authentication-service.git"
default-features = false
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
nix = { version = "0.28", features = ["resource"] } nix = { version = "0.28", features = ["resource"] }

View file

@ -322,6 +322,8 @@ pub async fn change_password_route(
.ok_or_else(|| Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))?;
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
// if services().users.password_hash(sender_user)? == Some("");
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::Password], stages: vec![AuthType::Password],

View file

@ -111,6 +111,10 @@ pub async fn upload_signing_keys_route(
auth_error: None, auth_error: None,
}; };
let master_key = services()
.users
.get_master_key(None, sender_user, &|user_id| user_id == sender_user)?;
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) =
services() services()
@ -126,7 +130,7 @@ pub async fn upload_signing_keys_route(
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else if master_key.is_some() {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }

View file

@ -23,6 +23,7 @@ mod room;
mod search; mod search;
mod session; mod session;
mod space; mod space;
mod sso;
mod state; mod state;
mod sync; mod sync;
mod tag; mod tag;
@ -60,6 +61,7 @@ pub use room::*;
pub use search::*; pub use search::*;
pub use session::*; pub use session::*;
pub use space::*; pub use space::*;
pub use sso::*;
pub use state::*; pub use state::*;
pub use sync::*; pub use sync::*;
pub use tag::*; pub use tag::*;
@ -76,3 +78,5 @@ pub const DEVICE_ID_LENGTH: usize = 10;
pub const TOKEN_LENGTH: usize = 32; pub const TOKEN_LENGTH: usize = 32;
pub const SESSION_ID_LENGTH: usize = 32; pub const SESSION_ID_LENGTH: usize = 32;
pub const AUTO_GEN_PASSWORD_LENGTH: usize = 15; pub const AUTO_GEN_PASSWORD_LENGTH: usize = 15;
pub const AUTH_SESSION_EXPIRATION_SECS: u64 = 60 * 5;
pub const LOGIN_TOKEN_EXPIRATION_SECS: u64 = 15;

View file

@ -1,5 +1,5 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma}; use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -24,10 +24,20 @@ struct Claims {
pub async fn get_login_types_route( pub async fn get_login_types_route(
_body: Ruma<get_login_types::v3::Request>, _body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> { ) -> Result<get_login_types::v3::Response> {
Ok(get_login_types::v3::Response::new(vec![ let mut flows = vec![
get_login_types::v3::LoginType::Password(Default::default()), get_login_types::v3::LoginType::Password(Default::default()),
get_login_types::v3::LoginType::ApplicationService(Default::default()), get_login_types::v3::LoginType::ApplicationService(Default::default()),
])) ];
if let v @ [_, ..] = &*services().sso.flows() {
let flow = get_login_types::v3::SsoLoginType {
identity_providers: v.to_owned(),
};
flows.push(get_login_types::v3::LoginType::Sso(flow));
}
Ok(get_login_types::v3::Response::new(flows))
} }
/// # `POST /_matrix/client/r0/login` /// # `POST /_matrix/client/r0/login`
@ -101,35 +111,64 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
user_id user_id
} }
login::v3::LoginInfo::Token(login::v3::Token { token }) => { login::v3::LoginInfo::Token(login::v3::Token { token }) => {
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { match (
let token = jsonwebtoken::decode::<Claims>( services().globals.jwt_decoding_key(),
token, &services().sso.providers().is_empty(),
jwt_decoding_key, ) {
&jsonwebtoken::Validation::default(), (_, false) => {
) let mut validation =
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256);
let username = token.claims.sub.to_lowercase(); validation.validate_nbf = false;
let user_id = validation.set_required_spec_claims(&["sub", "exp", "aud", "iss"]);
UserId::parse_with_server_name(username, services().globals.server_name())
let login_token = services()
.globals
.validate_claims::<LoginToken>(token, Some(validation))
.map_err(|_| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.")
})?; })?;
if services().appservice.is_exclusive_user_id(&user_id).await { login_token.audience().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid token audience.")
})?
}
(Some(jwt_decoding_key), _) => {
let token = jsonwebtoken::decode::<Claims>(
token,
jwt_decoding_key,
&jsonwebtoken::Validation::default(),
)
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
})?;
let username = token.claims.sub.to_lowercase();
let user_id =
UserId::parse_with_server_name(username, services().globals.server_name())
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if services().appservice.is_exclusive_user_id(&user_id).await {
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"User id reserved by appservice.",
));
}
user_id
}
(None, _) => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Exclusive, ErrorKind::Unknown,
"User id reserved by appservice.", "Token login is not supported (server has no jwt decoding key).",
)); ));
} }
user_id
} else {
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Token login is not supported (server has no jwt decoding key).",
));
} }
} }
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService { login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
identifier, identifier,
user, user,

View file

@ -1,16 +1,27 @@
use std::{ use std::{
collections::BTreeMap, borrow::Borrow,
collections::{BTreeMap, HashSet},
fmt, fmt,
hash::{Hash, Hasher},
net::{IpAddr, Ipv4Addr}, net::{IpAddr, Ipv4Addr},
}; };
use ruma::{OwnedServerName, RoomVersionId}; use figment::value::{Dict, Value};
use serde::{de::IgnoredAny, Deserialize}; use mas_oidc_client::types::{client_credentials::ClientCredentials, scope::Scope};
use ruma::{
api::client::session::get_login_types::v3::IdentityProvider, OwnedServerName, RoomVersionId,
};
use serde::{
de::{self, IgnoredAny},
Deserialize, Deserializer, Serialize,
};
use tracing::warn; use tracing::warn;
use url::Url; use url::Url;
mod proxy; mod proxy;
use crate::{Error, Result};
use self::proxy::ProxyConfig; use self::proxy::ProxyConfig;
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
@ -67,6 +78,8 @@ pub struct Config {
pub tracing_flame: bool, pub tracing_flame: bool,
#[serde(default)] #[serde(default)]
pub proxy: ProxyConfig, pub proxy: ProxyConfig,
#[serde(default, deserialize_with = "deserialize_providers")]
pub idps: HashSet<IdpConfig>,
pub jwt_secret: Option<String>, pub jwt_secret: Option<String>,
#[serde(default = "default_trusted_servers")] #[serde(default = "default_trusted_servers")]
pub trusted_servers: Vec<OwnedServerName>, pub trusted_servers: Vec<OwnedServerName>,
@ -101,6 +114,27 @@ pub struct WellKnownConfig {
pub server: Option<OwnedServerName>, pub server: Option<OwnedServerName>,
} }
#[derive(Clone, Debug, Deserialize)]
pub struct IdpConfig {
pub issuer: String,
#[serde(flatten)]
pub inner: IdentityProvider,
#[serde(deserialize_with = "deserialize_scopes")]
pub scopes: Scope,
pub client_id: String,
pub client_secret: String,
pub auth_method: String,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Template {
pub localpart: Option<String>,
pub displayname: Option<String>,
pub avatar_url: Option<String>,
pub email: Option<String>,
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config { impl Config {
@ -244,6 +278,49 @@ impl fmt::Display for Config {
} }
} }
impl Borrow<str> for IdpConfig {
fn borrow(&self) -> &str {
&self.inner.id
}
}
impl PartialEq for IdpConfig {
fn eq(&self, other: &Self) -> bool {
self.inner.id == other.inner.id
}
}
impl Eq for IdpConfig {}
impl Hash for IdpConfig {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.inner.id.hash(hasher)
}
}
impl Into<ClientCredentials> for IdpConfig {
fn into(self) -> ClientCredentials {
let IdpConfig {
client_id,
client_secret,
auth_method,
..
} = self;
match &*auth_method {
"basic" => ClientCredentials::ClientSecretBasic {
client_id,
client_secret,
},
"post" => ClientCredentials::ClientSecretPost {
client_id,
client_secret,
},
_ => unimplemented!(),
}
}
}
fn false_fn() -> bool { fn false_fn() -> bool {
false false
} }
@ -312,3 +389,46 @@ fn default_openid_token_ttl() -> u64 {
pub fn default_default_room_version() -> RoomVersionId { pub fn default_default_room_version() -> RoomVersionId {
RoomVersionId::V10 RoomVersionId::V10
} }
fn deserialize_scopes<'de, D>(deserializer: D) -> Result<Scope, D::Error>
where
D: Deserializer<'de>,
{
let scopes = <Vec<String>>::deserialize(deserializer)?;
scopes.join(" ").parse().map_err(de::Error::custom)
}
fn deserialize_providers<'de, D>(deserializer: D) -> Result<HashSet<IdpConfig>, D::Error>
where
D: Deserializer<'de>,
{
let mut result = HashSet::new();
let dict = Dict::deserialize(deserializer)
.map(Dict::into_iter)
.map_err(de::Error::custom)?;
warn!(?dict);
for (name, value) in dict {
let tag = value.tag();
let Some(dict) = value.into_dict() else {
return Err(de::Error::custom(Error::bad_config(
"Invalid SSO configuration. ",
)));
};
let id = String::from("id");
let name = name.parse().map_err(de::Error::custom)?;
let dict = Some((id, name)).into_iter().chain(dict).collect();
result.insert(
Value::Dict(tag, dict)
.deserialize()
.map_err(de::Error::custom)?,
);
}
Ok(result)
}

View file

@ -8,6 +8,7 @@ mod media;
mod pusher; mod pusher;
mod rooms; mod rooms;
mod sending; mod sending;
mod sso;
mod transaction_ids; mod transaction_ids;
mod uiaa; mod uiaa;
mod users; mod users;

View file

@ -119,6 +119,10 @@ impl service::users::Data for KeyValueDatabase {
} }
} }
fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> {
self.userid_password.insert(user_id.as_bytes(), b"0xff")
}
/// Returns the displayname of a user on this homeserver. /// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname self.userid_displayname

View file

@ -63,6 +63,9 @@ pub struct KeyValueDatabase {
pub(super) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count pub(super) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
pub(super) userid_providersubjectid: Arc<dyn KvTree>,
pub(super) providersubjectid_userid: Arc<dyn KvTree>,
//pub uiaa: uiaa::Uiaa, //pub uiaa: uiaa::Uiaa,
pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
pub(super) userdevicesessionid_uiaarequest: pub(super) userdevicesessionid_uiaarequest:
@ -297,6 +300,9 @@ impl KeyValueDatabase {
userfilterid_filter: builder.open_tree("userfilterid_filter")?, userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?, todeviceid_events: builder.open_tree("todeviceid_events")?,
userid_providersubjectid: builder.open_tree("userid_providersubjectid")?,
providersubjectid_userid: builder.open_tree("providersubjectid_userid")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
@ -971,6 +977,8 @@ impl KeyValueDatabase {
services().admin.start_handler(); services().admin.start_handler();
services().sso.start_handler().await?;
// Set emergency access for the conduit user // Set emergency access for the conduit user
match set_emergency_access() { match set_emergency_access() {
Ok(pwd_set) => { Ok(pwd_set) => {

View file

@ -9,7 +9,10 @@ use axum::{
Router, Router,
}; };
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
use conduit::api::{client_server, server_server}; use conduit::api::{
client_server::{self, SSO_CALLBACK_PATH},
server_server,
};
use figment::{ use figment::{
providers::{Env, Format, Toml}, providers::{Env, Format, Toml},
Figment, Figment,
@ -276,6 +279,11 @@ fn routes(config: &Config) -> Router {
.ruma_route(client_server::third_party_route) .ruma_route(client_server::third_party_route)
.ruma_route(client_server::request_3pid_management_token_via_email_route) .ruma_route(client_server::request_3pid_management_token_via_email_route)
.ruma_route(client_server::request_3pid_management_token_via_msisdn_route) .ruma_route(client_server::request_3pid_management_token_via_msisdn_route)
.ruma_route(client_server::get_sso_redirect_route)
.ruma_route(client_server::get_sso_redirect_with_provider_route)
// The specification will likely never introduce any endpoint for handling authorization callbacks.
// As a workaround, we use custom path that redirects the user to the default login handler.
.route(SSO_CALLBACK_PATH, get(client_server::sso_login_route))
.ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_capabilities_route)
.ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::get_pushrules_all_route)
.ruma_route(client_server::set_pushrule_route) .ruma_route(client_server::set_pushrule_route)

View file

@ -1,9 +1,10 @@
mod data; mod data;
pub use data::{Data, SigningKeys}; pub use data::{Data, SigningKeys};
use ruma::{ use ruma::{
serde::Base64, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, OwnedRoomAliasId, serde::Base64, signatures::KeyPair, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId,
}; };
use serde::{de::DeserializeOwned, Serialize};
use crate::api::server_server::DestinationResponse; use crate::api::server_server::DestinationResponse;
@ -37,6 +38,9 @@ use tracing::{error, info};
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
// https://github.com/rust-lang/rust/issues/104699
const PROBLEMATIC_CONST: &[u8] = b"0xCAFEBABE";
type WellKnownMap = HashMap<OwnedServerName, DestinationResponse>; type WellKnownMap = HashMap<OwnedServerName, DestinationResponse>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
@ -505,6 +509,36 @@ impl Service {
self.config.well_known_client() self.config.well_known_client()
} }
pub fn sign_claims<S: Serialize>(&self, claims: &S) -> String {
let key = jsonwebtoken::EncodingKey::from_secret(
self.keypair().sign(PROBLEMATIC_CONST).as_bytes(),
);
jsonwebtoken::encode(&jsonwebtoken::Header::default(), claims, &key)
.expect("signing JWTs always works")
}
/// Decode and validate a macaroon with this server's macaroon key.
pub fn validate_claims<T: DeserializeOwned>(
&self,
token: &str,
validation_data: Option<jsonwebtoken::Validation>,
) -> jsonwebtoken::errors::Result<T> {
let key = jsonwebtoken::DecodingKey::from_secret(
self.keypair().sign(PROBLEMATIC_CONST).as_bytes(),
);
let mut v = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256);
// these validations are redundant as all JWTs are stored in cookies
v.validate_exp = false;
v.validate_nbf = false;
v.required_spec_claims = Default::default();
jsonwebtoken::decode::<T>(token, &key, &validation_data.unwrap_or(v))
.map(|data| data.claims)
}
pub fn shutdown(&self) { pub fn shutdown(&self) {
self.shutdown.store(true, atomic::Ordering::Relaxed); self.shutdown.store(true, atomic::Ordering::Relaxed);
// On shutdown // On shutdown

View file

@ -19,6 +19,7 @@ pub mod pdu;
pub mod pusher; pub mod pusher;
pub mod rooms; pub mod rooms;
pub mod sending; pub mod sending;
pub mod sso;
pub mod transaction_ids; pub mod transaction_ids;
pub mod uiaa; pub mod uiaa;
pub mod users; pub mod users;
@ -35,6 +36,7 @@ pub struct Services {
pub globals: globals::Service, pub globals: globals::Service,
pub key_backups: key_backups::Service, pub key_backups: key_backups::Service,
pub media: media::Service, pub media: media::Service,
pub sso: Arc<sso::Service>,
pub sending: Arc<sending::Service>, pub sending: Arc<sending::Service>,
} }
@ -51,6 +53,7 @@ impl Services {
+ key_backups::Data + key_backups::Data
+ media::Data + media::Data
+ sending::Data + sending::Data
+ sso::Data
+ 'static, + 'static,
>( >(
db: &'static D, db: &'static D,
@ -120,6 +123,7 @@ impl Services {
key_backups: key_backups::Service { db }, key_backups: key_backups::Service { db },
media: media::Service { db }, media: media::Service { db },
sending: sending::Service::build(db, &config), sending: sending::Service::build(db, &config),
sso: sso::Service::build(db),
globals: globals::Service::load(db, config)?, globals: globals::Service::load(db, config)?,
}) })

View file

@ -110,6 +110,9 @@ impl Service {
AuthData::Dummy(_) => { AuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy); uiaainfo.completed.push(AuthType::Dummy);
} }
AuthData::FallbackAcknowledgement(fallback) => {
todo!()
}
k => error!("type not supported: {:?}", k), k => error!("type not supported: {:?}", k),
} }

View file

@ -217,4 +217,6 @@ pub trait Data: Send + Sync {
/// Find out which user an OpenID access token belongs to. /// Find out which user an OpenID access token belongs to.
fn find_from_openid_token(&self, token: &str) -> Result<Option<OwnedUserId>>; fn find_from_openid_token(&self, token: &str) -> Result<Option<OwnedUserId>>;
fn set_placeholder_password(&self, user_id: &UserId) -> Result<()>;
} }

View file

@ -602,6 +602,10 @@ impl Service {
pub fn find_from_openid_token(&self, token: &str) -> Result<Option<OwnedUserId>> { pub fn find_from_openid_token(&self, token: &str) -> Result<Option<OwnedUserId>> {
self.db.find_from_openid_token(token) self.db.find_from_openid_token(token)
} }
pub fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> {
self.db.set_placeholder_password(user_id)
}
} }
/// Ensure that a user only sees signatures from themselves and the target user /// Ensure that a user only sees signatures from themselves and the target user

View file

@ -175,6 +175,22 @@ impl Error {
} }
} }
impl From<mas_oidc_client::types::errors::ClientError> for Error {
fn from(e: mas_oidc_client::types::errors::ClientError) -> Self {
error!(
"Failed to complete authorization callback: {} {}",
e.error,
e.error_description.as_deref().unwrap_or_default()
);
// TODO: error conversion
Self::BadRequest(
ErrorKind::Unknown,
"Failed to complete authorization callback.",
)
}
}
#[cfg(feature = "persy")] #[cfg(feature = "persy")]
impl<T: Into<PersyError>> From<persy::PE<T>> for Error { impl<T: Into<PersyError>> From<persy::PE<T>> for Error {
fn from(err: persy::PE<T>) -> Self { fn from(err: persy::PE<T>) -> Self {

View file

@ -1,6 +1,7 @@
pub mod error; pub mod error;
use argon2::{Config, Variant}; use argon2::{Config, Variant};
use axum_extra::extract::cookie::{Cookie, SameSite};
use cmp::Ordering; use cmp::Ordering;
use rand::prelude::*; use rand::prelude::*;
use ring::digest; use ring::digest;
@ -8,7 +9,7 @@ use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonO
use std::{ use std::{
cmp, fmt, cmp, fmt,
str::FromStr, str::FromStr,
time::{SystemTime, UNIX_EPOCH}, time::{Duration, SystemTime, UNIX_EPOCH},
}; };
pub fn millis_since_unix_epoch() -> u64 { pub fn millis_since_unix_epoch() -> u64 {
@ -142,6 +143,29 @@ pub fn deserialize_from_str<
deserializer.deserialize_str(Visitor(std::marker::PhantomData)) deserializer.deserialize_str(Visitor(std::marker::PhantomData))
} }
pub fn build_cookie<'c>(
name: &'c str,
value: &'c str,
path: &'c str,
max_age: Option<u64>,
) -> Cookie<'c> {
let mut cookie = Cookie::new(name, value);
cookie.set_path(path);
cookie.set_secure(true);
cookie.set_http_only(true);
cookie.set_same_site(SameSite::None);
cookie.set_max_age(
max_age
.map(Duration::from_secs)
.map(TryInto::try_into)
.transpose()
.expect("time overflow"),
);
cookie
}
// Copied from librustdoc: // Copied from librustdoc:
// https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs // https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs