1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00
This commit is contained in:
Matthias Ahouansou 2024-07-10 09:44:44 +01:00
parent bf902f1607
commit ab21c5dbef
No known key found for this signature in database
5 changed files with 28 additions and 101 deletions

25
Cargo.lock generated
View file

@ -517,7 +517,6 @@ dependencies = [
"opentelemetry_sdk", "opentelemetry_sdk",
"parking_lot", "parking_lot",
"persy", "persy",
"quanta",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest",
@ -2069,21 +2068,6 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "quanta"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "quick-error" name = "quick-error"
version = "1.2.3" version = "1.2.3"
@ -2129,15 +2113,6 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "raw-cpuid"
version = "11.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd"
dependencies = [
"bitflags 2.5.0",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.1" version = "0.5.1"

View file

@ -146,9 +146,8 @@ tikv-jemallocator = { version = "0.5.0", features = [
"unprefixed_malloc_on_supported_platforms", "unprefixed_malloc_on_supported_platforms",
], optional = true } ], optional = true }
sd-notify = { version = "0.4.1", optional = true }
dashmap = "5.5.3" dashmap = "5.5.3"
quanta = "0.12.3" sd-notify = { version = "0.4.1", optional = true }
# Used for matrix spec type definitions and helpers # Used for matrix spec type definitions and helpers
[dependencies.ruma] [dependencies.ruma]

View file

@ -323,25 +323,6 @@ where
let target = if let Some(server) = sender_servername.clone() { let target = if let Some(server) = sender_servername.clone() {
Target::Server(server) Target::Server(server)
// Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
// Token::None => {
// let header = parts
// .headers
// .get("x-forwarded-for")
// .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
// let s = header
// .to_str()
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
// Some(
// s.parse()
// .map(Target::Ip)
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
// )
// .transpose()?
// }
// _ => None,
} else if let Some(appservice) = appservice_info.clone() { } else if let Some(appservice) = appservice_info.clone() {
Target::Appservice(appservice.registration.id) Target::Appservice(appservice.registration.id)
} else if let Some(user) = sender_user.clone() { } else if let Some(user) = sender_user.clone() {
@ -350,7 +331,7 @@ where
let ip = headers let ip = headers
.get("X-Forwarded-For") .get("X-Forwarded-For")
.and_then(|header| header.to_str().ok()) .and_then(|header| header.to_str().ok())
.map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header)) .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
.and_then(|ip| IpAddr::from_str(ip).ok()); .and_then(|ip| IpAddr::from_str(ip).ok());
if let Some(ip) = ip { if let Some(ip) = ip {
@ -368,7 +349,7 @@ where
} { } {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::LimitExceeded { ErrorKind::LimitExceeded {
retry_after: retry_after.map(|dur| RetryAfter::Delay(dur)), retry_after: retry_after.map(RetryAfter::Delay),
}, },
"Rate limit exceeded.", "Rate limit exceeded.",
)); ));

View file

@ -117,6 +117,8 @@ pub enum Restriction {
#[derive(Deserialize, Clone, Copy, Debug)] #[derive(Deserialize, Clone, Copy, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
// When deserializing, we want this prefix
#[allow(clippy::enum_variant_names)]
pub enum Timeframe { pub enum Timeframe {
PerSecond(NonZeroU64), PerSecond(NonZeroU64),
PerMinute(NonZeroU64), PerMinute(NonZeroU64),
@ -141,8 +143,6 @@ pub struct Limitation {
pub timeframe: Timeframe, pub timeframe: Timeframe,
#[serde(default = "default_non_zero")] #[serde(default = "default_non_zero")]
pub burst_capacity: NonZeroU64, pub burst_capacity: NonZeroU64,
#[serde(default = "default_non_zero")]
pub weight: NonZeroU64,
} }
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
@ -362,7 +362,6 @@ pub fn default_rate_limit() -> BTreeMap<Restriction, Limitation> {
Limitation { Limitation {
timeframe: Timeframe::PerMinute(NonZeroU64::MIN), timeframe: Timeframe::PerMinute(NonZeroU64::MIN),
burst_capacity: NonZeroU64::MIN, burst_capacity: NonZeroU64::MIN,
weight: NonZeroU64::MIN,
}, },
)]) )])
} }

View file

@ -1,12 +1,10 @@
use std::{ use std::{
hash::Hash, hash::Hash,
net::IpAddr, net::IpAddr,
sync::atomic::{AtomicU64, Ordering}, time::{Duration, Instant},
time::Duration,
}; };
use dashmap::DashMap; use dashmap::{mapref::entry::Entry, DashMap};
use quanta::Clock;
use ruma::{ use ruma::{
api::{ api::{
client::{account::register, session::login}, client::{account::register, session::login},
@ -37,8 +35,7 @@ impl From<Metadata> for Restriction {
} }
pub struct Service { pub struct Service {
store: DashMap<(Target, Restriction), AtomicU64>, store: DashMap<(Target, Restriction), (Instant, u64)>,
clock: Clock,
} }
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
@ -54,14 +51,13 @@ impl Service {
pub fn build() -> Self { pub fn build() -> Self {
Self { Self {
store: DashMap::new(), store: DashMap::new(),
clock: Clock::new(),
} }
} }
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> { pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
let restriction = metadata.into(); let restriction = metadata.into();
let arrival = self.clock.delta_as_nanos(0, self.clock.raw()); let arrival = Instant::now();
let config = &services().globals.config.rate_limiting; let config = &services().globals.config.rate_limiting;
@ -74,55 +70,32 @@ impl Service {
// weight: NonZeroU64::new(1).unwrap(), // weight: NonZeroU64::new(1).unwrap(),
// }); // });
let gap = Duration::from_nanos(limit.timeframe.nano_gap());
let key = (target, restriction); let key = (target, restriction);
tracing::info!(?limit); match self.store.entry(key) {
Entry::Occupied(mut entry) => {
let (instant, capacity) = entry.get_mut();
let increment = limit.timeframe.nano_gap() * limit.weight.get(); while *instant < arrival && *capacity != 0 {
tracing::info!(?increment); *capacity -= 1;
*instant += gap;
}
let mut prev_expectation = self if *capacity >= limit.burst_capacity.get() {
.store return Err(gap);
.get(&key) } else {
.as_deref() *capacity += 1;
.map(|n| n.load(Ordering::Acquire)) // TODO: update thing
.unwrap_or_else(|| arrival + increment); *instant += gap;
let weight = increment * limit.burst_capacity.get(); }
}
tracing::info!(?prev_expectation); Entry::Vacant(entry) => {
tracing::info!(?weight); entry.insert((Instant::now() + gap, 1));
let f = |prev_expectation: u64| {
let allowed = prev_expectation.saturating_sub(weight);
if arrival < allowed {
Err(Duration::from_nanos(allowed - arrival))
} else {
Ok(prev_expectation.max(arrival) + increment)
} }
}; };
let mut decision = f(prev_expectation); Ok(())
tracing::info!(?decision);
let entry = self.store.entry(key);
while let Ok(next_expectation) = decision {
match entry.or_default().compare_exchange_weak(
prev_expectation,
next_expectation,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return Ok(()),
Err(actual) => prev_expectation = actual,
}
decision = f(prev_expectation);
}
decision.map(|_| ())
} }
} }