1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-08-01 17:38:36 +00:00
conduit/src/service/rate_limiting/mod.rs
2024-07-06 17:06:39 +01:00

131 lines
3.6 KiB
Rust

use std::{
hash::Hash,
net::IpAddr,
sync::atomic::{AtomicU64, Ordering},
time::Duration,
};
use dashmap::DashMap;
use quanta::Clock;
use ruma::{
api::{
client::{account::register, session::login},
IncomingRequest, Metadata,
},
OwnedServerName, OwnedUserId,
};
use crate::{config::Restriction, services, Result};
impl From<Metadata> for Restriction {
fn from(metadata: Metadata) -> Self {
[
(register::v3::Request::METADATA, Restriction::Registration),
(login::v3::Request::METADATA, Restriction::Login),
]
.iter()
.find(|(other, _)| {
metadata
.history
.stable_paths()
.zip(other.history.stable_paths())
.all(|(a, b)| a == b)
})
.map(|(_, restriction)| restriction.to_owned())
.unwrap_or_default()
}
}
pub struct Service {
store: DashMap<(Target, Restriction), AtomicU64>,
clock: Clock,
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum Target {
User(OwnedUserId),
Server(OwnedServerName),
Appservice(String),
Ip(IpAddr),
None,
}
impl Service {
pub fn build() -> Self {
Self {
store: DashMap::new(),
clock: Clock::new(),
}
}
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
let restriction = metadata.into();
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
let config = &services().globals.config.rate_limiting;
let Some(limit) = config.get(&restriction).map(ToOwned::to_owned) else {
return Ok(());
};
// .unwrap_or(Limitation {
// per_minute: NonZeroU64::new(1).unwrap(),
// burst_capacity: NonZeroU64::new(1).unwrap(),
// weight: NonZeroU64::new(1).unwrap(),
// });
let key = (target, restriction);
tracing::info!(?limit);
let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get();
tracing::info!(?increment);
let mut prev_expectation = self
.store
.get(&key)
.as_deref()
.map(|n| n.load(Ordering::Acquire))
.unwrap_or_else(|| arrival + increment);
let weight = (increment * limit.burst_capacity.get()).max(1);
tracing::info!(?prev_expectation);
tracing::info!(?weight);
let f = |prev_expectation: u64| {
let allowed = prev_expectation.saturating_sub(weight);
if arrival < allowed {
Err(Duration::from_nanos(allowed - arrival))
} else {
Ok(prev_expectation.max(arrival) + increment)
}
};
let mut decision = f(prev_expectation);
tracing::info!(?decision);
let entry = self.store.entry(key);
while let Ok(next_expectation) = decision {
match entry.or_default().compare_exchange_weak(
prev_expectation,
next_expectation,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return Ok(()),
Err(actual) => prev_expectation = actual,
}
decision = f(prev_expectation);
}
decision.map(|_| ())
}
}
///// In-memory state and utility functions used to check whether the client has exceeded its rate limit.
///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible.
/////