1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00
This commit is contained in:
Matthias Ahouansou 2024-07-10 09:44:44 +01:00
parent bf902f1607
commit ab21c5dbef
No known key found for this signature in database
5 changed files with 28 additions and 101 deletions

25
Cargo.lock generated
View file

@ -517,7 +517,6 @@ dependencies = [
"opentelemetry_sdk",
"parking_lot",
"persy",
"quanta",
"rand",
"regex",
"reqwest",
@ -2069,21 +2068,6 @@ dependencies = [
"syn",
]
[[package]]
name = "quanta"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]]
name = "quick-error"
version = "1.2.3"
@ -2129,15 +2113,6 @@ dependencies = [
"getrandom",
]
[[package]]
name = "raw-cpuid"
version = "11.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd"
dependencies = [
"bitflags 2.5.0",
]
[[package]]
name = "redox_syscall"
version = "0.5.1"

View file

@ -146,9 +146,8 @@ tikv-jemallocator = { version = "0.5.0", features = [
"unprefixed_malloc_on_supported_platforms",
], optional = true }
sd-notify = { version = "0.4.1", optional = true }
dashmap = "5.5.3"
quanta = "0.12.3"
sd-notify = { version = "0.4.1", optional = true }
# Used for matrix spec type definitions and helpers
[dependencies.ruma]

View file

@ -323,25 +323,6 @@ where
let target = if let Some(server) = sender_servername.clone() {
Target::Server(server)
// Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
// Token::None => {
// let header = parts
// .headers
// .get("x-forwarded-for")
// .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
// let s = header
// .to_str()
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
// Some(
// s.parse()
// .map(Target::Ip)
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
// )
// .transpose()?
// }
// _ => None,
} else if let Some(appservice) = appservice_info.clone() {
Target::Appservice(appservice.registration.id)
} else if let Some(user) = sender_user.clone() {
@ -350,7 +331,7 @@ where
let ip = headers
.get("X-Forwarded-For")
.and_then(|header| header.to_str().ok())
.map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header))
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
.and_then(|ip| IpAddr::from_str(ip).ok());
if let Some(ip) = ip {
@ -368,7 +349,7 @@ where
} {
return Err(Error::BadRequest(
ErrorKind::LimitExceeded {
retry_after: retry_after.map(|dur| RetryAfter::Delay(dur)),
retry_after: retry_after.map(RetryAfter::Delay),
},
"Rate limit exceeded.",
));

View file

@ -117,6 +117,8 @@ pub enum Restriction {
#[derive(Deserialize, Clone, Copy, Debug)]
#[serde(rename_all = "snake_case")]
// When deserializing, we want this prefix
#[allow(clippy::enum_variant_names)]
pub enum Timeframe {
PerSecond(NonZeroU64),
PerMinute(NonZeroU64),
@ -141,8 +143,6 @@ pub struct Limitation {
pub timeframe: Timeframe,
#[serde(default = "default_non_zero")]
pub burst_capacity: NonZeroU64,
#[serde(default = "default_non_zero")]
pub weight: NonZeroU64,
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
@ -362,7 +362,6 @@ pub fn default_rate_limit() -> BTreeMap<Restriction, Limitation> {
Limitation {
timeframe: Timeframe::PerMinute(NonZeroU64::MIN),
burst_capacity: NonZeroU64::MIN,
weight: NonZeroU64::MIN,
},
)])
}

View file

@ -1,12 +1,10 @@
use std::{
hash::Hash,
net::IpAddr,
sync::atomic::{AtomicU64, Ordering},
time::Duration,
time::{Duration, Instant},
};
use dashmap::DashMap;
use quanta::Clock;
use dashmap::{mapref::entry::Entry, DashMap};
use ruma::{
api::{
client::{account::register, session::login},
@ -37,8 +35,7 @@ impl From<Metadata> for Restriction {
}
pub struct Service {
store: DashMap<(Target, Restriction), AtomicU64>,
clock: Clock,
store: DashMap<(Target, Restriction), (Instant, u64)>,
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
@ -54,14 +51,13 @@ impl Service {
pub fn build() -> Self {
Self {
store: DashMap::new(),
clock: Clock::new(),
}
}
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
let restriction = metadata.into();
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
let arrival = Instant::now();
let config = &services().globals.config.rate_limiting;
@ -74,55 +70,32 @@ impl Service {
// weight: NonZeroU64::new(1).unwrap(),
// });
let gap = Duration::from_nanos(limit.timeframe.nano_gap());
let key = (target, restriction);
tracing::info!(?limit);
match self.store.entry(key) {
Entry::Occupied(mut entry) => {
let (instant, capacity) = entry.get_mut();
let increment = limit.timeframe.nano_gap() * limit.weight.get();
tracing::info!(?increment);
while *instant < arrival && *capacity != 0 {
*capacity -= 1;
*instant += gap;
}
let mut prev_expectation = self
.store
.get(&key)
.as_deref()
.map(|n| n.load(Ordering::Acquire))
.unwrap_or_else(|| arrival + increment);
let weight = increment * limit.burst_capacity.get();
tracing::info!(?prev_expectation);
tracing::info!(?weight);
let f = |prev_expectation: u64| {
let allowed = prev_expectation.saturating_sub(weight);
if arrival < allowed {
Err(Duration::from_nanos(allowed - arrival))
} else {
Ok(prev_expectation.max(arrival) + increment)
if *capacity >= limit.burst_capacity.get() {
return Err(gap);
} else {
*capacity += 1;
// TODO: update thing
*instant += gap;
}
}
Entry::Vacant(entry) => {
entry.insert((Instant::now() + gap, 1));
}
};
let mut decision = f(prev_expectation);
tracing::info!(?decision);
let entry = self.store.entry(key);
while let Ok(next_expectation) = decision {
match entry.or_default().compare_exchange_weak(
prev_expectation,
next_expectation,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return Ok(()),
Err(actual) => prev_expectation = actual,
}
decision = f(prev_expectation);
}
decision.map(|_| ())
Ok(())
}
}