From 139588b64c137efee6b398f39551cc74c637a651 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 15 Jul 2024 06:08:25 +0200 Subject: [PATCH] nice --- Cargo.toml | 16 ++++- src/api/client_server/sso.rs | 110 ++++++++++++++++++++++++++++++----- src/service/sso/mod.rs | 23 +++++++- src/service/uiaa/mod.rs | 3 - 4 files changed, 131 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 044eefea..c52ed191 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,10 +41,23 @@ tower = { version = "0.4.13", features = ["util"] } tower-http = { version = "0.5", features = [ "add-extension", "cors", + "follow-redirect", + "map-request-body", "sensitive-headers", + "set-header", + "timeout", "trace", "util", ] } +# tower-http = { version = "0.5", features = [ +# "add-extension", +# "cors", +# "decompression-full", +# "sensitive-headers", +# "set-header", +# "trace", +# "util", +# ] } tower-service = "0.3" # Async runtime and utilities @@ -140,8 +153,9 @@ figment = { version = "0.10.8", features = ["env", "toml"] } # Validating urls in config url = { version = "2", features = ["serde"] } -mas-oidc-client = { version = "0.9", default-features = false, features = ["hyper"] } # HTML +mas-oidc-client = { git = "https://github.com/matrix-org/matrix-authentication-service", default-features = false } +mas-http = { git = "https://github.com/matrix-org/matrix-authentication-service", features = ["client"] } maud = { version = "0.26.0", default-features = false, features = ["axum"] } async-trait = "0.1.68" diff --git a/src/api/client_server/sso.rs b/src/api/client_server/sso.rs index bb7719bc..e0b7f9a2 100644 --- a/src/api/client_server/sso.rs +++ b/src/api/client_server/sso.rs @@ -15,7 +15,7 @@ use axum_extra::{ headers::{self}, TypedHeader, }; -use http::header; +use http::header::{self}; use mas_oidc_client::{ requests::{ authorization_code::{self, AuthorizationRequestData}, @@ -42,7 +42,7 @@ use serde_json::Value; use tracing::{error, info, warn}; use url::Url; -pub const CALLBACK_PATH: &str = "_matrix/client/unstable/sso/callback"; +pub const CALLBACK_PATH: &str = "/_matrix/client/unstable/conduit/callback"; /// # `GET /_matrix/client/v3/login/sso/redirect` /// @@ -155,10 +155,10 @@ async fn handle_callback_helper(req: axum::extract::Request) -> Result Result Result Result axum::respons Err(e) => e.into_response(), } } + +mod sso_callback { + use axum_extra::headers::{self, HeaderMapExt}; + use http::Method; + use mas_oidc_client::types::requests::AuthorizationResponse; + use ruma::{ + api::{ + client::Error, + error::{FromHttpRequestError, HeaderDeserializationError}, + IncomingRequest, Metadata, OutgoingResponse, + }, + metadata, + }; + + use crate::service::sso::SSO_SESSION_COOKIE; + + pub const METADATA: Metadata = metadata! { + method: GET, + rate_limited: false, + authentication: None, + history: { + 1.0 => "/_matrix/client/unstable/conduit/callback", + } + }; + + pub struct Request { + response: AuthorizationResponse, + cookie: String, + } + + pub struct Response {} + + impl IncomingRequest for Request { + type EndpointError = Error; + type OutgoingResponse = Response; + + const METADATA: Metadata = METADATA; + + fn try_from_http_request( + req: http::Request, + _path_args: &[S], + ) -> Result + where + B: AsRef<[u8]>, + S: AsRef, + { + if !(req.method() == METADATA.method + || req.method() == Method::HEAD && METADATA.method == Method::GET) + { + return Err(FromHttpRequestError::MethodMismatch { + expected: METADATA.method, + received: req.method().clone(), + }); + } + + let response: AuthorizationResponse = + serde_html_form::from_str(req.uri().query().unwrap_or(""))?; + + let Some(cookie) = req + .headers() + .typed_get() + .and_then(|cookie: headers::Cookie| { + cookie.get(SSO_SESSION_COOKIE).map(str::to_owned) + }) + else { + return Err(HeaderDeserializationError::MissingHeader( + "Cookie".to_owned(), + ))?; + }; + + Ok(Self { response, cookie }) + } + } + + impl OutgoingResponse for Response { + fn try_into_http_response( + self, + ) -> Result, ruma::api::error::IntoHttpError> { + todo!() + } + } +} diff --git a/src/service/sso/mod.rs b/src/service/sso/mod.rs index 242d92cc..6ffef5ab 100644 --- a/src/service/sso/mod.rs +++ b/src/service/sso/mod.rs @@ -12,13 +12,14 @@ use crate::{ }; use futures_util::future::{self}; use mas_oidc_client::{ - http_service::{hyper, HttpService}, + http_service::HttpService, requests::{authorization_code::AuthorizationValidationData, discovery}, types::oidc::VerifiedProviderMetadata, }; use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId}; use serde::{Deserialize, Serialize}; use tokio::sync::OnceCell; +use tower::BoxError; use tracing::error; use url::Url; @@ -40,9 +41,21 @@ pub struct Service { impl Service { pub fn build(db: &'static dyn Data) -> Result> { + let client = tower::ServiceBuilder::new() + .map_err(BoxError::from) + .layer(mas_http::BytesToBodyRequestLayer) + .layer(mas_http::BodyToBytesResponseLayer) + // .override_request_header(http::header::USER_AGENT, "conduit".to_owned()) + // .concurrency_limit(10) + // .follow_redirects() + // .layer(tower_http::timeout::TimeoutLayer::new( + // std::time::Duration::from_secs(10), + // )) + .service(mas_http::make_untraced_client()); + Ok(Arc::new(Self { db, - service: HttpService::new(hyper::hyper_service()), + service: HttpService::new(client), providers: OnceCell::new(), })) } @@ -159,7 +172,11 @@ pub struct ValidationData { impl ValidationData { pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self { - Self { provider, redirect_url, inner } + Self { + provider, + redirect_url, + inner, + } } } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 677d49f0..696be958 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -110,9 +110,6 @@ impl Service { AuthData::Dummy(_) => { uiaainfo.completed.push(AuthType::Dummy); } - AuthData::FallbackAcknowledgement(fallback) => { - todo!() - } k => error!("type not supported: {:?}", k), }