1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00
This commit is contained in:
avdb13 2024-07-15 06:08:25 +02:00
parent 5af171e7ee
commit 139588b64c
4 changed files with 131 additions and 21 deletions

View file

@ -41,10 +41,23 @@ tower = { version = "0.4.13", features = ["util"] }
tower-http = { version = "0.5", features = [ tower-http = { version = "0.5", features = [
"add-extension", "add-extension",
"cors", "cors",
"follow-redirect",
"map-request-body",
"sensitive-headers", "sensitive-headers",
"set-header",
"timeout",
"trace", "trace",
"util", "util",
] } ] }
# tower-http = { version = "0.5", features = [
# "add-extension",
# "cors",
# "decompression-full",
# "sensitive-headers",
# "set-header",
# "trace",
# "util",
# ] }
tower-service = "0.3" tower-service = "0.3"
# Async runtime and utilities # Async runtime and utilities
@ -140,8 +153,9 @@ figment = { version = "0.10.8", features = ["env", "toml"] }
# Validating urls in config # Validating urls in config
url = { version = "2", features = ["serde"] } url = { version = "2", features = ["serde"] }
mas-oidc-client = { version = "0.9", default-features = false, features = ["hyper"] }
# HTML # HTML
mas-oidc-client = { git = "https://github.com/matrix-org/matrix-authentication-service", default-features = false }
mas-http = { git = "https://github.com/matrix-org/matrix-authentication-service", features = ["client"] }
maud = { version = "0.26.0", default-features = false, features = ["axum"] } maud = { version = "0.26.0", default-features = false, features = ["axum"] }
async-trait = "0.1.68" async-trait = "0.1.68"

View file

