mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-06-27 16:35:59 +00:00
PoC
This commit is contained in:
parent
1f313c6807
commit
02cea0bb93
7 changed files with 227 additions and 3 deletions
|
@ -34,6 +34,7 @@ axum = { version = "0.7", default-features = false, features = [
|
||||||
"http2",
|
"http2",
|
||||||
"json",
|
"json",
|
||||||
"matched-path",
|
"matched-path",
|
||||||
|
"tokio",
|
||||||
], optional = true }
|
], optional = true }
|
||||||
axum-extra = { version = "0.9", features = ["typed-header"] }
|
axum-extra = { version = "0.9", features = ["typed-header"] }
|
||||||
axum-server = { version = "0.6", features = ["tls-rustls"] }
|
axum-server = { version = "0.6", features = ["tls-rustls"] }
|
||||||
|
@ -146,6 +147,8 @@ tikv-jemallocator = { version = "0.5.0", features = [
|
||||||
], optional = true }
|
], optional = true }
|
||||||
|
|
||||||
sd-notify = { version = "0.4.1", optional = true }
|
sd-notify = { version = "0.4.1", optional = true }
|
||||||
|
dashmap = "5.5.3"
|
||||||
|
quanta = "0.12.3"
|
||||||
|
|
||||||
# Used for matrix spec type definitions and helpers
|
# Used for matrix spec type definitions and helpers
|
||||||
[dependencies.ruma]
|
[dependencies.ruma]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
pub mod appservice_server;
|
pub mod appservice_server;
|
||||||
pub mod client_server;
|
pub mod client_server;
|
||||||
|
pub mod rate_limiting;
|
||||||
pub mod ruma_wrapper;
|
pub mod ruma_wrapper;
|
||||||
pub mod server_server;
|
pub mod server_server;
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
use std::{collections::BTreeMap, iter::FromIterator, str};
|
use std::{
|
||||||
|
collections::BTreeMap,
|
||||||
|
iter::FromIterator,
|
||||||
|
str::{self},
|
||||||
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait,
|
async_trait,
|
||||||
|
@ -23,7 +27,10 @@ use serde::Deserialize;
|
||||||
use tracing::{debug, error, warn};
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
use super::{Ruma, RumaResponse};
|
use super::{Ruma, RumaResponse};
|
||||||
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
|
use crate::{
|
||||||
|
service::{appservice::RegistrationInfo, rate_limiting::Target},
|
||||||
|
services, Error, Result,
|
||||||
|
};
|
||||||
|
|
||||||
enum Token {
|
enum Token {
|
||||||
Appservice(Box<RegistrationInfo>),
|
Appservice(Box<RegistrationInfo>),
|
||||||
|
@ -95,6 +102,44 @@ where
|
||||||
Token::None
|
Token::None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// doesn't work when Conduit is behind proxy
|
||||||
|
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
|
||||||
|
|
||||||
|
let target = match &token {
|
||||||
|
Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
|
||||||
|
Token::None => {
|
||||||
|
let header = parts
|
||||||
|
.headers
|
||||||
|
.get("x-forwarded-for")
|
||||||
|
.ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||||
|
|
||||||
|
let s = header
|
||||||
|
.to_str()
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
|
||||||
|
Some(
|
||||||
|
s.parse()
|
||||||
|
.map(Target::Ip)
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
|
||||||
|
)
|
||||||
|
.transpose()?
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(retry_after_ms) = target.map_or(Ok(()), |t| {
|
||||||
|
let key = (t, (&metadata).into());
|
||||||
|
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_or_reject(&key)
|
||||||
|
.map_err(Some)
|
||||||
|
}) {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::LimitExceeded { retry_after_ms },
|
||||||
|
"Rate limit exceeded.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||||
|
|
||||||
let (sender_user, sender_device, sender_servername, appservice_info) =
|
let (sender_user, sender_device, sender_servername, appservice_info) =
|
||||||
|
|
|
@ -2,6 +2,7 @@ use std::{
|
||||||
collections::BTreeMap,
|
collections::BTreeMap,
|
||||||
fmt,
|
fmt,
|
||||||
net::{IpAddr, Ipv4Addr},
|
net::{IpAddr, Ipv4Addr},
|
||||||
|
num::NonZeroU64,
|
||||||
};
|
};
|
||||||
|
|
||||||
use ruma::{OwnedServerName, RoomVersionId};
|
use ruma::{OwnedServerName, RoomVersionId};
|
||||||
|
@ -82,6 +83,8 @@ pub struct Config {
|
||||||
pub turn_secret: String,
|
pub turn_secret: String,
|
||||||
#[serde(default = "default_turn_ttl")]
|
#[serde(default = "default_turn_ttl")]
|
||||||
pub turn_ttl: u64,
|
pub turn_ttl: u64,
|
||||||
|
#[serde(default = "default_rate_limit")]
|
||||||
|
pub rate_limiting: BTreeMap<Restriction, Limitation>,
|
||||||
|
|
||||||
pub emergency_password: Option<String>,
|
pub emergency_password: Option<String>,
|
||||||
|
|
||||||
|
@ -101,6 +104,27 @@ pub struct WellKnownConfig {
|
||||||
pub server: Option<OwnedServerName>,
|
pub server: Option<OwnedServerName>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Restriction {
|
||||||
|
Registration,
|
||||||
|
Login,
|
||||||
|
|
||||||
|
#[default]
|
||||||
|
#[serde(rename = "")]
|
||||||
|
CatchAll,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||||
|
pub struct Limitation {
|
||||||
|
#[serde(default = "default_non_zero")]
|
||||||
|
pub per_minute: NonZeroU64,
|
||||||
|
#[serde(default = "default_non_zero")]
|
||||||
|
pub burst_capacity: NonZeroU64,
|
||||||
|
#[serde(default = "default_non_zero")]
|
||||||
|
pub weight: NonZeroU64,
|
||||||
|
}
|
||||||
|
|
||||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
@ -308,6 +332,21 @@ fn default_openid_token_ttl() -> u64 {
|
||||||
60 * 60
|
60 * 60
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_non_zero() -> NonZeroU64 {
|
||||||
|
NonZeroU64::new(1).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_rate_limit() -> BTreeMap<Restriction, Limitation> {
|
||||||
|
BTreeMap::from_iter([(
|
||||||
|
Restriction::default(),
|
||||||
|
Limitation {
|
||||||
|
per_minute: default_non_zero(),
|
||||||
|
burst_capacity: default_non_zero(),
|
||||||
|
weight: default_non_zero(),
|
||||||
|
},
|
||||||
|
)])
|
||||||
|
}
|
||||||
|
|
||||||
// I know, it's a great name
|
// I know, it's a great name
|
||||||
pub fn default_default_room_version() -> RoomVersionId {
|
pub fn default_default_room_version() -> RoomVersionId {
|
||||||
RoomVersionId::V10
|
RoomVersionId::V10
|
||||||
|
|
|
@ -200,7 +200,7 @@ async fn run_server() -> io::Result<()> {
|
||||||
.expect("failed to convert max request size"),
|
.expect("failed to convert max request size"),
|
||||||
));
|
));
|
||||||
|
|
||||||
let app = routes(config).layer(middlewares).into_make_service();
|
let app = routes(config).layer(middlewares).into_make_service_with_connect_info::<SocketAddr>();
|
||||||
let handle = ServerHandle::new();
|
let handle = ServerHandle::new();
|
||||||
|
|
||||||
tokio::spawn(shutdown_signal(handle.clone()));
|
tokio::spawn(shutdown_signal(handle.clone()));
|
||||||
|
|
|
@ -17,6 +17,7 @@ pub mod key_backups;
|
||||||
pub mod media;
|
pub mod media;
|
||||||
pub mod pdu;
|
pub mod pdu;
|
||||||
pub mod pusher;
|
pub mod pusher;
|
||||||
|
pub mod rate_limiting;
|
||||||
pub mod rooms;
|
pub mod rooms;
|
||||||
pub mod sending;
|
pub mod sending;
|
||||||
pub mod transaction_ids;
|
pub mod transaction_ids;
|
||||||
|
@ -26,6 +27,7 @@ pub mod users;
|
||||||
pub struct Services {
|
pub struct Services {
|
||||||
pub appservice: appservice::Service,
|
pub appservice: appservice::Service,
|
||||||
pub pusher: pusher::Service,
|
pub pusher: pusher::Service,
|
||||||
|
pub rate_limiting: rate_limiting::Service,
|
||||||
pub rooms: rooms::Service,
|
pub rooms: rooms::Service,
|
||||||
pub transaction_ids: transaction_ids::Service,
|
pub transaction_ids: transaction_ids::Service,
|
||||||
pub uiaa: uiaa::Service,
|
pub uiaa: uiaa::Service,
|
||||||
|
@ -59,6 +61,7 @@ impl Services {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
appservice: appservice::Service::build(db)?,
|
appservice: appservice::Service::build(db)?,
|
||||||
pusher: pusher::Service { db },
|
pusher: pusher::Service { db },
|
||||||
|
rate_limiting: rate_limiting::Service::build(),
|
||||||
rooms: rooms::Service {
|
rooms: rooms::Service {
|
||||||
alias: rooms::alias::Service { db },
|
alias: rooms::alias::Service { db },
|
||||||
auth_chain: rooms::auth_chain::Service { db },
|
auth_chain: rooms::auth_chain::Service { db },
|
||||||
|
|
133
src/service/rate_limiting/mod.rs
Normal file
133
src/service/rate_limiting/mod.rs
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
use std::{
|
||||||
|
hash::Hash,
|
||||||
|
net::IpAddr,
|
||||||
|
num::NonZeroU64,
|
||||||
|
sync::atomic::{AtomicU64, Ordering},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use quanta::Clock;
|
||||||
|
use ruma::{
|
||||||
|
api::{
|
||||||
|
client::{account::register, session::login},
|
||||||
|
IncomingRequest, Metadata,
|
||||||
|
},
|
||||||
|
OwnedUserId,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
config::{Limitation, Restriction},
|
||||||
|
services, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
impl From<&Metadata> for Restriction {
|
||||||
|
fn from(metadata: &Metadata) -> Self {
|
||||||
|
[
|
||||||
|
(register::v3::Request::METADATA, Restriction::Registration),
|
||||||
|
(login::v3::Request::METADATA, Restriction::Login),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.find(|(other, _)| {
|
||||||
|
metadata
|
||||||
|
.history
|
||||||
|
.stable_paths()
|
||||||
|
.zip(other.history.stable_paths())
|
||||||
|
.all(|(a, b)| a == b)
|
||||||
|
})
|
||||||
|
.map(|(_, restriction)| restriction.to_owned())
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Service {
|
||||||
|
store: DashMap<(Target, Restriction), AtomicU64>,
|
||||||
|
clock: Clock,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub enum Target {
|
||||||
|
User(OwnedUserId),
|
||||||
|
Ip(IpAddr),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service {
|
||||||
|
pub fn build() -> Self {
|
||||||
|
Self {
|
||||||
|
store: DashMap::new(),
|
||||||
|
clock: Clock::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> {
|
||||||
|
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
|
||||||
|
|
||||||
|
let config = &services().globals.config.rate_limiting;
|
||||||
|
|
||||||
|
let Some(limit) = config
|
||||||
|
.get(&key.1)
|
||||||
|
.map(ToOwned::to_owned) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
// .unwrap_or(Limitation {
|
||||||
|
// per_minute: NonZeroU64::new(1).unwrap(),
|
||||||
|
// burst_capacity: NonZeroU64::new(1).unwrap(),
|
||||||
|
// weight: NonZeroU64::new(1).unwrap(),
|
||||||
|
// });
|
||||||
|
|
||||||
|
tracing::info!(?limit);
|
||||||
|
|
||||||
|
let increment = u64::try_from(Duration::from_secs(60).as_nanos())
|
||||||
|
.expect("1_000_000_000 to be smaller than u64::MAX")
|
||||||
|
/ limit.per_minute.get()
|
||||||
|
* limit.weight.get();
|
||||||
|
tracing::info!(?increment);
|
||||||
|
|
||||||
|
let mut prev_expectation = self
|
||||||
|
.store
|
||||||
|
.get(key)
|
||||||
|
.as_deref()
|
||||||
|
.map(|n| n.load(Ordering::Acquire))
|
||||||
|
.unwrap_or_else(|| arrival + increment);
|
||||||
|
let weight = (increment * limit.burst_capacity.get()).max(1);
|
||||||
|
|
||||||
|
tracing::info!(?prev_expectation);
|
||||||
|
tracing::info!(?weight);
|
||||||
|
|
||||||
|
let f = |prev_expectation: u64| {
|
||||||
|
let allowed = prev_expectation.saturating_sub(weight);
|
||||||
|
|
||||||
|
if arrival < allowed {
|
||||||
|
Err(Duration::from_nanos(allowed - arrival))
|
||||||
|
} else {
|
||||||
|
Ok(prev_expectation.max(arrival) + increment)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut decision = f(prev_expectation);
|
||||||
|
|
||||||
|
tracing::info!(?decision);
|
||||||
|
|
||||||
|
while let Ok(next_expectation) = decision {
|
||||||
|
let entry = self.store.entry(key.clone());
|
||||||
|
|
||||||
|
match entry.or_default().compare_exchange_weak(
|
||||||
|
prev_expectation,
|
||||||
|
next_expectation,
|
||||||
|
Ordering::Release,
|
||||||
|
Ordering::Relaxed,
|
||||||
|
) {
|
||||||
|
Ok(_) => return Ok(()),
|
||||||
|
Err(actual) => prev_expectation = actual,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision = f(prev_expectation);
|
||||||
|
}
|
||||||
|
|
||||||
|
decision.map(|_| ())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///// In-memory state and utility functions used to check whether the client has exceeded its rate limit.
|
||||||
|
///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible.
|
||||||
|
/////
|
Loading…
Add table
Add a link
Reference in a new issue