1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00

more rate limit targets

This commit is contained in:
Matthias Ahouansou 2024-06-26 09:16:54 +01:00
parent 02cea0bb93
commit d6abf5472b
No known key found for this signature in database
4 changed files with 123 additions and 67 deletions

40
Cargo.lock generated
View file

@ -189,6 +189,7 @@ dependencies = [
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper 1.0.1",
"tokio",
"tower",
"tower-layer",
"tower-service",
@ -496,6 +497,7 @@ dependencies = [
"base64 0.22.1",
"bytes",
"clap",
"dashmap",
"directories",
"figment",
"futures-util",
@ -515,6 +517,7 @@ dependencies = [
"opentelemetry_sdk",
"parking_lot",
"persy",
"quanta",
"rand",
"regex",
"reqwest",
@ -664,6 +667,19 @@ dependencies = [
"syn",
]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.6.0"
@ -2053,6 +2069,21 @@ dependencies = [
"syn",
]
[[package]]
name = "quanta"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]]
name = "quick-error"
version = "1.2.3"
@ -2098,6 +2129,15 @@ dependencies = [
"getrandom",
]
[[package]]
name = "raw-cpuid"
version = "11.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd"
dependencies = [
"bitflags 2.5.0",
]
[[package]]
name = "redox_syscall"
version = "0.5.1"

View file

@ -1,5 +1,4 @@
pub mod appservice_server;
pub mod client_server;
pub mod rate_limiting;
pub mod ruma_wrapper;
pub mod server_server;

View file

@ -1,7 +1,8 @@
use std::{
collections::BTreeMap,
iter::FromIterator,
str::{self},
net::IpAddr,
str::{self, FromStr},
};
use axum::{
@ -102,44 +103,6 @@ where
Token::None
};
// doesn't work when Conduit is behind proxy
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
let target = match &token {
Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
Token::None => {
let header = parts
.headers
.get("x-forwarded-for")
.ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
let s = header
.to_str()
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
Some(
s.parse()
.map(Target::Ip)
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
)
.transpose()?
}
_ => None,
};
if let Err(retry_after_ms) = target.map_or(Ok(()), |t| {
let key = (t, (&metadata).into());
services()
.rate_limiting
.update_or_reject(&key)
.map_err(Some)
}) {
return Err(Error::BadRequest(
ErrorKind::LimitExceeded { retry_after_ms },
"Rate limit exceeded.",
));
}
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let (sender_user, sender_device, sender_servername, appservice_info) =
@ -350,8 +313,64 @@ where
}
};
// doesn't work when Conduit is behind proxy
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
let headers = parts.headers;
let target = if let Some(server) = sender_servername.clone() {
Target::Server(server)
// Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
// Token::None => {
// let header = parts
// .headers
// .get("x-forwarded-for")
// .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
// let s = header
// .to_str()
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
// Some(
// s.parse()
// .map(Target::Ip)
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
// )
// .transpose()?
// }
// _ => None,
} else if let Some(appservice) = appservice_info.clone() {
Target::Appservice(appservice.registration.id)
} else if let Some(user) = sender_user.clone() {
Target::User(user)
} else {
let ip = headers
.get("X-Forwarded-For")
.and_then(|header| header.to_str().ok())
.map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header))
.and_then(|ip| IpAddr::from_str(ip).ok());
if let Some(ip) = ip {
Target::Ip(ip)
} else {
Target::None
}
};
if let Err(retry_after_ms) = {
services()
.rate_limiting
.update_or_reject(target, metadata)
.map_err(Some)
} {
return Err(Error::BadRequest(
ErrorKind::LimitExceeded { retry_after_ms },
"Rate limit exceeded.",
));
}
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
*http_request.headers_mut().unwrap() = parts.headers;
*http_request.headers_mut().unwrap() = headers;
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
let user_id = sender_user.clone().unwrap_or_else(|| {

View file

@ -1,7 +1,6 @@
use std::{
hash::Hash,
net::IpAddr,
num::NonZeroU64,
sync::atomic::{AtomicU64, Ordering},
time::Duration,
};
@ -13,16 +12,13 @@ use ruma::{
client::{account::register, session::login},
IncomingRequest, Metadata,
},
OwnedUserId,
OwnedServerName, OwnedUserId,
};
use crate::{
config::{Limitation, Restriction},
services, Result,
};
use crate::{config::Restriction, services, Result};
impl From<&Metadata> for Restriction {
fn from(metadata: &Metadata) -> Self {
impl From<Metadata> for Restriction {
fn from(metadata: Metadata) -> Self {
[
(register::v3::Request::METADATA, Restriction::Registration),
(login::v3::Request::METADATA, Restriction::Login),
@ -48,7 +44,10 @@ pub struct Service {
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum Target {
User(OwnedUserId),
Server(OwnedServerName),
Appservice(String),
Ip(IpAddr),
None,
}
impl Service {
@ -59,14 +58,14 @@ impl Service {
}
}
pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> {
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
let restriction = metadata.into();
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
let config = &services().globals.config.rate_limiting;
let Some(limit) = config
.get(&key.1)
.map(ToOwned::to_owned) else {
let Some(limit) = config.get(&restriction).map(ToOwned::to_owned) else {
return Ok(());
};
// .unwrap_or(Limitation {
@ -75,17 +74,16 @@ impl Service {
// weight: NonZeroU64::new(1).unwrap(),
// });
let key = (target, restriction);
tracing::info!(?limit);
let increment = u64::try_from(Duration::from_secs(60).as_nanos())
.expect("1_000_000_000 to be smaller than u64::MAX")
/ limit.per_minute.get()
* limit.weight.get();
let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get();
tracing::info!(?increment);
let mut prev_expectation = self
.store
.get(key)
.get(&key)
.as_deref()
.map(|n| n.load(Ordering::Acquire))
.unwrap_or_else(|| arrival + increment);
@ -108,9 +106,9 @@ impl Service {
tracing::info!(?decision);
while let Ok(next_expectation) = decision {
let entry = self.store.entry(key.clone());
let entry = self.store.entry(key);
while let Ok(next_expectation) = decision {
match entry.or_default().compare_exchange_weak(
prev_expectation,
next_expectation,