mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-06-27 16:35:59 +00:00
Merge branch 'rate-limiting' into 'next'
Draft: feat: rate limiting Closes #4 See merge request famedly/conduit!693
This commit is contained in:
commit
e0054552ea
7 changed files with 303 additions and 18 deletions
42
Cargo.lock
generated
42
Cargo.lock
generated
|
@ -189,6 +189,7 @@ dependencies = [
|
||||||
"serde_path_to_error",
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper 1.0.1",
|
"sync_wrapper 1.0.1",
|
||||||
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
|
@ -496,6 +497,7 @@ dependencies = [
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"bytes",
|
"bytes",
|
||||||
"clap",
|
"clap",
|
||||||
|
"dashmap",
|
||||||
"directories",
|
"directories",
|
||||||
"figment",
|
"figment",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
@ -664,6 +666,19 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dashmap"
|
||||||
|
version = "5.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"hashbrown 0.14.5",
|
||||||
|
"lock_api",
|
||||||
|
"once_cell",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "data-encoding"
|
name = "data-encoding"
|
||||||
version = "2.6.0"
|
version = "2.6.0"
|
||||||
|
@ -2232,7 +2247,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma"
|
name = "ruma"
|
||||||
version = "0.10.1"
|
version = "0.10.1"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"assign",
|
"assign",
|
||||||
"js_int",
|
"js_int",
|
||||||
|
@ -2253,7 +2268,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-appservice-api"
|
name = "ruma-appservice-api"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"js_int",
|
"js_int",
|
||||||
"ruma-common",
|
"ruma-common",
|
||||||
|
@ -2265,7 +2280,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-client-api"
|
name = "ruma-client-api"
|
||||||
version = "0.18.0"
|
version = "0.18.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"as_variant",
|
"as_variant",
|
||||||
"assign",
|
"assign",
|
||||||
|
@ -2288,7 +2303,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-common"
|
name = "ruma-common"
|
||||||
version = "0.13.0"
|
version = "0.13.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"as_variant",
|
"as_variant",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
|
@ -2318,7 +2333,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-events"
|
name = "ruma-events"
|
||||||
version = "0.28.1"
|
version = "0.28.1"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"as_variant",
|
"as_variant",
|
||||||
"indexmap 2.2.6",
|
"indexmap 2.2.6",
|
||||||
|
@ -2334,13 +2349,14 @@ dependencies = [
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tracing",
|
"tracing",
|
||||||
"url",
|
"url",
|
||||||
|
"web-time",
|
||||||
"wildmatch",
|
"wildmatch",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-federation-api"
|
name = "ruma-federation-api"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"js_int",
|
"js_int",
|
||||||
"ruma-common",
|
"ruma-common",
|
||||||
|
@ -2352,7 +2368,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-identifiers-validation"
|
name = "ruma-identifiers-validation"
|
||||||
version = "0.9.5"
|
version = "0.9.5"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"js_int",
|
"js_int",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
@ -2361,7 +2377,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-identity-service-api"
|
name = "ruma-identity-service-api"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"js_int",
|
"js_int",
|
||||||
"ruma-common",
|
"ruma-common",
|
||||||
|
@ -2371,7 +2387,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-macros"
|
name = "ruma-macros"
|
||||||
version = "0.13.0"
|
version = "0.13.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"proc-macro-crate",
|
"proc-macro-crate",
|
||||||
|
@ -2386,7 +2402,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-push-gateway-api"
|
name = "ruma-push-gateway-api"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"js_int",
|
"js_int",
|
||||||
"ruma-common",
|
"ruma-common",
|
||||||
|
@ -2398,7 +2414,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-server-util"
|
name = "ruma-server-util"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"headers",
|
"headers",
|
||||||
"http 1.1.0",
|
"http 1.1.0",
|
||||||
|
@ -2411,7 +2427,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-signatures"
|
name = "ruma-signatures"
|
||||||
version = "0.15.0"
|
version = "0.15.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"ed25519-dalek",
|
"ed25519-dalek",
|
||||||
|
@ -2427,7 +2443,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruma-state-res"
|
name = "ruma-state-res"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itertools",
|
"itertools",
|
||||||
"js_int",
|
"js_int",
|
||||||
|
|
|
@ -34,6 +34,7 @@ axum = { version = "0.7", default-features = false, features = [
|
||||||
"http2",
|
"http2",
|
||||||
"json",
|
"json",
|
||||||
"matched-path",
|
"matched-path",
|
||||||
|
"tokio",
|
||||||
], optional = true }
|
], optional = true }
|
||||||
axum-extra = { version = "0.9", features = ["typed-header"] }
|
axum-extra = { version = "0.9", features = ["typed-header"] }
|
||||||
axum-server = { version = "0.6", features = ["tls-rustls"] }
|
axum-server = { version = "0.6", features = ["tls-rustls"] }
|
||||||
|
@ -145,6 +146,7 @@ tikv-jemallocator = { version = "0.5.0", features = [
|
||||||
"unprefixed_malloc_on_supported_platforms",
|
"unprefixed_malloc_on_supported_platforms",
|
||||||
], optional = true }
|
], optional = true }
|
||||||
|
|
||||||
|
dashmap = "5.5.3"
|
||||||
sd-notify = { version = "0.4.1", optional = true }
|
sd-notify = { version = "0.4.1", optional = true }
|
||||||
|
|
||||||
# Used for matrix spec type definitions and helpers
|
# Used for matrix spec type definitions and helpers
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
use std::{collections::BTreeMap, iter::FromIterator, str};
|
use std::{
|
||||||
|
collections::BTreeMap,
|
||||||
|
iter::FromIterator,
|
||||||
|
net::IpAddr,
|
||||||
|
str::{self, FromStr},
|
||||||
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait,
|
async_trait,
|
||||||
|
@ -15,7 +20,10 @@ use axum_extra::{
|
||||||
use bytes::{BufMut, BytesMut};
|
use bytes::{BufMut, BytesMut};
|
||||||
use http::{Request, StatusCode};
|
use http::{Request, StatusCode};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
|
api::{
|
||||||
|
client::error::{ErrorKind, RetryAfter},
|
||||||
|
AuthScheme, IncomingRequest, OutgoingResponse,
|
||||||
|
},
|
||||||
server_util::authorization::XMatrix,
|
server_util::authorization::XMatrix,
|
||||||
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
|
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
|
||||||
};
|
};
|
||||||
|
@ -23,7 +31,10 @@ use serde::Deserialize;
|
||||||
use tracing::{debug, error, warn};
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
use super::{Ruma, RumaResponse};
|
use super::{Ruma, RumaResponse};
|
||||||
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
|
use crate::{
|
||||||
|
service::{appservice::RegistrationInfo, rate_limiting::Target},
|
||||||
|
services, Error, Result,
|
||||||
|
};
|
||||||
|
|
||||||
enum Token {
|
enum Token {
|
||||||
Appservice(Box<RegistrationInfo>),
|
Appservice(Box<RegistrationInfo>),
|
||||||
|
@ -305,8 +316,51 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// doesn't work when Conduit is behind proxy
|
||||||
|
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
|
||||||
|
|
||||||
|
let headers = parts.headers;
|
||||||
|
|
||||||
|
let target = if let Some(server) = sender_servername.clone() {
|
||||||
|
Target::Server(server)
|
||||||
|
} else if let Some(appservice) = appservice_info.clone() {
|
||||||
|
if appservice.registration.rate_limited.unwrap_or(true) {
|
||||||
|
Target::Appservice(appservice.registration.id)
|
||||||
|
} else {
|
||||||
|
Target::None
|
||||||
|
}
|
||||||
|
} else if let Some(user) = sender_user.clone() {
|
||||||
|
Target::User(user)
|
||||||
|
} else {
|
||||||
|
let ip = headers
|
||||||
|
.get("X-Forwarded-For")
|
||||||
|
.and_then(|header| header.to_str().ok())
|
||||||
|
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
|
||||||
|
.and_then(|ip| IpAddr::from_str(ip).ok());
|
||||||
|
|
||||||
|
if let Some(ip) = ip {
|
||||||
|
Target::Ip(ip)
|
||||||
|
} else {
|
||||||
|
Target::None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(retry_after) = {
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_or_reject(target, metadata)
|
||||||
|
.map_err(Some)
|
||||||
|
} {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::LimitExceeded {
|
||||||
|
retry_after: retry_after.map(RetryAfter::Delay),
|
||||||
|
},
|
||||||
|
"Rate limit exceeded.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
||||||
*http_request.headers_mut().unwrap() = parts.headers;
|
*http_request.headers_mut().unwrap() = headers;
|
||||||
|
|
||||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||||
|
|
|
@ -2,6 +2,7 @@ use std::{
|
||||||
collections::BTreeMap,
|
collections::BTreeMap,
|
||||||
fmt,
|
fmt,
|
||||||
net::{IpAddr, Ipv4Addr},
|
net::{IpAddr, Ipv4Addr},
|
||||||
|
num::NonZeroU64,
|
||||||
};
|
};
|
||||||
|
|
||||||
use ruma::{OwnedServerName, RoomVersionId};
|
use ruma::{OwnedServerName, RoomVersionId};
|
||||||
|
@ -82,6 +83,8 @@ pub struct Config {
|
||||||
pub turn_secret: String,
|
pub turn_secret: String,
|
||||||
#[serde(default = "default_turn_ttl")]
|
#[serde(default = "default_turn_ttl")]
|
||||||
pub turn_ttl: u64,
|
pub turn_ttl: u64,
|
||||||
|
#[serde(default = "default_rate_limit")]
|
||||||
|
pub rate_limiting: BTreeMap<Restriction, Limitation>,
|
||||||
|
|
||||||
pub emergency_password: Option<String>,
|
pub emergency_password: Option<String>,
|
||||||
|
|
||||||
|
@ -101,6 +104,56 @@ pub struct WellKnownConfig {
|
||||||
pub server: Option<OwnedServerName>,
|
pub server: Option<OwnedServerName>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Restriction {
|
||||||
|
Registration,
|
||||||
|
Login,
|
||||||
|
RegistrationTokenValidity,
|
||||||
|
Message,
|
||||||
|
Join,
|
||||||
|
Invite,
|
||||||
|
Knock,
|
||||||
|
CreateMedia,
|
||||||
|
Transaction,
|
||||||
|
FederatedJoin,
|
||||||
|
FederatedInvite,
|
||||||
|
FederatedKnock,
|
||||||
|
|
||||||
|
#[default]
|
||||||
|
CatchAll,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Clone, Copy, Debug)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
// When deserializing, we want this prefix
|
||||||
|
#[allow(clippy::enum_variant_names)]
|
||||||
|
pub enum Timeframe {
|
||||||
|
PerSecond(NonZeroU64),
|
||||||
|
PerMinute(NonZeroU64),
|
||||||
|
PerHour(NonZeroU64),
|
||||||
|
PerDay(NonZeroU64),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Timeframe {
|
||||||
|
pub fn nano_gap(&self) -> u64 {
|
||||||
|
match self {
|
||||||
|
Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(),
|
||||||
|
Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(),
|
||||||
|
Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(),
|
||||||
|
Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||||
|
pub struct Limitation {
|
||||||
|
#[serde(default = "default_non_zero", flatten)]
|
||||||
|
pub timeframe: Timeframe,
|
||||||
|
#[serde(default = "default_non_zero")]
|
||||||
|
pub burst_capacity: NonZeroU64,
|
||||||
|
}
|
||||||
|
|
||||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
@ -308,6 +361,20 @@ fn default_openid_token_ttl() -> u64 {
|
||||||
60 * 60
|
60 * 60
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_non_zero() -> NonZeroU64 {
|
||||||
|
NonZeroU64::MIN
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_rate_limit() -> BTreeMap<Restriction, Limitation> {
|
||||||
|
BTreeMap::from_iter([(
|
||||||
|
Restriction::default(),
|
||||||
|
Limitation {
|
||||||
|
timeframe: Timeframe::PerMinute(NonZeroU64::MIN),
|
||||||
|
burst_capacity: NonZeroU64::MIN,
|
||||||
|
},
|
||||||
|
)])
|
||||||
|
}
|
||||||
|
|
||||||
// I know, it's a great name
|
// I know, it's a great name
|
||||||
pub fn default_default_room_version() -> RoomVersionId {
|
pub fn default_default_room_version() -> RoomVersionId {
|
||||||
RoomVersionId::V10
|
RoomVersionId::V10
|
||||||
|
|
|
@ -200,7 +200,9 @@ async fn run_server() -> io::Result<()> {
|
||||||
.expect("failed to convert max request size"),
|
.expect("failed to convert max request size"),
|
||||||
));
|
));
|
||||||
|
|
||||||
let app = routes(config).layer(middlewares).into_make_service();
|
let app = routes(config)
|
||||||
|
.layer(middlewares)
|
||||||
|
.into_make_service_with_connect_info::<SocketAddr>();
|
||||||
let handle = ServerHandle::new();
|
let handle = ServerHandle::new();
|
||||||
|
|
||||||
tokio::spawn(shutdown_signal(handle.clone()));
|
tokio::spawn(shutdown_signal(handle.clone()));
|
||||||
|
|
|
@ -17,6 +17,7 @@ pub mod key_backups;
|
||||||
pub mod media;
|
pub mod media;
|
||||||
pub mod pdu;
|
pub mod pdu;
|
||||||
pub mod pusher;
|
pub mod pusher;
|
||||||
|
pub mod rate_limiting;
|
||||||
pub mod rooms;
|
pub mod rooms;
|
||||||
pub mod sending;
|
pub mod sending;
|
||||||
pub mod transaction_ids;
|
pub mod transaction_ids;
|
||||||
|
@ -26,6 +27,7 @@ pub mod users;
|
||||||
pub struct Services {
|
pub struct Services {
|
||||||
pub appservice: appservice::Service,
|
pub appservice: appservice::Service,
|
||||||
pub pusher: pusher::Service,
|
pub pusher: pusher::Service,
|
||||||
|
pub rate_limiting: rate_limiting::Service,
|
||||||
pub rooms: rooms::Service,
|
pub rooms: rooms::Service,
|
||||||
pub transaction_ids: transaction_ids::Service,
|
pub transaction_ids: transaction_ids::Service,
|
||||||
pub uiaa: uiaa::Service,
|
pub uiaa: uiaa::Service,
|
||||||
|
@ -59,6 +61,7 @@ impl Services {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
appservice: appservice::Service::build(db)?,
|
appservice: appservice::Service::build(db)?,
|
||||||
pusher: pusher::Service { db },
|
pusher: pusher::Service { db },
|
||||||
|
rate_limiting: rate_limiting::Service::build(),
|
||||||
rooms: rooms::Service {
|
rooms: rooms::Service {
|
||||||
alias: rooms::alias::Service { db },
|
alias: rooms::alias::Service { db },
|
||||||
auth_chain: rooms::auth_chain::Service { db },
|
auth_chain: rooms::auth_chain::Service { db },
|
||||||
|
|
141
src/service/rate_limiting/mod.rs
Normal file
141
src/service/rate_limiting/mod.rs
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
use std::{
|
||||||
|
hash::Hash,
|
||||||
|
net::IpAddr,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
|
||||||
|
use dashmap::{mapref::entry::Entry, DashMap};
|
||||||
|
use ruma::{
|
||||||
|
api::{
|
||||||
|
client::{
|
||||||
|
account::{check_registration_token_validity, register},
|
||||||
|
knock::knock_room,
|
||||||
|
media::{create_content, create_content_async},
|
||||||
|
membership::{invite_user, join_room_by_id, join_room_by_id_or_alias},
|
||||||
|
message::send_message_event,
|
||||||
|
session::login,
|
||||||
|
state::send_state_event,
|
||||||
|
},
|
||||||
|
federation::{
|
||||||
|
knock::send_knock,
|
||||||
|
membership::{create_invite, create_join_event},
|
||||||
|
transactions::send_transaction_message,
|
||||||
|
},
|
||||||
|
IncomingRequest, Metadata,
|
||||||
|
},
|
||||||
|
OwnedServerName, OwnedUserId,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{config::Restriction, services, Result};
|
||||||
|
|
||||||
|
impl From<Metadata> for Restriction {
|
||||||
|
fn from(metadata: Metadata) -> Self {
|
||||||
|
#[allow(deprecated)]
|
||||||
|
match metadata {
|
||||||
|
register::v3::Request::METADATA => Restriction::Registration,
|
||||||
|
login::v3::Request::METADATA => Restriction::Login,
|
||||||
|
check_registration_token_validity::v1::Request::METADATA => {
|
||||||
|
Restriction::RegistrationTokenValidity
|
||||||
|
}
|
||||||
|
send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => {
|
||||||
|
Restriction::Message
|
||||||
|
}
|
||||||
|
join_room_by_id::v3::Request::METADATA
|
||||||
|
| join_room_by_id_or_alias::v3::Request::METADATA => Restriction::Join,
|
||||||
|
invite_user::v3::Request::METADATA => Restriction::Invite,
|
||||||
|
create_content::v3::Request::METADATA | create_content_async::v3::Request::METADATA => {
|
||||||
|
Restriction::CreateMedia
|
||||||
|
}
|
||||||
|
send_transaction_message::v1::Request::METADATA => Restriction::Transaction,
|
||||||
|
create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => {
|
||||||
|
Restriction::FederatedJoin
|
||||||
|
}
|
||||||
|
create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => {
|
||||||
|
Restriction::FederatedInvite
|
||||||
|
}
|
||||||
|
send_knock::v1::Request::METADATA => Restriction::FederatedKnock,
|
||||||
|
knock_room::v3::Request::METADATA => Restriction::Knock,
|
||||||
|
_ => Self::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Service {
|
||||||
|
store: DashMap<(Target, Restriction), (Instant, u64)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub enum Target {
|
||||||
|
User(OwnedUserId),
|
||||||
|
// Server endpoints should be rate-limited on a server and room basis
|
||||||
|
Server(OwnedServerName),
|
||||||
|
Appservice(String),
|
||||||
|
Ip(IpAddr),
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service {
|
||||||
|
pub fn build() -> Self {
|
||||||
|
Self {
|
||||||
|
store: DashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
|
||||||
|
if target == Target::None {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let restriction = metadata.into();
|
||||||
|
|
||||||
|
let arrival = Instant::now();
|
||||||
|
|
||||||
|
let config = &services().globals.config.rate_limiting;
|
||||||
|
|
||||||
|
let Some(limit) = config.get(&restriction) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
// .unwrap_or(Limitation {
|
||||||
|
// per_minute: NonZeroU64::new(1).unwrap(),
|
||||||
|
// burst_capacity: NonZeroU64::new(1).unwrap(),
|
||||||
|
// weight: NonZeroU64::new(1).unwrap(),
|
||||||
|
// });
|
||||||
|
|
||||||
|
let gap = Duration::from_nanos(limit.timeframe.nano_gap());
|
||||||
|
let key = (target, restriction);
|
||||||
|
|
||||||
|
match self.store.entry(key) {
|
||||||
|
Entry::Occupied(mut entry) => {
|
||||||
|
let (instant, capacity) = entry.get_mut();
|
||||||
|
|
||||||
|
while *instant < arrival && *capacity != 0 {
|
||||||
|
*capacity -= 1;
|
||||||
|
*instant += gap;
|
||||||
|
}
|
||||||
|
|
||||||
|
if *capacity >= limit.burst_capacity.get() {
|
||||||
|
return Err(gap);
|
||||||
|
} else {
|
||||||
|
let zero_capacity = *capacity == 0;
|
||||||
|
*capacity += 1;
|
||||||
|
|
||||||
|
// Ensures that the update point is in the future
|
||||||
|
if zero_capacity {
|
||||||
|
*instant = Instant::now()
|
||||||
|
}
|
||||||
|
|
||||||
|
*instant += gap;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Entry::Vacant(entry) => {
|
||||||
|
entry.insert((Instant::now() + gap, 1));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///// In-memory state and utility functions used to check whether the client has exceeded its rate limit.
|
||||||
|
///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible.
|
||||||
|
/////
|
Loading…
Add table
Add a link
Reference in a new issue