From 02cea0bb93ae91ebb1a6a3fcec41008b01fac732 Mon Sep 17 00:00:00 2001 From: mikoto Date: Mon, 10 Jun 2024 07:55:59 +0200 Subject: [PATCH 01/14] PoC --- Cargo.toml | 3 + src/api/mod.rs | 1 + src/api/ruma_wrapper/axum.rs | 49 +++++++++++- src/config/mod.rs | 39 +++++++++ src/main.rs | 2 +- src/service/mod.rs | 3 + src/service/rate_limiting/mod.rs | 133 +++++++++++++++++++++++++++++++ 7 files changed, 227 insertions(+), 3 deletions(-) create mode 100644 src/service/rate_limiting/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 67128f07..3d6d2594 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ axum = { version = "0.7", default-features = false, features = [ "http2", "json", "matched-path", + "tokio", ], optional = true } axum-extra = { version = "0.9", features = ["typed-header"] } axum-server = { version = "0.6", features = ["tls-rustls"] } @@ -146,6 +147,8 @@ tikv-jemallocator = { version = "0.5.0", features = [ ], optional = true } sd-notify = { version = "0.4.1", optional = true } +dashmap = "5.5.3" +quanta = "0.12.3" # Used for matrix spec type definitions and helpers [dependencies.ruma] diff --git a/src/api/mod.rs b/src/api/mod.rs index 0d2cd664..df951e58 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,5 @@ pub mod appservice_server; pub mod client_server; +pub mod rate_limiting; pub mod ruma_wrapper; pub mod server_server; diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 2c5da21b..ac97b391 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,4 +1,8 @@ -use std::{collections::BTreeMap, iter::FromIterator, str}; +use std::{ + collections::BTreeMap, + iter::FromIterator, + str::{self}, +}; use axum::{ async_trait, @@ -23,7 +27,10 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{ + service::{appservice::RegistrationInfo, rate_limiting::Target}, + services, Error, Result, +}; enum Token { Appservice(Box), @@ -95,6 +102,44 @@ where Token::None }; + // doesn't work when Conduit is behind proxy + // let remote_addr: ConnectInfo = parts.extract().await?; + + let target = match &token { + Token::User((user_id, _)) => Some(Target::User(user_id.clone())), + Token::None => { + let header = parts + .headers + .get("x-forwarded-for") + .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; + + let s = header + .to_str() + .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; + Some( + s.parse() + .map(Target::Ip) + .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")), + ) + .transpose()? + } + _ => None, + }; + + if let Err(retry_after_ms) = target.map_or(Ok(()), |t| { + let key = (t, (&metadata).into()); + + services() + .rate_limiting + .update_or_reject(&key) + .map_err(Some) + }) { + return Err(Error::BadRequest( + ErrorKind::LimitExceeded { retry_after_ms }, + "Rate limit exceeded.", + )); + } + let mut json_body = serde_json::from_slice::(&body).ok(); let (sender_user, sender_device, sender_servername, appservice_info) = diff --git a/src/config/mod.rs b/src/config/mod.rs index 378ab929..81051252 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,6 +2,7 @@ use std::{ collections::BTreeMap, fmt, net::{IpAddr, Ipv4Addr}, + num::NonZeroU64, }; use ruma::{OwnedServerName, RoomVersionId}; @@ -82,6 +83,8 @@ pub struct Config { pub turn_secret: String, #[serde(default = "default_turn_ttl")] pub turn_ttl: u64, + #[serde(default = "default_rate_limit")] + pub rate_limiting: BTreeMap, pub emergency_password: Option, @@ -101,6 +104,27 @@ pub struct WellKnownConfig { pub server: Option, } +#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "lowercase")] +pub enum Restriction { + Registration, + Login, + + #[default] + #[serde(rename = "")] + CatchAll, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct Limitation { + #[serde(default = "default_non_zero")] + pub per_minute: NonZeroU64, + #[serde(default = "default_non_zero")] + pub burst_capacity: NonZeroU64, + #[serde(default = "default_non_zero")] + pub weight: NonZeroU64, +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { @@ -308,6 +332,21 @@ fn default_openid_token_ttl() -> u64 { 60 * 60 } +fn default_non_zero() -> NonZeroU64 { + NonZeroU64::new(1).unwrap() +} + +pub fn default_rate_limit() -> BTreeMap { + BTreeMap::from_iter([( + Restriction::default(), + Limitation { + per_minute: default_non_zero(), + burst_capacity: default_non_zero(), + weight: default_non_zero(), + }, + )]) +} + // I know, it's a great name pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 diff --git a/src/main.rs b/src/main.rs index 8d242c53..49e8962b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -200,7 +200,7 @@ async fn run_server() -> io::Result<()> { .expect("failed to convert max request size"), )); - let app = routes(config).layer(middlewares).into_make_service(); + let app = routes(config).layer(middlewares).into_make_service_with_connect_info::(); let handle = ServerHandle::new(); tokio::spawn(shutdown_signal(handle.clone())); diff --git a/src/service/mod.rs b/src/service/mod.rs index 4c11bc18..c0a8307e 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -17,6 +17,7 @@ pub mod key_backups; pub mod media; pub mod pdu; pub mod pusher; +pub mod rate_limiting; pub mod rooms; pub mod sending; pub mod transaction_ids; @@ -26,6 +27,7 @@ pub mod users; pub struct Services { pub appservice: appservice::Service, pub pusher: pusher::Service, + pub rate_limiting: rate_limiting::Service, pub rooms: rooms::Service, pub transaction_ids: transaction_ids::Service, pub uiaa: uiaa::Service, @@ -59,6 +61,7 @@ impl Services { Ok(Self { appservice: appservice::Service::build(db)?, pusher: pusher::Service { db }, + rate_limiting: rate_limiting::Service::build(), rooms: rooms::Service { alias: rooms::alias::Service { db }, auth_chain: rooms::auth_chain::Service { db }, diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs new file mode 100644 index 00000000..233513b0 --- /dev/null +++ b/src/service/rate_limiting/mod.rs @@ -0,0 +1,133 @@ +use std::{ + hash::Hash, + net::IpAddr, + num::NonZeroU64, + sync::atomic::{AtomicU64, Ordering}, + time::Duration, +}; + +use dashmap::DashMap; +use quanta::Clock; +use ruma::{ + api::{ + client::{account::register, session::login}, + IncomingRequest, Metadata, + }, + OwnedUserId, +}; + +use crate::{ + config::{Limitation, Restriction}, + services, Result, +}; + +impl From<&Metadata> for Restriction { + fn from(metadata: &Metadata) -> Self { + [ + (register::v3::Request::METADATA, Restriction::Registration), + (login::v3::Request::METADATA, Restriction::Login), + ] + .iter() + .find(|(other, _)| { + metadata + .history + .stable_paths() + .zip(other.history.stable_paths()) + .all(|(a, b)| a == b) + }) + .map(|(_, restriction)| restriction.to_owned()) + .unwrap_or_default() + } +} + +pub struct Service { + store: DashMap<(Target, Restriction), AtomicU64>, + clock: Clock, +} + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Target { + User(OwnedUserId), + Ip(IpAddr), +} + +impl Service { + pub fn build() -> Self { + Self { + store: DashMap::new(), + clock: Clock::new(), + } + } + + pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> { + let arrival = self.clock.delta_as_nanos(0, self.clock.raw()); + + let config = &services().globals.config.rate_limiting; + + let Some(limit) = config + .get(&key.1) + .map(ToOwned::to_owned) else { + return Ok(()); + }; + // .unwrap_or(Limitation { + // per_minute: NonZeroU64::new(1).unwrap(), + // burst_capacity: NonZeroU64::new(1).unwrap(), + // weight: NonZeroU64::new(1).unwrap(), + // }); + + tracing::info!(?limit); + + let increment = u64::try_from(Duration::from_secs(60).as_nanos()) + .expect("1_000_000_000 to be smaller than u64::MAX") + / limit.per_minute.get() + * limit.weight.get(); + tracing::info!(?increment); + + let mut prev_expectation = self + .store + .get(key) + .as_deref() + .map(|n| n.load(Ordering::Acquire)) + .unwrap_or_else(|| arrival + increment); + let weight = (increment * limit.burst_capacity.get()).max(1); + + tracing::info!(?prev_expectation); + tracing::info!(?weight); + + let f = |prev_expectation: u64| { + let allowed = prev_expectation.saturating_sub(weight); + + if arrival < allowed { + Err(Duration::from_nanos(allowed - arrival)) + } else { + Ok(prev_expectation.max(arrival) + increment) + } + }; + + let mut decision = f(prev_expectation); + + tracing::info!(?decision); + + while let Ok(next_expectation) = decision { + let entry = self.store.entry(key.clone()); + + match entry.or_default().compare_exchange_weak( + prev_expectation, + next_expectation, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => return Ok(()), + Err(actual) => prev_expectation = actual, + } + + decision = f(prev_expectation); + } + + decision.map(|_| ()) + } +} + +///// In-memory state and utility functions used to check whether the client has exceeded its rate limit. +///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible. +///// From d6abf5472b4252c31a2a78e4632eef0a0b98db90 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 26 Jun 2024 09:16:54 +0100 Subject: [PATCH 02/14] more rate limit targets --- Cargo.lock | 40 +++++++++++++ src/api/mod.rs | 1 - src/api/ruma_wrapper/axum.rs | 99 +++++++++++++++++++------------- src/service/rate_limiting/mod.rs | 50 ++++++++-------- 4 files changed, 123 insertions(+), 67 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 30d951a7..e587f668 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,6 +189,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sync_wrapper 1.0.1", + "tokio", "tower", "tower-layer", "tower-service", @@ -496,6 +497,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "dashmap", "directories", "figment", "futures-util", @@ -515,6 +517,7 @@ dependencies = [ "opentelemetry_sdk", "parking_lot", "persy", + "quanta", "rand", "regex", "reqwest", @@ -664,6 +667,19 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -2053,6 +2069,21 @@ dependencies = [ "syn", ] +[[package]] +name = "quanta" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2098,6 +2129,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "raw-cpuid" +version = "11.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" +dependencies = [ + "bitflags 2.5.0", +] + [[package]] name = "redox_syscall" version = "0.5.1" diff --git a/src/api/mod.rs b/src/api/mod.rs index df951e58..0d2cd664 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,5 +1,4 @@ pub mod appservice_server; pub mod client_server; -pub mod rate_limiting; pub mod ruma_wrapper; pub mod server_server; diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index ac97b391..c7bc5879 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,7 +1,8 @@ use std::{ collections::BTreeMap, iter::FromIterator, - str::{self}, + net::IpAddr, + str::{self, FromStr}, }; use axum::{ @@ -102,44 +103,6 @@ where Token::None }; - // doesn't work when Conduit is behind proxy - // let remote_addr: ConnectInfo = parts.extract().await?; - - let target = match &token { - Token::User((user_id, _)) => Some(Target::User(user_id.clone())), - Token::None => { - let header = parts - .headers - .get("x-forwarded-for") - .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; - - let s = header - .to_str() - .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; - Some( - s.parse() - .map(Target::Ip) - .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")), - ) - .transpose()? - } - _ => None, - }; - - if let Err(retry_after_ms) = target.map_or(Ok(()), |t| { - let key = (t, (&metadata).into()); - - services() - .rate_limiting - .update_or_reject(&key) - .map_err(Some) - }) { - return Err(Error::BadRequest( - ErrorKind::LimitExceeded { retry_after_ms }, - "Rate limit exceeded.", - )); - } - let mut json_body = serde_json::from_slice::(&body).ok(); let (sender_user, sender_device, sender_servername, appservice_info) = @@ -350,8 +313,64 @@ where } }; + // doesn't work when Conduit is behind proxy + // let remote_addr: ConnectInfo = parts.extract().await?; + + let headers = parts.headers; + + let target = if let Some(server) = sender_servername.clone() { + Target::Server(server) + + // Token::User((user_id, _)) => Some(Target::User(user_id.clone())), + // Token::None => { + // let header = parts + // .headers + // .get("x-forwarded-for") + // .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; + + // let s = header + // .to_str() + // .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; + // Some( + // s.parse() + // .map(Target::Ip) + // .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")), + // ) + // .transpose()? + // } + // _ => None, + } else if let Some(appservice) = appservice_info.clone() { + Target::Appservice(appservice.registration.id) + } else if let Some(user) = sender_user.clone() { + Target::User(user) + } else { + let ip = headers + .get("X-Forwarded-For") + .and_then(|header| header.to_str().ok()) + .map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header)) + .and_then(|ip| IpAddr::from_str(ip).ok()); + + if let Some(ip) = ip { + Target::Ip(ip) + } else { + Target::None + } + }; + + if let Err(retry_after_ms) = { + services() + .rate_limiting + .update_or_reject(target, metadata) + .map_err(Some) + } { + return Err(Error::BadRequest( + ErrorKind::LimitExceeded { retry_after_ms }, + "Rate limit exceeded.", + )); + } + let mut http_request = Request::builder().uri(parts.uri).method(parts.method); - *http_request.headers_mut().unwrap() = parts.headers; + *http_request.headers_mut().unwrap() = headers; if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index 233513b0..be24f7fc 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -1,7 +1,6 @@ use std::{ hash::Hash, net::IpAddr, - num::NonZeroU64, sync::atomic::{AtomicU64, Ordering}, time::Duration, }; @@ -13,16 +12,13 @@ use ruma::{ client::{account::register, session::login}, IncomingRequest, Metadata, }, - OwnedUserId, + OwnedServerName, OwnedUserId, }; -use crate::{ - config::{Limitation, Restriction}, - services, Result, -}; +use crate::{config::Restriction, services, Result}; -impl From<&Metadata> for Restriction { - fn from(metadata: &Metadata) -> Self { +impl From for Restriction { + fn from(metadata: Metadata) -> Self { [ (register::v3::Request::METADATA, Restriction::Registration), (login::v3::Request::METADATA, Restriction::Login), @@ -48,7 +44,10 @@ pub struct Service { #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum Target { User(OwnedUserId), + Server(OwnedServerName), + Appservice(String), Ip(IpAddr), + None, } impl Service { @@ -59,33 +58,32 @@ impl Service { } } - pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> { + pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> { + let restriction = metadata.into(); + let arrival = self.clock.delta_as_nanos(0, self.clock.raw()); let config = &services().globals.config.rate_limiting; - let Some(limit) = config - .get(&key.1) - .map(ToOwned::to_owned) else { - return Ok(()); - }; - // .unwrap_or(Limitation { - // per_minute: NonZeroU64::new(1).unwrap(), - // burst_capacity: NonZeroU64::new(1).unwrap(), - // weight: NonZeroU64::new(1).unwrap(), - // }); + let Some(limit) = config.get(&restriction).map(ToOwned::to_owned) else { + return Ok(()); + }; + // .unwrap_or(Limitation { + // per_minute: NonZeroU64::new(1).unwrap(), + // burst_capacity: NonZeroU64::new(1).unwrap(), + // weight: NonZeroU64::new(1).unwrap(), + // }); + + let key = (target, restriction); tracing::info!(?limit); - let increment = u64::try_from(Duration::from_secs(60).as_nanos()) - .expect("1_000_000_000 to be smaller than u64::MAX") - / limit.per_minute.get() - * limit.weight.get(); + let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get(); tracing::info!(?increment); let mut prev_expectation = self .store - .get(key) + .get(&key) .as_deref() .map(|n| n.load(Ordering::Acquire)) .unwrap_or_else(|| arrival + increment); @@ -108,9 +106,9 @@ impl Service { tracing::info!(?decision); - while let Ok(next_expectation) = decision { - let entry = self.store.entry(key.clone()); + let entry = self.store.entry(key); + while let Ok(next_expectation) = decision { match entry.or_default().compare_exchange_weak( prev_expectation, next_expectation, From bdf12c2bbdf839caa653c7347b58508a22f628c0 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 26 Jun 2024 11:42:17 +0100 Subject: [PATCH 03/14] use into_iter --- src/service/rate_limiting/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index be24f7fc..1f5ef0a1 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -23,7 +23,7 @@ impl From for Restriction { (register::v3::Request::METADATA, Restriction::Registration), (login::v3::Request::METADATA, Restriction::Login), ] - .iter() + .into_iter() .find(|(other, _)| { metadata .history @@ -31,7 +31,7 @@ impl From for Restriction { .zip(other.history.stable_paths()) .all(|(a, b)| a == b) }) - .map(|(_, restriction)| restriction.to_owned()) + .map(|(_, restriction)| restriction) .unwrap_or_default() } } From 024f910bf9e4ca56339a53bc464bf15e35ae512c Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 26 Jun 2024 13:36:04 +0100 Subject: [PATCH 04/14] allow for different timeframes for configuration --- src/api/ruma_wrapper/axum.rs | 11 ++++++++--- src/config/mod.rs | 26 +++++++++++++++++++++++--- src/service/rate_limiting/mod.rs | 2 +- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index c7bc5879..8ab2a468 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -20,7 +20,10 @@ use axum_extra::{ use bytes::{BufMut, BytesMut}; use http::{Request, StatusCode}; use ruma::{ - api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, + api::{ + client::error::{ErrorKind, RetryAfter}, + AuthScheme, IncomingRequest, OutgoingResponse, + }, server_util::authorization::XMatrix, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId, }; @@ -357,14 +360,16 @@ where } }; - if let Err(retry_after_ms) = { + if let Err(retry_after) = { services() .rate_limiting .update_or_reject(target, metadata) .map_err(Some) } { return Err(Error::BadRequest( - ErrorKind::LimitExceeded { retry_after_ms }, + ErrorKind::LimitExceeded { + retry_after: retry_after.map(|dur| RetryAfter::Delay(dur)), + }, "Rate limit exceeded.", )); } diff --git a/src/config/mod.rs b/src/config/mod.rs index 81051252..efa02370 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -115,10 +115,30 @@ pub enum Restriction { CatchAll, } +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +pub enum Timeframe { + PerSecond(NonZeroU64), + PerMinute(NonZeroU64), + PerHour(NonZeroU64), + PerDay(NonZeroU64), +} + +impl Timeframe { + pub fn nano_gap(&self) -> u64 { + match self { + Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(), + Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(), + Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(), + Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(), + } + } +} + #[derive(Clone, Copy, Debug, Deserialize)] pub struct Limitation { - #[serde(default = "default_non_zero")] - pub per_minute: NonZeroU64, + #[serde(default = "default_non_zero", flatten)] + pub timeframe: Timeframe, #[serde(default = "default_non_zero")] pub burst_capacity: NonZeroU64, #[serde(default = "default_non_zero")] @@ -340,7 +360,7 @@ pub fn default_rate_limit() -> BTreeMap { BTreeMap::from_iter([( Restriction::default(), Limitation { - per_minute: default_non_zero(), + timeframe: Timeframe::PerMinute(default_non_zero()), burst_capacity: default_non_zero(), weight: default_non_zero(), }, diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index 1f5ef0a1..bfe7746a 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -78,7 +78,7 @@ impl Service { tracing::info!(?limit); - let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get(); + let increment = 1_000_000_000u64 / limit.timeframe.nano_gap() * limit.weight.get(); tracing::info!(?increment); let mut prev_expectation = self From e20fcb029ac9e818a1a4d19193fe00b38f81c5eb Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 26 Jun 2024 16:26:30 +0100 Subject: [PATCH 05/14] fix nano gap --- src/service/rate_limiting/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index bfe7746a..387a17fa 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -78,7 +78,7 @@ impl Service { tracing::info!(?limit); - let increment = 1_000_000_000u64 / limit.timeframe.nano_gap() * limit.weight.get(); + let increment = limit.timeframe.nano_gap() * limit.weight.get(); tracing::info!(?increment); let mut prev_expectation = self @@ -87,7 +87,7 @@ impl Service { .as_deref() .map(|n| n.load(Ordering::Acquire)) .unwrap_or_else(|| arrival + increment); - let weight = (increment * limit.burst_capacity.get()).max(1); + let weight = increment * limit.burst_capacity.get(); tracing::info!(?prev_expectation); tracing::info!(?weight); From bf902f160707b38275de3d7cc13edf25d9743ad6 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Tue, 9 Jul 2024 13:38:36 +0100 Subject: [PATCH 06/14] use ::MIN --- src/config/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index efa02370..70f9fca9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -353,16 +353,16 @@ fn default_openid_token_ttl() -> u64 { } fn default_non_zero() -> NonZeroU64 { - NonZeroU64::new(1).unwrap() + NonZeroU64::MIN } pub fn default_rate_limit() -> BTreeMap { BTreeMap::from_iter([( Restriction::default(), Limitation { - timeframe: Timeframe::PerMinute(default_non_zero()), - burst_capacity: default_non_zero(), - weight: default_non_zero(), + timeframe: Timeframe::PerMinute(NonZeroU64::MIN), + burst_capacity: NonZeroU64::MIN, + weight: NonZeroU64::MIN, }, )]) } From ab21c5dbef9d82cd366fb3c4d9ea0e3079dcac44 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 10 Jul 2024 09:44:44 +0100 Subject: [PATCH 07/14] simplify --- Cargo.lock | 25 ----------- Cargo.toml | 3 +- src/api/ruma_wrapper/axum.rs | 23 +--------- src/config/mod.rs | 5 +-- src/service/rate_limiting/mod.rs | 73 ++++++++++---------------------- 5 files changed, 28 insertions(+), 101 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e587f668..e394ce76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -517,7 +517,6 @@ dependencies = [ "opentelemetry_sdk", "parking_lot", "persy", - "quanta", "rand", "regex", "reqwest", @@ -2069,21 +2068,6 @@ dependencies = [ "syn", ] -[[package]] -name = "quanta" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" -dependencies = [ - "crossbeam-utils", - "libc", - "once_cell", - "raw-cpuid", - "wasi", - "web-sys", - "winapi", -] - [[package]] name = "quick-error" version = "1.2.3" @@ -2129,15 +2113,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "raw-cpuid" -version = "11.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" -dependencies = [ - "bitflags 2.5.0", -] - [[package]] name = "redox_syscall" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 3d6d2594..d24914ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -146,9 +146,8 @@ tikv-jemallocator = { version = "0.5.0", features = [ "unprefixed_malloc_on_supported_platforms", ], optional = true } -sd-notify = { version = "0.4.1", optional = true } dashmap = "5.5.3" -quanta = "0.12.3" +sd-notify = { version = "0.4.1", optional = true } # Used for matrix spec type definitions and helpers [dependencies.ruma] diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 8ab2a468..18d1baad 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -323,25 +323,6 @@ where let target = if let Some(server) = sender_servername.clone() { Target::Server(server) - - // Token::User((user_id, _)) => Some(Target::User(user_id.clone())), - // Token::None => { - // let header = parts - // .headers - // .get("x-forwarded-for") - // .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; - - // let s = header - // .to_str() - // .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; - // Some( - // s.parse() - // .map(Target::Ip) - // .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")), - // ) - // .transpose()? - // } - // _ => None, } else if let Some(appservice) = appservice_info.clone() { Target::Appservice(appservice.registration.id) } else if let Some(user) = sender_user.clone() { @@ -350,7 +331,7 @@ where let ip = headers .get("X-Forwarded-For") .and_then(|header| header.to_str().ok()) - .map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header)) + .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header)) .and_then(|ip| IpAddr::from_str(ip).ok()); if let Some(ip) = ip { @@ -368,7 +349,7 @@ where } { return Err(Error::BadRequest( ErrorKind::LimitExceeded { - retry_after: retry_after.map(|dur| RetryAfter::Delay(dur)), + retry_after: retry_after.map(RetryAfter::Delay), }, "Rate limit exceeded.", )); diff --git a/src/config/mod.rs b/src/config/mod.rs index 70f9fca9..2d08d855 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -117,6 +117,8 @@ pub enum Restriction { #[derive(Deserialize, Clone, Copy, Debug)] #[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] pub enum Timeframe { PerSecond(NonZeroU64), PerMinute(NonZeroU64), @@ -141,8 +143,6 @@ pub struct Limitation { pub timeframe: Timeframe, #[serde(default = "default_non_zero")] pub burst_capacity: NonZeroU64, - #[serde(default = "default_non_zero")] - pub weight: NonZeroU64, } const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; @@ -362,7 +362,6 @@ pub fn default_rate_limit() -> BTreeMap { Limitation { timeframe: Timeframe::PerMinute(NonZeroU64::MIN), burst_capacity: NonZeroU64::MIN, - weight: NonZeroU64::MIN, }, )]) } diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index 387a17fa..fb72fec9 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -1,12 +1,10 @@ use std::{ hash::Hash, net::IpAddr, - sync::atomic::{AtomicU64, Ordering}, - time::Duration, + time::{Duration, Instant}, }; -use dashmap::DashMap; -use quanta::Clock; +use dashmap::{mapref::entry::Entry, DashMap}; use ruma::{ api::{ client::{account::register, session::login}, @@ -37,8 +35,7 @@ impl From for Restriction { } pub struct Service { - store: DashMap<(Target, Restriction), AtomicU64>, - clock: Clock, + store: DashMap<(Target, Restriction), (Instant, u64)>, } #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] @@ -54,14 +51,13 @@ impl Service { pub fn build() -> Self { Self { store: DashMap::new(), - clock: Clock::new(), } } pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> { let restriction = metadata.into(); - let arrival = self.clock.delta_as_nanos(0, self.clock.raw()); + let arrival = Instant::now(); let config = &services().globals.config.rate_limiting; @@ -74,55 +70,32 @@ impl Service { // weight: NonZeroU64::new(1).unwrap(), // }); + let gap = Duration::from_nanos(limit.timeframe.nano_gap()); let key = (target, restriction); - tracing::info!(?limit); + match self.store.entry(key) { + Entry::Occupied(mut entry) => { + let (instant, capacity) = entry.get_mut(); - let increment = limit.timeframe.nano_gap() * limit.weight.get(); - tracing::info!(?increment); + while *instant < arrival && *capacity != 0 { + *capacity -= 1; + *instant += gap; + } - let mut prev_expectation = self - .store - .get(&key) - .as_deref() - .map(|n| n.load(Ordering::Acquire)) - .unwrap_or_else(|| arrival + increment); - let weight = increment * limit.burst_capacity.get(); - - tracing::info!(?prev_expectation); - tracing::info!(?weight); - - let f = |prev_expectation: u64| { - let allowed = prev_expectation.saturating_sub(weight); - - if arrival < allowed { - Err(Duration::from_nanos(allowed - arrival)) - } else { - Ok(prev_expectation.max(arrival) + increment) + if *capacity >= limit.burst_capacity.get() { + return Err(gap); + } else { + *capacity += 1; + // TODO: update thing + *instant += gap; + } + } + Entry::Vacant(entry) => { + entry.insert((Instant::now() + gap, 1)); } }; - let mut decision = f(prev_expectation); - - tracing::info!(?decision); - - let entry = self.store.entry(key); - - while let Ok(next_expectation) = decision { - match entry.or_default().compare_exchange_weak( - prev_expectation, - next_expectation, - Ordering::Release, - Ordering::Relaxed, - ) { - Ok(_) => return Ok(()), - Err(actual) => prev_expectation = actual, - } - - decision = f(prev_expectation); - } - - decision.map(|_| ()) + Ok(()) } } From 6a3b1945677ce756aa8fef1386c8f25f5f334581 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 10 Jul 2024 09:51:23 +0100 Subject: [PATCH 08/14] fmt --- src/main.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 49e8962b..1e60fa9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -200,7 +200,9 @@ async fn run_server() -> io::Result<()> { .expect("failed to convert max request size"), )); - let app = routes(config).layer(middlewares).into_make_service_with_connect_info::(); + let app = routes(config) + .layer(middlewares) + .into_make_service_with_connect_info::(); let handle = ServerHandle::new(); tokio::spawn(shutdown_signal(handle.clone())); From 1e76cc5cee2b5c6ad7c05763dbc1028d43b71ee2 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 10 Jul 2024 10:40:33 +0100 Subject: [PATCH 09/14] don't rate limit appservices if the registration file says they shouldn't be --- src/api/ruma_wrapper/axum.rs | 6 +++++- src/service/rate_limiting/mod.rs | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 18d1baad..92abb4db 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -324,7 +324,11 @@ where let target = if let Some(server) = sender_servername.clone() { Target::Server(server) } else if let Some(appservice) = appservice_info.clone() { - Target::Appservice(appservice.registration.id) + if appservice.registration.rate_limited.unwrap_or(true) { + Target::Appservice(appservice.registration.id) + } else { + Target::None + } } else if let Some(user) = sender_user.clone() { Target::User(user) } else { diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index fb72fec9..b33b2381 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -55,6 +55,10 @@ impl Service { } pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> { + if target == Target::None { + return Ok(()); + } + let restriction = metadata.into(); let arrival = Instant::now(); From 499548321fd17466cf033eac7cc5eed0dde1d70e Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 10 Jul 2024 11:17:38 +0100 Subject: [PATCH 10/14] enforce maximum capacity --- src/service/rate_limiting/mod.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index b33b2381..3900e843 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -89,8 +89,14 @@ impl Service { if *capacity >= limit.burst_capacity.get() { return Err(gap); } else { + let zero_capacity = *capacity == 0; *capacity += 1; - // TODO: update thing + + // Ensures that the update point is in the future + if zero_capacity { + *instant = Instant::now() + } + *instant += gap; } } From 613107e7cf511163cffecc57b29edc6f8f89fadd Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Thu, 11 Jul 2024 13:14:10 +0100 Subject: [PATCH 11/14] add rate limiting for registration token validity --- src/config/mod.rs | 4 ++-- src/service/rate_limiting/mod.rs | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 2d08d855..105b265a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -105,13 +105,13 @@ pub struct WellKnownConfig { } #[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "snake_case")] pub enum Restriction { Registration, Login, + RegistrationTokenValidity, #[default] - #[serde(rename = "")] CatchAll, } diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index 3900e843..e4c4e201 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -7,7 +7,10 @@ use std::{ use dashmap::{mapref::entry::Entry, DashMap}; use ruma::{ api::{ - client::{account::register, session::login}, + client::{ + account::{check_registration_token_validity, register}, + session::login, + }, IncomingRequest, Metadata, }, OwnedServerName, OwnedUserId, @@ -20,6 +23,10 @@ impl From for Restriction { [ (register::v3::Request::METADATA, Restriction::Registration), (login::v3::Request::METADATA, Restriction::Login), + ( + check_registration_token_validity::v1::Request::METADATA, + Restriction::RegistrationTokenValidity, + ), ] .into_iter() .find(|(other, _)| { From 3bf113d92001fe941f33f4f281ca0226e1de88e3 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Thu, 11 Jul 2024 13:15:45 +0100 Subject: [PATCH 12/14] don't take ownership of limitation --- src/service/rate_limiting/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index e4c4e201..ca2995ab 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -72,7 +72,7 @@ impl Service { let config = &services().globals.config.rate_limiting; - let Some(limit) = config.get(&restriction).map(ToOwned::to_owned) else { + let Some(limit) = config.get(&restriction) else { return Ok(()); }; // .unwrap_or(Limitation { From e3cfe360a106d38e217605beba6359e232af6047 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Thu, 11 Jul 2024 13:24:06 +0100 Subject: [PATCH 13/14] simplify conversion to restriction --- src/service/rate_limiting/mod.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index ca2995ab..d17fa4dd 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -29,13 +29,7 @@ impl From for Restriction { ), ] .into_iter() - .find(|(other, _)| { - metadata - .history - .stable_paths() - .zip(other.history.stable_paths()) - .all(|(a, b)| a == b) - }) + .find(|(other, _)| metadata.history.all_paths().eq(other.history.all_paths())) .map(|(_, restriction)| restriction) .unwrap_or_default() } From 619ea68405916dd09a1b861983af9f543020b5f0 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Fri, 12 Jul 2024 13:15:24 +0100 Subject: [PATCH 14/14] a lot more endpoints --- Cargo.lock | 27 ++++++++--------- src/config/mod.rs | 9 ++++++ src/service/rate_limiting/mod.rs | 50 ++++++++++++++++++++++++-------- 3 files changed, 61 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e394ce76..d28d930b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2247,7 +2247,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "assign", "js_int", @@ -2268,7 +2268,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "ruma-common", @@ -2280,7 +2280,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "as_variant", "assign", @@ -2303,7 +2303,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "as_variant", "base64 0.22.1", @@ -2333,7 +2333,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "as_variant", "indexmap 2.2.6", @@ -2349,13 +2349,14 @@ dependencies = [ "thiserror", "tracing", "url", + "web-time", "wildmatch", ] [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "ruma-common", @@ -2367,7 +2368,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "thiserror", @@ -2376,7 +2377,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "ruma-common", @@ -2386,7 +2387,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "once_cell", "proc-macro-crate", @@ -2401,7 +2402,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "ruma-common", @@ -2413,7 +2414,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "headers", "http 1.1.0", @@ -2426,7 +2427,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -2442,7 +2443,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "itertools", "js_int", diff --git a/src/config/mod.rs b/src/config/mod.rs index 105b265a..7294a5de 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -110,6 +110,15 @@ pub enum Restriction { Registration, Login, RegistrationTokenValidity, + Message, + Join, + Invite, + Knock, + CreateMedia, + Transaction, + FederatedJoin, + FederatedInvite, + FederatedKnock, #[default] CatchAll, diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index d17fa4dd..68dc3a63 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -9,7 +9,17 @@ use ruma::{ api::{ client::{ account::{check_registration_token_validity, register}, + knock::knock_room, + media::{create_content, create_content_async}, + membership::{invite_user, join_room_by_id, join_room_by_id_or_alias}, + message::send_message_event, session::login, + state::send_state_event, + }, + federation::{ + knock::send_knock, + membership::{create_invite, create_join_event}, + transactions::send_transaction_message, }, IncomingRequest, Metadata, }, @@ -20,18 +30,33 @@ use crate::{config::Restriction, services, Result}; impl From for Restriction { fn from(metadata: Metadata) -> Self { - [ - (register::v3::Request::METADATA, Restriction::Registration), - (login::v3::Request::METADATA, Restriction::Login), - ( - check_registration_token_validity::v1::Request::METADATA, - Restriction::RegistrationTokenValidity, - ), - ] - .into_iter() - .find(|(other, _)| metadata.history.all_paths().eq(other.history.all_paths())) - .map(|(_, restriction)| restriction) - .unwrap_or_default() + #[allow(deprecated)] + match metadata { + register::v3::Request::METADATA => Restriction::Registration, + login::v3::Request::METADATA => Restriction::Login, + check_registration_token_validity::v1::Request::METADATA => { + Restriction::RegistrationTokenValidity + } + send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => { + Restriction::Message + } + join_room_by_id::v3::Request::METADATA + | join_room_by_id_or_alias::v3::Request::METADATA => Restriction::Join, + invite_user::v3::Request::METADATA => Restriction::Invite, + create_content::v3::Request::METADATA | create_content_async::v3::Request::METADATA => { + Restriction::CreateMedia + } + send_transaction_message::v1::Request::METADATA => Restriction::Transaction, + create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => { + Restriction::FederatedJoin + } + create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => { + Restriction::FederatedInvite + } + send_knock::v1::Request::METADATA => Restriction::FederatedKnock, + knock_room::v3::Request::METADATA => Restriction::Knock, + _ => Self::default(), + } } } @@ -42,6 +67,7 @@ pub struct Service { #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum Target { User(OwnedUserId), + // Server endpoints should be rate-limited on a server and room basis Server(OwnedServerName), Appservice(String), Ip(IpAddr),