1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00

Merge branch 'rate-limiting' into 'next'

Draft: feat: rate limiting

Closes #4

See merge request famedly/conduit!693
This commit is contained in:
avdb 2024-07-20 21:04:46 +00:00
commit e0054552ea
7 changed files with 303 additions and 18 deletions

42
Cargo.lock generated
View file

@ -189,6 +189,7 @@ dependencies = [
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper 1.0.1",
"tokio",
"tower",
"tower-layer",
"tower-service",
@ -496,6 +497,7 @@ dependencies = [
"base64 0.22.1",
"bytes",
"clap",
"dashmap",
"directories",
"figment",
"futures-util",
@ -664,6 +666,19 @@ dependencies = [
"syn",
]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.6.0"
@ -2232,7 +2247,7 @@ dependencies = [
[[package]]
name = "ruma"
version = "0.10.1"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"assign",
"js_int",
@ -2253,7 +2268,7 @@ dependencies = [
[[package]]
name = "ruma-appservice-api"
version = "0.10.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"js_int",
"ruma-common",
@ -2265,7 +2280,7 @@ dependencies = [
[[package]]
name = "ruma-client-api"
version = "0.18.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"as_variant",
"assign",
@ -2288,7 +2303,7 @@ dependencies = [
[[package]]
name = "ruma-common"
version = "0.13.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"as_variant",
"base64 0.22.1",
@ -2318,7 +2333,7 @@ dependencies = [
[[package]]
name = "ruma-events"
version = "0.28.1"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"as_variant",
"indexmap 2.2.6",
@ -2334,13 +2349,14 @@ dependencies = [
"thiserror",
"tracing",
"url",
"web-time",
"wildmatch",
]
[[package]]
name = "ruma-federation-api"
version = "0.9.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"js_int",
"ruma-common",
@ -2352,7 +2368,7 @@ dependencies = [
[[package]]
name = "ruma-identifiers-validation"
version = "0.9.5"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"js_int",
"thiserror",
@ -2361,7 +2377,7 @@ dependencies = [
[[package]]
name = "ruma-identity-service-api"
version = "0.9.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"js_int",
"ruma-common",
@ -2371,7 +2387,7 @@ dependencies = [
[[package]]
name = "ruma-macros"
version = "0.13.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"once_cell",
"proc-macro-crate",
@ -2386,7 +2402,7 @@ dependencies = [
[[package]]
name = "ruma-push-gateway-api"
version = "0.9.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"js_int",
"ruma-common",
@ -2398,7 +2414,7 @@ dependencies = [
[[package]]
name = "ruma-server-util"
version = "0.3.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"headers",
"http 1.1.0",
@ -2411,7 +2427,7 @@ dependencies = [
[[package]]
name = "ruma-signatures"
version = "0.15.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"base64 0.22.1",
"ed25519-dalek",
@ -2427,7 +2443,7 @@ dependencies = [
[[package]]
name = "ruma-state-res"
version = "0.11.0"
source = "git+https://github.com/ruma/ruma#fec2152d879a6c6c2bccce984d4b8f424f460cb2"
source = "git+https://github.com/ruma/ruma#50a46cc5f658fd1cef5bdae6f08db292c3135366"
dependencies = [
"itertools",
"js_int",

View file

@ -34,6 +34,7 @@ axum = { version = "0.7", default-features = false, features = [
"http2",
"json",
"matched-path",
"tokio",
], optional = true }
axum-extra = { version = "0.9", features = ["typed-header"] }
axum-server = { version = "0.6", features = ["tls-rustls"] }
@ -145,6 +146,7 @@ tikv-jemallocator = { version = "0.5.0", features = [
"unprefixed_malloc_on_supported_platforms",
], optional = true }
dashmap = "5.5.3"
sd-notify = { version = "0.4.1", optional = true }
# Used for matrix spec type definitions and helpers

View file

@ -1,4 +1,9 @@
use std::{collections::BTreeMap, iter::FromIterator, str};
use std::{
collections::BTreeMap,
iter::FromIterator,
net::IpAddr,
str::{self, FromStr},
};
use axum::{
async_trait,
@ -15,7 +20,10 @@ use axum_extra::{
use bytes::{BufMut, BytesMut};
use http::{Request, StatusCode};
use ruma::{
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
api::{
client::error::{ErrorKind, RetryAfter},
AuthScheme, IncomingRequest, OutgoingResponse,
},
server_util::authorization::XMatrix,
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
};
@ -23,7 +31,10 @@ use serde::Deserialize;
use tracing::{debug, error, warn};
use super::{Ruma, RumaResponse};
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
use crate::{
service::{appservice::RegistrationInfo, rate_limiting::Target},
services, Error, Result,
};
enum Token {
Appservice(Box<RegistrationInfo>),
@ -305,8 +316,51 @@ where
}
};
// doesn't work when Conduit is behind proxy
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
let headers = parts.headers;
let target = if let Some(server) = sender_servername.clone() {
Target::Server(server)
} else if let Some(appservice) = appservice_info.clone() {
if appservice.registration.rate_limited.unwrap_or(true) {
Target::Appservice(appservice.registration.id)
} else {
Target::None
}
} else if let Some(user) = sender_user.clone() {
Target::User(user)
} else {
let ip = headers
.get("X-Forwarded-For")
.and_then(|header| header.to_str().ok())
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
.and_then(|ip| IpAddr::from_str(ip).ok());
if let Some(ip) = ip {
Target::Ip(ip)
} else {
Target::None
}
};
if let Err(retry_after) = {
services()
.rate_limiting
.update_or_reject(target, metadata)
.map_err(Some)
} {
return Err(Error::BadRequest(
ErrorKind::LimitExceeded {
retry_after: retry_after.map(RetryAfter::Delay),
},
"Rate limit exceeded.",
));
}
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
*http_request.headers_mut().unwrap() = parts.headers;
*http_request.headers_mut().unwrap() = headers;
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
let user_id = sender_user.clone().unwrap_or_else(|| {

View file

@ -2,6 +2,7 @@ use std::{
collections::BTreeMap,
fmt,
net::{IpAddr, Ipv4Addr},
num::NonZeroU64,
};
use ruma::{OwnedServerName, RoomVersionId};
@ -82,6 +83,8 @@ pub struct Config {
pub turn_secret: String,
#[serde(default = "default_turn_ttl")]
pub turn_ttl: u64,
#[serde(default = "default_rate_limit")]
pub rate_limiting: BTreeMap<Restriction, Limitation>,
pub emergency_password: Option<String>,
@ -101,6 +104,56 @@ pub struct WellKnownConfig {
pub server: Option<OwnedServerName>,
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[serde(rename_all = "snake_case")]
pub enum Restriction {
Registration,
Login,
RegistrationTokenValidity,
Message,
Join,
Invite,
Knock,
CreateMedia,
Transaction,
FederatedJoin,
FederatedInvite,
FederatedKnock,
#[default]
CatchAll,
}
#[derive(Deserialize, Clone, Copy, Debug)]
#[serde(rename_all = "snake_case")]
// When deserializing, we want this prefix
#[allow(clippy::enum_variant_names)]
pub enum Timeframe {
PerSecond(NonZeroU64),
PerMinute(NonZeroU64),
PerHour(NonZeroU64),
PerDay(NonZeroU64),
}
impl Timeframe {
pub fn nano_gap(&self) -> u64 {
match self {
Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(),
Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(),
Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(),
Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(),
}
}
}
#[derive(Clone, Copy, Debug, Deserialize)]
pub struct Limitation {
#[serde(default = "default_non_zero", flatten)]
pub timeframe: Timeframe,
#[serde(default = "default_non_zero")]
pub burst_capacity: NonZeroU64,
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config {
@ -308,6 +361,20 @@ fn default_openid_token_ttl() -> u64 {
60 * 60
}
fn default_non_zero() -> NonZeroU64 {
NonZeroU64::MIN
}
pub fn default_rate_limit() -> BTreeMap<Restriction, Limitation> {
BTreeMap::from_iter([(
Restriction::default(),
Limitation {
timeframe: Timeframe::PerMinute(NonZeroU64::MIN),
burst_capacity: NonZeroU64::MIN,
},
)])
}
// I know, it's a great name
pub fn default_default_room_version() -> RoomVersionId {
RoomVersionId::V10

View file

@ -200,7 +200,9 @@ async fn run_server() -> io::Result<()> {
.expect("failed to convert max request size"),
));
let app = routes(config).layer(middlewares).into_make_service();
let app = routes(config)
.layer(middlewares)
.into_make_service_with_connect_info::<SocketAddr>();
let handle = ServerHandle::new();
tokio::spawn(shutdown_signal(handle.clone()));

View file

@ -17,6 +17,7 @@ pub mod key_backups;
pub mod media;
pub mod pdu;
pub mod pusher;
pub mod rate_limiting;
pub mod rooms;
pub mod sending;
pub mod transaction_ids;
@ -26,6 +27,7 @@ pub mod users;
pub struct Services {
pub appservice: appservice::Service,
pub pusher: pusher::Service,
pub rate_limiting: rate_limiting::Service,
pub rooms: rooms::Service,
pub transaction_ids: transaction_ids::Service,
pub uiaa: uiaa::Service,
@ -59,6 +61,7 @@ impl Services {
Ok(Self {
appservice: appservice::Service::build(db)?,
pusher: pusher::Service { db },
rate_limiting: rate_limiting::Service::build(),
rooms: rooms::Service {
alias: rooms::alias::Service { db },
auth_chain: rooms::auth_chain::Service { db },

View file

@ -0,0 +1,141 @@
use std::{
hash::Hash,
net::IpAddr,
time::{Duration, Instant},
};
use dashmap::{mapref::entry::Entry, DashMap};
use ruma::{
api::{
client::{
account::{check_registration_token_validity, register},
knock::knock_room,
media::{create_content, create_content_async},
membership::{invite_user, join_room_by_id, join_room_by_id_or_alias},
message::send_message_event,
session::login,
state::send_state_event,
},
federation::{
knock::send_knock,
membership::{create_invite, create_join_event},
transactions::send_transaction_message,
},
IncomingRequest, Metadata,
},
OwnedServerName, OwnedUserId,
};
use crate::{config::Restriction, services, Result};
impl From<Metadata> for Restriction {
fn from(metadata: Metadata) -> Self {
#[allow(deprecated)]
match metadata {
register::v3::Request::METADATA => Restriction::Registration,
login::v3::Request::METADATA => Restriction::Login,
check_registration_token_validity::v1::Request::METADATA => {
Restriction::RegistrationTokenValidity
}
send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => {
Restriction::Message
}
join_room_by_id::v3::Request::METADATA
| join_room_by_id_or_alias::v3::Request::METADATA => Restriction::Join,
invite_user::v3::Request::METADATA => Restriction::Invite,
create_content::v3::Request::METADATA | create_content_async::v3::Request::METADATA => {
Restriction::CreateMedia
}
send_transaction_message::v1::Request::METADATA => Restriction::Transaction,
create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => {
Restriction::FederatedJoin
}
create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => {
Restriction::FederatedInvite
}
send_knock::v1::Request::METADATA => Restriction::FederatedKnock,
knock_room::v3::Request::METADATA => Restriction::Knock,
_ => Self::default(),
}
}
}
pub struct Service {
store: DashMap<(Target, Restriction), (Instant, u64)>,
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum Target {
User(OwnedUserId),
// Server endpoints should be rate-limited on a server and room basis
Server(OwnedServerName),
Appservice(String),
Ip(IpAddr),
None,
}
impl Service {
pub fn build() -> Self {
Self {
store: DashMap::new(),
}
}
pub fn update_or_reject(&self, target: Target, metadata: Metadata) -> Result<(), Duration> {
if target == Target::None {
return Ok(());
}
let restriction = metadata.into();
let arrival = Instant::now();
let config = &services().globals.config.rate_limiting;
let Some(limit) = config.get(&restriction) else {
return Ok(());
};
// .unwrap_or(Limitation {
// per_minute: NonZeroU64::new(1).unwrap(),
// burst_capacity: NonZeroU64::new(1).unwrap(),
// weight: NonZeroU64::new(1).unwrap(),
// });
let gap = Duration::from_nanos(limit.timeframe.nano_gap());
let key = (target, restriction);
match self.store.entry(key) {
Entry::Occupied(mut entry) => {
let (instant, capacity) = entry.get_mut();
while *instant < arrival && *capacity != 0 {
*capacity -= 1;
*instant += gap;
}
if *capacity >= limit.burst_capacity.get() {
return Err(gap);
} else {
let zero_capacity = *capacity == 0;
*capacity += 1;
// Ensures that the update point is in the future
if zero_capacity {
*instant = Instant::now()
}
*instant += gap;
}
}
Entry::Vacant(entry) => {
entry.insert((Instant::now() + gap, 1));
}
};
Ok(())
}
}
///// In-memory state and utility functions used to check whether the client has exceeded its rate limit.
///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible.
/////