mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-06-27 16:35:59 +00:00
more rate limit targets
This commit is contained in:
parent
02cea0bb93
commit
d6abf5472b
4 changed files with 123 additions and 67 deletions
40
Cargo.lock
generated
40
Cargo.lock
generated
|
@ -189,6 +189,7 @@ dependencies = [
|
||||||
"serde_path_to_error",
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper 1.0.1",
|
"sync_wrapper 1.0.1",
|
||||||
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
|
@ -496,6 +497,7 @@ dependencies = [
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"bytes",
|
"bytes",
|
||||||
"clap",
|
"clap",
|
||||||
|
"dashmap",
|
||||||
"directories",
|
"directories",
|
||||||
"figment",
|
"figment",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
@ -515,6 +517,7 @@ dependencies = [
|
||||||
"opentelemetry_sdk",
|
"opentelemetry_sdk",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"persy",
|
"persy",
|
||||||
|
"quanta",
|
||||||
"rand",
|
"rand",
|
||||||
"regex",
|
"regex",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
@ -664,6 +667,19 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dashmap"
|
||||||
|
version = "5.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"hashbrown 0.14.5",
|
||||||
|
"lock_api",
|
||||||
|
"once_cell",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "data-encoding"
|
name = "data-encoding"
|
||||||
version = "2.6.0"
|
version = "2.6.0"
|
||||||
|
@ -2053,6 +2069,21 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quanta"
|
||||||
|
version = "0.12.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
"libc",
|
||||||
|
"once_cell",
|
||||||
|
"raw-cpuid",
|
||||||
|
"wasi",
|
||||||
|
"web-sys",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quick-error"
|
name = "quick-error"
|
||||||
version = "1.2.3"
|
version = "1.2.3"
|
||||||
|
@ -2098,6 +2129,15 @@ dependencies = [
|
||||||
"getrandom",
|
"getrandom",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "raw-cpuid"
|
||||||
|
version = "11.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.5.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.1"
|
version = "0.5.1"
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
pub mod appservice_server;
|
pub mod appservice_server;
|
||||||
pub mod client_server;
|
pub mod client_server;
|
||||||
pub mod rate_limiting;
|
|
||||||
pub mod ruma_wrapper;
|
pub mod ruma_wrapper;
|
||||||
pub mod server_server;
|
pub mod server_server;
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::BTreeMap,
|
collections::BTreeMap,
|
||||||
iter::FromIterator,
|
iter::FromIterator,
|
||||||
str::{self},
|
net::IpAddr,
|
||||||
|
str::{self, FromStr},
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
|
@ -102,44 +103,6 @@ where
|
||||||
Token::None
|
Token::None
|
||||||
};
|
};
|
||||||
|
|
||||||
// doesn't work when Conduit is behind proxy
|
|
||||||
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
|
|
||||||
|
|
||||||
let target = match &token {
|
|
||||||
Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
|
|
||||||
Token::None => {
|
|
||||||
let header = parts
|
|
||||||
.headers
|
|
||||||
.get("x-forwarded-for")
|
|
||||||
.ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
|
||||||
|
|
||||||
let s = header
|
|
||||||
.to_str()
|
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
|
||||||
Some(
|
|
||||||
s.parse()
|
|
||||||
.map(Target::Ip)
|
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
|
|
||||||
)
|
|
||||||
.transpose()?
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(retry_after_ms) = target.map_or(Ok(()), |t| {
|
|
||||||
let key = (t, (&metadata).into());
|
|
||||||
|
|
||||||
services()
|
|
||||||
.rate_limiting
|
|
||||||
.update_or_reject(&key)
|
|
||||||
.map_err(Some)
|
|
||||||
}) {
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::LimitExceeded { retry_after_ms },
|
|
||||||
"Rate limit exceeded.",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||||
|
|
||||||
let (sender_user, sender_device, sender_servername, appservice_info) =
|
let (sender_user, sender_device, sender_servername, appservice_info) =
|
||||||
|
@ -350,8 +313,64 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// doesn't work when Conduit is behind proxy
|
||||||
|
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
|
||||||
|
|
||||||
|
let headers = parts.headers;
|
||||||
|
|
||||||
|
let target = if let Some(server) = sender_servername.clone() {
|
||||||
|
Target::Server(server)
|
||||||
|
|
||||||
|
// Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
|
||||||
|
// Token::None => {
|
||||||
|
// let header = parts
|
||||||
|
// .headers
|
||||||
|
// .get("x-forwarded-for")
|
||||||
|
// .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||||
|
|
||||||
|
// let s = header
|
||||||
|
// .to_str()
|
||||||
|
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||||
|
// Some(
|
||||||
|
// s.parse()
|
||||||
|
// .map(Target::Ip)
|
||||||
|
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
|
||||||
|
// )
|
||||||
|
// .transpose()?
|
||||||
|
// }
|
||||||
|
// _ => None,
|
||||||
|
} else if let Some(appservice) = appservice_info.clone() {
|
||||||
|
Target::Appservice(appservice.registration.id)
|
||||||
|
} else if let Some(user) = sender_user.clone() {
|
||||||
|
Target::User(user)
|
||||||
|
} else {
|
||||||
|
let ip = headers
|
||||||
|
.get("X-Forwarded-For")
|
||||||
|
.and_then(|header| header.to_str().ok())
|
||||||
|
.map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header))
|
||||||
|
.and_then(|ip| IpAddr::from_str(ip).ok());
|
||||||
|
|
||||||
|
if let Some(ip) = ip {
|
||||||
|
Target::Ip(ip)
|
||||||
|
} else {
|
||||||
|
Target::None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(retry_after_ms) = {
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_or_reject(target, metadata)
|
||||||
|
.map_err(Some)
|
||||||
|
} {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::LimitExceeded { retry_after_ms },
|
||||||
|
"Rate limit exceeded.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
||||||
*http_request.headers_mut().unwrap() = parts.headers;
|
*http_request.headers_mut().unwrap() = headers;
|
||||||
|
|
||||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
hash::Hash,
|
hash::Hash,
|
||||||
net::IpAddr,
|
net::IpAddr,
|
||||||
num::NonZeroU64,
|
|
||||||
sync::atomic::{AtomicU64, Ordering},
|
sync::atomic::{AtomicU64, Ordering},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
@ -13,16 +12,13 @@ use ruma::{
|
||||||
client::{account::register, session::login},
|
client::{account::register, session::login},
|
||||||
IncomingRequest, Metadata,
|
IncomingRequest, Metadata,
|
||||||
},
|
},
|
||||||
OwnedUserId,
|
OwnedServerName, OwnedUserId,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{config::Restriction, services, Result};
|
||||||
config::{Limitation, Restriction},
|
|
||||||
services, Result,
|
|
||||||
};
|
|
||||||
|
|
||||||
impl From<&Metadata> for Restriction {
|
impl From<Metadata> for Restriction {
|
||||||
fn from(metadata: &Metadata) -> Self {
|
fn from(metadata: Metadata) -> Self {
|
||||||
[
|
[
|
||||||
(register::v3::Request::METADATA, Restriction::Registration),
|
(register::v3::Request::METADATA, Restriction::Registration),
|
||||||
(login::v3::Request::METADATA, Restriction::Login),
|
(login::v3::Request::METADATA, Restriction::Login),
|
||||||
|
@ -48,7 +44,10 @@ pub struct Service {
|
||||||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
pub enum Target {
|
pub enum Target {
|
||||||
User(OwnedUserId),
|
User(OwnedUserId),
|
||||||
|
Server(OwnedServerName),
|
||||||
|
Appservice(String),
|
||||||
Ip(IpAddr),
|
Ip(IpAddr),
|
||||||
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
|
@ -59,33 +58,32 @@ impl Service {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> {
|
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
|
||||||
|
let restriction = metadata.into();
|
||||||
|
|
||||||
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
|
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
|
||||||
|
|
||||||
let config = &services().globals.config.rate_limiting;
|
let config = &services().globals.config.rate_limiting;
|
||||||
|
|
||||||
let Some(limit) = config
|
let Some(limit) = config.get(&restriction).map(ToOwned::to_owned) else {
|
||||||
.get(&key.1)
|
return Ok(());
|
||||||
.map(ToOwned::to_owned) else {
|
};
|
||||||
return Ok(());
|
// .unwrap_or(Limitation {
|
||||||
};
|
// per_minute: NonZeroU64::new(1).unwrap(),
|
||||||
// .unwrap_or(Limitation {
|
// burst_capacity: NonZeroU64::new(1).unwrap(),
|
||||||
// per_minute: NonZeroU64::new(1).unwrap(),
|
// weight: NonZeroU64::new(1).unwrap(),
|
||||||
// burst_capacity: NonZeroU64::new(1).unwrap(),
|
// });
|
||||||
// weight: NonZeroU64::new(1).unwrap(),
|
|
||||||
// });
|
let key = (target, restriction);
|
||||||
|
|
||||||
tracing::info!(?limit);
|
tracing::info!(?limit);
|
||||||
|
|
||||||
let increment = u64::try_from(Duration::from_secs(60).as_nanos())
|
let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get();
|
||||||
.expect("1_000_000_000 to be smaller than u64::MAX")
|
|
||||||
/ limit.per_minute.get()
|
|
||||||
* limit.weight.get();
|
|
||||||
tracing::info!(?increment);
|
tracing::info!(?increment);
|
||||||
|
|
||||||
let mut prev_expectation = self
|
let mut prev_expectation = self
|
||||||
.store
|
.store
|
||||||
.get(key)
|
.get(&key)
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.map(|n| n.load(Ordering::Acquire))
|
.map(|n| n.load(Ordering::Acquire))
|
||||||
.unwrap_or_else(|| arrival + increment);
|
.unwrap_or_else(|| arrival + increment);
|
||||||
|
@ -108,9 +106,9 @@ impl Service {
|
||||||
|
|
||||||
tracing::info!(?decision);
|
tracing::info!(?decision);
|
||||||
|
|
||||||
while let Ok(next_expectation) = decision {
|
let entry = self.store.entry(key);
|
||||||
let entry = self.store.entry(key.clone());
|
|
||||||
|
|
||||||
|
while let Ok(next_expectation) = decision {
|
||||||
match entry.or_default().compare_exchange_weak(
|
match entry.or_default().compare_exchange_weak(
|
||||||
prev_expectation,
|
prev_expectation,
|
||||||
next_expectation,
|
next_expectation,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue