mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-06-27 16:35:59 +00:00
simplify
This commit is contained in:
parent
bf902f1607
commit
ab21c5dbef
5 changed files with 28 additions and 101 deletions
25
Cargo.lock
generated
25
Cargo.lock
generated
|
@ -517,7 +517,6 @@ dependencies = [
|
|||
"opentelemetry_sdk",
|
||||
"parking_lot",
|
||||
"persy",
|
||||
"quanta",
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest",
|
||||
|
@ -2069,21 +2068,6 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "1.2.3"
|
||||
|
@ -2129,15 +2113,6 @@ dependencies = [
|
|||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.1"
|
||||
|
|
|
@ -146,9 +146,8 @@ tikv-jemallocator = { version = "0.5.0", features = [
|
|||
"unprefixed_malloc_on_supported_platforms",
|
||||
], optional = true }
|
||||
|
||||
sd-notify = { version = "0.4.1", optional = true }
|
||||
dashmap = "5.5.3"
|
||||
quanta = "0.12.3"
|
||||
sd-notify = { version = "0.4.1", optional = true }
|
||||
|
||||
# Used for matrix spec type definitions and helpers
|
||||
[dependencies.ruma]
|
||||
|
|
|
@ -323,25 +323,6 @@ where
|
|||
|
||||
let target = if let Some(server) = sender_servername.clone() {
|
||||
Target::Server(server)
|
||||
|
||||
// Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
|
||||
// Token::None => {
|
||||
// let header = parts
|
||||
// .headers
|
||||
// .get("x-forwarded-for")
|
||||
// .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||
|
||||
// let s = header
|
||||
// .to_str()
|
||||
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||
// Some(
|
||||
// s.parse()
|
||||
// .map(Target::Ip)
|
||||
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
|
||||
// )
|
||||
// .transpose()?
|
||||
// }
|
||||
// _ => None,
|
||||
} else if let Some(appservice) = appservice_info.clone() {
|
||||
Target::Appservice(appservice.registration.id)
|
||||
} else if let Some(user) = sender_user.clone() {
|
||||
|
@ -350,7 +331,7 @@ where
|
|||
let ip = headers
|
||||
.get("X-Forwarded-For")
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header))
|
||||
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
|
||||
.and_then(|ip| IpAddr::from_str(ip).ok());
|
||||
|
||||
if let Some(ip) = ip {
|
||||
|
@ -368,7 +349,7 @@ where
|
|||
} {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::LimitExceeded {
|
||||
retry_after: retry_after.map(|dur| RetryAfter::Delay(dur)),
|
||||
retry_after: retry_after.map(RetryAfter::Delay),
|
||||
},
|
||||
"Rate limit exceeded.",
|
||||
));
|
||||
|
|
|
@ -117,6 +117,8 @@ pub enum Restriction {
|
|||
|
||||
#[derive(Deserialize, Clone, Copy, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
// When deserializing, we want this prefix
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
pub enum Timeframe {
|
||||
PerSecond(NonZeroU64),
|
||||
PerMinute(NonZeroU64),
|
||||
|
@ -141,8 +143,6 @@ pub struct Limitation {
|
|||
pub timeframe: Timeframe,
|
||||
#[serde(default = "default_non_zero")]
|
||||
pub burst_capacity: NonZeroU64,
|
||||
#[serde(default = "default_non_zero")]
|
||||
pub weight: NonZeroU64,
|
||||
}
|
||||
|
||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||
|
@ -362,7 +362,6 @@ pub fn default_rate_limit() -> BTreeMap<Restriction, Limitation> {
|
|||
Limitation {
|
||||
timeframe: Timeframe::PerMinute(NonZeroU64::MIN),
|
||||
burst_capacity: NonZeroU64::MIN,
|
||||
weight: NonZeroU64::MIN,
|
||||
},
|
||||
)])
|
||||
}
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
use std::{
|
||||
hash::Hash,
|
||||
net::IpAddr,
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
time::Duration,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use quanta::Clock;
|
||||
use dashmap::{mapref::entry::Entry, DashMap};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{account::register, session::login},
|
||||
|
@ -37,8 +35,7 @@ impl From<Metadata> for Restriction {
|
|||
}
|
||||
|
||||
pub struct Service {
|
||||
store: DashMap<(Target, Restriction), AtomicU64>,
|
||||
clock: Clock,
|
||||
store: DashMap<(Target, Restriction), (Instant, u64)>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
|
@ -54,14 +51,13 @@ impl Service {
|
|||
pub fn build() -> Self {
|
||||
Self {
|
||||
store: DashMap::new(),
|
||||
clock: Clock::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
|
||||
let restriction = metadata.into();
|
||||
|
||||
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
|
||||
let arrival = Instant::now();
|
||||
|
||||
let config = &services().globals.config.rate_limiting;
|
||||
|
||||
|
@ -74,55 +70,32 @@ impl Service {
|
|||
// weight: NonZeroU64::new(1).unwrap(),
|
||||
// });
|
||||
|
||||
let gap = Duration::from_nanos(limit.timeframe.nano_gap());
|
||||
let key = (target, restriction);
|
||||
|
||||
tracing::info!(?limit);
|
||||
match self.store.entry(key) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let (instant, capacity) = entry.get_mut();
|
||||
|
||||
let increment = limit.timeframe.nano_gap() * limit.weight.get();
|
||||
tracing::info!(?increment);
|
||||
while *instant < arrival && *capacity != 0 {
|
||||
*capacity -= 1;
|
||||
*instant += gap;
|
||||
}
|
||||
|
||||
let mut prev_expectation = self
|
||||
.store
|
||||
.get(&key)
|
||||
.as_deref()
|
||||
.map(|n| n.load(Ordering::Acquire))
|
||||
.unwrap_or_else(|| arrival + increment);
|
||||
let weight = increment * limit.burst_capacity.get();
|
||||
|
||||
tracing::info!(?prev_expectation);
|
||||
tracing::info!(?weight);
|
||||
|
||||
let f = |prev_expectation: u64| {
|
||||
let allowed = prev_expectation.saturating_sub(weight);
|
||||
|
||||
if arrival < allowed {
|
||||
Err(Duration::from_nanos(allowed - arrival))
|
||||
} else {
|
||||
Ok(prev_expectation.max(arrival) + increment)
|
||||
if *capacity >= limit.burst_capacity.get() {
|
||||
return Err(gap);
|
||||
} else {
|
||||
*capacity += 1;
|
||||
// TODO: update thing
|
||||
*instant += gap;
|
||||
}
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert((Instant::now() + gap, 1));
|
||||
}
|
||||
};
|
||||
|
||||
let mut decision = f(prev_expectation);
|
||||
|
||||
tracing::info!(?decision);
|
||||
|
||||
let entry = self.store.entry(key);
|
||||
|
||||
while let Ok(next_expectation) = decision {
|
||||
match entry.or_default().compare_exchange_weak(
|
||||
prev_expectation,
|
||||
next_expectation,
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => return Ok(()),
|
||||
Err(actual) => prev_expectation = actual,
|
||||
}
|
||||
|
||||
decision = f(prev_expectation);
|
||||
}
|
||||
|
||||
decision.map(|_| ())
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue