diff --git a/Cargo.toml b/Cargo.toml index 67128f07..3d6d2594 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ axum = { version = "0.7", default-features = false, features = [ "http2", "json", "matched-path", + "tokio", ], optional = true } axum-extra = { version = "0.9", features = ["typed-header"] } axum-server = { version = "0.6", features = ["tls-rustls"] } @@ -146,6 +147,8 @@ tikv-jemallocator = { version = "0.5.0", features = [ ], optional = true } sd-notify = { version = "0.4.1", optional = true } +dashmap = "5.5.3" +quanta = "0.12.3" # Used for matrix spec type definitions and helpers [dependencies.ruma] diff --git a/src/api/mod.rs b/src/api/mod.rs index 0d2cd664..df951e58 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,5 @@ pub mod appservice_server; pub mod client_server; +pub mod rate_limiting; pub mod ruma_wrapper; pub mod server_server; diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 2c5da21b..ac97b391 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,4 +1,8 @@ -use std::{collections::BTreeMap, iter::FromIterator, str}; +use std::{ + collections::BTreeMap, + iter::FromIterator, + str::{self}, +}; use axum::{ async_trait, @@ -23,7 +27,10 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{ + service::{appservice::RegistrationInfo, rate_limiting::Target}, + services, Error, Result, +}; enum Token { Appservice(Box), @@ -95,6 +102,44 @@ where Token::None }; + // doesn't work when Conduit is behind proxy + // let remote_addr: ConnectInfo = parts.extract().await?; + + let target = match &token { + Token::User((user_id, _)) => Some(Target::User(user_id.clone())), + Token::None => { + let header = parts + .headers + .get("x-forwarded-for") + .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; + + let s = header + .to_str() + .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; + Some( + s.parse() + .map(Target::Ip) + .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")), + ) + .transpose()? + } + _ => None, + }; + + if let Err(retry_after_ms) = target.map_or(Ok(()), |t| { + let key = (t, (&metadata).into()); + + services() + .rate_limiting + .update_or_reject(&key) + .map_err(Some) + }) { + return Err(Error::BadRequest( + ErrorKind::LimitExceeded { retry_after_ms }, + "Rate limit exceeded.", + )); + } + let mut json_body = serde_json::from_slice::(&body).ok(); let (sender_user, sender_device, sender_servername, appservice_info) = diff --git a/src/config/mod.rs b/src/config/mod.rs index 378ab929..81051252 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,6 +2,7 @@ use std::{ collections::BTreeMap, fmt, net::{IpAddr, Ipv4Addr}, + num::NonZeroU64, }; use ruma::{OwnedServerName, RoomVersionId}; @@ -82,6 +83,8 @@ pub struct Config { pub turn_secret: String, #[serde(default = "default_turn_ttl")] pub turn_ttl: u64, + #[serde(default = "default_rate_limit")] + pub rate_limiting: BTreeMap, pub emergency_password: Option, @@ -101,6 +104,27 @@ pub struct WellKnownConfig { pub server: Option, } +#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "lowercase")] +pub enum Restriction { + Registration, + Login, + + #[default] + #[serde(rename = "")] + CatchAll, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct Limitation { + #[serde(default = "default_non_zero")] + pub per_minute: NonZeroU64, + #[serde(default = "default_non_zero")] + pub burst_capacity: NonZeroU64, + #[serde(default = "default_non_zero")] + pub weight: NonZeroU64, +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { @@ -308,6 +332,21 @@ fn default_openid_token_ttl() -> u64 { 60 * 60 } +fn default_non_zero() -> NonZeroU64 { + NonZeroU64::new(1).unwrap() +} + +pub fn default_rate_limit() -> BTreeMap { + BTreeMap::from_iter([( + Restriction::default(), + Limitation { + per_minute: default_non_zero(), + burst_capacity: default_non_zero(), + weight: default_non_zero(), + }, + )]) +} + // I know, it's a great name pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 diff --git a/src/main.rs b/src/main.rs index 8d242c53..49e8962b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -200,7 +200,7 @@ async fn run_server() -> io::Result<()> { .expect("failed to convert max request size"), )); - let app = routes(config).layer(middlewares).into_make_service(); + let app = routes(config).layer(middlewares).into_make_service_with_connect_info::(); let handle = ServerHandle::new(); tokio::spawn(shutdown_signal(handle.clone())); diff --git a/src/service/mod.rs b/src/service/mod.rs index 4c11bc18..c0a8307e 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -17,6 +17,7 @@ pub mod key_backups; pub mod media; pub mod pdu; pub mod pusher; +pub mod rate_limiting; pub mod rooms; pub mod sending; pub mod transaction_ids; @@ -26,6 +27,7 @@ pub mod users; pub struct Services { pub appservice: appservice::Service, pub pusher: pusher::Service, + pub rate_limiting: rate_limiting::Service, pub rooms: rooms::Service, pub transaction_ids: transaction_ids::Service, pub uiaa: uiaa::Service, @@ -59,6 +61,7 @@ impl Services { Ok(Self { appservice: appservice::Service::build(db)?, pusher: pusher::Service { db }, + rate_limiting: rate_limiting::Service::build(), rooms: rooms::Service { alias: rooms::alias::Service { db }, auth_chain: rooms::auth_chain::Service { db }, diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs new file mode 100644 index 00000000..233513b0 --- /dev/null +++ b/src/service/rate_limiting/mod.rs @@ -0,0 +1,133 @@ +use std::{ + hash::Hash, + net::IpAddr, + num::NonZeroU64, + sync::atomic::{AtomicU64, Ordering}, + time::Duration, +}; + +use dashmap::DashMap; +use quanta::Clock; +use ruma::{ + api::{ + client::{account::register, session::login}, + IncomingRequest, Metadata, + }, + OwnedUserId, +}; + +use crate::{ + config::{Limitation, Restriction}, + services, Result, +}; + +impl From<&Metadata> for Restriction { + fn from(metadata: &Metadata) -> Self { + [ + (register::v3::Request::METADATA, Restriction::Registration), + (login::v3::Request::METADATA, Restriction::Login), + ] + .iter() + .find(|(other, _)| { + metadata + .history + .stable_paths() + .zip(other.history.stable_paths()) + .all(|(a, b)| a == b) + }) + .map(|(_, restriction)| restriction.to_owned()) + .unwrap_or_default() + } +} + +pub struct Service { + store: DashMap<(Target, Restriction), AtomicU64>, + clock: Clock, +} + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Target { + User(OwnedUserId), + Ip(IpAddr), +} + +impl Service { + pub fn build() -> Self { + Self { + store: DashMap::new(), + clock: Clock::new(), + } + } + + pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> { + let arrival = self.clock.delta_as_nanos(0, self.clock.raw()); + + let config = &services().globals.config.rate_limiting; + + let Some(limit) = config + .get(&key.1) + .map(ToOwned::to_owned) else { + return Ok(()); + }; + // .unwrap_or(Limitation { + // per_minute: NonZeroU64::new(1).unwrap(), + // burst_capacity: NonZeroU64::new(1).unwrap(), + // weight: NonZeroU64::new(1).unwrap(), + // }); + + tracing::info!(?limit); + + let increment = u64::try_from(Duration::from_secs(60).as_nanos()) + .expect("1_000_000_000 to be smaller than u64::MAX") + / limit.per_minute.get() + * limit.weight.get(); + tracing::info!(?increment); + + let mut prev_expectation = self + .store + .get(key) + .as_deref() + .map(|n| n.load(Ordering::Acquire)) + .unwrap_or_else(|| arrival + increment); + let weight = (increment * limit.burst_capacity.get()).max(1); + + tracing::info!(?prev_expectation); + tracing::info!(?weight); + + let f = |prev_expectation: u64| { + let allowed = prev_expectation.saturating_sub(weight); + + if arrival < allowed { + Err(Duration::from_nanos(allowed - arrival)) + } else { + Ok(prev_expectation.max(arrival) + increment) + } + }; + + let mut decision = f(prev_expectation); + + tracing::info!(?decision); + + while let Ok(next_expectation) = decision { + let entry = self.store.entry(key.clone()); + + match entry.or_default().compare_exchange_weak( + prev_expectation, + next_expectation, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => return Ok(()), + Err(actual) => prev_expectation = actual, + } + + decision = f(prev_expectation); + } + + decision.map(|_| ()) + } +} + +///// In-memory state and utility functions used to check whether the client has exceeded its rate limit. +///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible. +/////