mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-06-27 16:35:59 +00:00
Merge branch 'rate-limiting' into 'next'
Draft: feat: rate limiting Closes #4 See merge request famedly/conduit!693
This commit is contained in:
commit
e0054552ea
7 changed files with 303 additions and 18 deletions
42
Cargo.lock
generated
42
Cargo.lock
generated
|
@ -189,6 +189,7 @@ dependencies = [
|
|||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper 1.0.1",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
|
@ -496,6 +497,7 @@ dependencies = [
|
|||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"clap",
|
||||
"dashmap",
|
||||
"directories",
|
||||
"figment",
|
||||
"futures-util",
|
||||
|
@ -664,6 +666,19 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown 0.14.5",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.6.0"
|
||||
|
@ -2232,7 +2247,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma"
|
||||
version = "0.10.1"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"assign",
|
||||
"js_int",
|
||||
|
@ -2253,7 +2268,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-appservice-api"
|
||||
version = "0.10.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"js_int",
|
||||
"ruma-common",
|
||||
|
@ -2265,7 +2280,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-client-api"
|
||||
version = "0.18.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"as_variant",
|
||||
"assign",
|
||||
|
@ -2288,7 +2303,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-common"
|
||||
version = "0.13.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"as_variant",
|
||||
"base64 0.22.1",
|
||||
|
@ -2318,7 +2333,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-events"
|
||||
version = "0.28.1"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"as_variant",
|
||||
"indexmap 2.2.6",
|
||||
|
@ -2334,13 +2349,14 @@ dependencies = [
|
|||
"thiserror",
|
||||
"tracing",
|
||||
"url",
|
||||
"web-time",
|
||||
"wildmatch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruma-federation-api"
|
||||
version = "0.9.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"js_int",
|
||||
"ruma-common",
|
||||
|
@ -2352,7 +2368,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-identifiers-validation"
|
||||
version = "0.9.5"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"js_int",
|
||||
"thiserror",
|
||||
|
@ -2361,7 +2377,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-identity-service-api"
|
||||
version = "0.9.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"js_int",
|
||||
"ruma-common",
|
||||
|
@ -2371,7 +2387,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-macros"
|
||||
version = "0.13.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"proc-macro-crate",
|
||||
|
@ -2386,7 +2402,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-push-gateway-api"
|
||||
version = "0.9.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"js_int",
|
||||
"ruma-common",
|
||||
|
@ -2398,7 +2414,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-server-util"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"headers",
|
||||
"http 1.1.0",
|
||||
|
@ -2411,7 +2427,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-signatures"
|
||||
version = "0.15.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"ed25519-dalek",
|
||||
|
@ -2427,7 +2443,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "ruma-state-res"
|
||||
version = "0.11.0"
|
||||
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
|
||||
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
|
||||
dependencies = [
|
||||
"itertools",
|
||||
"js_int",
|
||||
|
|
|
@ -34,6 +34,7 @@ axum = { version = "0.7", default-features = false, features = [
|
|||
"http2",
|
||||
"json",
|
||||
"matched-path",
|
||||
"tokio",
|
||||
], optional = true }
|
||||
axum-extra = { version = "0.9", features = ["typed-header"] }
|
||||
axum-server = { version = "0.6", features = ["tls-rustls"] }
|
||||
|
@ -145,6 +146,7 @@ tikv-jemallocator = { version = "0.5.0", features = [
|
|||
"unprefixed_malloc_on_supported_platforms",
|
||||
], optional = true }
|
||||
|
||||
dashmap = "5.5.3"
|
||||
sd-notify = { version = "0.4.1", optional = true }
|
||||
|
||||
# Used for matrix spec type definitions and helpers
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
use std::{collections::BTreeMap, iter::FromIterator, str};
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
iter::FromIterator,
|
||||
net::IpAddr,
|
||||
str::{self, FromStr},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
async_trait,
|
||||
|
@ -15,7 +20,10 @@ use axum_extra::{
|
|||
use bytes::{BufMut, BytesMut};
|
||||
use http::{Request, StatusCode};
|
||||
use ruma::{
|
||||
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
|
||||
api::{
|
||||
client::error::{ErrorKind, RetryAfter},
|
||||
AuthScheme, IncomingRequest, OutgoingResponse,
|
||||
},
|
||||
server_util::authorization::XMatrix,
|
||||
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
|
||||
};
|
||||
|
@ -23,7 +31,10 @@ use serde::Deserialize;
|
|||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::{Ruma, RumaResponse};
|
||||
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
|
||||
use crate::{
|
||||
service::{appservice::RegistrationInfo, rate_limiting::Target},
|
||||
services, Error, Result,
|
||||
};
|
||||
|
||||
enum Token {
|
||||
Appservice(Box<RegistrationInfo>),
|
||||
|
@ -305,8 +316,51 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
// doesn't work when Conduit is behind proxy
|
||||
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
|
||||
|
||||
let headers = parts.headers;
|
||||
|
||||
let target = if let Some(server) = sender_servername.clone() {
|
||||
Target::Server(server)
|
||||
} else if let Some(appservice) = appservice_info.clone() {
|
||||
if appservice.registration.rate_limited.unwrap_or(true) {
|
||||
Target::Appservice(appservice.registration.id)
|
||||
} else {
|
||||
Target::None
|
||||
}
|
||||
} else if let Some(user) = sender_user.clone() {
|
||||
Target::User(user)
|
||||
} else {
|
||||
let ip = headers
|
||||
.get("X-Forwarded-For")
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
|
||||
.and_then(|ip| IpAddr::from_str(ip).ok());
|
||||
|
||||
if let Some(ip) = ip {
|
||||
Target::Ip(ip)
|
||||
} else {
|
||||
Target::None
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(retry_after) = {
|
||||
services()
|
||||
.rate_limiting
|
||||
.update_or_reject(target, metadata)
|
||||
.map_err(Some)
|
||||
} {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::LimitExceeded {
|
||||
retry_after: retry_after.map(RetryAfter::Delay),
|
||||
},
|
||||
"Rate limit exceeded.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
||||
*http_request.headers_mut().unwrap() = parts.headers;
|
||||
*http_request.headers_mut().unwrap() = headers;
|
||||
|
||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||
|
|
|
@ -2,6 +2,7 @@ use std::{
|
|||
collections::BTreeMap,
|
||||
fmt,
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
num::NonZeroU64,
|
||||
};
|
||||
|
||||
use ruma::{OwnedServerName, RoomVersionId};
|
||||
|
@ -82,6 +83,8 @@ pub struct Config {
|
|||
pub turn_secret: String,
|
||||
#[serde(default = "default_turn_ttl")]
|
||||
pub turn_ttl: u64,
|
||||
#[serde(default = "default_rate_limit")]
|
||||
pub rate_limiting: BTreeMap<Restriction, Limitation>,
|
||||
|
||||
pub emergency_password: Option<String>,
|
||||
|
||||
|
@ -101,6 +104,56 @@ pub struct WellKnownConfig {
|
|||
pub server: Option<OwnedServerName>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Restriction {
|
||||
Registration,
|
||||
Login,
|
||||
RegistrationTokenValidity,
|
||||
Message,
|
||||
Join,
|
||||
Invite,
|
||||
Knock,
|
||||
CreateMedia,
|
||||
Transaction,
|
||||
FederatedJoin,
|
||||
FederatedInvite,
|
||||
FederatedKnock,
|
||||
|
||||
#[default]
|
||||
CatchAll,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Copy, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
// When deserializing, we want this prefix
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
pub enum Timeframe {
|
||||
PerSecond(NonZeroU64),
|
||||
PerMinute(NonZeroU64),
|
||||
PerHour(NonZeroU64),
|
||||
PerDay(NonZeroU64),
|
||||
}
|
||||
|
||||
impl Timeframe {
|
||||
pub fn nano_gap(&self) -> u64 {
|
||||
match self {
|
||||
Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(),
|
||||
Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(),
|
||||
Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(),
|
||||
Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||
pub struct Limitation {
|
||||
#[serde(default = "default_non_zero", flatten)]
|
||||
pub timeframe: Timeframe,
|
||||
#[serde(default = "default_non_zero")]
|
||||
pub burst_capacity: NonZeroU64,
|
||||
}
|
||||
|
||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||
|
||||
impl Config {
|
||||
|
@ -308,6 +361,20 @@ fn default_openid_token_ttl() -> u64 {
|
|||
60 * 60
|
||||
}
|
||||
|
||||
fn default_non_zero() -> NonZeroU64 {
|
||||
NonZeroU64::MIN
|
||||
}
|
||||
|
||||
pub fn default_rate_limit() -> BTreeMap<Restriction, Limitation> {
|
||||
BTreeMap::from_iter([(
|
||||
Restriction::default(),
|
||||
Limitation {
|
||||
timeframe: Timeframe::PerMinute(NonZeroU64::MIN),
|
||||
burst_capacity: NonZeroU64::MIN,
|
||||
},
|
||||
)])
|
||||
}
|
||||
|
||||
// I know, it's a great name
|
||||
pub fn default_default_room_version() -> RoomVersionId {
|
||||
RoomVersionId::V10
|
||||
|
|
|
@ -200,7 +200,9 @@ async fn run_server() -> io::Result<()> {
|
|||
.expect("failed to convert max request size"),
|
||||
));
|
||||
|
||||
let app = routes(config).layer(middlewares).into_make_service();
|
||||
let app = routes(config)
|
||||
.layer(middlewares)
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
let handle = ServerHandle::new();
|
||||
|
||||
tokio::spawn(shutdown_signal(handle.clone()));
|
||||
|
|
|
@ -17,6 +17,7 @@ pub mod key_backups;
|
|||
pub mod media;
|
||||
pub mod pdu;
|
||||
pub mod pusher;
|
||||
pub mod rate_limiting;
|
||||
pub mod rooms;
|
||||
pub mod sending;
|
||||
pub mod transaction_ids;
|
||||
|
@ -26,6 +27,7 @@ pub mod users;
|
|||
pub struct Services {
|
||||
pub appservice: appservice::Service,
|
||||
pub pusher: pusher::Service,
|
||||
pub rate_limiting: rate_limiting::Service,
|
||||
pub rooms: rooms::Service,
|
||||
pub transaction_ids: transaction_ids::Service,
|
||||
pub uiaa: uiaa::Service,
|
||||
|
@ -59,6 +61,7 @@ impl Services {
|
|||
Ok(Self {
|
||||
appservice: appservice::Service::build(db)?,
|
||||
pusher: pusher::Service { db },
|
||||
rate_limiting: rate_limiting::Service::build(),
|
||||
rooms: rooms::Service {
|
||||
alias: rooms::alias::Service { db },
|
||||
auth_chain: rooms::auth_chain::Service { db },
|
||||
|
|
141
src/service/rate_limiting/mod.rs
Normal file
141
src/service/rate_limiting/mod.rs
Normal file
|
@ -0,0 +1,141 @@
|
|||
use std::{
|
||||
hash::Hash,
|
||||
net::IpAddr,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use dashmap::{mapref::entry::Entry, DashMap};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{
|
||||
account::{check_registration_token_validity, register},
|
||||
knock::knock_room,
|
||||
media::{create_content, create_content_async},
|
||||
membership::{invite_user, join_room_by_id, join_room_by_id_or_alias},
|
||||
message::send_message_event,
|
||||
session::login,
|
||||
state::send_state_event,
|
||||
},
|
||||
federation::{
|
||||
knock::send_knock,
|
||||
membership::{create_invite, create_join_event},
|
||||
transactions::send_transaction_message,
|
||||
},
|
||||
IncomingRequest, Metadata,
|
||||
},
|
||||
OwnedServerName, OwnedUserId,
|
||||
};
|
||||
|
||||
use crate::{config::Restriction, services, Result};
|
||||
|
||||
impl From<Metadata> for Restriction {
|
||||
fn from(metadata: Metadata) -> Self {
|
||||
#[allow(deprecated)]
|
||||
match metadata {
|
||||
register::v3::Request::METADATA => Restriction::Registration,
|
||||
login::v3::Request::METADATA => Restriction::Login,
|
||||
check_registration_token_validity::v1::Request::METADATA => {
|
||||
Restriction::RegistrationTokenValidity
|
||||
}
|
||||
send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => {
|
||||
Restriction::Message
|
||||
}
|
||||
join_room_by_id::v3::Request::METADATA
|
||||
| join_room_by_id_or_alias::v3::Request::METADATA => Restriction::Join,
|
||||
invite_user::v3::Request::METADATA => Restriction::Invite,
|
||||
create_content::v3::Request::METADATA | create_content_async::v3::Request::METADATA => {
|
||||
Restriction::CreateMedia
|
||||
}
|
||||
send_transaction_message::v1::Request::METADATA => Restriction::Transaction,
|
||||
create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => {
|
||||
Restriction::FederatedJoin
|
||||
}
|
||||
create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => {
|
||||
Restriction::FederatedInvite
|
||||
}
|
||||
send_knock::v1::Request::METADATA => Restriction::FederatedKnock,
|
||||
knock_room::v3::Request::METADATA => Restriction::Knock,
|
||||
_ => Self::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Service {
|
||||
store: DashMap<(Target, Restriction), (Instant, u64)>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub enum Target {
|
||||
User(OwnedUserId),
|
||||
// Server endpoints should be rate-limited on a server and room basis
|
||||
Server(OwnedServerName),
|
||||
Appservice(String),
|
||||
Ip(IpAddr),
|
||||
None,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn build() -> Self {
|
||||
Self {
|
||||
store: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
|
||||
if target == Target::None {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let restriction = metadata.into();
|
||||
|
||||
let arrival = Instant::now();
|
||||
|
||||
let config = &services().globals.config.rate_limiting;
|
||||
|
||||
let Some(limit) = config.get(&restriction) else {
|
||||
return Ok(());
|
||||
};
|
||||
// .unwrap_or(Limitation {
|
||||
// per_minute: NonZeroU64::new(1).unwrap(),
|
||||
// burst_capacity: NonZeroU64::new(1).unwrap(),
|
||||
// weight: NonZeroU64::new(1).unwrap(),
|
||||
// });
|
||||
|
||||
let gap = Duration::from_nanos(limit.timeframe.nano_gap());
|
||||
let key = (target, restriction);
|
||||
|
||||
match self.store.entry(key) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let (instant, capacity) = entry.get_mut();
|
||||
|
||||
while *instant < arrival && *capacity != 0 {
|
||||
*capacity -= 1;
|
||||
*instant += gap;
|
||||
}
|
||||
|
||||
if *capacity >= limit.burst_capacity.get() {
|
||||
return Err(gap);
|
||||
} else {
|
||||
let zero_capacity = *capacity == 0;
|
||||
*capacity += 1;
|
||||
|
||||
// Ensures that the update point is in the future
|
||||
if zero_capacity {
|
||||
*instant = Instant::now()
|
||||
}
|
||||
|
||||
*instant += gap;
|
||||
}
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert((Instant::now() + gap, 1));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
///// In-memory state and utility functions used to check whether the client has exceeded its rate limit.
|
||||
///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible.
|
||||
/////
|
Loading…
Add table
Add a link
Reference in a new issue