mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-10-15 19:42:07 +00:00
WIP: rate-limiting
This commit is contained in:
parent
e757a98e10
commit
08319f011f
8 changed files with 686 additions and 12 deletions
|
@ -3,7 +3,13 @@
|
||||||
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma};
|
use crate::{
|
||||||
|
service::{
|
||||||
|
media::{size, FileMeta},
|
||||||
|
rate_limiting::Target,
|
||||||
|
},
|
||||||
|
services, utils, Error, Result, Ruma,
|
||||||
|
};
|
||||||
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE};
|
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{
|
api::{
|
||||||
|
@ -54,6 +60,8 @@ pub async fn get_media_config_auth_route(
|
||||||
pub async fn create_content_route(
|
pub async fn create_content_route(
|
||||||
body: Ruma<create_content::v3::Request>,
|
body: Ruma<create_content::v3::Request>,
|
||||||
) -> Result<create_content::v3::Response> {
|
) -> Result<create_content::v3::Response> {
|
||||||
|
let sender_user = body.sender_user.expect("user is authenticated");
|
||||||
|
|
||||||
let create_content::v3::Request {
|
let create_content::v3::Request {
|
||||||
filename,
|
filename,
|
||||||
content_type,
|
content_type,
|
||||||
|
@ -61,6 +69,13 @@ pub async fn create_content_route(
|
||||||
..
|
..
|
||||||
} = body.body;
|
} = body.body;
|
||||||
|
|
||||||
|
let target = Target::from_client_request(body.appservice_info, &sender_user);
|
||||||
|
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.check_media_upload(target, size(&file)?)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let media_id = utils::random_string(MXC_LENGTH);
|
let media_id = utils::random_string(MXC_LENGTH);
|
||||||
|
|
||||||
services()
|
services()
|
||||||
|
@ -71,7 +86,7 @@ pub async fn create_content_route(
|
||||||
filename.as_deref(),
|
filename.as_deref(),
|
||||||
content_type.as_deref(),
|
content_type.as_deref(),
|
||||||
&file,
|
&file,
|
||||||
body.sender_user.as_deref(),
|
Some(&sender_user),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -176,6 +191,17 @@ pub async fn get_content_route(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
if let Some(target) = Target::from_client_request_optional_auth(
|
||||||
|
body.appservice_info,
|
||||||
|
&body.sender_user,
|
||||||
|
body.sender_ip_address,
|
||||||
|
) {
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_media_post_fetch(target, size(&file)?)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(media::get_content::v3::Response {
|
Ok(media::get_content::v3::Response {
|
||||||
file,
|
file,
|
||||||
content_type,
|
content_type,
|
||||||
|
@ -190,7 +216,18 @@ pub async fn get_content_route(
|
||||||
pub async fn get_content_auth_route(
|
pub async fn get_content_auth_route(
|
||||||
body: Ruma<get_content::v1::Request>,
|
body: Ruma<get_content::v1::Request>,
|
||||||
) -> Result<get_content::v1::Response> {
|
) -> Result<get_content::v1::Response> {
|
||||||
get_content(&body.server_name, body.media_id.clone(), true, true).await
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
let resp = get_content(&body.server_name, body.media_id.clone(), true, true).await?;
|
||||||
|
|
||||||
|
let target = Target::from_client_request(body.appservice_info, sender_user);
|
||||||
|
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_media_post_fetch(target, size(&resp.file)?)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_content(
|
pub async fn get_content(
|
||||||
|
@ -249,6 +286,17 @@ pub async fn get_content_as_filename_route(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
if let Some(target) = Target::from_client_request_optional_auth(
|
||||||
|
body.appservice_info,
|
||||||
|
&body.sender_user,
|
||||||
|
body.sender_ip_address,
|
||||||
|
) {
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_media_post_fetch(target, size(&file)?)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(media::get_content_as_filename::v3::Response {
|
Ok(media::get_content_as_filename::v3::Response {
|
||||||
file,
|
file,
|
||||||
content_type,
|
content_type,
|
||||||
|
@ -263,14 +311,25 @@ pub async fn get_content_as_filename_route(
|
||||||
pub async fn get_content_as_filename_auth_route(
|
pub async fn get_content_as_filename_auth_route(
|
||||||
body: Ruma<get_content_as_filename::v1::Request>,
|
body: Ruma<get_content_as_filename::v1::Request>,
|
||||||
) -> Result<get_content_as_filename::v1::Response, Error> {
|
) -> Result<get_content_as_filename::v1::Response, Error> {
|
||||||
get_content_as_filename(
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
let resp = get_content_as_filename(
|
||||||
&body.server_name,
|
&body.server_name,
|
||||||
body.media_id.clone(),
|
body.media_id.clone(),
|
||||||
body.filename.clone(),
|
body.filename.clone(),
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
.await
|
.await?;
|
||||||
|
|
||||||
|
let target = Target::from_client_request(body.appservice_info, sender_user);
|
||||||
|
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_media_post_fetch(target, size(&resp.file)?)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_content_as_filename(
|
async fn get_content_as_filename(
|
||||||
|
@ -337,6 +396,17 @@ pub async fn get_content_thumbnail_route(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
if let Some(target) = Target::from_client_request_optional_auth(
|
||||||
|
body.appservice_info,
|
||||||
|
&body.sender_user,
|
||||||
|
body.sender_ip_address,
|
||||||
|
) {
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_media_post_fetch(target, size(&file)?)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(media::get_content_thumbnail::v3::Response {
|
Ok(media::get_content_thumbnail::v3::Response {
|
||||||
file,
|
file,
|
||||||
content_type,
|
content_type,
|
||||||
|
@ -351,7 +421,9 @@ pub async fn get_content_thumbnail_route(
|
||||||
pub async fn get_content_thumbnail_auth_route(
|
pub async fn get_content_thumbnail_auth_route(
|
||||||
body: Ruma<get_content_thumbnail::v1::Request>,
|
body: Ruma<get_content_thumbnail::v1::Request>,
|
||||||
) -> Result<get_content_thumbnail::v1::Response> {
|
) -> Result<get_content_thumbnail::v1::Response> {
|
||||||
get_content_thumbnail(
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
let resp = get_content_thumbnail(
|
||||||
&body.server_name,
|
&body.server_name,
|
||||||
body.media_id.clone(),
|
body.media_id.clone(),
|
||||||
body.height,
|
body.height,
|
||||||
|
@ -361,7 +433,16 @@ pub async fn get_content_thumbnail_auth_route(
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
.await
|
.await?;
|
||||||
|
|
||||||
|
let target = Target::from_client_request(body.appservice_info, sender_user);
|
||||||
|
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_media_post_fetch(target, size(&resp.file)?)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
use std::{collections::BTreeMap, error::Error as _, iter::FromIterator, str};
|
use std::{
|
||||||
|
collections::BTreeMap,
|
||||||
|
error::Error as _,
|
||||||
|
iter::FromIterator,
|
||||||
|
net::IpAddr,
|
||||||
|
str::{self, FromStr},
|
||||||
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
|
@ -24,7 +30,10 @@ use serde::Deserialize;
|
||||||
use tracing::{debug, error, warn};
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
use super::{Ruma, RumaResponse};
|
use super::{Ruma, RumaResponse};
|
||||||
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
|
use crate::{
|
||||||
|
service::{appservice::RegistrationInfo, rate_limiting::Target},
|
||||||
|
services, Error, Result,
|
||||||
|
};
|
||||||
|
|
||||||
enum Token {
|
enum Token {
|
||||||
Appservice(Box<RegistrationInfo>),
|
Appservice(Box<RegistrationInfo>),
|
||||||
|
@ -327,6 +336,23 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let sender_ip_address = parts
|
||||||
|
.headers
|
||||||
|
.get("X-Forwarded-For")
|
||||||
|
.and_then(|header| header.to_str().ok())
|
||||||
|
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
|
||||||
|
.and_then(|ip| IpAddr::from_str(ip).ok());
|
||||||
|
|
||||||
|
let target = if let Some(server_name) = sender_servername.clone() {
|
||||||
|
Some(Target::Server(server_name))
|
||||||
|
} else if let Some(user) = &sender_user {
|
||||||
|
Some(Target::from_client_request(appservice_info.clone(), user))
|
||||||
|
} else {
|
||||||
|
sender_ip_address.map(Target::Ip)
|
||||||
|
};
|
||||||
|
|
||||||
|
services().rate_limiting.check(target, metadata).await?;
|
||||||
|
|
||||||
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
||||||
*http_request.headers_mut().unwrap() = parts.headers;
|
*http_request.headers_mut().unwrap() = parts.headers;
|
||||||
|
|
||||||
|
@ -377,6 +403,7 @@ where
|
||||||
sender_servername,
|
sender_servername,
|
||||||
appservice_info,
|
appservice_info,
|
||||||
json_body,
|
json_body,
|
||||||
|
sender_ip_address,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ use ruma::{
|
||||||
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
|
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
|
||||||
OwnedUserId,
|
OwnedUserId,
|
||||||
};
|
};
|
||||||
use std::ops::Deref;
|
use std::{net::IpAddr, ops::Deref};
|
||||||
|
|
||||||
#[cfg(feature = "conduit_bin")]
|
#[cfg(feature = "conduit_bin")]
|
||||||
mod axum;
|
mod axum;
|
||||||
|
@ -14,6 +14,7 @@ pub struct Ruma<T> {
|
||||||
pub sender_user: Option<OwnedUserId>,
|
pub sender_user: Option<OwnedUserId>,
|
||||||
pub sender_device: Option<OwnedDeviceId>,
|
pub sender_device: Option<OwnedDeviceId>,
|
||||||
pub sender_servername: Option<OwnedServerName>,
|
pub sender_servername: Option<OwnedServerName>,
|
||||||
|
pub sender_ip_address: Option<IpAddr>,
|
||||||
// This is None when body is not a valid string
|
// This is None when body is not a valid string
|
||||||
pub json_body: Option<CanonicalJsonValue>,
|
pub json_body: Option<CanonicalJsonValue>,
|
||||||
pub appservice_info: Option<RegistrationInfo>,
|
pub appservice_info: Option<RegistrationInfo>,
|
||||||
|
|
|
@ -4,8 +4,9 @@ use crate::{
|
||||||
api::client_server::{self, claim_keys_helper, get_keys_helper},
|
api::client_server::{self, claim_keys_helper, get_keys_helper},
|
||||||
service::{
|
service::{
|
||||||
globals::SigningKeys,
|
globals::SigningKeys,
|
||||||
media::FileMeta,
|
media::{size, FileMeta},
|
||||||
pdu::{gen_event_id_canonical_json, PduBuilder},
|
pdu::{gen_event_id_canonical_json, PduBuilder},
|
||||||
|
rate_limiting::Target,
|
||||||
},
|
},
|
||||||
services, utils, Error, PduEvent, Result, Ruma, SUPPORTED_VERSIONS,
|
services, utils, Error, PduEvent, Result, Ruma, SUPPORTED_VERSIONS,
|
||||||
};
|
};
|
||||||
|
@ -2237,6 +2238,11 @@ pub async fn create_invite_route(
|
||||||
pub async fn get_content_route(
|
pub async fn get_content_route(
|
||||||
body: Ruma<get_content::v1::Request>,
|
body: Ruma<get_content::v1::Request>,
|
||||||
) -> Result<get_content::v1::Response> {
|
) -> Result<get_content::v1::Response> {
|
||||||
|
let sender_servername = body
|
||||||
|
.sender_servername
|
||||||
|
.as_ref()
|
||||||
|
.expect("server is authenticated");
|
||||||
|
|
||||||
services()
|
services()
|
||||||
.media
|
.media
|
||||||
.check_blocked(services().globals.server_name(), &body.media_id)?;
|
.check_blocked(services().globals.server_name(), &body.media_id)?;
|
||||||
|
@ -2250,6 +2256,11 @@ pub async fn get_content_route(
|
||||||
.get(services().globals.server_name(), &body.media_id, true)
|
.get(services().globals.server_name(), &body.media_id, true)
|
||||||
.await?
|
.await?
|
||||||
{
|
{
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_media_post_fetch(Target::Server(sender_servername.to_owned()), size(&file)?)
|
||||||
|
.await;
|
||||||
|
|
||||||
Ok(get_content::v1::Response::new(
|
Ok(get_content::v1::Response::new(
|
||||||
ContentMetadata::new(),
|
ContentMetadata::new(),
|
||||||
FileOrLocation::File(Content {
|
FileOrLocation::File(Content {
|
||||||
|
@ -2269,6 +2280,11 @@ pub async fn get_content_route(
|
||||||
pub async fn get_content_thumbnail_route(
|
pub async fn get_content_thumbnail_route(
|
||||||
body: Ruma<get_content_thumbnail::v1::Request>,
|
body: Ruma<get_content_thumbnail::v1::Request>,
|
||||||
) -> Result<get_content_thumbnail::v1::Response> {
|
) -> Result<get_content_thumbnail::v1::Response> {
|
||||||
|
let sender_servername = body
|
||||||
|
.sender_servername
|
||||||
|
.as_ref()
|
||||||
|
.expect("server is authenticated");
|
||||||
|
|
||||||
services()
|
services()
|
||||||
.media
|
.media
|
||||||
.check_blocked(services().globals.server_name(), &body.media_id)?;
|
.check_blocked(services().globals.server_name(), &body.media_id)?;
|
||||||
|
@ -2295,6 +2311,11 @@ pub async fn get_content_thumbnail_route(
|
||||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
services()
|
||||||
|
.rate_limiting
|
||||||
|
.update_media_post_fetch(Target::Server(sender_servername.to_owned()), size(&file)?)
|
||||||
|
.await;
|
||||||
|
|
||||||
services()
|
services()
|
||||||
.media
|
.media
|
||||||
.upload_thumbnail(
|
.upload_thumbnail(
|
||||||
|
|
|
@ -17,7 +17,9 @@ use url::Url;
|
||||||
use crate::Error;
|
use crate::Error;
|
||||||
|
|
||||||
mod proxy;
|
mod proxy;
|
||||||
use self::proxy::ProxyConfig;
|
pub mod rate_limiting;
|
||||||
|
|
||||||
|
use self::{proxy::ProxyConfig, rate_limiting::Config as RateLimitingConfig};
|
||||||
|
|
||||||
const SHA256_HEX_LENGTH: u8 = 64;
|
const SHA256_HEX_LENGTH: u8 = 64;
|
||||||
|
|
||||||
|
@ -138,6 +140,8 @@ pub struct Config {
|
||||||
|
|
||||||
pub media: MediaConfig,
|
pub media: MediaConfig,
|
||||||
|
|
||||||
|
pub rate_limiting: RateLimitingConfig,
|
||||||
|
|
||||||
pub emergency_password: Option<String>,
|
pub emergency_password: Option<String>,
|
||||||
|
|
||||||
pub catchall: BTreeMap<String, IgnoredAny>,
|
pub catchall: BTreeMap<String, IgnoredAny>,
|
||||||
|
@ -281,6 +285,7 @@ impl From<IncompleteConfig> for Config {
|
||||||
log,
|
log,
|
||||||
turn,
|
turn,
|
||||||
media,
|
media,
|
||||||
|
rate_limiting: todo!(),
|
||||||
emergency_password,
|
emergency_password,
|
||||||
catchall,
|
catchall,
|
||||||
}
|
}
|
||||||
|
|
117
src/config/rate_limiting.rs
Normal file
117
src/config/rate_limiting.rs
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
use std::{collections::HashMap, num::NonZeroU64};
|
||||||
|
|
||||||
|
use bytesize::ByteSize;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use crate::service::rate_limiting::{ClientRestriction, FederationRestriction, Restriction};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
pub media: MediaConfig,
|
||||||
|
pub client: HashMap<ClientRestriction, Limitation>,
|
||||||
|
pub federation: HashMap<FederationRestriction, Limitation>,
|
||||||
|
pub global: GlobalConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn get(&self, restriction: &Restriction) -> &Limitation {
|
||||||
|
// Maybe look into https://github.com/moriyoshi-kasuga/enum-table
|
||||||
|
match restriction {
|
||||||
|
Restriction::Client(client_restriction) => self.client.get(client_restriction).unwrap(),
|
||||||
|
Restriction::Federation(federation_restriction) => {
|
||||||
|
self.federation.get(federation_restriction).unwrap()
|
||||||
|
}
|
||||||
|
Restriction::Media(media_restriction) => todo!(),
|
||||||
|
Restriction::CatchAll => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct GlobalConfig {
|
||||||
|
pub client: HashMap<ClientRestriction, Limitation>,
|
||||||
|
pub federation: HashMap<FederationRestriction, Limitation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: fold into one instead of copy-pasting
|
||||||
|
impl GlobalConfig {
|
||||||
|
pub fn get(&self, restriction: &Restriction) -> &Limitation {
|
||||||
|
// Maybe look into https://github.com/moriyoshi-kasuga/enum-table
|
||||||
|
match restriction {
|
||||||
|
Restriction::Client(client_restriction) => self.client.get(client_restriction).unwrap(),
|
||||||
|
Restriction::Federation(federation_restriction) => {
|
||||||
|
self.federation.get(federation_restriction).unwrap()
|
||||||
|
}
|
||||||
|
Restriction::Media(media_restriction) => todo!(),
|
||||||
|
Restriction::CatchAll => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||||
|
pub struct Limitation {
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub timeframe: Timeframe,
|
||||||
|
pub burst_capacity: NonZeroU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Clone, Copy, Debug)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
// When deserializing, we want this prefix
|
||||||
|
#[allow(clippy::enum_variant_names)]
|
||||||
|
pub enum Timeframe {
|
||||||
|
PerSecond(NonZeroU64),
|
||||||
|
PerMinute(NonZeroU64),
|
||||||
|
PerHour(NonZeroU64),
|
||||||
|
PerDay(NonZeroU64),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Timeframe {
|
||||||
|
pub fn nano_gap(&self) -> u64 {
|
||||||
|
match self {
|
||||||
|
Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(),
|
||||||
|
Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(),
|
||||||
|
Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(),
|
||||||
|
Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct MediaConfig {
|
||||||
|
pub upload: MediaLimitation,
|
||||||
|
pub fetch: MediaLimitation,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||||
|
pub struct MediaLimitation {
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub timeframe: MediaTimeframe,
|
||||||
|
pub burst_capacity: ByteSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Clone, Copy, Debug)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
// When deserializing, we want this prefix
|
||||||
|
#[allow(clippy::enum_variant_names)]
|
||||||
|
pub enum MediaTimeframe {
|
||||||
|
PerSecond(ByteSize),
|
||||||
|
PerMinute(ByteSize),
|
||||||
|
PerHour(ByteSize),
|
||||||
|
PerDay(ByteSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MediaTimeframe {
|
||||||
|
pub fn bytes_per_sec(&self) -> u64 {
|
||||||
|
match self {
|
||||||
|
MediaTimeframe::PerSecond(t) => t.as_u64(),
|
||||||
|
MediaTimeframe::PerMinute(t) => t.as_u64() / 60,
|
||||||
|
MediaTimeframe::PerHour(t) => t.as_u64() / (60 * 60),
|
||||||
|
MediaTimeframe::PerDay(t) => t.as_u64() / (60 * 60 * 24),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_non_zero() -> NonZeroU64 {
|
||||||
|
NonZeroU64::MIN
|
||||||
|
}
|
|
@ -17,6 +17,7 @@ pub mod key_backups;
|
||||||
pub mod media;
|
pub mod media;
|
||||||
pub mod pdu;
|
pub mod pdu;
|
||||||
pub mod pusher;
|
pub mod pusher;
|
||||||
|
pub mod rate_limiting;
|
||||||
pub mod rooms;
|
pub mod rooms;
|
||||||
pub mod sending;
|
pub mod sending;
|
||||||
pub mod transaction_ids;
|
pub mod transaction_ids;
|
||||||
|
@ -36,6 +37,7 @@ pub struct Services {
|
||||||
pub key_backups: key_backups::Service,
|
pub key_backups: key_backups::Service,
|
||||||
pub media: Arc<media::Service>,
|
pub media: Arc<media::Service>,
|
||||||
pub sending: Arc<sending::Service>,
|
pub sending: Arc<sending::Service>,
|
||||||
|
pub rate_limiting: Arc<rate_limiting::Service>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Services {
|
impl Services {
|
||||||
|
@ -123,6 +125,8 @@ impl Services {
|
||||||
media: Arc::new(media::Service { db }),
|
media: Arc::new(media::Service { db }),
|
||||||
sending: sending::Service::build(db, &config),
|
sending: sending::Service::build(db, &config),
|
||||||
|
|
||||||
|
rate_limiting: rate_limiting::Service::build(),
|
||||||
|
|
||||||
globals: globals::Service::load(db, config)?,
|
globals: globals::Service::load(db, config)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
418
src/service/rate_limiting/mod.rs
Normal file
418
src/service/rate_limiting/mod.rs
Normal file
|
@ -0,0 +1,418 @@
|
||||||
|
use std::{
|
||||||
|
collections::{hash_map::Entry, HashMap},
|
||||||
|
net::IpAddr,
|
||||||
|
sync::Arc,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use ruma::{
|
||||||
|
api::{
|
||||||
|
client::error::{ErrorKind, RetryAfter},
|
||||||
|
federation::membership::create_knock_event,
|
||||||
|
Metadata,
|
||||||
|
},
|
||||||
|
OwnedServerName, OwnedUserId, UserId,
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::{sync::Mutex, time::Instant};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
config::rate_limiting::MediaLimitation, service::appservice::RegistrationInfo, services, Error,
|
||||||
|
Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub enum Target {
|
||||||
|
User(OwnedUserId),
|
||||||
|
// Server endpoints should be rate-limited on a server and room basis
|
||||||
|
Server(OwnedServerName),
|
||||||
|
Appservice { id: String, rate_limited: bool },
|
||||||
|
Ip(IpAddr),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Target {
|
||||||
|
pub fn from_client_request(
|
||||||
|
registration_info: Option<RegistrationInfo>,
|
||||||
|
sender_user: &UserId,
|
||||||
|
) -> Self {
|
||||||
|
if let Some(info) = registration_info {
|
||||||
|
// `rate_limited` only effects "masqueraded users", "The sender [user?] is excluded"
|
||||||
|
return Target::Appservice {
|
||||||
|
id: info.registration.id,
|
||||||
|
rate_limited: info.registration.rate_limited.unwrap_or(true)
|
||||||
|
&& !(sender_user.server_name() == services().globals.server_name()
|
||||||
|
&& info.registration.sender_localpart == sender_user.localpart()),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
Target::User(sender_user.to_owned())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_client_request_optional_auth(
|
||||||
|
registration_info: Option<RegistrationInfo>,
|
||||||
|
sender_user: &Option<OwnedUserId>,
|
||||||
|
ip_addr: Option<IpAddr>,
|
||||||
|
) -> Option<Self> {
|
||||||
|
if let Some(sender_user) = sender_user.as_ref() {
|
||||||
|
Some(Self::from_client_request(registration_info, sender_user))
|
||||||
|
} else {
|
||||||
|
ip_addr.map(Self::Ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limited(&self) -> bool {
|
||||||
|
if let Target::Appservice {
|
||||||
|
rate_limited: false,
|
||||||
|
..
|
||||||
|
} = self
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub enum Restriction {
|
||||||
|
Client(ClientRestriction),
|
||||||
|
Federation(FederationRestriction),
|
||||||
|
Media(MediaRestriction),
|
||||||
|
|
||||||
|
#[default]
|
||||||
|
CatchAll,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ClientRestriction {
|
||||||
|
Registration,
|
||||||
|
Login,
|
||||||
|
RegistrationTokenValidity,
|
||||||
|
|
||||||
|
SendEvent,
|
||||||
|
|
||||||
|
Join,
|
||||||
|
Invite,
|
||||||
|
Knock,
|
||||||
|
|
||||||
|
SendReport,
|
||||||
|
CreateAlias,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum FederationRestriction {
|
||||||
|
Join,
|
||||||
|
Knock,
|
||||||
|
Invite,
|
||||||
|
|
||||||
|
// Transactions should be handled by a completely dedicated rate-limiter
|
||||||
|
Transaction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||||
|
pub enum MediaRestriction {
|
||||||
|
Create,
|
||||||
|
Fetch,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Metadata> for Restriction {
|
||||||
|
fn from(value: Metadata) -> Self {
|
||||||
|
use ruma::api::{
|
||||||
|
client::{
|
||||||
|
account::{check_registration_token_validity, register},
|
||||||
|
alias::create_alias,
|
||||||
|
authenticated_media::{
|
||||||
|
get_content, get_content_as_filename, get_content_thumbnail, get_media_preview,
|
||||||
|
},
|
||||||
|
knock::knock_room,
|
||||||
|
media::{self, create_content, create_mxc_uri},
|
||||||
|
membership::{invite_user, join_room_by_id, join_room_by_id_or_alias},
|
||||||
|
message::send_message_event,
|
||||||
|
reporting::report_user,
|
||||||
|
room::{report_content, report_room},
|
||||||
|
session::login,
|
||||||
|
state::send_state_event,
|
||||||
|
},
|
||||||
|
federation::{
|
||||||
|
authenticated_media::{
|
||||||
|
get_content as federation_get_content,
|
||||||
|
get_content_thumbnail as federation_get_content_thumbnail,
|
||||||
|
},
|
||||||
|
membership::{create_invite, create_join_event},
|
||||||
|
},
|
||||||
|
IncomingRequest,
|
||||||
|
};
|
||||||
|
use Restriction::*;
|
||||||
|
|
||||||
|
match value {
|
||||||
|
register::v3::Request::METADATA => Client(ClientRestriction::Registration),
|
||||||
|
check_registration_token_validity::v1::Request::METADATA => {
|
||||||
|
Client(ClientRestriction::RegistrationTokenValidity)
|
||||||
|
}
|
||||||
|
login::v3::Request::METADATA => Client(ClientRestriction::Login),
|
||||||
|
send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => {
|
||||||
|
Client(ClientRestriction::SendEvent)
|
||||||
|
}
|
||||||
|
join_room_by_id::v3::Request::METADATA
|
||||||
|
| join_room_by_id_or_alias::v3::Request::METADATA => Client(ClientRestriction::Join),
|
||||||
|
invite_user::v3::Request::METADATA => Client(ClientRestriction::Invite),
|
||||||
|
knock_room::v3::Request::METADATA => Client(ClientRestriction::Knock),
|
||||||
|
report_user::v3::Request::METADATA
|
||||||
|
| report_content::v3::Request::METADATA
|
||||||
|
| report_room::v3::Request::METADATA => Client(ClientRestriction::SendReport),
|
||||||
|
create_alias::v3::Request::METADATA => Client(ClientRestriction::CreateAlias),
|
||||||
|
// NOTE: handle async media upload in a way that doesn't half the number of uploads you can do within a short timeframe, while not allowing pre-generation of MXC uris to allow uploading double the number of media at once
|
||||||
|
create_content::v3::Request::METADATA | create_mxc_uri::v1::Request::METADATA => {
|
||||||
|
Media(MediaRestriction::Create)
|
||||||
|
}
|
||||||
|
// Unauthenticate media is deprecated
|
||||||
|
#[allow(deprecated)]
|
||||||
|
media::get_content::v3::Request::METADATA
|
||||||
|
| media::get_content_as_filename::v3::Request::METADATA
|
||||||
|
| media::get_content_thumbnail::v3::Request::METADATA
|
||||||
|
| media::get_media_preview::v3::Request::METADATA
|
||||||
|
| get_content::v1::Request::METADATA
|
||||||
|
| get_content_as_filename::v1::Request::METADATA
|
||||||
|
| get_content_thumbnail::v1::Request::METADATA
|
||||||
|
| get_media_preview::v1::Request::METADATA
|
||||||
|
| federation_get_content::v1::Request::METADATA
|
||||||
|
| federation_get_content_thumbnail::v1::Request::METADATA => {
|
||||||
|
Media(MediaRestriction::Fetch)
|
||||||
|
}
|
||||||
|
// v1 is deprecated
|
||||||
|
#[allow(deprecated)]
|
||||||
|
create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => {
|
||||||
|
Federation(FederationRestriction::Join)
|
||||||
|
}
|
||||||
|
create_knock_event::v1::Request::METADATA => Federation(FederationRestriction::Knock),
|
||||||
|
create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => {
|
||||||
|
Federation(FederationRestriction::Invite)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => Self::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Service {
|
||||||
|
buckets: Mutex<HashMap<(Target, Restriction), Arc<Mutex<Instant>>>>,
|
||||||
|
global_bucket: Mutex<HashMap<Restriction, Arc<Mutex<Instant>>>>,
|
||||||
|
media_upload: Mutex<HashMap<Target, Arc<Mutex<Instant>>>>,
|
||||||
|
media_fetch: Mutex<HashMap<Target, Arc<Mutex<Instant>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service {
|
||||||
|
pub fn build() -> Arc<Self> {
|
||||||
|
Arc::new(Self {
|
||||||
|
buckets: Mutex::new(HashMap::new()),
|
||||||
|
global_bucket: Mutex::new(HashMap::new()),
|
||||||
|
media_upload: Mutex::new(HashMap::new()),
|
||||||
|
media_fetch: Mutex::new(HashMap::new()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: use checked and saturating arithmetic
|
||||||
|
|
||||||
|
/// Takes the target and request, and either accepts the request while adding to the
|
||||||
|
/// bucket, or rejects the request, returning the duration that should be waited until
|
||||||
|
/// the request should be retried.
|
||||||
|
pub async fn check(&self, target: Option<Target>, request: Metadata) -> Result<()> {
|
||||||
|
let restriction: Restriction = request.into();
|
||||||
|
let arrival = Instant::now();
|
||||||
|
|
||||||
|
{
|
||||||
|
let map = self.global_bucket.lock().await;
|
||||||
|
|
||||||
|
if let Some(value) = map.get(&restriction) {
|
||||||
|
let value = value.lock().await;
|
||||||
|
|
||||||
|
if arrival.checked_duration_since(*value).is_none() {
|
||||||
|
instant_to_err(&value)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(target) = target {
|
||||||
|
if restriction == Restriction::Media(MediaRestriction::Fetch) {
|
||||||
|
self.check_media_pre_fetch(&target, arrival).await?
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = services().globals.config.rate_limiting.get(&restriction);
|
||||||
|
|
||||||
|
let mut map = self.buckets.lock().await;
|
||||||
|
let entry = map.entry((target, restriction));
|
||||||
|
match entry {
|
||||||
|
Entry::Occupied(occupied_entry) => {
|
||||||
|
let entry = Arc::clone(occupied_entry.get());
|
||||||
|
let mut entry = entry.lock().await;
|
||||||
|
|
||||||
|
if arrival.checked_duration_since(*entry).is_none() {
|
||||||
|
return instant_to_err(&entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
let min_instant = arrival
|
||||||
|
- Duration::from_nanos(
|
||||||
|
config.timeframe.nano_gap() * config.burst_capacity.get(),
|
||||||
|
);
|
||||||
|
*entry =
|
||||||
|
entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap());
|
||||||
|
}
|
||||||
|
Entry::Vacant(vacant_entry) => {
|
||||||
|
vacant_entry.insert(Arc::new(Mutex::new(
|
||||||
|
arrival
|
||||||
|
- Duration::from_nanos(
|
||||||
|
config.timeframe.nano_gap() * (config.burst_capacity.get() - 1),
|
||||||
|
),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let config = services()
|
||||||
|
.globals
|
||||||
|
.config
|
||||||
|
.rate_limiting
|
||||||
|
.global
|
||||||
|
.get(&restriction);
|
||||||
|
|
||||||
|
let mut map = self.global_bucket.lock().await;
|
||||||
|
|
||||||
|
let entry = map.entry(restriction);
|
||||||
|
match entry {
|
||||||
|
Entry::Occupied(occupied_entry) => {
|
||||||
|
let entry = Arc::clone(occupied_entry.get());
|
||||||
|
let mut entry = entry.lock().await;
|
||||||
|
|
||||||
|
if arrival.checked_duration_since(*entry).is_none() {
|
||||||
|
return instant_to_err(&entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
let min_instant = arrival
|
||||||
|
- Duration::from_nanos(
|
||||||
|
config.timeframe.nano_gap() * config.burst_capacity.get(),
|
||||||
|
);
|
||||||
|
*entry =
|
||||||
|
entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap());
|
||||||
|
}
|
||||||
|
Entry::Vacant(vacant_entry) => {
|
||||||
|
vacant_entry.insert(Arc::new(Mutex::new(
|
||||||
|
arrival
|
||||||
|
- Duration::from_nanos(
|
||||||
|
config.timeframe.nano_gap() * (config.burst_capacity.get() - 1),
|
||||||
|
),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_media_upload(&self, target: Target, size: u64) -> Result<()> {
|
||||||
|
if !target.rate_limited() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let arrival = Instant::now();
|
||||||
|
|
||||||
|
let MediaLimitation {
|
||||||
|
timeframe,
|
||||||
|
burst_capacity,
|
||||||
|
} = services().globals.config.rate_limiting.media.upload;
|
||||||
|
|
||||||
|
let mut map = self.media_upload.lock().await;
|
||||||
|
let entry = map.entry(target);
|
||||||
|
|
||||||
|
match entry {
|
||||||
|
Entry::Occupied(occupied_entry) => {
|
||||||
|
let entry = Arc::clone(occupied_entry.get());
|
||||||
|
let mut entry = entry.lock().await;
|
||||||
|
|
||||||
|
//TODO: use more precise conversion than secs
|
||||||
|
let min_instant = arrival
|
||||||
|
- Duration::from_secs(burst_capacity.as_u64() / timeframe.bytes_per_sec());
|
||||||
|
let proposed_entry =
|
||||||
|
entry.max(min_instant) + Duration::from_secs(size / timeframe.bytes_per_sec());
|
||||||
|
|
||||||
|
if arrival.checked_duration_since(proposed_entry).is_none() {
|
||||||
|
return instant_to_err(&proposed_entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
*entry = proposed_entry;
|
||||||
|
}
|
||||||
|
Entry::Vacant(vacant_entry) => {
|
||||||
|
vacant_entry.insert(Arc::new(Mutex::new(
|
||||||
|
arrival
|
||||||
|
- Duration::from_nanos(burst_capacity.as_u64() / timeframe.bytes_per_sec()),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_media_pre_fetch(&self, target: &Target, arrival: Instant) -> Result<()> {
|
||||||
|
let map = self.media_fetch.lock().await;
|
||||||
|
if let Some(mutex) = map.get(target) {
|
||||||
|
let mutex = mutex.lock().await;
|
||||||
|
|
||||||
|
if arrival.checked_duration_since(*mutex).is_none() {
|
||||||
|
return instant_to_err(&mutex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update_media_post_fetch(&self, target: Target, size: u64) {
|
||||||
|
if !target.rate_limited() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let MediaLimitation {
|
||||||
|
timeframe,
|
||||||
|
burst_capacity,
|
||||||
|
} = services().globals.config.rate_limiting.media.fetch;
|
||||||
|
|
||||||
|
let arrival = Instant::now();
|
||||||
|
|
||||||
|
let mut map = self.media_fetch.lock().await;
|
||||||
|
let entry = map.entry(target);
|
||||||
|
|
||||||
|
match entry {
|
||||||
|
Entry::Occupied(occupied_entry) => {
|
||||||
|
let entry = Arc::clone(occupied_entry.get());
|
||||||
|
let mut entry = entry.lock().await;
|
||||||
|
|
||||||
|
//TODO: use more precise conversion than secs
|
||||||
|
let min_instant = arrival
|
||||||
|
- Duration::from_secs(burst_capacity.as_u64() / timeframe.bytes_per_sec());
|
||||||
|
let proposed_entry =
|
||||||
|
entry.max(min_instant) + Duration::from_secs(size / timeframe.bytes_per_sec());
|
||||||
|
|
||||||
|
*entry = proposed_entry;
|
||||||
|
}
|
||||||
|
Entry::Vacant(vacant_entry) => {
|
||||||
|
vacant_entry.insert(Arc::new(Mutex::new(
|
||||||
|
arrival
|
||||||
|
- Duration::from_nanos(burst_capacity.as_u64() / timeframe.bytes_per_sec()),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn instant_to_err(instant: &Instant) -> Result<()> {
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
Err(Error::BadRequest(
|
||||||
|
ErrorKind::LimitExceeded {
|
||||||
|
// Not using ::DateTime because conversion from Instant to SystemTime is convoluted
|
||||||
|
retry_after: Some(RetryAfter::Delay(instant.duration_since(now))),
|
||||||
|
},
|
||||||
|
"Rate limit exceeded",
|
||||||
|
))
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue