From d6abf5472b4252c31a2a78e4632eef0a0b98db90 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 26 Jun 2024 09:16:54 +0100 Subject: [PATCH] more rate limit targets --- Cargo.lock | 40 +++++++++++++ src/api/mod.rs | 1 - src/api/ruma_wrapper/axum.rs | 99 +++++++++++++++++++------------- src/service/rate_limiting/mod.rs | 50 ++++++++-------- 4 files changed, 123 insertions(+), 67 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 30d951a7..e587f668 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,6 +189,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sync_wrapper 1.0.1", + "tokio", "tower", "tower-layer", "tower-service", @@ -496,6 +497,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "dashmap", "directories", "figment", "futures-util", @@ -515,6 +517,7 @@ dependencies = [ "opentelemetry_sdk", "parking_lot", "persy", + "quanta", "rand", "regex", "reqwest", @@ -664,6 +667,19 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -2053,6 +2069,21 @@ dependencies = [ "syn", ] +[[package]] +name = "quanta" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2098,6 +2129,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "raw-cpuid" +version = "11.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" +dependencies = [ + "bitflags 2.5.0", +] + [[package]] name = "redox_syscall" version = "0.5.1" diff --git a/src/api/mod.rs b/src/api/mod.rs index df951e58..0d2cd664 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,5 +1,4 @@ pub mod appservice_server; pub mod client_server; -pub mod rate_limiting; pub mod ruma_wrapper; pub mod server_server; diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index ac97b391..c7bc5879 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,7 +1,8 @@ use std::{ collections::BTreeMap, iter::FromIterator, - str::{self}, + net::IpAddr, + str::{self, FromStr}, }; use axum::{ @@ -102,44 +103,6 @@ where Token::None }; - // doesn't work when Conduit is behind proxy - // let remote_addr: ConnectInfo = parts.extract().await?; - - let target = match &token { - Token::User((user_id, _)) => Some(Target::User(user_id.clone())), - Token::None => { - let header = parts - .headers - .get("x-forwarded-for") - .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; - - let s = header - .to_str() - .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; - Some( - s.parse() - .map(Target::Ip) - .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")), - ) - .transpose()? - } - _ => None, - }; - - if let Err(retry_after_ms) = target.map_or(Ok(()), |t| { - let key = (t, (&metadata).into()); - - services() - .rate_limiting - .update_or_reject(&key) - .map_err(Some) - }) { - return Err(Error::BadRequest( - ErrorKind::LimitExceeded { retry_after_ms }, - "Rate limit exceeded.", - )); - } - let mut json_body = serde_json::from_slice::(&body).ok(); let (sender_user, sender_device, sender_servername, appservice_info) = @@ -350,8 +313,64 @@ where } }; + // doesn't work when Conduit is behind proxy + // let remote_addr: ConnectInfo = parts.extract().await?; + + let headers = parts.headers; + + let target = if let Some(server) = sender_servername.clone() { + Target::Server(server) + + // Token::User((user_id, _)) => Some(Target::User(user_id.clone())), + // Token::None => { + // let header = parts + // .headers + // .get("x-forwarded-for") + // .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; + + // let s = header + // .to_str() + // .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; + // Some( + // s.parse() + // .map(Target::Ip) + // .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")), + // ) + // .transpose()? + // } + // _ => None, + } else if let Some(appservice) = appservice_info.clone() { + Target::Appservice(appservice.registration.id) + } else if let Some(user) = sender_user.clone() { + Target::User(user) + } else { + let ip = headers + .get("X-Forwarded-For") + .and_then(|header| header.to_str().ok()) + .map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header)) + .and_then(|ip| IpAddr::from_str(ip).ok()); + + if let Some(ip) = ip { + Target::Ip(ip) + } else { + Target::None + } + }; + + if let Err(retry_after_ms) = { + services() + .rate_limiting + .update_or_reject(target, metadata) + .map_err(Some) + } { + return Err(Error::BadRequest( + ErrorKind::LimitExceeded { retry_after_ms }, + "Rate limit exceeded.", + )); + } + let mut http_request = Request::builder().uri(parts.uri).method(parts.method); - *http_request.headers_mut().unwrap() = parts.headers; + *http_request.headers_mut().unwrap() = headers; if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index 233513b0..be24f7fc 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -1,7 +1,6 @@ use std::{ hash::Hash, net::IpAddr, - num::NonZeroU64, sync::atomic::{AtomicU64, Ordering}, time::Duration, }; @@ -13,16 +12,13 @@ use ruma::{ client::{account::register, session::login}, IncomingRequest, Metadata, }, - OwnedUserId, + OwnedServerName, OwnedUserId, }; -use crate::{ - config::{Limitation, Restriction}, - services, Result, -}; +use crate::{config::Restriction, services, Result}; -impl From<&Metadata> for Restriction { - fn from(metadata: &Metadata) -> Self { +impl From for Restriction { + fn from(metadata: Metadata) -> Self { [ (register::v3::Request::METADATA, Restriction::Registration), (login::v3::Request::METADATA, Restriction::Login), @@ -48,7 +44,10 @@ pub struct Service { #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum Target { User(OwnedUserId), + Server(OwnedServerName), + Appservice(String), Ip(IpAddr), + None, } impl Service { @@ -59,33 +58,32 @@ impl Service { } } - pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> { + pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> { + let restriction = metadata.into(); + let arrival = self.clock.delta_as_nanos(0, self.clock.raw()); let config = &services().globals.config.rate_limiting; - let Some(limit) = config - .get(&key.1) - .map(ToOwned::to_owned) else { - return Ok(()); - }; - // .unwrap_or(Limitation { - // per_minute: NonZeroU64::new(1).unwrap(), - // burst_capacity: NonZeroU64::new(1).unwrap(), - // weight: NonZeroU64::new(1).unwrap(), - // }); + let Some(limit) = config.get(&restriction).map(ToOwned::to_owned) else { + return Ok(()); + }; + // .unwrap_or(Limitation { + // per_minute: NonZeroU64::new(1).unwrap(), + // burst_capacity: NonZeroU64::new(1).unwrap(), + // weight: NonZeroU64::new(1).unwrap(), + // }); + + let key = (target, restriction); tracing::info!(?limit); - let increment = u64::try_from(Duration::from_secs(60).as_nanos()) - .expect("1_000_000_000 to be smaller than u64::MAX") - / limit.per_minute.get() - * limit.weight.get(); + let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get(); tracing::info!(?increment); let mut prev_expectation = self .store - .get(key) + .get(&key) .as_deref() .map(|n| n.load(Ordering::Acquire)) .unwrap_or_else(|| arrival + increment); @@ -108,9 +106,9 @@ impl Service { tracing::info!(?decision); - while let Ok(next_expectation) = decision { - let entry = self.store.entry(key.clone()); + let entry = self.store.entry(key); + while let Ok(next_expectation) = decision { match entry.or_default().compare_exchange_weak( prev_expectation, next_expectation,