1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00
This commit is contained in:
mikoto 2024-06-10 07:55:59 +02:00 committed by Matthias Ahouansou
parent 1f313c6807
commit 02cea0bb93
No known key found for this signature in database
7 changed files with 227 additions and 3 deletions

View file

@ -34,6 +34,7 @@ axum = { version = "0.7", default-features = false, features = [
"http2",
"json",
"matched-path",
"tokio",
], optional = true }
axum-extra = { version = "0.9", features = ["typed-header"] }
axum-server = { version = "0.6", features = ["tls-rustls"] }
@ -146,6 +147,8 @@ tikv-jemallocator = { version = "0.5.0", features = [
], optional = true }
sd-notify = { version = "0.4.1", optional = true }
dashmap = "5.5.3"
quanta = "0.12.3"
# Used for matrix spec type definitions and helpers
[dependencies.ruma]

View file

@ -1,4 +1,5 @@
pub mod appservice_server;
pub mod client_server;
pub mod rate_limiting;
pub mod ruma_wrapper;
pub mod server_server;

View file

@ -1,4 +1,8 @@
use std::{collections::BTreeMap, iter::FromIterator, str};
use std::{
collections::BTreeMap,
iter::FromIterator,
str::{self},
};
use axum::{
async_trait,
@ -23,7 +27,10 @@ use serde::Deserialize;
use tracing::{debug, error, warn};
use super::{Ruma, RumaResponse};
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
use crate::{
service::{appservice::RegistrationInfo, rate_limiting::Target},
services, Error, Result,
};
enum Token {
Appservice(Box<RegistrationInfo>),
@ -95,6 +102,44 @@ where
Token::None
};
// doesn't work when Conduit is behind proxy
// let remote_addr: ConnectInfo<SocketAddr> = parts.extract().await?;
let target = match &token {
Token::User((user_id, _)) => Some(Target::User(user_id.clone())),
Token::None => {
let header = parts
.headers
.get("x-forwarded-for")
.ok_or_else(|| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
let s = header
.to_str()
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting."))?;
Some(
s.parse()
.map(Target::Ip)
.map_err(|_| Error::BadRequest(ErrorKind::Unauthorized, "Rate limiting.")),
)
.transpose()?
}
_ => None,
};
if let Err(retry_after_ms) = target.map_or(Ok(()), |t| {
let key = (t, (&metadata).into());
services()
.rate_limiting
.update_or_reject(&key)
.map_err(Some)
}) {
return Err(Error::BadRequest(
ErrorKind::LimitExceeded { retry_after_ms },
"Rate limit exceeded.",
));
}
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let (sender_user, sender_device, sender_servername, appservice_info) =

View file

@ -2,6 +2,7 @@ use std::{
collections::BTreeMap,
fmt,
net::{IpAddr, Ipv4Addr},
num::NonZeroU64,
};
use ruma::{OwnedServerName, RoomVersionId};
@ -82,6 +83,8 @@ pub struct Config {
pub turn_secret: String,
#[serde(default = "default_turn_ttl")]
pub turn_ttl: u64,
#[serde(default = "default_rate_limit")]
pub rate_limiting: BTreeMap<Restriction, Limitation>,
pub emergency_password: Option<String>,
@ -101,6 +104,27 @@ pub struct WellKnownConfig {
pub server: Option<OwnedServerName>,
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[serde(rename_all = "lowercase")]
pub enum Restriction {
Registration,
Login,
#[default]
#[serde(rename = "")]
CatchAll,
}
#[derive(Clone, Copy, Debug, Deserialize)]
pub struct Limitation {
#[serde(default = "default_non_zero")]
pub per_minute: NonZeroU64,
#[serde(default = "default_non_zero")]
pub burst_capacity: NonZeroU64,
#[serde(default = "default_non_zero")]
pub weight: NonZeroU64,
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config {
@ -308,6 +332,21 @@ fn default_openid_token_ttl() -> u64 {
60 * 60
}
fn default_non_zero() -> NonZeroU64 {
NonZeroU64::new(1).unwrap()
}
pub fn default_rate_limit() -> BTreeMap<Restriction, Limitation> {
BTreeMap::from_iter([(
Restriction::default(),
Limitation {
per_minute: default_non_zero(),
burst_capacity: default_non_zero(),
weight: default_non_zero(),
},
)])
}
// I know, it's a great name
pub fn default_default_room_version() -> RoomVersionId {
RoomVersionId::V10

View file

@ -200,7 +200,7 @@ async fn run_server() -> io::Result<()> {
.expect("failed to convert max request size"),
));
let app = routes(config).layer(middlewares).into_make_service();
let app = routes(config).layer(middlewares).into_make_service_with_connect_info::<SocketAddr>();
let handle = ServerHandle::new();
tokio::spawn(shutdown_signal(handle.clone()));

View file

@ -17,6 +17,7 @@ pub mod key_backups;
pub mod media;
pub mod pdu;
pub mod pusher;
pub mod rate_limiting;
pub mod rooms;
pub mod sending;
pub mod transaction_ids;
@ -26,6 +27,7 @@ pub mod users;
pub struct Services {
pub appservice: appservice::Service,
pub pusher: pusher::Service,
pub rate_limiting: rate_limiting::Service,
pub rooms: rooms::Service,
pub transaction_ids: transaction_ids::Service,
pub uiaa: uiaa::Service,
@ -59,6 +61,7 @@ impl Services {
Ok(Self {
appservice: appservice::Service::build(db)?,
pusher: pusher::Service { db },
rate_limiting: rate_limiting::Service::build(),
rooms: rooms::Service {
alias: rooms::alias::Service { db },
auth_chain: rooms::auth_chain::Service { db },

View file

@ -0,0 +1,133 @@
use std::{
hash::Hash,
net::IpAddr,
num::NonZeroU64,
sync::atomic::{AtomicU64, Ordering},
time::Duration,
};
use dashmap::DashMap;
use quanta::Clock;
use ruma::{
api::{
client::{account::register, session::login},
IncomingRequest, Metadata,
},
OwnedUserId,
};
use crate::{
config::{Limitation, Restriction},
services, Result,
};
impl From<&Metadata> for Restriction {
fn from(metadata: &Metadata) -> Self {
[
(register::v3::Request::METADATA, Restriction::Registration),
(login::v3::Request::METADATA, Restriction::Login),
]
.iter()
.find(|(other, _)| {
metadata
.history
.stable_paths()
.zip(other.history.stable_paths())
.all(|(a, b)| a == b)
})
.map(|(_, restriction)| restriction.to_owned())
.unwrap_or_default()
}
}
pub struct Service {
store: DashMap<(Target, Restriction), AtomicU64>,
clock: Clock,
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum Target {
User(OwnedUserId),
Ip(IpAddr),
}
impl Service {
pub fn build() -> Self {
Self {
store: DashMap::new(),
clock: Clock::new(),
}
}
pub fn update_or_reject(&self, key: &(Target, Restriction)) -> Result<(), Duration> {
let arrival = self.clock.delta_as_nanos(0, self.clock.raw());
let config = &services().globals.config.rate_limiting;
let Some(limit) = config
.get(&key.1)
.map(ToOwned::to_owned) else {
return Ok(());
};
// .unwrap_or(Limitation {
// per_minute: NonZeroU64::new(1).unwrap(),
// burst_capacity: NonZeroU64::new(1).unwrap(),
// weight: NonZeroU64::new(1).unwrap(),
// });
tracing::info!(?limit);
let increment = u64::try_from(Duration::from_secs(60).as_nanos())
.expect("1_000_000_000 to be smaller than u64::MAX")
/ limit.per_minute.get()
* limit.weight.get();
tracing::info!(?increment);
let mut prev_expectation = self
.store
.get(key)
.as_deref()
.map(|n| n.load(Ordering::Acquire))
.unwrap_or_else(|| arrival + increment);
let weight = (increment * limit.burst_capacity.get()).max(1);
tracing::info!(?prev_expectation);
tracing::info!(?weight);
let f = |prev_expectation: u64| {
let allowed = prev_expectation.saturating_sub(weight);
if arrival < allowed {
Err(Duration::from_nanos(allowed - arrival))
} else {
Ok(prev_expectation.max(arrival) + increment)
}
};
let mut decision = f(prev_expectation);
tracing::info!(?decision);
while let Ok(next_expectation) = decision {
let entry = self.store.entry(key.clone());
match entry.or_default().compare_exchange_weak(
prev_expectation,
next_expectation,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return Ok(()),
Err(actual) => prev_expectation = actual,
}
decision = f(prev_expectation);
}
decision.map(|_| ())
}
}
///// In-memory state and utility functions used to check whether the client has exceeded its rate limit.
///// This leverages the generic cell rate algorithm, making the required checks as cheap as possible.
/////