diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index e922b157..ce770519 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -3,7 +3,13 @@ use std::time::Duration; -use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma}; +use crate::{ + service::{ + media::{size, FileMeta}, + rate_limiting::Target, + }, + services, utils, Error, Result, Ruma, +}; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE}; use ruma::{ api::{ @@ -54,6 +60,8 @@ pub async fn get_media_config_auth_route( pub async fn create_content_route( body: Ruma, ) -> Result { + let sender_user = body.sender_user.expect("user is authenticated"); + let create_content::v3::Request { filename, content_type, @@ -61,6 +69,13 @@ pub async fn create_content_route( .. } = body.body; + let target = Target::from_client_request(body.appservice_info, &sender_user); + + services() + .rate_limiting + .check_media_upload(target, size(&file)?) + .await?; + let media_id = utils::random_string(MXC_LENGTH); services() @@ -71,7 +86,7 @@ pub async fn create_content_route( filename.as_deref(), content_type.as_deref(), &file, - body.sender_user.as_deref(), + Some(&sender_user), ) .await?; @@ -176,6 +191,17 @@ pub async fn get_content_route( ) .await?; + if let Some(target) = Target::from_client_request_optional_auth( + body.appservice_info, + &body.sender_user, + body.sender_ip_address, + ) { + services() + .rate_limiting + .update_media_post_fetch(target, size(&file)?) + .await; + } + Ok(media::get_content::v3::Response { file, content_type, @@ -190,7 +216,18 @@ pub async fn get_content_route( pub async fn get_content_auth_route( body: Ruma, ) -> Result { - get_content(&body.server_name, body.media_id.clone(), true, true).await + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let resp = get_content(&body.server_name, body.media_id.clone(), true, true).await?; + + let target = Target::from_client_request(body.appservice_info, sender_user); + + services() + .rate_limiting + .update_media_post_fetch(target, size(&resp.file)?) + .await; + + Ok(resp) } pub async fn get_content( @@ -249,6 +286,17 @@ pub async fn get_content_as_filename_route( ) .await?; + if let Some(target) = Target::from_client_request_optional_auth( + body.appservice_info, + &body.sender_user, + body.sender_ip_address, + ) { + services() + .rate_limiting + .update_media_post_fetch(target, size(&file)?) + .await; + } + Ok(media::get_content_as_filename::v3::Response { file, content_type, @@ -263,14 +311,25 @@ pub async fn get_content_as_filename_route( pub async fn get_content_as_filename_auth_route( body: Ruma, ) -> Result { - get_content_as_filename( + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let resp = get_content_as_filename( &body.server_name, body.media_id.clone(), body.filename.clone(), true, true, ) - .await + .await?; + + let target = Target::from_client_request(body.appservice_info, sender_user); + + services() + .rate_limiting + .update_media_post_fetch(target, size(&resp.file)?) + .await; + + Ok(resp) } async fn get_content_as_filename( @@ -337,6 +396,17 @@ pub async fn get_content_thumbnail_route( ) .await?; + if let Some(target) = Target::from_client_request_optional_auth( + body.appservice_info, + &body.sender_user, + body.sender_ip_address, + ) { + services() + .rate_limiting + .update_media_post_fetch(target, size(&file)?) + .await; + } + Ok(media::get_content_thumbnail::v3::Response { file, content_type, @@ -351,7 +421,9 @@ pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_auth_route( body: Ruma, ) -> Result { - get_content_thumbnail( + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let resp = get_content_thumbnail( &body.server_name, body.media_id.clone(), body.height, @@ -361,7 +433,16 @@ pub async fn get_content_thumbnail_auth_route( true, true, ) - .await + .await?; + + let target = Target::from_client_request(body.appservice_info, sender_user); + + services() + .rate_limiting + .update_media_post_fetch(target, size(&resp.file)?) + .await; + + Ok(resp) } #[allow(clippy::too_many_arguments)] diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 04456543..1a9e176a 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,4 +1,10 @@ -use std::{collections::BTreeMap, error::Error as _, iter::FromIterator, str}; +use std::{ + collections::BTreeMap, + error::Error as _, + iter::FromIterator, + net::IpAddr, + str::{self, FromStr}, +}; use axum::{ body::Body, @@ -24,7 +30,10 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{ + service::{appservice::RegistrationInfo, rate_limiting::Target}, + services, Error, Result, +}; enum Token { Appservice(Box), @@ -327,6 +336,23 @@ where } }; + let sender_ip_address = parts + .headers + .get("X-Forwarded-For") + .and_then(|header| header.to_str().ok()) + .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header)) + .and_then(|ip| IpAddr::from_str(ip).ok()); + + let target = if let Some(server_name) = sender_servername.clone() { + Some(Target::Server(server_name)) + } else if let Some(user) = &sender_user { + Some(Target::from_client_request(appservice_info.clone(), user)) + } else { + sender_ip_address.map(Target::Ip) + }; + + services().rate_limiting.check(target, metadata).await?; + let mut http_request = Request::builder().uri(parts.uri).method(parts.method); *http_request.headers_mut().unwrap() = parts.headers; @@ -377,6 +403,7 @@ where sender_servername, appservice_info, json_body, + sender_ip_address, }) } } diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index 862da1dc..a741676c 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -3,7 +3,7 @@ use ruma::{ api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, }; -use std::ops::Deref; +use std::{net::IpAddr, ops::Deref}; #[cfg(feature = "conduit_bin")] mod axum; @@ -14,6 +14,7 @@ pub struct Ruma { pub sender_user: Option, pub sender_device: Option, pub sender_servername: Option, + pub sender_ip_address: Option, // This is None when body is not a valid string pub json_body: Option, pub appservice_info: Option, diff --git a/src/api/server_server.rs b/src/api/server_server.rs index adc764ff..8aff7541 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -4,8 +4,9 @@ use crate::{ api::client_server::{self, claim_keys_helper, get_keys_helper}, service::{ globals::SigningKeys, - media::FileMeta, + media::{size, FileMeta}, pdu::{gen_event_id_canonical_json, PduBuilder}, + rate_limiting::Target, }, services, utils, Error, PduEvent, Result, Ruma, SUPPORTED_VERSIONS, }; @@ -2237,6 +2238,11 @@ pub async fn create_invite_route( pub async fn get_content_route( body: Ruma, ) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + services() .media .check_blocked(services().globals.server_name(), &body.media_id)?; @@ -2250,6 +2256,11 @@ pub async fn get_content_route( .get(services().globals.server_name(), &body.media_id, true) .await? { + services() + .rate_limiting + .update_media_post_fetch(Target::Server(sender_servername.to_owned()), size(&file)?) + .await; + Ok(get_content::v1::Response::new( ContentMetadata::new(), FileOrLocation::File(Content { @@ -2269,6 +2280,11 @@ pub async fn get_content_route( pub async fn get_content_thumbnail_route( body: Ruma, ) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + services() .media .check_blocked(services().globals.server_name(), &body.media_id)?; @@ -2295,6 +2311,11 @@ pub async fn get_content_thumbnail_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); }; + services() + .rate_limiting + .update_media_post_fetch(Target::Server(sender_servername.to_owned()), size(&file)?) + .await; + services() .media .upload_thumbnail( diff --git a/src/config/mod.rs b/src/config/mod.rs index 098dc20d..b484647d 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,7 +17,9 @@ use url::Url; use crate::Error; mod proxy; -use self::proxy::ProxyConfig; +pub mod rate_limiting; + +use self::{proxy::ProxyConfig, rate_limiting::Config as RateLimitingConfig}; const SHA256_HEX_LENGTH: u8 = 64; @@ -138,6 +140,8 @@ pub struct Config { pub media: MediaConfig, + pub rate_limiting: RateLimitingConfig, + pub emergency_password: Option, pub catchall: BTreeMap, @@ -281,6 +285,7 @@ impl From for Config { log, turn, media, + rate_limiting: todo!(), emergency_password, catchall, } diff --git a/src/config/rate_limiting.rs b/src/config/rate_limiting.rs new file mode 100644 index 00000000..c2fc2b43 --- /dev/null +++ b/src/config/rate_limiting.rs @@ -0,0 +1,117 @@ +use std::{collections::HashMap, num::NonZeroU64}; + +use bytesize::ByteSize; +use serde::Deserialize; + +use crate::service::rate_limiting::{ClientRestriction, FederationRestriction, Restriction}; + +#[derive(Debug, Clone)] +pub struct Config { + pub media: MediaConfig, + pub client: HashMap, + pub federation: HashMap, + pub global: GlobalConfig, +} + +impl Config { + pub fn get(&self, restriction: &Restriction) -> &Limitation { + // Maybe look into https://github.com/moriyoshi-kasuga/enum-table + match restriction { + Restriction::Client(client_restriction) => self.client.get(client_restriction).unwrap(), + Restriction::Federation(federation_restriction) => { + self.federation.get(federation_restriction).unwrap() + } + Restriction::Media(media_restriction) => todo!(), + Restriction::CatchAll => todo!(), + } + } +} + +#[derive(Debug, Clone)] +pub struct GlobalConfig { + pub client: HashMap, + pub federation: HashMap, +} + +//TODO: fold into one instead of copy-pasting +impl GlobalConfig { + pub fn get(&self, restriction: &Restriction) -> &Limitation { + // Maybe look into https://github.com/moriyoshi-kasuga/enum-table + match restriction { + Restriction::Client(client_restriction) => self.client.get(client_restriction).unwrap(), + Restriction::Federation(federation_restriction) => { + self.federation.get(federation_restriction).unwrap() + } + Restriction::Media(media_restriction) => todo!(), + Restriction::CatchAll => todo!(), + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct Limitation { + #[serde(flatten)] + pub timeframe: Timeframe, + pub burst_capacity: NonZeroU64, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum Timeframe { + PerSecond(NonZeroU64), + PerMinute(NonZeroU64), + PerHour(NonZeroU64), + PerDay(NonZeroU64), +} + +impl Timeframe { + pub fn nano_gap(&self) -> u64 { + match self { + Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(), + Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(), + Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(), + Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct MediaConfig { + pub upload: MediaLimitation, + pub fetch: MediaLimitation, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct MediaLimitation { + #[serde(flatten)] + pub timeframe: MediaTimeframe, + pub burst_capacity: ByteSize, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum MediaTimeframe { + PerSecond(ByteSize), + PerMinute(ByteSize), + PerHour(ByteSize), + PerDay(ByteSize), +} + +impl MediaTimeframe { + pub fn bytes_per_sec(&self) -> u64 { + match self { + MediaTimeframe::PerSecond(t) => t.as_u64(), + MediaTimeframe::PerMinute(t) => t.as_u64() / 60, + MediaTimeframe::PerHour(t) => t.as_u64() / (60 * 60), + MediaTimeframe::PerDay(t) => t.as_u64() / (60 * 60 * 24), + } + } +} + +fn default_non_zero() -> NonZeroU64 { + NonZeroU64::MIN +} diff --git a/src/service/mod.rs b/src/service/mod.rs index 432c0e7a..cc171ebe 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -17,6 +17,7 @@ pub mod key_backups; pub mod media; pub mod pdu; pub mod pusher; +pub mod rate_limiting; pub mod rooms; pub mod sending; pub mod transaction_ids; @@ -36,6 +37,7 @@ pub struct Services { pub key_backups: key_backups::Service, pub media: Arc, pub sending: Arc, + pub rate_limiting: Arc, } impl Services { @@ -123,6 +125,8 @@ impl Services { media: Arc::new(media::Service { db }), sending: sending::Service::build(db, &config), + rate_limiting: rate_limiting::Service::build(), + globals: globals::Service::load(db, config)?, }) } diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs new file mode 100644 index 00000000..bf496a55 --- /dev/null +++ b/src/service/rate_limiting/mod.rs @@ -0,0 +1,418 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + net::IpAddr, + sync::Arc, + time::Duration, +}; + +use ruma::{ + api::{ + client::error::{ErrorKind, RetryAfter}, + federation::membership::create_knock_event, + Metadata, + }, + OwnedServerName, OwnedUserId, UserId, +}; +use serde::Deserialize; +use tokio::{sync::Mutex, time::Instant}; + +use crate::{ + config::rate_limiting::MediaLimitation, service::appservice::RegistrationInfo, services, Error, + Result, +}; + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Target { + User(OwnedUserId), + // Server endpoints should be rate-limited on a server and room basis + Server(OwnedServerName), + Appservice { id: String, rate_limited: bool }, + Ip(IpAddr), +} + +impl Target { + pub fn from_client_request( + registration_info: Option, + sender_user: &UserId, + ) -> Self { + if let Some(info) = registration_info { + // `rate_limited` only effects "masqueraded users", "The sender [user?] is excluded" + return Target::Appservice { + id: info.registration.id, + rate_limited: info.registration.rate_limited.unwrap_or(true) + && !(sender_user.server_name() == services().globals.server_name() + && info.registration.sender_localpart == sender_user.localpart()), + }; + } + + Target::User(sender_user.to_owned()) + } + + pub fn from_client_request_optional_auth( + registration_info: Option, + sender_user: &Option, + ip_addr: Option, + ) -> Option { + if let Some(sender_user) = sender_user.as_ref() { + Some(Self::from_client_request(registration_info, sender_user)) + } else { + ip_addr.map(Self::Ip) + } + } + + fn rate_limited(&self) -> bool { + if let Target::Appservice { + rate_limited: false, + .. + } = self + { + return false; + } + + true + } +} + +#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Restriction { + Client(ClientRestriction), + Federation(FederationRestriction), + Media(MediaRestriction), + + #[default] + CatchAll, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum ClientRestriction { + Registration, + Login, + RegistrationTokenValidity, + + SendEvent, + + Join, + Invite, + Knock, + + SendReport, + CreateAlias, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum FederationRestriction { + Join, + Knock, + Invite, + + // Transactions should be handled by a completely dedicated rate-limiter + Transaction, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum MediaRestriction { + Create, + Fetch, +} + +impl From for Restriction { + fn from(value: Metadata) -> Self { + use ruma::api::{ + client::{ + account::{check_registration_token_validity, register}, + alias::create_alias, + authenticated_media::{ + get_content, get_content_as_filename, get_content_thumbnail, get_media_preview, + }, + knock::knock_room, + media::{self, create_content, create_mxc_uri}, + membership::{invite_user, join_room_by_id, join_room_by_id_or_alias}, + message::send_message_event, + reporting::report_user, + room::{report_content, report_room}, + session::login, + state::send_state_event, + }, + federation::{ + authenticated_media::{ + get_content as federation_get_content, + get_content_thumbnail as federation_get_content_thumbnail, + }, + membership::{create_invite, create_join_event}, + }, + IncomingRequest, + }; + use Restriction::*; + + match value { + register::v3::Request::METADATA => Client(ClientRestriction::Registration), + check_registration_token_validity::v1::Request::METADATA => { + Client(ClientRestriction::RegistrationTokenValidity) + } + login::v3::Request::METADATA => Client(ClientRestriction::Login), + send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => { + Client(ClientRestriction::SendEvent) + } + join_room_by_id::v3::Request::METADATA + | join_room_by_id_or_alias::v3::Request::METADATA => Client(ClientRestriction::Join), + invite_user::v3::Request::METADATA => Client(ClientRestriction::Invite), + knock_room::v3::Request::METADATA => Client(ClientRestriction::Knock), + report_user::v3::Request::METADATA + | report_content::v3::Request::METADATA + | report_room::v3::Request::METADATA => Client(ClientRestriction::SendReport), + create_alias::v3::Request::METADATA => Client(ClientRestriction::CreateAlias), + // NOTE: handle async media upload in a way that doesn't half the number of uploads you can do within a short timeframe, while not allowing pre-generation of MXC uris to allow uploading double the number of media at once + create_content::v3::Request::METADATA | create_mxc_uri::v1::Request::METADATA => { + Media(MediaRestriction::Create) + } + // Unauthenticate media is deprecated + #[allow(deprecated)] + media::get_content::v3::Request::METADATA + | media::get_content_as_filename::v3::Request::METADATA + | media::get_content_thumbnail::v3::Request::METADATA + | media::get_media_preview::v3::Request::METADATA + | get_content::v1::Request::METADATA + | get_content_as_filename::v1::Request::METADATA + | get_content_thumbnail::v1::Request::METADATA + | get_media_preview::v1::Request::METADATA + | federation_get_content::v1::Request::METADATA + | federation_get_content_thumbnail::v1::Request::METADATA => { + Media(MediaRestriction::Fetch) + } + // v1 is deprecated + #[allow(deprecated)] + create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => { + Federation(FederationRestriction::Join) + } + create_knock_event::v1::Request::METADATA => Federation(FederationRestriction::Knock), + create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => { + Federation(FederationRestriction::Invite) + } + + _ => Self::default(), + } + } +} + +pub struct Service { + buckets: Mutex>>>, + global_bucket: Mutex>>>, + media_upload: Mutex>>>, + media_fetch: Mutex>>>, +} + +impl Service { + pub fn build() -> Arc { + Arc::new(Self { + buckets: Mutex::new(HashMap::new()), + global_bucket: Mutex::new(HashMap::new()), + media_upload: Mutex::new(HashMap::new()), + media_fetch: Mutex::new(HashMap::new()), + }) + } + + //TODO: use checked and saturating arithmetic + + /// Takes the target and request, and either accepts the request while adding to the + /// bucket, or rejects the request, returning the duration that should be waited until + /// the request should be retried. + pub async fn check(&self, target: Option, request: Metadata) -> Result<()> { + let restriction: Restriction = request.into(); + let arrival = Instant::now(); + + { + let map = self.global_bucket.lock().await; + + if let Some(value) = map.get(&restriction) { + let value = value.lock().await; + + if arrival.checked_duration_since(*value).is_none() { + instant_to_err(&value)?; + } + } + } + + if let Some(target) = target { + if restriction == Restriction::Media(MediaRestriction::Fetch) { + self.check_media_pre_fetch(&target, arrival).await? + } + + let config = services().globals.config.rate_limiting.get(&restriction); + + let mut map = self.buckets.lock().await; + let entry = map.entry((target, restriction)); + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + if arrival.checked_duration_since(*entry).is_none() { + return instant_to_err(&entry); + } + + let min_instant = arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * config.burst_capacity.get(), + ); + *entry = + entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), + ), + ))); + } + } + } + + { + let config = services() + .globals + .config + .rate_limiting + .global + .get(&restriction); + + let mut map = self.global_bucket.lock().await; + + let entry = map.entry(restriction); + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + if arrival.checked_duration_since(*entry).is_none() { + return instant_to_err(&entry); + } + + let min_instant = arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * config.burst_capacity.get(), + ); + *entry = + entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), + ), + ))); + } + } + } + + Ok(()) + } + + pub async fn check_media_upload(&self, target: Target, size: u64) -> Result<()> { + if !target.rate_limited() { + return Ok(()); + } + + let arrival = Instant::now(); + + let MediaLimitation { + timeframe, + burst_capacity, + } = services().globals.config.rate_limiting.media.upload; + + let mut map = self.media_upload.lock().await; + let entry = map.entry(target); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + //TODO: use more precise conversion than secs + let min_instant = arrival + - Duration::from_secs(burst_capacity.as_u64() / timeframe.bytes_per_sec()); + let proposed_entry = + entry.max(min_instant) + Duration::from_secs(size / timeframe.bytes_per_sec()); + + if arrival.checked_duration_since(proposed_entry).is_none() { + return instant_to_err(&proposed_entry); + } + + *entry = proposed_entry; + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos(burst_capacity.as_u64() / timeframe.bytes_per_sec()), + ))); + } + } + + Ok(()) + } + + async fn check_media_pre_fetch(&self, target: &Target, arrival: Instant) -> Result<()> { + let map = self.media_fetch.lock().await; + if let Some(mutex) = map.get(target) { + let mutex = mutex.lock().await; + + if arrival.checked_duration_since(*mutex).is_none() { + return instant_to_err(&mutex); + } + } + + Ok(()) + } + + pub async fn update_media_post_fetch(&self, target: Target, size: u64) { + if !target.rate_limited() { + return; + } + + let MediaLimitation { + timeframe, + burst_capacity, + } = services().globals.config.rate_limiting.media.fetch; + + let arrival = Instant::now(); + + let mut map = self.media_fetch.lock().await; + let entry = map.entry(target); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + //TODO: use more precise conversion than secs + let min_instant = arrival + - Duration::from_secs(burst_capacity.as_u64() / timeframe.bytes_per_sec()); + let proposed_entry = + entry.max(min_instant) + Duration::from_secs(size / timeframe.bytes_per_sec()); + + *entry = proposed_entry; + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos(burst_capacity.as_u64() / timeframe.bytes_per_sec()), + ))); + } + } + } +} + +fn instant_to_err(instant: &Instant) -> Result<()> { + let now = Instant::now(); + + Err(Error::BadRequest( + ErrorKind::LimitExceeded { + // Not using ::DateTime because conversion from Instant to SystemTime is convoluted + retry_after: Some(RetryAfter::Delay(instant.duration_since(now))), + }, + "Rate limit exceeded", + )) +}