use std::{ hash::Hash, net::IpAddr, num::NonZeroU64, sync::atomic::{AtomicU64, Ordering}, time::Duration, }; use dashmap::DashMap; use quanta::Clock; use ruma::{ api::{ client::{account::register, session::login}, IncomingRequest, Metadata, }, OwnedUserId, }; use crate::{ config::{Limitation, Restriction}, services, Result, }; impl From<&Metadata> for Restriction { fn from(metadata: &Metadata) -> Self { [ (register::v3::Request::METADATA, Restriction::Registration), (login::v3::Request::METADATA, Restriction::Login), ] .iter() .find(|(other, _)| { metadata .history .stable_paths() .zip(other.history.stable_paths()) .all(|(a, b)| a == b) }) .map(|(_, restriction)| restriction.to_owned()) .unwrap_or_default() } } pub struct Service { store: DashMap<(Target, Restriction), AtomicU64>, clock: Clock, } #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum Target { User(OwnedUserId), Ip(IpAddr), } impl Service { pub fn build() -> Self { Self { store: DashMap::new(), clock: Clock::new(), } } pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> { let arrival = self.clock.delta_as_nanos(0, self.clock.raw()); let config = &services().globals.config.rate_limiting; let Some(limit) = config .get(&key.1) .map(ToOwned::to_owned) else { return Ok(()); }; // .unwrap_or(Limitation { // per_minute: NonZeroU64::new(1).unwrap(), // burst_capacity: NonZeroU64::new(1).unwrap(), // weight: NonZeroU64::new(1).unwrap(), // }); tracing::info!(?limit); let increment = u64::try_from(Duration::from_secs(60).as_nanos()) .expect("1_000_000_000 to be smaller than u64::MAX") / limit.per_minute.get() * limit.weight.get(); tracing::info!(?increment); let mut prev_expectation = self .store .get(key) .as_deref() .map(|n| n.load(Ordering::Acquire)) .unwrap_or_else(|| arrival + increment); let weight = (increment * limit.burst_capacity.get()).max(1); tracing::info!(?prev_expectation); tracing::info!(?weight); let f = |prev_expectation: u64| { let allowed = prev_expectation.saturating_sub(weight); if arrival < allowed { Err(Duration::from_nanos(allowed - arrival)) } else { Ok(prev_expectation.max(arrival) + increment) } }; let mut decision = f(prev_expectation); tracing::info!(?decision); while let Ok(next_expectation) = decision { let entry = self.store.entry(key.clone()); match entry.or_default().compare_exchange_weak( prev_expectation, next_expectation, Ordering::Release, Ordering::Relaxed, ) { Ok(_) => return Ok(()), Err(actual) => prev_expectation = actual, } decision = f(prev_expectation); } decision.map(|_| ()) } } ///// In-memory state and utility functions used to check whether the client has exceeded its rate limit. ///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible. /////