From 024f910bf9e4ca56339a53bc464bf15e35ae512c Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 26 Jun 2024 13:36:04 +0100 Subject: [PATCH] allow for different timeframes for configuration --- src/api/ruma_wrapper/axum.rs | 11 ++++++++--- src/config/mod.rs | 26 +++++++++++++++++++++++--- src/service/rate_limiting/mod.rs | 2 +- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index c7bc5879..8ab2a468 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -20,7 +20,10 @@ use axum_extra::{ use bytes::{BufMut, BytesMut}; use http::{Request, StatusCode}; use ruma::{ - api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, + api::{ + client::error::{ErrorKind, RetryAfter}, + AuthScheme, IncomingRequest, OutgoingResponse, + }, server_util::authorization::XMatrix, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId, }; @@ -357,14 +360,16 @@ where } }; - if let Err(retry_after_ms) = { + if let Err(retry_after) = { services() .rate_limiting .update_or_reject(target, metadata) .map_err(Some) } { return Err(Error::BadRequest( - ErrorKind::LimitExceeded { retry_after_ms }, + ErrorKind::LimitExceeded { + retry_after: retry_after.map(|dur| RetryAfter::Delay(dur)), + }, "Rate limit exceeded.", )); } diff --git a/src/config/mod.rs b/src/config/mod.rs index 81051252..efa02370 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -115,10 +115,30 @@ pub enum Restriction { CatchAll, } +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +pub enum Timeframe { + PerSecond(NonZeroU64), + PerMinute(NonZeroU64), + PerHour(NonZeroU64), + PerDay(NonZeroU64), +} + +impl Timeframe { + pub fn nano_gap(&self) -> u64 { + match self { + Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(), + Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(), + Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(), + Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(), + } + } +} + #[derive(Clone, Copy, Debug, Deserialize)] pub struct Limitation { - #[serde(default = "default_non_zero")] - pub per_minute: NonZeroU64, + #[serde(default = "default_non_zero", flatten)] + pub timeframe: Timeframe, #[serde(default = "default_non_zero")] pub burst_capacity: NonZeroU64, #[serde(default = "default_non_zero")] @@ -340,7 +360,7 @@ pub fn default_rate_limit() -> BTreeMap { BTreeMap::from_iter([( Restriction::default(), Limitation { - per_minute: default_non_zero(), + timeframe: Timeframe::PerMinute(default_non_zero()), burst_capacity: default_non_zero(), weight: default_non_zero(), }, diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index 1f5ef0a1..bfe7746a 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -78,7 +78,7 @@ impl Service { tracing::info!(?limit); - let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get(); + let increment = 1_000_000_000u64 / limit.timeframe.nano_gap() * limit.weight.get(); tracing::info!(?increment); let mut prev_expectation = self