mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-08-01 17:38:36 +00:00
131 lines
3.6 KiB
Rust
131 lines
3.6 KiB
Rust
use std::{
|
|
hash::Hash,
|
|
net::IpAddr,
|
|
sync::atomic::{AtomicU64, Ordering},
|
|
time::Duration,
|
|
};
|
|
|
|
use dashmap::DashMap;
|
|
use quanta::Clock;
|
|
use ruma::{
|
|
api::{
|
|
client::{account::register, session::login},
|
|
IncomingRequest, Metadata,
|
|
},
|
|
OwnedServerName, OwnedUserId,
|
|
};
|
|
|
|
use crate::{config::Restriction, services, Result};
|
|
|
|
impl From<Metadata> for Restriction {
|
|
fn from(metadata: Metadata) -> Self {
|
|
[
|
|
(register::v3::Request::METADATA, Restriction::Registration),
|
|
(login::v3::Request::METADATA, Restriction::Login),
|
|
]
|
|
.iter()
|
|
.find(|(other, _)| {
|
|
metadata
|
|
.history
|
|
.stable_paths()
|
|
.zip(other.history.stable_paths())
|
|
.all(|(a, b)| a == b)
|
|
})
|
|
.map(|(_, restriction)| restriction.to_owned())
|
|
.unwrap_or_default()
|
|
}
|
|
}
|
|
|
|
pub struct Service {
|
|
store: DashMap<(Target, Restriction), AtomicU64>,
|
|
clock: Clock,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
|
pub enum Target {
|
|
User(OwnedUserId),
|
|
Server(OwnedServerName),
|
|
Appservice(String),
|
|
Ip(IpAddr),
|
|
None,
|
|
}
|
|
|
|
impl Service {
|
|
pub fn build() -> Self {
|
|
Self {
|
|
store: DashMap::new(),
|
|
clock: Clock::new(),
|
|
}
|
|
}
|
|
|
|
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
|
|
let restriction = metadata.into();
|
|
|
|
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
|
|
|
|
let config = &services().globals.config.rate_limiting;
|
|
|
|
let Some(limit) = config.get(&restriction).map(ToOwned::to_owned) else {
|
|
return Ok(());
|
|
};
|
|
// .unwrap_or(Limitation {
|
|
// per_minute: NonZeroU64::new(1).unwrap(),
|
|
// burst_capacity: NonZeroU64::new(1).unwrap(),
|
|
// weight: NonZeroU64::new(1).unwrap(),
|
|
// });
|
|
|
|
let key = (target, restriction);
|
|
|
|
tracing::info!(?limit);
|
|
|
|
let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get();
|
|
tracing::info!(?increment);
|
|
|
|
let mut prev_expectation = self
|
|
.store
|
|
.get(&key)
|
|
.as_deref()
|
|
.map(|n| n.load(Ordering::Acquire))
|
|
.unwrap_or_else(|| arrival + increment);
|
|
let weight = (increment * limit.burst_capacity.get()).max(1);
|
|
|
|
tracing::info!(?prev_expectation);
|
|
tracing::info!(?weight);
|
|
|
|
let f = |prev_expectation: u64| {
|
|
let allowed = prev_expectation.saturating_sub(weight);
|
|
|
|
if arrival < allowed {
|
|
Err(Duration::from_nanos(allowed - arrival))
|
|
} else {
|
|
Ok(prev_expectation.max(arrival) + increment)
|
|
}
|
|
};
|
|
|
|
let mut decision = f(prev_expectation);
|
|
|
|
tracing::info!(?decision);
|
|
|
|
let entry = self.store.entry(key);
|
|
|
|
while let Ok(next_expectation) = decision {
|
|
match entry.or_default().compare_exchange_weak(
|
|
prev_expectation,
|
|
next_expectation,
|
|
Ordering::Release,
|
|
Ordering::Relaxed,
|
|
) {
|
|
Ok(_) => return Ok(()),
|
|
Err(actual) => prev_expectation = actual,
|
|
}
|
|
|
|
decision = f(prev_expectation);
|
|
}
|
|
|
|
decision.map(|_| ())
|
|
}
|
|
}
|
|
|
|
///// In-memory state and utility functions used to check whether the client has exceeded its rate limit.
|
|
///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible.
|
|
/////
|