From 0d723046621234a6f6bada7973742c33f880b69a Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Sat, 14 Jun 2025 15:36:16 +0100 Subject: [PATCH] WIP: rate-limiting --- src/api/client_server/media.rs | 359 ++++++++++++------- src/api/ruma_wrapper/axum.rs | 31 +- src/api/ruma_wrapper/mod.rs | 3 +- src/api/server_server.rs | 23 +- src/config/mod.rs | 10 +- src/config/rate_limiting.rs | 115 ++++++ src/database/key_value/media.rs | 34 +- src/service/admin/mod.rs | 9 +- src/service/media/data.rs | 5 +- src/service/media/mod.rs | 44 ++- src/service/mod.rs | 4 + src/service/rate_limiting/mod.rs | 591 +++++++++++++++++++++++++++++++ 12 files changed, 1070 insertions(+), 158 deletions(-) create mode 100644 src/config/rate_limiting.rs create mode 100644 src/service/rate_limiting/mod.rs diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index e922b157..85364e73 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -3,7 +3,13 @@ use std::time::Duration; -use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma}; +use crate::{ + service::{ + media::{size, FileMeta}, + rate_limiting::Target, + }, + services, utils, Error, Result, Ruma, +}; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE}; use ruma::{ api::{ @@ -54,6 +60,8 @@ pub async fn get_media_config_auth_route( pub async fn create_content_route( body: Ruma, ) -> Result { + let sender_user = body.sender_user.expect("user is authenticated"); + let create_content::v3::Request { filename, content_type, @@ -61,6 +69,13 @@ pub async fn create_content_route( .. } = body.body; + let target = Target::from_client_request(body.appservice_info, &sender_user); + + services() + .rate_limiting + .check_media_upload(target, size(&file)?) + .await?; + let media_id = utils::random_string(MXC_LENGTH); services() @@ -71,7 +86,7 @@ pub async fn create_content_route( filename.as_deref(), content_type.as_deref(), &file, - body.sender_user.as_deref(), + Some(&sender_user), ) .await?; @@ -84,7 +99,13 @@ pub async fn create_content_route( pub async fn get_remote_content( server_name: &ServerName, media_id: String, + target: Target, ) -> Result { + services() + .rate_limiting + .check_media_pre_fetch(&target) + .await?; + let content_response = match services() .sending .send_federation_request( @@ -153,6 +174,11 @@ pub async fn get_remote_content( ) .await?; + services() + .rate_limiting + .update_media_post_fetch(target, size(&content_response.file)?) + .await; + Ok(content_response) } @@ -171,11 +197,21 @@ pub async fn get_content_route( } = get_content( &body.server_name, body.media_id.clone(), - body.allow_remote, - false, + body.sender_ip_address.map(Target::Ip), ) .await?; + if let Some(target) = Target::from_client_request_optional_auth( + body.appservice_info, + &body.sender_user, + body.sender_ip_address, + ) { + services() + .rate_limiting + .update_media_post_fetch(target, size(&file)?) + .await; + } + Ok(media::get_content::v3::Response { file, content_type, @@ -190,14 +226,24 @@ pub async fn get_content_route( pub async fn get_content_auth_route( body: Ruma, ) -> Result { - get_content(&body.server_name, body.media_id.clone(), true, true).await + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + + let sender_user = sender_user.as_ref().expect("user is authenticated"); + + let target = Target::from_client_request(appservice_info, sender_user); + + get_content(&body.server_name, body.media_id.clone(), Some(target)).await } pub async fn get_content( server_name: &ServerName, media_id: String, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -207,7 +253,7 @@ pub async fn get_content( file, })) = services() .media - .get(server_name, &media_id, authenticated) + .get(server_name, &media_id, target.clone()) .await { Ok(get_content::v1::Response { @@ -215,16 +261,25 @@ pub async fn get_content( content_type, content_disposition: Some(content_disposition), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let remote_content_response = get_remote_content(server_name, media_id.clone()).await?; - - Ok(get_content::v1::Response { - content_disposition: remote_content_response.content_disposition, - content_type: remote_content_response.content_type, - file: remote_content_response.file, - }) } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); + + if let Some(target) = target { + if server_name != services().globals.server_name() && target.is_authenticated() { + let remote_content_response = + get_remote_content(server_name, media_id.clone(), target).await?; + + Ok(get_content::v1::Response { + content_disposition: remote_content_response.content_disposition, + content_type: remote_content_response.content_type, + file: remote_content_response.file, + }) + } else { + error + } + } else { + error + } } } @@ -244,8 +299,7 @@ pub async fn get_content_as_filename_route( &body.server_name, body.media_id.clone(), body.filename.clone(), - body.allow_remote, - false, + body.sender_ip_address.map(Target::Ip), ) .await?; @@ -263,22 +317,33 @@ pub async fn get_content_as_filename_route( pub async fn get_content_as_filename_auth_route( body: Ruma, ) -> Result { - get_content_as_filename( + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + + let sender_user = sender_user.as_ref().expect("user is authenticated"); + + let target = Target::from_client_request(appservice_info, sender_user); + + let resp = get_content_as_filename( &body.server_name, body.media_id.clone(), body.filename.clone(), - true, - true, + Some(target), ) - .await + .await?; + + Ok(resp) } async fn get_content_as_filename( server_name: &ServerName, media_id: String, filename: String, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -286,7 +351,7 @@ async fn get_content_as_filename( file, content_type, .. })) = services() .media - .get(server_name, &media_id, authenticated) + .get(server_name, &media_id, target.clone()) .await { Ok(get_content_as_filename::v1::Response { @@ -297,19 +362,28 @@ async fn get_content_as_filename( .with_filename(Some(filename.clone())), ), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let remote_content_response = get_remote_content(server_name, media_id.clone()).await?; - - Ok(get_content_as_filename::v1::Response { - content_disposition: Some( - ContentDisposition::new(ContentDispositionType::Inline) - .with_filename(Some(filename.clone())), - ), - content_type: remote_content_response.content_type, - file: remote_content_response.file, - }) } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); + + if let Some(target) = target { + if server_name != services().globals.server_name() && target.is_authenticated() { + let remote_content_response = + get_remote_content(server_name, media_id.clone(), target).await?; + + Ok(get_content_as_filename::v1::Response { + content_disposition: Some( + ContentDisposition::new(ContentDispositionType::Inline) + .with_filename(Some(filename.clone())), + ), + content_type: remote_content_response.content_type, + file: remote_content_response.file, + }) + } else { + error + } + } else { + error + } } } @@ -321,6 +395,17 @@ async fn get_content_as_filename( pub async fn get_content_thumbnail_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_user, + sender_ip_address, + appservice_info, + .. + } = body; + + let target = + Target::from_client_request_optional_auth(appservice_info, &sender_user, sender_ip_address); + let get_content_thumbnail::v1::Response { file, content_type, @@ -332,8 +417,7 @@ pub async fn get_content_thumbnail_route( body.width, body.method.clone(), body.animated, - body.allow_remote, - false, + target, ) .await?; @@ -351,17 +435,27 @@ pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_auth_route( body: Ruma, ) -> Result { - get_content_thumbnail( + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + let sender_user = sender_user.as_ref().expect("user is authenticated"); + let target = Target::from_client_request(appservice_info, sender_user); + + let resp = get_content_thumbnail( &body.server_name, body.media_id.clone(), body.height, body.width, body.method.clone(), body.animated, - true, - true, + Some(target), ) - .await + .await?; + + Ok(resp) } #[allow(clippy::too_many_arguments)] @@ -372,8 +466,7 @@ async fn get_content_thumbnail( width: UInt, method: Option, animated: Option, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -392,7 +485,7 @@ async fn get_content_thumbnail( height .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, - authenticated, + target.clone(), ) .await? { @@ -401,99 +494,117 @@ async fn get_content_thumbnail( content_type, content_disposition: Some(content_disposition), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let thumbnail_response = match services() - .sending - .send_federation_request( - server_name, - federation_media::get_content_thumbnail::v1::Request { - height, - width, - method: method.clone(), - media_id: media_id.clone(), - timeout_ms: Duration::from_secs(20), - animated, - }, - ) - .await - { - Ok(federation_media::get_content_thumbnail::v1::Response { - metadata: _, - content: FileOrLocation::File(content), - }) => get_content_thumbnail::v1::Response { - file: content.file, - content_type: content.content_type, - content_disposition: content.content_disposition, - }, + } else { + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); - Ok(federation_media::get_content_thumbnail::v1::Response { - metadata: _, - content: FileOrLocation::Location(url), - }) => { - let get_content::v1::Response { - file, - content_type, - content_disposition, - } = get_location_content(url).await?; + if let Some(target) = target { + if server_name != services().globals.server_name() { + services() + .rate_limiting + .check_media_pre_fetch(&target) + .await?; - get_content_thumbnail::v1::Response { - file, - content_type, - content_disposition, - } - } - Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => { - let media::get_content_thumbnail::v3::Response { - file, - content_type, - content_disposition, - .. - } = services() + let thumbnail_response = match services() .sending .send_federation_request( server_name, - media::get_content_thumbnail::v3::Request { + federation_media::get_content_thumbnail::v1::Request { height, width, method: method.clone(), - server_name: server_name.to_owned(), media_id: media_id.clone(), timeout_ms: Duration::from_secs(20), - allow_redirect: false, animated, - allow_remote: false, }, ) + .await + { + Ok(federation_media::get_content_thumbnail::v1::Response { + metadata: _, + content: FileOrLocation::File(content), + }) => get_content_thumbnail::v1::Response { + file: content.file, + content_type: content.content_type, + content_disposition: content.content_disposition, + }, + + Ok(federation_media::get_content_thumbnail::v1::Response { + metadata: _, + content: FileOrLocation::Location(url), + }) => { + let get_content::v1::Response { + file, + content_type, + content_disposition, + } = get_location_content(url).await?; + + get_content_thumbnail::v1::Response { + file, + content_type, + content_disposition, + } + } + Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => { + let media::get_content_thumbnail::v3::Response { + file, + content_type, + content_disposition, + .. + } = services() + .sending + .send_federation_request( + server_name, + media::get_content_thumbnail::v3::Request { + height, + width, + method: method.clone(), + server_name: server_name.to_owned(), + media_id: media_id.clone(), + timeout_ms: Duration::from_secs(20), + allow_redirect: false, + animated, + allow_remote: false, + }, + ) + .await?; + + get_content_thumbnail::v1::Response { + file, + content_type, + content_disposition, + } + } + Err(e) => return Err(e), + }; + + services() + .rate_limiting + .update_media_post_fetch(target, size(&thumbnail_response.file)?) + .await; + + services() + .media + .upload_thumbnail( + server_name, + &media_id, + thumbnail_response + .content_disposition + .as_ref() + .and_then(|cd| cd.filename.as_deref()), + thumbnail_response.content_type.as_deref(), + width.try_into().expect("all UInts are valid u32s"), + height.try_into().expect("all UInts are valid u32s"), + &thumbnail_response.file, + ) .await?; - get_content_thumbnail::v1::Response { - file, - content_type, - content_disposition, - } + Ok(thumbnail_response) + } else { + error } - Err(e) => return Err(e), - }; - - services() - .media - .upload_thumbnail( - server_name, - &media_id, - thumbnail_response - .content_disposition - .as_ref() - .and_then(|cd| cd.filename.as_deref()), - thumbnail_response.content_type.as_deref(), - width.try_into().expect("all UInts are valid u32s"), - height.try_into().expect("all UInts are valid u32s"), - &thumbnail_response.file, - ) - .await?; - - Ok(thumbnail_response) - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + } else { + error + } } } diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 04456543..1a9e176a 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,4 +1,10 @@ -use std::{collections::BTreeMap, error::Error as _, iter::FromIterator, str}; +use std::{ + collections::BTreeMap, + error::Error as _, + iter::FromIterator, + net::IpAddr, + str::{self, FromStr}, +}; use axum::{ body::Body, @@ -24,7 +30,10 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{ + service::{appservice::RegistrationInfo, rate_limiting::Target}, + services, Error, Result, +}; enum Token { Appservice(Box), @@ -327,6 +336,23 @@ where } }; + let sender_ip_address = parts + .headers + .get("X-Forwarded-For") + .and_then(|header| header.to_str().ok()) + .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header)) + .and_then(|ip| IpAddr::from_str(ip).ok()); + + let target = if let Some(server_name) = sender_servername.clone() { + Some(Target::Server(server_name)) + } else if let Some(user) = &sender_user { + Some(Target::from_client_request(appservice_info.clone(), user)) + } else { + sender_ip_address.map(Target::Ip) + }; + + services().rate_limiting.check(target, metadata).await?; + let mut http_request = Request::builder().uri(parts.uri).method(parts.method); *http_request.headers_mut().unwrap() = parts.headers; @@ -377,6 +403,7 @@ where sender_servername, appservice_info, json_body, + sender_ip_address, }) } } diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index 862da1dc..a741676c 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -3,7 +3,7 @@ use ruma::{ api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, }; -use std::ops::Deref; +use std::{net::IpAddr, ops::Deref}; #[cfg(feature = "conduit_bin")] mod axum; @@ -14,6 +14,7 @@ pub struct Ruma { pub sender_user: Option, pub sender_device: Option, pub sender_servername: Option, + pub sender_ip_address: Option, // This is None when body is not a valid string pub json_body: Option, pub appservice_info: Option, diff --git a/src/api/server_server.rs b/src/api/server_server.rs index adc764ff..10f4fb8c 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -6,6 +6,7 @@ use crate::{ globals::SigningKeys, media::FileMeta, pdu::{gen_event_id_canonical_json, PduBuilder}, + rate_limiting::Target, }, services, utils, Error, PduEvent, Result, Ruma, SUPPORTED_VERSIONS, }; @@ -2237,6 +2238,13 @@ pub async fn create_invite_route( pub async fn get_content_route( body: Ruma, ) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + let target = Some(Target::Server(sender_servername.to_owned())); + services() .media .check_blocked(services().globals.server_name(), &body.media_id)?; @@ -2247,7 +2255,11 @@ pub async fn get_content_route( file, }) = services() .media - .get(services().globals.server_name(), &body.media_id, true) + .get( + services().globals.server_name(), + &body.media_id, + target.clone(), + ) .await? { Ok(get_content::v1::Response::new( @@ -2269,6 +2281,13 @@ pub async fn get_content_route( pub async fn get_content_thumbnail_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_servername, + .. + } = body; + let sender_servername = sender_servername.expect("server is authenticated"); + services() .media .check_blocked(services().globals.server_name(), &body.media_id)?; @@ -2288,7 +2307,7 @@ pub async fn get_content_thumbnail_route( body.height .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - true, + Some(Target::Server(sender_servername)), ) .await? else { diff --git a/src/config/mod.rs b/src/config/mod.rs index 098dc20d..0e7549cf 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,7 +17,9 @@ use url::Url; use crate::Error; mod proxy; -use self::proxy::ProxyConfig; +pub mod rate_limiting; + +use self::{proxy::ProxyConfig, rate_limiting::Config as RateLimitingConfig}; const SHA256_HEX_LENGTH: u8 = 64; @@ -92,6 +94,8 @@ pub struct IncompleteConfig { #[serde(default)] pub media: IncompleteMediaConfig, + pub rate_limiting: RateLimitingConfig, + pub emergency_password: Option, #[serde(flatten)] @@ -138,6 +142,8 @@ pub struct Config { pub media: MediaConfig, + pub rate_limiting: RateLimitingConfig, + pub emergency_password: Option, pub catchall: BTreeMap, @@ -184,6 +190,7 @@ impl From for Config { turn_ttl, turn, media, + rate_limiting, emergency_password, catchall, } = val; @@ -281,6 +288,7 @@ impl From for Config { log, turn, media, + rate_limiting, emergency_password, catchall, } diff --git a/src/config/rate_limiting.rs b/src/config/rate_limiting.rs new file mode 100644 index 00000000..a1f26886 --- /dev/null +++ b/src/config/rate_limiting.rs @@ -0,0 +1,115 @@ +use std::{collections::HashMap, hash::Hash, num::NonZeroU64}; + +use bytesize::ByteSize; +use serde::Deserialize; + +use crate::service::rate_limiting::{ClientRestriction, FederationRestriction, Restriction}; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + #[serde(flatten)] + pub target: ConfigFragment, + pub global: ConfigFragment, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ConfigFragment { + pub client: ConfigSideFragment, + pub federation: ConfigSideFragment, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ConfigSideFragment +where + K: Eq + Hash, +{ + #[serde(flatten)] + pub map: HashMap, + pub media: C, +} + +impl ConfigFragment { + pub fn get(&self, restriction: &Restriction) -> &RequestLimitation { + // Maybe look into https://github.com/moriyoshi-kasuga/enum-table + match restriction { + Restriction::Client(client_restriction) => { + self.client.map.get(client_restriction).unwrap() + } + Restriction::Federation(federation_restriction) => { + self.federation.map.get(federation_restriction).unwrap() + } + Restriction::Media(media_restriction) => todo!(), + Restriction::CatchAll => todo!(), + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct RequestLimitation { + #[serde(flatten)] + pub timeframe: Timeframe, + pub burst_capacity: NonZeroU64, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum Timeframe { + PerSecond(NonZeroU64), + PerMinute(NonZeroU64), + PerHour(NonZeroU64), + PerDay(NonZeroU64), +} + +impl Timeframe { + pub fn nano_gap(&self) -> u64 { + match self { + Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(), + Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(), + Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(), + Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(), + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct ClientMediaConfig { + pub download: MediaLimitation, + pub upload: MediaLimitation, + pub fetch: MediaLimitation, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct FederationMediaConfig { + pub download: MediaLimitation, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct MediaLimitation { + #[serde(flatten)] + pub timeframe: MediaTimeframe, + pub burst_capacity: ByteSize, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum MediaTimeframe { + PerSecond(ByteSize), + PerMinute(ByteSize), + PerHour(ByteSize), + PerDay(ByteSize), +} + +impl MediaTimeframe { + pub fn bytes_per_sec(&self) -> u64 { + match self { + MediaTimeframe::PerSecond(t) => t.as_u64(), + MediaTimeframe::PerMinute(t) => t.as_u64() / 60, + MediaTimeframe::PerHour(t) => t.as_u64() / (60 * 60), + MediaTimeframe::PerDay(t) => t.as_u64() / (60 * 60 * 24), + } + } +} diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index 695c7d3c..f6326cdb 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -203,19 +203,7 @@ impl service::media::Data for KeyValueDatabase { let is_blocked_via_filehash = self.is_blocked_filehash(&sha256_digest)?; - let time_info = if let Some(filehash_meta) = self - .filehash_metadata - .get(&sha256_digest)? - .map(FilehashMetadata::from_vec) - { - Some(FileInfo { - creation: filehash_meta.creation(&sha256_digest)?, - last_access: filehash_meta.last_access(&sha256_digest)?, - size: filehash_meta.size(&sha256_digest)?, - }) - } else { - None - }; + let file_info = self.file_info(&sha256_digest)?; Some(MediaQueryFileInfo { uploader_localpart, @@ -224,7 +212,7 @@ impl service::media::Data for KeyValueDatabase { content_type, unauthenticated_access_permitted, is_blocked_via_filehash, - file_info: time_info, + file_info, }) } else { None @@ -1353,6 +1341,24 @@ impl service::media::Data for KeyValueDatabase { Ok(()) } } + + fn file_info(&self, sha256_digest: &[u8]) -> Result, Error> { + Ok( + if let Some(filehash_meta) = self + .filehash_metadata + .get(sha256_digest)? + .map(FilehashMetadata::from_vec) + { + Some(FileInfo { + creation: filehash_meta.creation(&sha256_digest)?, + last_access: filehash_meta.last_access(&sha256_digest)?, + size: filehash_meta.size(&sha256_digest)?, + }) + } else { + None + }, + ) + } } impl KeyValueDatabase { diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 0a3f87b9..a7ef3d83 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -41,6 +41,7 @@ use tokio::sync::{mpsc, Mutex, RwLock}; use crate::{ api::client_server::{self, leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, + service::rate_limiting::Target, services, utils::{self, HtmlEscape}, Error, PduEvent, Result, @@ -1174,8 +1175,12 @@ impl Service { file, content_type, content_disposition, - } = client_server::media::get_content(server_name, media_id.to_owned(), true, true) - .await?; + } = client_server::media::get_content( + server_name, + media_id.to_owned(), + Some(Target::User(services().globals.server_user().to_owned())), + ) + .await?; if let Ok(image) = image::load_from_memory(&file) { let filename = content_disposition.and_then(|cd| cd.filename); diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 444f5f9a..53934837 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,7 +1,7 @@ use ruma::{OwnedServerName, ServerName, UserId}; use sha2::{digest::Output, Sha256}; -use crate::{config::MediaRetentionConfig, Error, Result}; +use crate::{config::MediaRetentionConfig, service::media::FileInfo, Error, Result}; use super::{ BlockedMediaInfo, DbFileMeta, MediaListItem, MediaQuery, MediaType, ServerNameOrUserId, @@ -124,4 +124,7 @@ pub trait Data: Send + Sync { fn update_last_accessed(&self, server_name: &ServerName, media_id: &str) -> Result<()>; fn update_last_accessed_filehash(&self, sha256_digest: &[u8]) -> Result<()>; + + /// Returns the known information about a file + fn file_info(&self, sha256_digest: &[u8]) -> Result>; } diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 2f5c814d..5d3283ad 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -17,6 +17,7 @@ use tracing::{error, info, warn}; use crate::{ config::{DirectoryStructure, MediaBackendConfig, S3MediaBackend}, + service::rate_limiting::Target, services, utils, Error, Result, }; use image::imageops::FilterType; @@ -237,7 +238,7 @@ impl Service { &self, servername: &ServerName, media_id: &str, - authenticated: bool, + target: Option, ) -> Result> { let DbFileMeta { sha256_digest, @@ -246,12 +247,19 @@ impl Service { unauthenticated_access_permitted, } = self.db.search_file_metadata(servername, media_id)?; - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } let file = self.get_file(&sha256_digest, None).await?; + services() + .rate_limiting + .check_media_download(target, size(&file)?) + .await?; + Ok(Some(FileMeta { content_disposition: content_disposition(filename, &content_type), content_type, @@ -288,7 +296,7 @@ impl Service { media_id: &str, width: u32, height: u32, - authenticated: bool, + target: Option, ) -> Result> { if let Some((width, height, crop)) = self.thumbnail_properties(width, height) { if let Ok(DbFileMeta { @@ -300,10 +308,19 @@ impl Service { .db .search_thumbnail_metadata(servername, media_id, width, height) { - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } + let file_info = self.file_info(&sha256_digest)?; + + services() + .rate_limiting + .check_media_download(target, file_info.size) + .await?; + // Using saved thumbnail let file = self .get_file(&sha256_digest, Some((servername, media_id))) @@ -314,19 +331,15 @@ impl Service { content_type, file, })) - } else if !authenticated { + } else if !target.as_ref().is_some_and(Target::is_authenticated) { return Ok(None); } else if let Ok(DbFileMeta { sha256_digest, filename, content_type, - unauthenticated_access_permitted, + .. }) = self.db.search_file_metadata(servername, media_id) { - if !(authenticated || unauthenticated_access_permitted) { - return Ok(None); - } - let content_disposition = content_disposition(filename.clone(), &content_type); // Generate a thumbnail let file = self.get_file(&sha256_digest, None).await?; @@ -426,7 +439,9 @@ impl Service { return Ok(None); }; - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } @@ -662,6 +677,13 @@ impl Service { .update_last_accessed_filehash(sha256_digest) .map(|_| file) } + + fn file_info(&self, sha256_digest: &[u8]) -> Result { + self.db + .file_info(sha256_digest) + .transpose() + .unwrap_or_else(|| Err(Error::BadRequest(ErrorKind::NotFound, "Fi)le not found"))) + } } /// Creates the media file, using the configured media backend diff --git a/src/service/mod.rs b/src/service/mod.rs index 432c0e7a..6c511391 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -17,6 +17,7 @@ pub mod key_backups; pub mod media; pub mod pdu; pub mod pusher; +pub mod rate_limiting; pub mod rooms; pub mod sending; pub mod transaction_ids; @@ -36,6 +37,7 @@ pub struct Services { pub key_backups: key_backups::Service, pub media: Arc, pub sending: Arc, + pub rate_limiting: Arc, } impl Services { @@ -123,6 +125,8 @@ impl Services { media: Arc::new(media::Service { db }), sending: sending::Service::build(db, &config), + rate_limiting: rate_limiting::Service::build(&config), + globals: globals::Service::load(db, config)?, }) } diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs new file mode 100644 index 00000000..057411cd --- /dev/null +++ b/src/service/rate_limiting/mod.rs @@ -0,0 +1,591 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + net::IpAddr, + sync::Arc, + time::Duration, +}; + +use ruma::{ + api::{ + client::error::{ErrorKind, RetryAfter}, + federation::membership::create_knock_event, + Metadata, + }, + OwnedServerName, OwnedUserId, UserId, +}; +use serde::Deserialize; +use tokio::{ + sync::{Mutex, MutexGuard}, + time::Instant, +}; + +use crate::{ + config::rate_limiting::MediaLimitation, service::appservice::RegistrationInfo, services, + Config, Error, Result, +}; + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Target { + User(OwnedUserId), + // Server endpoints should be rate-limited on a server and room basis + Server(OwnedServerName), + Appservice { id: String, rate_limited: bool }, + Ip(IpAddr), +} + +impl Target { + pub fn from_client_request( + registration_info: Option, + sender_user: &UserId, + ) -> Self { + if let Some(info) = registration_info { + // `rate_limited` only effects "masqueraded users", "The sender [user?] is excluded" + return Target::Appservice { + id: info.registration.id, + rate_limited: info.registration.rate_limited.unwrap_or(true) + && !(sender_user.server_name() == services().globals.server_name() + && info.registration.sender_localpart == sender_user.localpart()), + }; + } + + Target::User(sender_user.to_owned()) + } + + pub fn from_client_request_optional_auth( + registration_info: Option, + sender_user: &Option, + ip_addr: Option, + ) -> Option { + if let Some(sender_user) = sender_user.as_ref() { + Some(Self::from_client_request(registration_info, sender_user)) + } else { + ip_addr.map(Self::Ip) + } + } + + fn rate_limited(&self) -> bool { + match self { + Target::User(user_id) => user_id != services().globals.server_user(), + Target::Appservice { + id: _, + rate_limited, + } => *rate_limited, + _ => true, + } + } + + pub fn is_authenticated(&self) -> bool { + !matches!(self, Target::Ip(_)) + } +} + +#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Restriction { + Client(ClientRestriction), + Federation(FederationRestriction), + Media(MediaRestriction), + + #[default] + CatchAll, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum ClientRestriction { + Registration, + Login, + RegistrationTokenValidity, + + SendEvent, + + Join, + Invite, + Knock, + + SendReport, + CreateAlias, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum FederationRestriction { + Join, + Knock, + Invite, + + // Transactions should be handled by a completely dedicated rate-limiter + Transaction, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum MediaRestriction { + Create, + Download, +} + +impl From for Restriction { + fn from(value: Metadata) -> Self { + use ruma::api::{ + client::{ + account::{check_registration_token_validity, register}, + alias::create_alias, + authenticated_media::{ + get_content, get_content_as_filename, get_content_thumbnail, get_media_preview, + }, + knock::knock_room, + media::{self, create_content, create_mxc_uri}, + membership::{invite_user, join_room_by_id, join_room_by_id_or_alias}, + message::send_message_event, + reporting::report_user, + room::{report_content, report_room}, + session::login, + state::send_state_event, + }, + federation::{ + authenticated_media::{ + get_content as federation_get_content, + get_content_thumbnail as federation_get_content_thumbnail, + }, + membership::{create_invite, create_join_event}, + }, + IncomingRequest, + }; + use Restriction::*; + + match value { + register::v3::Request::METADATA => Client(ClientRestriction::Registration), + check_registration_token_validity::v1::Request::METADATA => { + Client(ClientRestriction::RegistrationTokenValidity) + } + login::v3::Request::METADATA => Client(ClientRestriction::Login), + send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => { + Client(ClientRestriction::SendEvent) + } + join_room_by_id::v3::Request::METADATA + | join_room_by_id_or_alias::v3::Request::METADATA => Client(ClientRestriction::Join), + invite_user::v3::Request::METADATA => Client(ClientRestriction::Invite), + knock_room::v3::Request::METADATA => Client(ClientRestriction::Knock), + report_user::v3::Request::METADATA + | report_content::v3::Request::METADATA + | report_room::v3::Request::METADATA => Client(ClientRestriction::SendReport), + create_alias::v3::Request::METADATA => Client(ClientRestriction::CreateAlias), + // NOTE: handle async media upload in a way that doesn't half the number of uploads you can do within a short timeframe, while not allowing pre-generation of MXC uris to allow uploading double the number of media at once + create_content::v3::Request::METADATA | create_mxc_uri::v1::Request::METADATA => { + Media(MediaRestriction::Create) + } + // Unauthenticate media is deprecated + #[allow(deprecated)] + media::get_content::v3::Request::METADATA + | media::get_content_as_filename::v3::Request::METADATA + | media::get_content_thumbnail::v3::Request::METADATA + | media::get_media_preview::v3::Request::METADATA + | get_content::v1::Request::METADATA + | get_content_as_filename::v1::Request::METADATA + | get_content_thumbnail::v1::Request::METADATA + | get_media_preview::v1::Request::METADATA + | federation_get_content::v1::Request::METADATA + | federation_get_content_thumbnail::v1::Request::METADATA => { + Media(MediaRestriction::Download) + } + // v1 is deprecated + #[allow(deprecated)] + create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => { + Federation(FederationRestriction::Join) + } + create_knock_event::v1::Request::METADATA => Federation(FederationRestriction::Knock), + create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => { + Federation(FederationRestriction::Invite) + } + + _ => Self::default(), + } + } +} + +type MediaBucket = Mutex>>>; +type GlobalMediaBucket = Arc>; + +pub struct Service { + buckets: Mutex>>>, + global_bucket: Mutex>>>, + + media_upload: MediaBucket, + media_fetch: MediaBucket, + media_download: MediaBucket, + + global_media_upload: GlobalMediaBucket, + global_media_fetch: GlobalMediaBucket, + global_media_download_client: GlobalMediaBucket, + global_media_download_federation: GlobalMediaBucket, +} + +impl Service { + pub fn build(config: &Config) -> Arc { + let now = Instant::now(); + let global_media_config = &config.rate_limiting.global; + + Arc::new(Self { + buckets: Mutex::new(HashMap::new()), + global_bucket: Mutex::new(HashMap::new()), + + media_upload: Mutex::new(HashMap::new()), + media_fetch: Mutex::new(HashMap::new()), + media_download: Mutex::new(HashMap::new()), + + global_media_upload: default_media_entry(global_media_config.client.media.upload, now), + global_media_fetch: default_media_entry(global_media_config.client.media.fetch, now), + global_media_download_client: default_media_entry( + global_media_config.client.media.download, + now, + ), + global_media_download_federation: default_media_entry( + global_media_config.federation.media.download, + now, + ), + }) + } + + //TODO: use checked and saturating arithmetic + + /// Takes the target and request, and either accepts the request while adding to the + /// bucket, or rejects the request, returning the duration that should be waited until + /// the request should be retried. + pub async fn check(&self, target: Option, request: Metadata) -> Result<()> { + let restriction: Restriction = request.into(); + let arrival = Instant::now(); + + { + let map = self.global_bucket.lock().await; + + if let Some(value) = map.get(&restriction) { + let value = value.lock().await; + + if arrival.checked_duration_since(*value).is_none() { + instant_to_err(&value)?; + } + } + } + + if let Some(target) = target { + let config = services() + .globals + .config + .rate_limiting + .target + .get(&restriction); + + let mut map = self.buckets.lock().await; + let entry = map.entry((target, restriction)); + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + if arrival.checked_duration_since(*entry).is_none() { + return instant_to_err(&entry); + } + + let min_instant = arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * config.burst_capacity.get(), + ); + *entry = + entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), + ), + ))); + } + } + } + + { + let config = services() + .globals + .config + .rate_limiting + .global + .get(&restriction); + + let mut map = self.global_bucket.lock().await; + + let entry = map.entry(restriction); + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + if arrival.checked_duration_since(*entry).is_none() { + return instant_to_err(&entry); + } + + let min_instant = arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * config.burst_capacity.get(), + ); + *entry = + entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), + ), + ))); + } + } + } + + Ok(()) + } + + pub async fn check_media_download(&self, target: Option, size: u64) -> Result<()> { + // All targets besides servers use the client-server API + let (target_limitation, global_limitation, global_bucket) = + if let Some(Target::Server(_)) = &target { + ( + services() + .globals + .config + .rate_limiting + .target + .federation + .media + .download, + services() + .globals + .config + .rate_limiting + .global + .federation + .media + .download, + &self.global_media_download_federation, + ) + } else { + ( + services() + .globals + .config + .rate_limiting + .target + .client + .media + .download, + services() + .globals + .config + .rate_limiting + .global + .client + .media + .download, + &self.global_media_download_client, + ) + }; + + check_media( + target, + size, + target_limitation, + global_limitation, + &self.media_download, + global_bucket, + ) + .await + } + + pub async fn check_media_upload(&self, target: Target, size: u64) -> Result<()> { + let target_limitation = services() + .globals + .config + .rate_limiting + .target + // Media can only be uploaded on the client-server API + .client + .media + .upload; + + let global_limitation = services() + .globals + .config + .rate_limiting + .global + // Media can only be uploaded on the client-server API + .client + .media + .upload; + + check_media( + Some(target), + size, + target_limitation, + global_limitation, + &self.media_upload, + &self.global_media_upload, + ) + .await + } + + pub async fn check_media_pre_fetch(&self, target: &Target) -> Result<()> { + let arrival = Instant::now(); + + let global_bucket = self.global_media_fetch; + + let map = self.media_fetch.lock().await; + if let Some(mutex) = map.get(target) { + let mutex = mutex.lock().await; + + if arrival.checked_duration_since(*mutex).is_none() { + return instant_to_err(&mutex); + } + } + + Ok(()) + } + + pub async fn update_media_post_fetch(&self, target: Target, size: u64) { + if !target.rate_limited() { + return; + } + + let limitation = services() + .globals + .config + .rate_limiting + .target + // Media can only be "fetched" (causing our server to download media from another server) by the client-server API + .client + .media + .fetch; + + let arrival = Instant::now(); + + let mut map = self.media_fetch.lock().await; + let entry = map.entry(target); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + + update_media_entry(size, &limitation, &arrival, entry, false).await; + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + limitation.burst_capacity.as_u64() + / limitation.timeframe.bytes_per_sec(), + ), + ))); + } + } + } +} + +async fn update_media_entry( + size: u64, + limitation: &MediaLimitation, + arrival: &Instant, + entry: Arc>, + and_check: bool, +) -> Result<()> { + let mut entry = entry.lock().await; + + //TODO: use more precise conversion than secs + let proposed_entry = get_proposed_entry(size, limitation, arrival, &entry, and_check)?; + + *entry = proposed_entry; + + Ok(()) +} + +fn get_proposed_entry( + size: u64, + limitation: &MediaLimitation, + arrival: &Instant, + entry: &MutexGuard<'_, Instant>, + and_check: bool, +) -> Result { + let min_instant = *arrival + - Duration::from_secs( + limitation.burst_capacity.as_u64() / limitation.timeframe.bytes_per_sec(), + ); + + let proposed_entry = + entry.max(min_instant) + Duration::from_secs(size / limitation.timeframe.bytes_per_sec()); + + if and_check && arrival.checked_duration_since(proposed_entry).is_none() { + return instant_to_err(&proposed_entry).map(|_| proposed_entry); + } + + Ok(proposed_entry) +} + +async fn check_media( + target: Option, + size: u64, + target_limitation: MediaLimitation, + global_limitation: MediaLimitation, + target_map: &MediaBucket, + global_bucket: &GlobalMediaBucket, +) -> Result<()> { + if !target.as_ref().is_some_and(Target::rate_limited) { + return Ok(()); + } + + let arrival = Instant::now(); + + let mut global_bucket = global_bucket.lock().await; + let proposed = get_proposed_entry(size, &global_limitation, &arrival, &global_bucket, true)?; + + if let Some(target) = target { + let mut map = target_map.lock().await; + let entry = map.entry(target); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + + update_media_entry(size, &target_limitation, &arrival, entry, true).await; + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(default_media_entry(target_limitation, arrival)); + } + } + } + + *global_bucket = proposed; + + Ok(()) +} + +fn default_media_entry( + target_limitation: MediaLimitation, + arrival: Instant, +) -> Arc> { + Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + target_limitation.burst_capacity.as_u64() + / target_limitation.timeframe.bytes_per_sec(), + ), + )) +} + +fn instant_to_err(instant: &Instant) -> Result<()> { + let now = Instant::now(); + + Err(Error::BadRequest( + ErrorKind::LimitExceeded { + // Not using ::DateTime because conversion from Instant to SystemTime is convoluted + retry_after: Some(RetryAfter::Delay(instant.duration_since(now))), + }, + "Rate limit exceeded", + )) +}