mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-06-27 16:35:59 +00:00
more rate limit targets
This commit is contained in:
parent
02cea0bb93
commit
d6abf5472b
4 changed files with 123 additions and 67 deletions
40
Cargo.lock
generated
40
Cargo.lock
generated
|
@ -189,6 +189,7 @@ dependencies = [
|
|||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper 1.0.1",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
|
@ -496,6 +497,7 @@ dependencies = [
|
|||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"clap",
|
||||
"dashmap",
|
||||
"directories",
|
||||
"figment",
|
||||
"futures-util",
|
||||
|
@ -515,6 +517,7 @@ dependencies = [
|
|||
"opentelemetry_sdk",
|
||||
"parking_lot",
|
||||
"persy",
|
||||
"quanta",
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest",
|
||||
|
@ -664,6 +667,19 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown 0.14.5",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.6.0"
|
||||
|
@ -2053,6 +2069,21 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "1.2.3"
|
||||
|
@ -2098,6 +2129,15 @@ dependencies = [
|
|||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.1"
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
pub mod appservice_server;
|
||||
pub mod client_server;
|
||||
pub mod rate_limiting;
|
||||
pub mod ruma_wrapper;
|
||||
pub mod server_server;
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
iter::FromIterator,
|
||||
str::{self},
|
||||
net::IpAddr,
|
||||
str::{self, FromStr},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
|
@ -102,44 +103,6 @@ where
|
|||
Token::None
|
||||
};
|
||||
|
||||
// doesn't work when Conduit is behind proxy
|
||||
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
|
||||
|
||||
let target = match &token {
|
||||
Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
|
||||
Token::None => {
|
||||
let header = parts
|
||||
.headers
|
||||
.get("x-forwarded-for")
|
||||
.ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||
|
||||
let s = header
|
||||
.to_str()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||
Some(
|
||||
s.parse()
|
||||
.map(Target::Ip)
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
|
||||
)
|
||||
.transpose()?
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Err(retry_after_ms) = target.map_or(Ok(()), |t| {
|
||||
let key = (t, (&metadata).into());
|
||||
|
||||
services()
|
||||
.rate_limiting
|
||||
.update_or_reject(&key)
|
||||
.map_err(Some)
|
||||
}) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::LimitExceeded { retry_after_ms },
|
||||
"Rate limit exceeded.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||
|
||||
let (sender_user, sender_device, sender_servername, appservice_info) =
|
||||
|
@ -350,8 +313,64 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
// doesn't work when Conduit is behind proxy
|
||||
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
|
||||
|
||||
let headers = parts.headers;
|
||||
|
||||
let target = if let Some(server) = sender_servername.clone() {
|
||||
Target::Server(server)
|
||||
|
||||
// Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
|
||||
// Token::None => {
|
||||
// let header = parts
|
||||
// .headers
|
||||
// .get("x-forwarded-for")
|
||||
// .ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||
|
||||
// let s = header
|
||||
// .to_str()
|
||||
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||
// Some(
|
||||
// s.parse()
|
||||
// .map(Target::Ip)
|
||||
// .map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
|
||||
// )
|
||||
// .transpose()?
|
||||
// }
|
||||
// _ => None,
|
||||
} else if let Some(appservice) = appservice_info.clone() {
|
||||
Target::Appservice(appservice.registration.id)
|
||||
} else if let Some(user) = sender_user.clone() {
|
||||
Target::User(user)
|
||||
} else {
|
||||
let ip = headers
|
||||
.get("X-Forwarded-For")
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.map(|header| header.split_once(",").map(|(ip, _)| ip).unwrap_or(header))
|
||||
.and_then(|ip| IpAddr::from_str(ip).ok());
|
||||
|
||||
if let Some(ip) = ip {
|
||||
Target::Ip(ip)
|
||||
} else {
|
||||
Target::None
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(retry_after_ms) = {
|
||||
services()
|
||||
.rate_limiting
|
||||
.update_or_reject(target, metadata)
|
||||
.map_err(Some)
|
||||
} {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::LimitExceeded { retry_after_ms },
|
||||
"Rate limit exceeded.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
||||
*http_request.headers_mut().unwrap() = parts.headers;
|
||||
*http_request.headers_mut().unwrap() = headers;
|
||||
|
||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use std::{
|
||||
hash::Hash,
|
||||
net::IpAddr,
|
||||
num::NonZeroU64,
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
|
@ -13,16 +12,13 @@ use ruma::{
|
|||
client::{account::register, session::login},
|
||||
IncomingRequest, Metadata,
|
||||
},
|
||||
OwnedUserId,
|
||||
OwnedServerName, OwnedUserId,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
config::{Limitation, Restriction},
|
||||
services, Result,
|
||||
};
|
||||
use crate::{config::Restriction, services, Result};
|
||||
|
||||
impl From<&Metadata> for Restriction {
|
||||
fn from(metadata: &Metadata) -> Self {
|
||||
impl From<Metadata> for Restriction {
|
||||
fn from(metadata: Metadata) -> Self {
|
||||
[
|
||||
(register::v3::Request::METADATA, Restriction::Registration),
|
||||
(login::v3::Request::METADATA, Restriction::Login),
|
||||
|
@ -48,7 +44,10 @@ pub struct Service {
|
|||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub enum Target {
|
||||
User(OwnedUserId),
|
||||
Server(OwnedServerName),
|
||||
Appservice(String),
|
||||
Ip(IpAddr),
|
||||
None,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
|
@ -59,33 +58,32 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> {
|
||||
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
|
||||
let restriction = metadata.into();
|
||||
|
||||
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
|
||||
|
||||
let config = &services().globals.config.rate_limiting;
|
||||
|
||||
let Some(limit) = config
|
||||
.get(&key.1)
|
||||
.map(ToOwned::to_owned) else {
|
||||
return Ok(());
|
||||
};
|
||||
// .unwrap_or(Limitation {
|
||||
// per_minute: NonZeroU64::new(1).unwrap(),
|
||||
// burst_capacity: NonZeroU64::new(1).unwrap(),
|
||||
// weight: NonZeroU64::new(1).unwrap(),
|
||||
// });
|
||||
let Some(limit) = config.get(&restriction).map(ToOwned::to_owned) else {
|
||||
return Ok(());
|
||||
};
|
||||
// .unwrap_or(Limitation {
|
||||
// per_minute: NonZeroU64::new(1).unwrap(),
|
||||
// burst_capacity: NonZeroU64::new(1).unwrap(),
|
||||
// weight: NonZeroU64::new(1).unwrap(),
|
||||
// });
|
||||
|
||||
let key = (target, restriction);
|
||||
|
||||
tracing::info!(?limit);
|
||||
|
||||
let increment = u64::try_from(Duration::from_secs(60).as_nanos())
|
||||
.expect("1_000_000_000 to be smaller than u64::MAX")
|
||||
/ limit.per_minute.get()
|
||||
* limit.weight.get();
|
||||
let increment = 1_000_000_000u64 / limit.per_minute.get() * limit.weight.get();
|
||||
tracing::info!(?increment);
|
||||
|
||||
let mut prev_expectation = self
|
||||
.store
|
||||
.get(key)
|
||||
.get(&key)
|
||||
.as_deref()
|
||||
.map(|n| n.load(Ordering::Acquire))
|
||||
.unwrap_or_else(|| arrival + increment);
|
||||
|
@ -108,9 +106,9 @@ impl Service {
|
|||
|
||||
tracing::info!(?decision);
|
||||
|
||||
while let Ok(next_expectation) = decision {
|
||||
let entry = self.store.entry(key.clone());
|
||||
let entry = self.store.entry(key);
|
||||
|
||||
while let Ok(next_expectation) = decision {
|
||||
match entry.or_default().compare_exchange_weak(
|
||||
prev_expectation,
|
||||
next_expectation,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue