diff --git a/Cargo.lock b/Cargo.lock index 30d951a7..d28d930b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,6 +189,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sync_wrapper 1.0.1", + "tokio", "tower", "tower-layer", "tower-service", @@ -496,6 +497,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "dashmap", "directories", "figment", "futures-util", @@ -664,6 +666,19 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -2232,7 +2247,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "assign", "js_int", @@ -2253,7 +2268,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "ruma-common", @@ -2265,7 +2280,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "as_variant", "assign", @@ -2288,7 +2303,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "as_variant", "base64 0.22.1", @@ -2318,7 +2333,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "as_variant", "indexmap 2.2.6", @@ -2334,13 +2349,14 @@ dependencies = [ "thiserror", "tracing", "url", + "web-time", "wildmatch", ] [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "ruma-common", @@ -2352,7 +2368,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "thiserror", @@ -2361,7 +2377,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "ruma-common", @@ -2371,7 +2387,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "once_cell", "proc-macro-crate", @@ -2386,7 +2402,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "js_int", "ruma-common", @@ -2398,7 +2414,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "headers", "http 1.1.0", @@ -2411,7 +2427,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -2427,7 +2443,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2" +source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366" dependencies = [ "itertools", "js_int", diff --git a/Cargo.toml b/Cargo.toml index 67128f07..d24914ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ axum = { version = "0.7", default-features = false, features = [ "http2", "json", "matched-path", + "tokio", ], optional = true } axum-extra = { version = "0.9", features = ["typed-header"] } axum-server = { version = "0.6", features = ["tls-rustls"] } @@ -145,6 +146,7 @@ tikv-jemallocator = { version = "0.5.0", features = [ "unprefixed_malloc_on_supported_platforms", ], optional = true } +dashmap = "5.5.3" sd-notify = { version = "0.4.1", optional = true } # Used for matrix spec type definitions and helpers diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 2c5da21b..92abb4db 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,4 +1,9 @@ -use std::{collections::BTreeMap, iter::FromIterator, str}; +use std::{ + collections::BTreeMap, + iter::FromIterator, + net::IpAddr, + str::{self, FromStr}, +}; use axum::{ async_trait, @@ -15,7 +20,10 @@ use axum_extra::{ use bytes::{BufMut, BytesMut}; use http::{Request, StatusCode}; use ruma::{ - api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, + api::{ + client::error::{ErrorKind, RetryAfter}, + AuthScheme, IncomingRequest, OutgoingResponse, + }, server_util::authorization::XMatrix, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId, }; @@ -23,7 +31,10 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{ + service::{appservice::RegistrationInfo, rate_limiting::Target}, + services, Error, Result, +}; enum Token { Appservice(Box), @@ -305,8 +316,51 @@ where } }; + // doesn't work when Conduit is behind proxy + // let remote_addr: ConnectInfo = parts.extract().await?; + + let headers = parts.headers; + + let target = if let Some(server) = sender_servername.clone() { + Target::Server(server) + } else if let Some(appservice) = appservice_info.clone() { + if appservice.registration.rate_limited.unwrap_or(true) { + Target::Appservice(appservice.registration.id) + } else { + Target::None + } + } else if let Some(user) = sender_user.clone() { + Target::User(user) + } else { + let ip = headers + .get("X-Forwarded-For") + .and_then(|header| header.to_str().ok()) + .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header)) + .and_then(|ip| IpAddr::from_str(ip).ok()); + + if let Some(ip) = ip { + Target::Ip(ip) + } else { + Target::None + } + }; + + if let Err(retry_after) = { + services() + .rate_limiting + .update_or_reject(target, metadata) + .map_err(Some) + } { + return Err(Error::BadRequest( + ErrorKind::LimitExceeded { + retry_after: retry_after.map(RetryAfter::Delay), + }, + "Rate limit exceeded.", + )); + } + let mut http_request = Request::builder().uri(parts.uri).method(parts.method); - *http_request.headers_mut().unwrap() = parts.headers; + *http_request.headers_mut().unwrap() = headers; if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { diff --git a/src/config/mod.rs b/src/config/mod.rs index 378ab929..7294a5de 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,6 +2,7 @@ use std::{ collections::BTreeMap, fmt, net::{IpAddr, Ipv4Addr}, + num::NonZeroU64, }; use ruma::{OwnedServerName, RoomVersionId}; @@ -82,6 +83,8 @@ pub struct Config { pub turn_secret: String, #[serde(default = "default_turn_ttl")] pub turn_ttl: u64, + #[serde(default = "default_rate_limit")] + pub rate_limiting: BTreeMap, pub emergency_password: Option, @@ -101,6 +104,56 @@ pub struct WellKnownConfig { pub server: Option, } +#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum Restriction { + Registration, + Login, + RegistrationTokenValidity, + Message, + Join, + Invite, + Knock, + CreateMedia, + Transaction, + FederatedJoin, + FederatedInvite, + FederatedKnock, + + #[default] + CatchAll, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum Timeframe { + PerSecond(NonZeroU64), + PerMinute(NonZeroU64), + PerHour(NonZeroU64), + PerDay(NonZeroU64), +} + +impl Timeframe { + pub fn nano_gap(&self) -> u64 { + match self { + Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(), + Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(), + Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(), + Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(), + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct Limitation { + #[serde(default = "default_non_zero", flatten)] + pub timeframe: Timeframe, + #[serde(default = "default_non_zero")] + pub burst_capacity: NonZeroU64, +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { @@ -308,6 +361,20 @@ fn default_openid_token_ttl() -> u64 { 60 * 60 } +fn default_non_zero() -> NonZeroU64 { + NonZeroU64::MIN +} + +pub fn default_rate_limit() -> BTreeMap { + BTreeMap::from_iter([( + Restriction::default(), + Limitation { + timeframe: Timeframe::PerMinute(NonZeroU64::MIN), + burst_capacity: NonZeroU64::MIN, + }, + )]) +} + // I know, it's a great name pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 diff --git a/src/main.rs b/src/main.rs index 8d242c53..1e60fa9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -200,7 +200,9 @@ async fn run_server() -> io::Result<()> { .expect("failed to convert max request size"), )); - let app = routes(config).layer(middlewares).into_make_service(); + let app = routes(config) + .layer(middlewares) + .into_make_service_with_connect_info::(); let handle = ServerHandle::new(); tokio::spawn(shutdown_signal(handle.clone())); diff --git a/src/service/mod.rs b/src/service/mod.rs index 4c11bc18..c0a8307e 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -17,6 +17,7 @@ pub mod key_backups; pub mod media; pub mod pdu; pub mod pusher; +pub mod rate_limiting; pub mod rooms; pub mod sending; pub mod transaction_ids; @@ -26,6 +27,7 @@ pub mod users; pub struct Services { pub appservice: appservice::Service, pub pusher: pusher::Service, + pub rate_limiting: rate_limiting::Service, pub rooms: rooms::Service, pub transaction_ids: transaction_ids::Service, pub uiaa: uiaa::Service, @@ -59,6 +61,7 @@ impl Services { Ok(Self { appservice: appservice::Service::build(db)?, pusher: pusher::Service { db }, + rate_limiting: rate_limiting::Service::build(), rooms: rooms::Service { alias: rooms::alias::Service { db }, auth_chain: rooms::auth_chain::Service { db }, diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs new file mode 100644 index 00000000..68dc3a63 --- /dev/null +++ b/src/service/rate_limiting/mod.rs @@ -0,0 +1,141 @@ +use std::{ + hash::Hash, + net::IpAddr, + time::{Duration, Instant}, +}; + +use dashmap::{mapref::entry::Entry, DashMap}; +use ruma::{ + api::{ + client::{ + account::{check_registration_token_validity, register}, + knock::knock_room, + media::{create_content, create_content_async}, + membership::{invite_user, join_room_by_id, join_room_by_id_or_alias}, + message::send_message_event, + session::login, + state::send_state_event, + }, + federation::{ + knock::send_knock, + membership::{create_invite, create_join_event}, + transactions::send_transaction_message, + }, + IncomingRequest, Metadata, + }, + OwnedServerName, OwnedUserId, +}; + +use crate::{config::Restriction, services, Result}; + +impl From for Restriction { + fn from(metadata: Metadata) -> Self { + #[allow(deprecated)] + match metadata { + register::v3::Request::METADATA => Restriction::Registration, + login::v3::Request::METADATA => Restriction::Login, + check_registration_token_validity::v1::Request::METADATA => { + Restriction::RegistrationTokenValidity + } + send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => { + Restriction::Message + } + join_room_by_id::v3::Request::METADATA + | join_room_by_id_or_alias::v3::Request::METADATA => Restriction::Join, + invite_user::v3::Request::METADATA => Restriction::Invite, + create_content::v3::Request::METADATA | create_content_async::v3::Request::METADATA => { + Restriction::CreateMedia + } + send_transaction_message::v1::Request::METADATA => Restriction::Transaction, + create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => { + Restriction::FederatedJoin + } + create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => { + Restriction::FederatedInvite + } + send_knock::v1::Request::METADATA => Restriction::FederatedKnock, + knock_room::v3::Request::METADATA => Restriction::Knock, + _ => Self::default(), + } + } +} + +pub struct Service { + store: DashMap<(Target, Restriction), (Instant, u64)>, +} + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Target { + User(OwnedUserId), + // Server endpoints should be rate-limited on a server and room basis + Server(OwnedServerName), + Appservice(String), + Ip(IpAddr), + None, +} + +impl Service { + pub fn build() -> Self { + Self { + store: DashMap::new(), + } + } + + pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> { + if target == Target::None { + return Ok(()); + } + + let restriction = metadata.into(); + + let arrival = Instant::now(); + + let config = &services().globals.config.rate_limiting; + + let Some(limit) = config.get(&restriction) else { + return Ok(()); + }; + // .unwrap_or(Limitation { + // per_minute: NonZeroU64::new(1).unwrap(), + // burst_capacity: NonZeroU64::new(1).unwrap(), + // weight: NonZeroU64::new(1).unwrap(), + // }); + + let gap = Duration::from_nanos(limit.timeframe.nano_gap()); + let key = (target, restriction); + + match self.store.entry(key) { + Entry::Occupied(mut entry) => { + let (instant, capacity) = entry.get_mut(); + + while *instant < arrival && *capacity != 0 { + *capacity -= 1; + *instant += gap; + } + + if *capacity >= limit.burst_capacity.get() { + return Err(gap); + } else { + let zero_capacity = *capacity == 0; + *capacity += 1; + + // Ensures that the update point is in the future + if zero_capacity { + *instant = Instant::now() + } + + *instant += gap; + } + } + Entry::Vacant(entry) => { + entry.insert((Instant::now() + gap, 1)); + } + }; + + Ok(()) + } +} + +///// In-memory state and utility functions used to check whether the client has exceeded its rate limit. +///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible. +/////