@ -15,7 +15,7 @@ use axum_extra::{
headers::{self}, headers::{self},
TypedHeader, TypedHeader,
}; };
use http::header; use http::header::{self};
use mas_oidc_client::{ use mas_oidc_client::{
requests::{ requests::{
authorization_code::{self, AuthorizationRequestData}, authorization_code::{self, AuthorizationRequestData},
@ -42,7 +42,7 @@ use serde_json::Value;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use url::Url; use url::Url;
pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; pub const CALLBACK_PATH: &str = "/_matrix/client/unstable/conduit/callback";
/// # `GET /_matrix/client/v3/login/sso/redirect` /// # `GET /_matrix/client/v3/login/sso/redirect`
/// ///
@ -155,10 +155,10 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result<axum::res
let AuthorizationResponse { let AuthorizationResponse {
code, code,
access_token, access_token: _,
token_type, token_type: _,
id_token, id_token: _,
expires_in, expires_in: _,
} = serde_html_form::from_str(query).map_err(|_| { } = serde_html_form::from_str(query).map_err(|_| {
serde_html_form::from_str(query) serde_html_form::from_str(query)
.map(ClientError::into) .map(ClientError::into)
@ -234,10 +234,10 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result<axum::res
let ( let (
AccessTokenResponse { AccessTokenResponse {
access_token, access_token,
refresh_token, refresh_token: _,
token_type, token_type: _,
expires_in, expires_in: _,
scope, scope: _,
.. ..
}, },
Some(id_token), Some(id_token),
@ -257,9 +257,9 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result<axum::res
unreachable!("ID token should never be empty") unreachable!("ID token should never be empty")
}; };
let mut userinfo = HashMap::default(); let mut _userinfo = HashMap::default();
if let Some(endpoint) = &provider.metadata.userinfo_endpoint { if let Some(endpoint) = provider.metadata.userinfo_endpoint.as_ref() {
userinfo = userinfo::fetch_userinfo( _userinfo = userinfo::fetch_userinfo(
services().sso.service(), services().sso.service(),
endpoint, endpoint,
&access_token, &access_token,
@ -268,7 +268,7 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result<axum::res
) )
.await .await
.map_err(|_| Error::bad_config("Failed to fetch claims for userinfo endpoint."))?; .map_err(|_| Error::bad_config("Failed to fetch claims for userinfo endpoint."))?;
}; }
let (_, id_token) = id_token.into_parts(); let (_, id_token) = id_token.into_parts();
@ -397,3 +397,85 @@ pub async fn handle_callback_route(req: axum::extract::Request) -> axum::respons
Err(e) => e.into_response(), Err(e) => e.into_response(),
} }
} }
mod sso_callback {
use axum_extra::headers::{self, HeaderMapExt};
use http::Method;
use mas_oidc_client::types::requests::AuthorizationResponse;
use ruma::{
api::{
client::Error,
error::{FromHttpRequestError, HeaderDeserializationError},
IncomingRequest, Metadata, OutgoingResponse,
},
metadata,
};
use crate::service::sso::SSO_SESSION_COOKIE;
pub const METADATA: Metadata = metadata! {
method: GET,
rate_limited: false,
authentication: None,
history: {
1.0 => "/_matrix/client/unstable/conduit/callback",
}
};
pub struct Request {
response: AuthorizationResponse,
cookie: String,
}
pub struct Response {}
impl IncomingRequest for Request {
type EndpointError = Error;
type OutgoingResponse = Response;
const METADATA: Metadata = METADATA;
fn try_from_http_request<B, S>(
req: http::Request<B>,
_path_args: &[S],
) -> Result<Self, FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>,
{
if !(req.method() == METADATA.method
|| req.method() == Method::HEAD && METADATA.method == Method::GET)
{
return Err(FromHttpRequestError::MethodMismatch {
expected: METADATA.method,
received: req.method().clone(),
});
}
let response: AuthorizationResponse =
serde_html_form::from_str(req.uri().query().unwrap_or(""))?;
let Some(cookie) = req
.headers()
.typed_get()
.and_then(|cookie: headers::Cookie| {
cookie.get(SSO_SESSION_COOKIE).map(str::to_owned)
})
else {
return Err(HeaderDeserializationError::MissingHeader(
"Cookie".to_owned(),
))?;
};
Ok(Self { response, cookie })
}
}
impl OutgoingResponse for Response {
fn try_into_http_response<T: Default + bytes::BufMut>(
self,
) -> Result<http::Response<T>, ruma::api::error::IntoHttpError> {
todo!()
}
}
}

View file

@ -12,13 +12,14 @@ use crate::{
}; };
use futures_util::future::{self}; use futures_util::future::{self};
use mas_oidc_client::{ use mas_oidc_client::{
http_service::{hyper, HttpService}, http_service::HttpService,
requests::{authorization_code::AuthorizationValidationData, discovery}, requests::{authorization_code::AuthorizationValidationData, discovery},
types::oidc::VerifiedProviderMetadata, types::oidc::VerifiedProviderMetadata,
}; };
use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId}; use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell; use tokio::sync::OnceCell;
use tower::BoxError;
use tracing::error; use tracing::error;
use url::Url; use url::Url;
@ -40,9 +41,21 @@ pub struct Service {
impl Service { impl Service {
pub fn build(db: &'static dyn Data) -> Result<Arc<Self>> { pub fn build(db: &'static dyn Data) -> Result<Arc<Self>> {
let client = tower::ServiceBuilder::new()
.map_err(BoxError::from)
.layer(mas_http::BytesToBodyRequestLayer)
.layer(mas_http::BodyToBytesResponseLayer)
// .override_request_header(http::header::USER_AGENT, "conduit".to_owned())
// .concurrency_limit(10)
// .follow_redirects()
// .layer(tower_http::timeout::TimeoutLayer::new(
// std::time::Duration::from_secs(10),
// ))
.service(mas_http::make_untraced_client());
Ok(Arc::new(Self { Ok(Arc::new(Self {
db, db,
service: HttpService::new(hyper::hyper_service()), service: HttpService::new(client),
providers: OnceCell::new(), providers: OnceCell::new(),
})) }))
} }
@ -159,7 +172,11 @@ pub struct ValidationData {
impl ValidationData { impl ValidationData {
pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self { pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self {
Self { provider, redirect_url, inner } Self {
provider,
redirect_url,
inner,
}
} }
} }

View file

@ -110,9 +110,6 @@ impl Service {
AuthData::Dummy(_) => { AuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy); uiaainfo.completed.push(AuthType::Dummy);
} }
AuthData::FallbackAcknowledgement(fallback) => {
todo!()
}
k => error!("type not supported: {:?}", k), k => error!("type not supported: {:?}", k),
} }