From ab21c5dbef9d82cd366fb3c4d9ea0e3079dcac44 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 10 Jul 2024 09:44:44 +0100 Subject: [PATCH] simplify --- Cargo.lock | 25 ----------- Cargo.toml | 3 +- src/api/ruma_wrapper/axum.rs | 23 +--------- src/config/mod.rs | 5 +-- src/service/rate_limiting/mod.rs | 73 ++++++++++---------------------- 5 files changed, 28 insertions(+), 101 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e587f668..e394ce76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -517,7 +517,6 @@ dependencies = [ "opentelemetry_sdk", "parking_lot", "persy", - "quanta", "rand", "regex", "reqwest", @@ -2069,21 +2068,6 @@ dependencies = [ "syn", ] -[[package]] -name = "quanta" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" -dependencies = [ - "crossbeam-utils", - "libc", - "once_cell", - "raw-cpuid", - "wasi", - "web-sys", - "winapi", -] - [[package]] name = "quick-error" version = "1.2.3" @@ -2129,15 +2113,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "raw-cpuid" -version = "11.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" -dependencies = [ - "bitflags 2.5.0", -] - [[package]] name = "redox_syscall" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 3d6d2594..d24914ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -146,9 +146,8 @@ tikv-jemallocator = { version = "0.5.0", features = [ "unprefixed_malloc_on_supported_platforms", ], optional = true } -sd-notify = { version = "0.4.1", optional = true } dashmap = "5.5.3" -quanta = "0.12.3" +sd-notify = { version = "0.4.1", optional = true } # Used for matrix spec type definitions and helpers [dependencies.ruma] diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 8ab2a468..18d1baad 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -323,25 +323,6 @@ where let target = if let Some(server) = sender_servername.clone() { Target::Server(server) - - // Token::User((user_id, _)) => Some(Target::User(user_id.clone())), - // Token::None => { - // let header = parts - // .headers - // .get("x-forwarded-for") - // .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; - - // let s = header - // .to_str() - // .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?; - // Some( - // s.parse() - // .map(Target::Ip) - // .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")), - // ) - // .transpose()? - // } - // _ => None, } else if let Some(appservice) = appservice_info.clone() { Target::Appservice(appservice.registration.id) } else if let Some(user) = sender_user.clone() { @@ -350,7 +331,7 @@ where let ip = headers .get("X-Forwarded-For") .and_then(|header| header.to_str().ok()) - .map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header)) + .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header)) .and_then(|ip| IpAddr::from_str(ip).ok()); if let Some(ip) = ip { @@ -368,7 +349,7 @@ where } { return Err(Error::BadRequest( ErrorKind::LimitExceeded { - retry_after: retry_after.map(|dur| RetryAfter::Delay(dur)), + retry_after: retry_after.map(RetryAfter::Delay), }, "Rate limit exceeded.", )); diff --git a/src/config/mod.rs b/src/config/mod.rs index 70f9fca9..2d08d855 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -117,6 +117,8 @@ pub enum Restriction { #[derive(Deserialize, Clone, Copy, Debug)] #[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] pub enum Timeframe { PerSecond(NonZeroU64), PerMinute(NonZeroU64), @@ -141,8 +143,6 @@ pub struct Limitation { pub timeframe: Timeframe, #[serde(default = "default_non_zero")] pub burst_capacity: NonZeroU64, - #[serde(default = "default_non_zero")] - pub weight: NonZeroU64, } const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; @@ -362,7 +362,6 @@ pub fn default_rate_limit() -> BTreeMap { Limitation { timeframe: Timeframe::PerMinute(NonZeroU64::MIN), burst_capacity: NonZeroU64::MIN, - weight: NonZeroU64::MIN, }, )]) } diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs index 387a17fa..fb72fec9 100644 --- a/src/service/rate_limiting/mod.rs +++ b/src/service/rate_limiting/mod.rs @@ -1,12 +1,10 @@ use std::{ hash::Hash, net::IpAddr, - sync::atomic::{AtomicU64, Ordering}, - time::Duration, + time::{Duration, Instant}, }; -use dashmap::DashMap; -use quanta::Clock; +use dashmap::{mapref::entry::Entry, DashMap}; use ruma::{ api::{ client::{account::register, session::login}, @@ -37,8 +35,7 @@ impl From for Restriction { } pub struct Service { - store: DashMap<(Target, Restriction), AtomicU64>, - clock: Clock, + store: DashMap<(Target, Restriction), (Instant, u64)>, } #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] @@ -54,14 +51,13 @@ impl Service { pub fn build() -> Self { Self { store: DashMap::new(), - clock: Clock::new(), } } pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> { let restriction = metadata.into(); - let arrival = self.clock.delta_as_nanos(0, self.clock.raw()); + let arrival = Instant::now(); let config = &services().globals.config.rate_limiting; @@ -74,55 +70,32 @@ impl Service { // weight: NonZeroU64::new(1).unwrap(), // }); + let gap = Duration::from_nanos(limit.timeframe.nano_gap()); let key = (target, restriction); - tracing::info!(?limit); + match self.store.entry(key) { + Entry::Occupied(mut entry) => { + let (instant, capacity) = entry.get_mut(); - let increment = limit.timeframe.nano_gap() * limit.weight.get(); - tracing::info!(?increment); + while *instant < arrival && *capacity != 0 { + *capacity -= 1; + *instant += gap; + } - let mut prev_expectation = self - .store - .get(&key) - .as_deref() - .map(|n| n.load(Ordering::Acquire)) - .unwrap_or_else(|| arrival + increment); - let weight = increment * limit.burst_capacity.get(); - - tracing::info!(?prev_expectation); - tracing::info!(?weight); - - let f = |prev_expectation: u64| { - let allowed = prev_expectation.saturating_sub(weight); - - if arrival < allowed { - Err(Duration::from_nanos(allowed - arrival)) - } else { - Ok(prev_expectation.max(arrival) + increment) + if *capacity >= limit.burst_capacity.get() { + return Err(gap); + } else { + *capacity += 1; + // TODO: update thing + *instant += gap; + } + } + Entry::Vacant(entry) => { + entry.insert((Instant::now() + gap, 1)); } }; - let mut decision = f(prev_expectation); - - tracing::info!(?decision); - - let entry = self.store.entry(key); - - while let Ok(next_expectation) = decision { - match entry.or_default().compare_exchange_weak( - prev_expectation, - next_expectation, - Ordering::Release, - Ordering::Relaxed, - ) { - Ok(_) => return Ok(()), - Err(actual) => prev_expectation = actual, - } - - decision = f(prev_expectation); - } - - decision.map(|_| ()) + Ok(()) } }