mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-10-15 19:42:07 +00:00
WIP: rate-limiting
This commit is contained in:
parent
e757a98e10
commit
0d72304662
12 changed files with 1070 additions and 158 deletions
|
@ -3,7 +3,13 @@
|
|||
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma};
|
||||
use crate::{
|
||||
service::{
|
||||
media::{size, FileMeta},
|
||||
rate_limiting::Target,
|
||||
},
|
||||
services, utils, Error, Result, Ruma,
|
||||
};
|
||||
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE};
|
||||
use ruma::{
|
||||
api::{
|
||||
|
@ -54,6 +60,8 @@ pub async fn get_media_config_auth_route(
|
|||
pub async fn create_content_route(
|
||||
body: Ruma<create_content::v3::Request>,
|
||||
) -> Result<create_content::v3::Response> {
|
||||
let sender_user = body.sender_user.expect("user is authenticated");
|
||||
|
||||
let create_content::v3::Request {
|
||||
filename,
|
||||
content_type,
|
||||
|
@ -61,6 +69,13 @@ pub async fn create_content_route(
|
|||
..
|
||||
} = body.body;
|
||||
|
||||
let target = Target::from_client_request(body.appservice_info, &sender_user);
|
||||
|
||||
services()
|
||||
.rate_limiting
|
||||
.check_media_upload(target, size(&file)?)
|
||||
.await?;
|
||||
|
||||
let media_id = utils::random_string(MXC_LENGTH);
|
||||
|
||||
services()
|
||||
|
@ -71,7 +86,7 @@ pub async fn create_content_route(
|
|||
filename.as_deref(),
|
||||
content_type.as_deref(),
|
||||
&file,
|
||||
body.sender_user.as_deref(),
|
||||
Some(&sender_user),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
@ -84,7 +99,13 @@ pub async fn create_content_route(
|
|||
pub async fn get_remote_content(
|
||||
server_name: &ServerName,
|
||||
media_id: String,
|
||||
target: Target,
|
||||
) -> Result<get_content::v1::Response, Error> {
|
||||
services()
|
||||
.rate_limiting
|
||||
.check_media_pre_fetch(&target)
|
||||
.await?;
|
||||
|
||||
let content_response = match services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
|
@ -153,6 +174,11 @@ pub async fn get_remote_content(
|
|||
)
|
||||
.await?;
|
||||
|
||||
services()
|
||||
.rate_limiting
|
||||
.update_media_post_fetch(target, size(&content_response.file)?)
|
||||
.await;
|
||||
|
||||
Ok(content_response)
|
||||
}
|
||||
|
||||
|
@ -171,11 +197,21 @@ pub async fn get_content_route(
|
|||
} = get_content(
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
body.allow_remote,
|
||||
false,
|
||||
body.sender_ip_address.map(Target::Ip),
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Some(target) = Target::from_client_request_optional_auth(
|
||||
body.appservice_info,
|
||||
&body.sender_user,
|
||||
body.sender_ip_address,
|
||||
) {
|
||||
services()
|
||||
.rate_limiting
|
||||
.update_media_post_fetch(target, size(&file)?)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(media::get_content::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
|
@ -190,14 +226,24 @@ pub async fn get_content_route(
|
|||
pub async fn get_content_auth_route(
|
||||
body: Ruma<get_content::v1::Request>,
|
||||
) -> Result<get_content::v1::Response> {
|
||||
get_content(&body.server_name, body.media_id.clone(), true, true).await
|
||||
let Ruma::<get_content::v1::Request> {
|
||||
body,
|
||||
sender_user,
|
||||
appservice_info,
|
||||
..
|
||||
} = body;
|
||||
|
||||
let sender_user = sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let target = Target::from_client_request(appservice_info, sender_user);
|
||||
|
||||
get_content(&body.server_name, body.media_id.clone(), Some(target)).await
|
||||
}
|
||||
|
||||
pub async fn get_content(
|
||||
server_name: &ServerName,
|
||||
media_id: String,
|
||||
allow_remote: bool,
|
||||
authenticated: bool,
|
||||
target: Option<Target>,
|
||||
) -> Result<get_content::v1::Response, Error> {
|
||||
services().media.check_blocked(server_name, &media_id)?;
|
||||
|
||||
|
@ -207,7 +253,7 @@ pub async fn get_content(
|
|||
file,
|
||||
})) = services()
|
||||
.media
|
||||
.get(server_name, &media_id, authenticated)
|
||||
.get(server_name, &media_id, target.clone())
|
||||
.await
|
||||
{
|
||||
Ok(get_content::v1::Response {
|
||||
|
@ -215,16 +261,25 @@ pub async fn get_content(
|
|||
content_type,
|
||||
content_disposition: Some(content_disposition),
|
||||
})
|
||||
} else if server_name != services().globals.server_name() && allow_remote && authenticated {
|
||||
let remote_content_response = get_remote_content(server_name, media_id.clone()).await?;
|
||||
|
||||
Ok(get_content::v1::Response {
|
||||
content_disposition: remote_content_response.content_disposition,
|
||||
content_type: remote_content_response.content_type,
|
||||
file: remote_content_response.file,
|
||||
})
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
|
||||
if let Some(target) = target {
|
||||
if server_name != services().globals.server_name() && target.is_authenticated() {
|
||||
let remote_content_response =
|
||||
get_remote_content(server_name, media_id.clone(), target).await?;
|
||||
|
||||
Ok(get_content::v1::Response {
|
||||
content_disposition: remote_content_response.content_disposition,
|
||||
content_type: remote_content_response.content_type,
|
||||
file: remote_content_response.file,
|
||||
})
|
||||
} else {
|
||||
error
|
||||
}
|
||||
} else {
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -244,8 +299,7 @@ pub async fn get_content_as_filename_route(
|
|||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
body.filename.clone(),
|
||||
body.allow_remote,
|
||||
false,
|
||||
body.sender_ip_address.map(Target::Ip),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
@ -263,22 +317,33 @@ pub async fn get_content_as_filename_route(
|
|||
pub async fn get_content_as_filename_auth_route(
|
||||
body: Ruma<get_content_as_filename::v1::Request>,
|
||||
) -> Result<get_content_as_filename::v1::Response, Error> {
|
||||
get_content_as_filename(
|
||||
let Ruma::<get_content_as_filename::v1::Request> {
|
||||
body,
|
||||
sender_user,
|
||||
appservice_info,
|
||||
..
|
||||
} = body;
|
||||
|
||||
let sender_user = sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let target = Target::from_client_request(appservice_info, sender_user);
|
||||
|
||||
let resp = get_content_as_filename(
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
body.filename.clone(),
|
||||
true,
|
||||
true,
|
||||
Some(target),
|
||||
)
|
||||
.await
|
||||
.await?;
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn get_content_as_filename(
|
||||
server_name: &ServerName,
|
||||
media_id: String,
|
||||
filename: String,
|
||||
allow_remote: bool,
|
||||
authenticated: bool,
|
||||
target: Option<Target>,
|
||||
) -> Result<get_content_as_filename::v1::Response, Error> {
|
||||
services().media.check_blocked(server_name, &media_id)?;
|
||||
|
||||
|
@ -286,7 +351,7 @@ async fn get_content_as_filename(
|
|||
file, content_type, ..
|
||||
})) = services()
|
||||
.media
|
||||
.get(server_name, &media_id, authenticated)
|
||||
.get(server_name, &media_id, target.clone())
|
||||
.await
|
||||
{
|
||||
Ok(get_content_as_filename::v1::Response {
|
||||
|
@ -297,19 +362,28 @@ async fn get_content_as_filename(
|
|||
.with_filename(Some(filename.clone())),
|
||||
),
|
||||
})
|
||||
} else if server_name != services().globals.server_name() && allow_remote && authenticated {
|
||||
let remote_content_response = get_remote_content(server_name, media_id.clone()).await?;
|
||||
|
||||
Ok(get_content_as_filename::v1::Response {
|
||||
content_disposition: Some(
|
||||
ContentDisposition::new(ContentDispositionType::Inline)
|
||||
.with_filename(Some(filename.clone())),
|
||||
),
|
||||
content_type: remote_content_response.content_type,
|
||||
file: remote_content_response.file,
|
||||
})
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
|
||||
if let Some(target) = target {
|
||||
if server_name != services().globals.server_name() && target.is_authenticated() {
|
||||
let remote_content_response =
|
||||
get_remote_content(server_name, media_id.clone(), target).await?;
|
||||
|
||||
Ok(get_content_as_filename::v1::Response {
|
||||
content_disposition: Some(
|
||||
ContentDisposition::new(ContentDispositionType::Inline)
|
||||
.with_filename(Some(filename.clone())),
|
||||
),
|
||||
content_type: remote_content_response.content_type,
|
||||
file: remote_content_response.file,
|
||||
})
|
||||
} else {
|
||||
error
|
||||
}
|
||||
} else {
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -321,6 +395,17 @@ async fn get_content_as_filename(
|
|||
pub async fn get_content_thumbnail_route(
|
||||
body: Ruma<media::get_content_thumbnail::v3::Request>,
|
||||
) -> Result<media::get_content_thumbnail::v3::Response> {
|
||||
let Ruma::<media::get_content_thumbnail::v3::Request> {
|
||||
body,
|
||||
sender_user,
|
||||
sender_ip_address,
|
||||
appservice_info,
|
||||
..
|
||||
} = body;
|
||||
|
||||
let target =
|
||||
Target::from_client_request_optional_auth(appservice_info, &sender_user, sender_ip_address);
|
||||
|
||||
let get_content_thumbnail::v1::Response {
|
||||
file,
|
||||
content_type,
|
||||
|
@ -332,8 +417,7 @@ pub async fn get_content_thumbnail_route(
|
|||
body.width,
|
||||
body.method.clone(),
|
||||
body.animated,
|
||||
body.allow_remote,
|
||||
false,
|
||||
target,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
@ -351,17 +435,27 @@ pub async fn get_content_thumbnail_route(
|
|||
pub async fn get_content_thumbnail_auth_route(
|
||||
body: Ruma<get_content_thumbnail::v1::Request>,
|
||||
) -> Result<get_content_thumbnail::v1::Response> {
|
||||
get_content_thumbnail(
|
||||
let Ruma::<get_content_thumbnail::v1::Request> {
|
||||
body,
|
||||
sender_user,
|
||||
appservice_info,
|
||||
..
|
||||
} = body;
|
||||
let sender_user = sender_user.as_ref().expect("user is authenticated");
|
||||
let target = Target::from_client_request(appservice_info, sender_user);
|
||||
|
||||
let resp = get_content_thumbnail(
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
body.height,
|
||||
body.width,
|
||||
body.method.clone(),
|
||||
body.animated,
|
||||
true,
|
||||
true,
|
||||
Some(target),
|
||||
)
|
||||
.await
|
||||
.await?;
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
@ -372,8 +466,7 @@ async fn get_content_thumbnail(
|
|||
width: UInt,
|
||||
method: Option<Method>,
|
||||
animated: Option<bool>,
|
||||
allow_remote: bool,
|
||||
authenticated: bool,
|
||||
target: Option<Target>,
|
||||
) -> Result<get_content_thumbnail::v1::Response, Error> {
|
||||
services().media.check_blocked(server_name, &media_id)?;
|
||||
|
||||
|
@ -392,7 +485,7 @@ async fn get_content_thumbnail(
|
|||
height
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
|
||||
authenticated,
|
||||
target.clone(),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
|
@ -401,99 +494,117 @@ async fn get_content_thumbnail(
|
|||
content_type,
|
||||
content_disposition: Some(content_disposition),
|
||||
})
|
||||
} else if server_name != services().globals.server_name() && allow_remote && authenticated {
|
||||
let thumbnail_response = match services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server_name,
|
||||
federation_media::get_content_thumbnail::v1::Request {
|
||||
height,
|
||||
width,
|
||||
method: method.clone(),
|
||||
media_id: media_id.clone(),
|
||||
timeout_ms: Duration::from_secs(20),
|
||||
animated,
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(federation_media::get_content_thumbnail::v1::Response {
|
||||
metadata: _,
|
||||
content: FileOrLocation::File(content),
|
||||
}) => get_content_thumbnail::v1::Response {
|
||||
file: content.file,
|
||||
content_type: content.content_type,
|
||||
content_disposition: content.content_disposition,
|
||||
},
|
||||
} else {
|
||||
let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
|
||||
Ok(federation_media::get_content_thumbnail::v1::Response {
|
||||
metadata: _,
|
||||
content: FileOrLocation::Location(url),
|
||||
}) => {
|
||||
let get_content::v1::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
} = get_location_content(url).await?;
|
||||
if let Some(target) = target {
|
||||
if server_name != services().globals.server_name() {
|
||||
services()
|
||||
.rate_limiting
|
||||
.check_media_pre_fetch(&target)
|
||||
.await?;
|
||||
|
||||
get_content_thumbnail::v1::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
}
|
||||
}
|
||||
Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => {
|
||||
let media::get_content_thumbnail::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
..
|
||||
} = services()
|
||||
let thumbnail_response = match services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server_name,
|
||||
media::get_content_thumbnail::v3::Request {
|
||||
federation_media::get_content_thumbnail::v1::Request {
|
||||
height,
|
||||
width,
|
||||
method: method.clone(),
|
||||
server_name: server_name.to_owned(),
|
||||
media_id: media_id.clone(),
|
||||
timeout_ms: Duration::from_secs(20),
|
||||
allow_redirect: false,
|
||||
animated,
|
||||
allow_remote: false,
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(federation_media::get_content_thumbnail::v1::Response {
|
||||
metadata: _,
|
||||
content: FileOrLocation::File(content),
|
||||
}) => get_content_thumbnail::v1::Response {
|
||||
file: content.file,
|
||||
content_type: content.content_type,
|
||||
content_disposition: content.content_disposition,
|
||||
},
|
||||
|
||||
Ok(federation_media::get_content_thumbnail::v1::Response {
|
||||
metadata: _,
|
||||
content: FileOrLocation::Location(url),
|
||||
}) => {
|
||||
let get_content::v1::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
} = get_location_content(url).await?;
|
||||
|
||||
get_content_thumbnail::v1::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
}
|
||||
}
|
||||
Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => {
|
||||
let media::get_content_thumbnail::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
..
|
||||
} = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server_name,
|
||||
media::get_content_thumbnail::v3::Request {
|
||||
height,
|
||||
width,
|
||||
method: method.clone(),
|
||||
server_name: server_name.to_owned(),
|
||||
media_id: media_id.clone(),
|
||||
timeout_ms: Duration::from_secs(20),
|
||||
allow_redirect: false,
|
||||
animated,
|
||||
allow_remote: false,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
get_content_thumbnail::v1::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
services()
|
||||
.rate_limiting
|
||||
.update_media_post_fetch(target, size(&thumbnail_response.file)?)
|
||||
.await;
|
||||
|
||||
services()
|
||||
.media
|
||||
.upload_thumbnail(
|
||||
server_name,
|
||||
&media_id,
|
||||
thumbnail_response
|
||||
.content_disposition
|
||||
.as_ref()
|
||||
.and_then(|cd| cd.filename.as_deref()),
|
||||
thumbnail_response.content_type.as_deref(),
|
||||
width.try_into().expect("all UInts are valid u32s"),
|
||||
height.try_into().expect("all UInts are valid u32s"),
|
||||
&thumbnail_response.file,
|
||||
)
|
||||
.await?;
|
||||
|
||||
get_content_thumbnail::v1::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
}
|
||||
Ok(thumbnail_response)
|
||||
} else {
|
||||
error
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
services()
|
||||
.media
|
||||
.upload_thumbnail(
|
||||
server_name,
|
||||
&media_id,
|
||||
thumbnail_response
|
||||
.content_disposition
|
||||
.as_ref()
|
||||
.and_then(|cd| cd.filename.as_deref()),
|
||||
thumbnail_response.content_type.as_deref(),
|
||||
width.try_into().expect("all UInts are valid u32s"),
|
||||
height.try_into().expect("all UInts are valid u32s"),
|
||||
&thumbnail_response.file,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(thumbnail_response)
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
} else {
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
use std::{collections::BTreeMap, error::Error as _, iter::FromIterator, str};
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
error::Error as _,
|
||||
iter::FromIterator,
|
||||
net::IpAddr,
|
||||
str::{self, FromStr},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
|
@ -24,7 +30,10 @@ use serde::Deserialize;
|
|||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::{Ruma, RumaResponse};
|
||||
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
|
||||
use crate::{
|
||||
service::{appservice::RegistrationInfo, rate_limiting::Target},
|
||||
services, Error, Result,
|
||||
};
|
||||
|
||||
enum Token {
|
||||
Appservice(Box<RegistrationInfo>),
|
||||
|
@ -327,6 +336,23 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
let sender_ip_address = parts
|
||||
.headers
|
||||
.get("X-Forwarded-For")
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
|
||||
.and_then(|ip| IpAddr::from_str(ip).ok());
|
||||
|
||||
let target = if let Some(server_name) = sender_servername.clone() {
|
||||
Some(Target::Server(server_name))
|
||||
} else if let Some(user) = &sender_user {
|
||||
Some(Target::from_client_request(appservice_info.clone(), user))
|
||||
} else {
|
||||
sender_ip_address.map(Target::Ip)
|
||||
};
|
||||
|
||||
services().rate_limiting.check(target, metadata).await?;
|
||||
|
||||
let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
|
||||
*http_request.headers_mut().unwrap() = parts.headers;
|
||||
|
||||
|
@ -377,6 +403,7 @@ where
|
|||
sender_servername,
|
||||
appservice_info,
|
||||
json_body,
|
||||
sender_ip_address,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use ruma::{
|
|||
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
|
||||
OwnedUserId,
|
||||
};
|
||||
use std::ops::Deref;
|
||||
use std::{net::IpAddr, ops::Deref};
|
||||
|
||||
#[cfg(feature = "conduit_bin")]
|
||||
mod axum;
|
||||
|
@ -14,6 +14,7 @@ pub struct Ruma<T> {
|
|||
pub sender_user: Option<OwnedUserId>,
|
||||
pub sender_device: Option<OwnedDeviceId>,
|
||||
pub sender_servername: Option<OwnedServerName>,
|
||||
pub sender_ip_address: Option<IpAddr>,
|
||||
// This is None when body is not a valid string
|
||||
pub json_body: Option<CanonicalJsonValue>,
|
||||
pub appservice_info: Option<RegistrationInfo>,
|
||||
|
|
|
@ -6,6 +6,7 @@ use crate::{
|
|||
globals::SigningKeys,
|
||||
media::FileMeta,
|
||||
pdu::{gen_event_id_canonical_json, PduBuilder},
|
||||
rate_limiting::Target,
|
||||
},
|
||||
services, utils, Error, PduEvent, Result, Ruma, SUPPORTED_VERSIONS,
|
||||
};
|
||||
|
@ -2237,6 +2238,13 @@ pub async fn create_invite_route(
|
|||
pub async fn get_content_route(
|
||||
body: Ruma<get_content::v1::Request>,
|
||||
) -> Result<get_content::v1::Response> {
|
||||
let sender_servername = body
|
||||
.sender_servername
|
||||
.as_ref()
|
||||
.expect("server is authenticated");
|
||||
|
||||
let target = Some(Target::Server(sender_servername.to_owned()));
|
||||
|
||||
services()
|
||||
.media
|
||||
.check_blocked(services().globals.server_name(), &body.media_id)?;
|
||||
|
@ -2247,7 +2255,11 @@ pub async fn get_content_route(
|
|||
file,
|
||||
}) = services()
|
||||
.media
|
||||
.get(services().globals.server_name(), &body.media_id, true)
|
||||
.get(
|
||||
services().globals.server_name(),
|
||||
&body.media_id,
|
||||
target.clone(),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Ok(get_content::v1::Response::new(
|
||||
|
@ -2269,6 +2281,13 @@ pub async fn get_content_route(
|
|||
pub async fn get_content_thumbnail_route(
|
||||
body: Ruma<get_content_thumbnail::v1::Request>,
|
||||
) -> Result<get_content_thumbnail::v1::Response> {
|
||||
let Ruma::<get_content_thumbnail::v1::Request> {
|
||||
body,
|
||||
sender_servername,
|
||||
..
|
||||
} = body;
|
||||
let sender_servername = sender_servername.expect("server is authenticated");
|
||||
|
||||
services()
|
||||
.media
|
||||
.check_blocked(services().globals.server_name(), &body.media_id)?;
|
||||
|
@ -2288,7 +2307,7 @@ pub async fn get_content_thumbnail_route(
|
|||
body.height
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
|
||||
true,
|
||||
Some(Target::Server(sender_servername)),
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
|
|
|
@ -17,7 +17,9 @@ use url::Url;
|
|||
use crate::Error;
|
||||
|
||||
mod proxy;
|
||||
use self::proxy::ProxyConfig;
|
||||
pub mod rate_limiting;
|
||||
|
||||
use self::{proxy::ProxyConfig, rate_limiting::Config as RateLimitingConfig};
|
||||
|
||||
const SHA256_HEX_LENGTH: u8 = 64;
|
||||
|
||||
|
@ -92,6 +94,8 @@ pub struct IncompleteConfig {
|
|||
#[serde(default)]
|
||||
pub media: IncompleteMediaConfig,
|
||||
|
||||
pub rate_limiting: RateLimitingConfig,
|
||||
|
||||
pub emergency_password: Option<String>,
|
||||
|
||||
#[serde(flatten)]
|
||||
|
@ -138,6 +142,8 @@ pub struct Config {
|
|||
|
||||
pub media: MediaConfig,
|
||||
|
||||
pub rate_limiting: RateLimitingConfig,
|
||||
|
||||
pub emergency_password: Option<String>,
|
||||
|
||||
pub catchall: BTreeMap<String, IgnoredAny>,
|
||||
|
@ -184,6 +190,7 @@ impl From<IncompleteConfig> for Config {
|
|||
turn_ttl,
|
||||
turn,
|
||||
media,
|
||||
rate_limiting,
|
||||
emergency_password,
|
||||
catchall,
|
||||
} = val;
|
||||
|
@ -281,6 +288,7 @@ impl From<IncompleteConfig> for Config {
|
|||
log,
|
||||
turn,
|
||||
media,
|
||||
rate_limiting,
|
||||
emergency_password,
|
||||
catchall,
|
||||
}
|
||||
|
|
115
src/config/rate_limiting.rs
Normal file
115
src/config/rate_limiting.rs
Normal file
|
@ -0,0 +1,115 @@
|
|||
use std::{collections::HashMap, hash::Hash, num::NonZeroU64};
|
||||
|
||||
use bytesize::ByteSize;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::service::rate_limiting::{ClientRestriction, FederationRestriction, Restriction};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
#[serde(flatten)]
|
||||
pub target: ConfigFragment,
|
||||
pub global: ConfigFragment,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ConfigFragment {
|
||||
pub client: ConfigSideFragment<ClientRestriction, ClientMediaConfig>,
|
||||
pub federation: ConfigSideFragment<FederationRestriction, FederationMediaConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ConfigSideFragment<K, C>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
#[serde(flatten)]
|
||||
pub map: HashMap<K, RequestLimitation>,
|
||||
pub media: C,
|
||||
}
|
||||
|
||||
impl ConfigFragment {
|
||||
pub fn get(&self, restriction: &Restriction) -> &RequestLimitation {
|
||||
// Maybe look into https://github.com/moriyoshi-kasuga/enum-table
|
||||
match restriction {
|
||||
Restriction::Client(client_restriction) => {
|
||||
self.client.map.get(client_restriction).unwrap()
|
||||
}
|
||||
Restriction::Federation(federation_restriction) => {
|
||||
self.federation.map.get(federation_restriction).unwrap()
|
||||
}
|
||||
Restriction::Media(media_restriction) => todo!(),
|
||||
Restriction::CatchAll => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||
pub struct RequestLimitation {
|
||||
#[serde(flatten)]
|
||||
pub timeframe: Timeframe,
|
||||
pub burst_capacity: NonZeroU64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Copy, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
// When deserializing, we want this prefix
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
pub enum Timeframe {
|
||||
PerSecond(NonZeroU64),
|
||||
PerMinute(NonZeroU64),
|
||||
PerHour(NonZeroU64),
|
||||
PerDay(NonZeroU64),
|
||||
}
|
||||
|
||||
impl Timeframe {
|
||||
pub fn nano_gap(&self) -> u64 {
|
||||
match self {
|
||||
Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(),
|
||||
Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(),
|
||||
Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(),
|
||||
Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||
pub struct ClientMediaConfig {
|
||||
pub download: MediaLimitation,
|
||||
pub upload: MediaLimitation,
|
||||
pub fetch: MediaLimitation,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||
pub struct FederationMediaConfig {
|
||||
pub download: MediaLimitation,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)]
|
||||
pub struct MediaLimitation {
|
||||
#[serde(flatten)]
|
||||
pub timeframe: MediaTimeframe,
|
||||
pub burst_capacity: ByteSize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Copy, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
// When deserializing, we want this prefix
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
pub enum MediaTimeframe {
|
||||
PerSecond(ByteSize),
|
||||
PerMinute(ByteSize),
|
||||
PerHour(ByteSize),
|
||||
PerDay(ByteSize),
|
||||
}
|
||||
|
||||
impl MediaTimeframe {
|
||||
pub fn bytes_per_sec(&self) -> u64 {
|
||||
match self {
|
||||
MediaTimeframe::PerSecond(t) => t.as_u64(),
|
||||
MediaTimeframe::PerMinute(t) => t.as_u64() / 60,
|
||||
MediaTimeframe::PerHour(t) => t.as_u64() / (60 * 60),
|
||||
MediaTimeframe::PerDay(t) => t.as_u64() / (60 * 60 * 24),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -203,19 +203,7 @@ impl service::media::Data for KeyValueDatabase {
|
|||
|
||||
let is_blocked_via_filehash = self.is_blocked_filehash(&sha256_digest)?;
|
||||
|
||||
let time_info = if let Some(filehash_meta) = self
|
||||
.filehash_metadata
|
||||
.get(&sha256_digest)?
|
||||
.map(FilehashMetadata::from_vec)
|
||||
{
|
||||
Some(FileInfo {
|
||||
creation: filehash_meta.creation(&sha256_digest)?,
|
||||
last_access: filehash_meta.last_access(&sha256_digest)?,
|
||||
size: filehash_meta.size(&sha256_digest)?,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let file_info = self.file_info(&sha256_digest)?;
|
||||
|
||||
Some(MediaQueryFileInfo {
|
||||
uploader_localpart,
|
||||
|
@ -224,7 +212,7 @@ impl service::media::Data for KeyValueDatabase {
|
|||
content_type,
|
||||
unauthenticated_access_permitted,
|
||||
is_blocked_via_filehash,
|
||||
file_info: time_info,
|
||||
file_info,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
@ -1353,6 +1341,24 @@ impl service::media::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn file_info(&self, sha256_digest: &[u8]) -> Result<Option<FileInfo>, Error> {
|
||||
Ok(
|
||||
if let Some(filehash_meta) = self
|
||||
.filehash_metadata
|
||||
.get(sha256_digest)?
|
||||
.map(FilehashMetadata::from_vec)
|
||||
{
|
||||
Some(FileInfo {
|
||||
creation: filehash_meta.creation(&sha256_digest)?,
|
||||
last_access: filehash_meta.last_access(&sha256_digest)?,
|
||||
size: filehash_meta.size(&sha256_digest)?,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KeyValueDatabase {
|
||||
|
|
|
@ -41,6 +41,7 @@ use tokio::sync::{mpsc, Mutex, RwLock};
|
|||
|
||||
use crate::{
|
||||
api::client_server::{self, leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
|
||||
service::rate_limiting::Target,
|
||||
services,
|
||||
utils::{self, HtmlEscape},
|
||||
Error, PduEvent, Result,
|
||||
|
@ -1174,8 +1175,12 @@ impl Service {
|
|||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
} = client_server::media::get_content(server_name, media_id.to_owned(), true, true)
|
||||
.await?;
|
||||
} = client_server::media::get_content(
|
||||
server_name,
|
||||
media_id.to_owned(),
|
||||
Some(Target::User(services().globals.server_user().to_owned())),
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Ok(image) = image::load_from_memory(&file) {
|
||||
let filename = content_disposition.and_then(|cd| cd.filename);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use ruma::{OwnedServerName, ServerName, UserId};
|
||||
use sha2::{digest::Output, Sha256};
|
||||
|
||||
use crate::{config::MediaRetentionConfig, Error, Result};
|
||||
use crate::{config::MediaRetentionConfig, service::media::FileInfo, Error, Result};
|
||||
|
||||
use super::{
|
||||
BlockedMediaInfo, DbFileMeta, MediaListItem, MediaQuery, MediaType, ServerNameOrUserId,
|
||||
|
@ -124,4 +124,7 @@ pub trait Data: Send + Sync {
|
|||
fn update_last_accessed(&self, server_name: &ServerName, media_id: &str) -> Result<()>;
|
||||
|
||||
fn update_last_accessed_filehash(&self, sha256_digest: &[u8]) -> Result<()>;
|
||||
|
||||
/// Returns the known information about a file
|
||||
fn file_info(&self, sha256_digest: &[u8]) -> Result<Option<FileInfo>>;
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ use tracing::{error, info, warn};
|
|||
|
||||
use crate::{
|
||||
config::{DirectoryStructure, MediaBackendConfig, S3MediaBackend},
|
||||
service::rate_limiting::Target,
|
||||
services, utils, Error, Result,
|
||||
};
|
||||
use image::imageops::FilterType;
|
||||
|
@ -237,7 +238,7 @@ impl Service {
|
|||
&self,
|
||||
servername: &ServerName,
|
||||
media_id: &str,
|
||||
authenticated: bool,
|
||||
target: Option<Target>,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
let DbFileMeta {
|
||||
sha256_digest,
|
||||
|
@ -246,12 +247,19 @@ impl Service {
|
|||
unauthenticated_access_permitted,
|
||||
} = self.db.search_file_metadata(servername, media_id)?;
|
||||
|
||||
if !(authenticated || unauthenticated_access_permitted) {
|
||||
if !(target.as_ref().is_some_and(Target::is_authenticated)
|
||||
|| unauthenticated_access_permitted)
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let file = self.get_file(&sha256_digest, None).await?;
|
||||
|
||||
services()
|
||||
.rate_limiting
|
||||
.check_media_download(target, size(&file)?)
|
||||
.await?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition: content_disposition(filename, &content_type),
|
||||
content_type,
|
||||
|
@ -288,7 +296,7 @@ impl Service {
|
|||
media_id: &str,
|
||||
width: u32,
|
||||
height: u32,
|
||||
authenticated: bool,
|
||||
target: Option<Target>,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
if let Some((width, height, crop)) = self.thumbnail_properties(width, height) {
|
||||
if let Ok(DbFileMeta {
|
||||
|
@ -300,10 +308,19 @@ impl Service {
|
|||
.db
|
||||
.search_thumbnail_metadata(servername, media_id, width, height)
|
||||
{
|
||||
if !(authenticated || unauthenticated_access_permitted) {
|
||||
if !(target.as_ref().is_some_and(Target::is_authenticated)
|
||||
|| unauthenticated_access_permitted)
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let file_info = self.file_info(&sha256_digest)?;
|
||||
|
||||
services()
|
||||
.rate_limiting
|
||||
.check_media_download(target, file_info.size)
|
||||
.await?;
|
||||
|
||||
// Using saved thumbnail
|
||||
let file = self
|
||||
.get_file(&sha256_digest, Some((servername, media_id)))
|
||||
|
@ -314,19 +331,15 @@ impl Service {
|
|||
content_type,
|
||||
file,
|
||||
}))
|
||||
} else if !authenticated {
|
||||
} else if !target.as_ref().is_some_and(Target::is_authenticated) {
|
||||
return Ok(None);
|
||||
} else if let Ok(DbFileMeta {
|
||||
sha256_digest,
|
||||
filename,
|
||||
content_type,
|
||||
unauthenticated_access_permitted,
|
||||
..
|
||||
}) = self.db.search_file_metadata(servername, media_id)
|
||||
{
|
||||
if !(authenticated || unauthenticated_access_permitted) {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let content_disposition = content_disposition(filename.clone(), &content_type);
|
||||
// Generate a thumbnail
|
||||
let file = self.get_file(&sha256_digest, None).await?;
|
||||
|
@ -426,7 +439,9 @@ impl Service {
|
|||
return Ok(None);
|
||||
};
|
||||
|
||||
if !(authenticated || unauthenticated_access_permitted) {
|
||||
if !(target.as_ref().is_some_and(Target::is_authenticated)
|
||||
|| unauthenticated_access_permitted)
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
@ -662,6 +677,13 @@ impl Service {
|
|||
.update_last_accessed_filehash(sha256_digest)
|
||||
.map(|_| file)
|
||||
}
|
||||
|
||||
fn file_info(&self, sha256_digest: &[u8]) -> Result<FileInfo> {
|
||||
self.db
|
||||
.file_info(sha256_digest)
|
||||
.transpose()
|
||||
.unwrap_or_else(|| Err(Error::BadRequest(ErrorKind::NotFound, "Fi)le not found")))
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates the media file, using the configured media backend
|
||||
|
|
|
@ -17,6 +17,7 @@ pub mod key_backups;
|
|||
pub mod media;
|
||||
pub mod pdu;
|
||||
pub mod pusher;
|
||||
pub mod rate_limiting;
|
||||
pub mod rooms;
|
||||
pub mod sending;
|
||||
pub mod transaction_ids;
|
||||
|
@ -36,6 +37,7 @@ pub struct Services {
|
|||
pub key_backups: key_backups::Service,
|
||||
pub media: Arc<media::Service>,
|
||||
pub sending: Arc<sending::Service>,
|
||||
pub rate_limiting: Arc<rate_limiting::Service>,
|
||||
}
|
||||
|
||||
impl Services {
|
||||
|
@ -123,6 +125,8 @@ impl Services {
|
|||
media: Arc::new(media::Service { db }),
|
||||
sending: sending::Service::build(db, &config),
|
||||
|
||||
rate_limiting: rate_limiting::Service::build(&config),
|
||||
|
||||
globals: globals::Service::load(db, config)?,
|
||||
})
|
||||
}
|
||||
|
|
591
src/service/rate_limiting/mod.rs
Normal file
591
src/service/rate_limiting/mod.rs
Normal file
|
@ -0,0 +1,591 @@
|
|||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
net::IpAddr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use ruma::{
|
||||
api::{
|
||||
client::error::{ErrorKind, RetryAfter},
|
||||
federation::membership::create_knock_event,
|
||||
Metadata,
|
||||
},
|
||||
OwnedServerName, OwnedUserId, UserId,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tokio::{
|
||||
sync::{Mutex, MutexGuard},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
config::rate_limiting::MediaLimitation, service::appservice::RegistrationInfo, services,
|
||||
Config, Error, Result,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub enum Target {
|
||||
User(OwnedUserId),
|
||||
// Server endpoints should be rate-limited on a server and room basis
|
||||
Server(OwnedServerName),
|
||||
Appservice { id: String, rate_limited: bool },
|
||||
Ip(IpAddr),
|
||||
}
|
||||
|
||||
impl Target {
|
||||
pub fn from_client_request(
|
||||
registration_info: Option<RegistrationInfo>,
|
||||
sender_user: &UserId,
|
||||
) -> Self {
|
||||
if let Some(info) = registration_info {
|
||||
// `rate_limited` only effects "masqueraded users", "The sender [user?] is excluded"
|
||||
return Target::Appservice {
|
||||
id: info.registration.id,
|
||||
rate_limited: info.registration.rate_limited.unwrap_or(true)
|
||||
&& !(sender_user.server_name() == services().globals.server_name()
|
||||
&& info.registration.sender_localpart == sender_user.localpart()),
|
||||
};
|
||||
}
|
||||
|
||||
Target::User(sender_user.to_owned())
|
||||
}
|
||||
|
||||
pub fn from_client_request_optional_auth(
|
||||
registration_info: Option<RegistrationInfo>,
|
||||
sender_user: &Option<OwnedUserId>,
|
||||
ip_addr: Option<IpAddr>,
|
||||
) -> Option<Self> {
|
||||
if let Some(sender_user) = sender_user.as_ref() {
|
||||
Some(Self::from_client_request(registration_info, sender_user))
|
||||
} else {
|
||||
ip_addr.map(Self::Ip)
|
||||
}
|
||||
}
|
||||
|
||||
fn rate_limited(&self) -> bool {
|
||||
match self {
|
||||
Target::User(user_id) => user_id != services().globals.server_user(),
|
||||
Target::Appservice {
|
||||
id: _,
|
||||
rate_limited,
|
||||
} => *rate_limited,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
!matches!(self, Target::Ip(_))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub enum Restriction {
|
||||
Client(ClientRestriction),
|
||||
Federation(FederationRestriction),
|
||||
Media(MediaRestriction),
|
||||
|
||||
#[default]
|
||||
CatchAll,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ClientRestriction {
|
||||
Registration,
|
||||
Login,
|
||||
RegistrationTokenValidity,
|
||||
|
||||
SendEvent,
|
||||
|
||||
Join,
|
||||
Invite,
|
||||
Knock,
|
||||
|
||||
SendReport,
|
||||
CreateAlias,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum FederationRestriction {
|
||||
Join,
|
||||
Knock,
|
||||
Invite,
|
||||
|
||||
// Transactions should be handled by a completely dedicated rate-limiter
|
||||
Transaction,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub enum MediaRestriction {
|
||||
Create,
|
||||
Download,
|
||||
}
|
||||
|
||||
impl From<Metadata> for Restriction {
|
||||
fn from(value: Metadata) -> Self {
|
||||
use ruma::api::{
|
||||
client::{
|
||||
account::{check_registration_token_validity, register},
|
||||
alias::create_alias,
|
||||
authenticated_media::{
|
||||
get_content, get_content_as_filename, get_content_thumbnail, get_media_preview,
|
||||
},
|
||||
knock::knock_room,
|
||||
media::{self, create_content, create_mxc_uri},
|
||||
membership::{invite_user, join_room_by_id, join_room_by_id_or_alias},
|
||||
message::send_message_event,
|
||||
reporting::report_user,
|
||||
room::{report_content, report_room},
|
||||
session::login,
|
||||
state::send_state_event,
|
||||
},
|
||||
federation::{
|
||||
authenticated_media::{
|
||||
get_content as federation_get_content,
|
||||
get_content_thumbnail as federation_get_content_thumbnail,
|
||||
},
|
||||
membership::{create_invite, create_join_event},
|
||||
},
|
||||
IncomingRequest,
|
||||
};
|
||||
use Restriction::*;
|
||||
|
||||
match value {
|
||||
register::v3::Request::METADATA => Client(ClientRestriction::Registration),
|
||||
check_registration_token_validity::v1::Request::METADATA => {
|
||||
Client(ClientRestriction::RegistrationTokenValidity)
|
||||
}
|
||||
login::v3::Request::METADATA => Client(ClientRestriction::Login),
|
||||
send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => {
|
||||
Client(ClientRestriction::SendEvent)
|
||||
}
|
||||
join_room_by_id::v3::Request::METADATA
|
||||
| join_room_by_id_or_alias::v3::Request::METADATA => Client(ClientRestriction::Join),
|
||||
invite_user::v3::Request::METADATA => Client(ClientRestriction::Invite),
|
||||
knock_room::v3::Request::METADATA => Client(ClientRestriction::Knock),
|
||||
report_user::v3::Request::METADATA
|
||||
| report_content::v3::Request::METADATA
|
||||
| report_room::v3::Request::METADATA => Client(ClientRestriction::SendReport),
|
||||
create_alias::v3::Request::METADATA => Client(ClientRestriction::CreateAlias),
|
||||
// NOTE: handle async media upload in a way that doesn't half the number of uploads you can do within a short timeframe, while not allowing pre-generation of MXC uris to allow uploading double the number of media at once
|
||||
create_content::v3::Request::METADATA | create_mxc_uri::v1::Request::METADATA => {
|
||||
Media(MediaRestriction::Create)
|
||||
}
|
||||
// Unauthenticate media is deprecated
|
||||
#[allow(deprecated)]
|
||||
media::get_content::v3::Request::METADATA
|
||||
| media::get_content_as_filename::v3::Request::METADATA
|
||||
| media::get_content_thumbnail::v3::Request::METADATA
|
||||
| media::get_media_preview::v3::Request::METADATA
|
||||
| get_content::v1::Request::METADATA
|
||||
| get_content_as_filename::v1::Request::METADATA
|
||||
| get_content_thumbnail::v1::Request::METADATA
|
||||
| get_media_preview::v1::Request::METADATA
|
||||
| federation_get_content::v1::Request::METADATA
|
||||
| federation_get_content_thumbnail::v1::Request::METADATA => {
|
||||
Media(MediaRestriction::Download)
|
||||
}
|
||||
// v1 is deprecated
|
||||
#[allow(deprecated)]
|
||||
create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => {
|
||||
Federation(FederationRestriction::Join)
|
||||
}
|
||||
create_knock_event::v1::Request::METADATA => Federation(FederationRestriction::Knock),
|
||||
create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => {
|
||||
Federation(FederationRestriction::Invite)
|
||||
}
|
||||
|
||||
_ => Self::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MediaBucket = Mutex<HashMap<Target, Arc<Mutex<Instant>>>>;
|
||||
type GlobalMediaBucket = Arc<Mutex<Instant>>;
|
||||
|
||||
pub struct Service {
|
||||
buckets: Mutex<HashMap<(Target, Restriction), Arc<Mutex<Instant>>>>,
|
||||
global_bucket: Mutex<HashMap<Restriction, Arc<Mutex<Instant>>>>,
|
||||
|
||||
media_upload: MediaBucket,
|
||||
media_fetch: MediaBucket,
|
||||
media_download: MediaBucket,
|
||||
|
||||
global_media_upload: GlobalMediaBucket,
|
||||
global_media_fetch: GlobalMediaBucket,
|
||||
global_media_download_client: GlobalMediaBucket,
|
||||
global_media_download_federation: GlobalMediaBucket,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn build(config: &Config) -> Arc<Self> {
|
||||
let now = Instant::now();
|
||||
let global_media_config = &config.rate_limiting.global;
|
||||
|
||||
Arc::new(Self {
|
||||
buckets: Mutex::new(HashMap::new()),
|
||||
global_bucket: Mutex::new(HashMap::new()),
|
||||
|
||||
media_upload: Mutex::new(HashMap::new()),
|
||||
media_fetch: Mutex::new(HashMap::new()),
|
||||
media_download: Mutex::new(HashMap::new()),
|
||||
|
||||
global_media_upload: default_media_entry(global_media_config.client.media.upload, now),
|
||||
global_media_fetch: default_media_entry(global_media_config.client.media.fetch, now),
|
||||
global_media_download_client: default_media_entry(
|
||||
global_media_config.client.media.download,
|
||||
now,
|
||||
),
|
||||
global_media_download_federation: default_media_entry(
|
||||
global_media_config.federation.media.download,
|
||||
now,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
//TODO: use checked and saturating arithmetic
|
||||
|
||||
/// Takes the target and request, and either accepts the request while adding to the
|
||||
/// bucket, or rejects the request, returning the duration that should be waited until
|
||||
/// the request should be retried.
|
||||
pub async fn check(&self, target: Option<Target>, request: Metadata) -> Result<()> {
|
||||
let restriction: Restriction = request.into();
|
||||
let arrival = Instant::now();
|
||||
|
||||
{
|
||||
let map = self.global_bucket.lock().await;
|
||||
|
||||
if let Some(value) = map.get(&restriction) {
|
||||
let value = value.lock().await;
|
||||
|
||||
if arrival.checked_duration_since(*value).is_none() {
|
||||
instant_to_err(&value)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(target) = target {
|
||||
let config = services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.target
|
||||
.get(&restriction);
|
||||
|
||||
let mut map = self.buckets.lock().await;
|
||||
let entry = map.entry((target, restriction));
|
||||
match entry {
|
||||
Entry::Occupied(occupied_entry) => {
|
||||
let entry = Arc::clone(occupied_entry.get());
|
||||
let mut entry = entry.lock().await;
|
||||
|
||||
if arrival.checked_duration_since(*entry).is_none() {
|
||||
return instant_to_err(&entry);
|
||||
}
|
||||
|
||||
let min_instant = arrival
|
||||
- Duration::from_nanos(
|
||||
config.timeframe.nano_gap() * config.burst_capacity.get(),
|
||||
);
|
||||
*entry =
|
||||
entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap());
|
||||
}
|
||||
Entry::Vacant(vacant_entry) => {
|
||||
vacant_entry.insert(Arc::new(Mutex::new(
|
||||
arrival
|
||||
- Duration::from_nanos(
|
||||
config.timeframe.nano_gap() * (config.burst_capacity.get() - 1),
|
||||
),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let config = services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.global
|
||||
.get(&restriction);
|
||||
|
||||
let mut map = self.global_bucket.lock().await;
|
||||
|
||||
let entry = map.entry(restriction);
|
||||
match entry {
|
||||
Entry::Occupied(occupied_entry) => {
|
||||
let entry = Arc::clone(occupied_entry.get());
|
||||
let mut entry = entry.lock().await;
|
||||
|
||||
if arrival.checked_duration_since(*entry).is_none() {
|
||||
return instant_to_err(&entry);
|
||||
}
|
||||
|
||||
let min_instant = arrival
|
||||
- Duration::from_nanos(
|
||||
config.timeframe.nano_gap() * config.burst_capacity.get(),
|
||||
);
|
||||
*entry =
|
||||
entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap());
|
||||
}
|
||||
Entry::Vacant(vacant_entry) => {
|
||||
vacant_entry.insert(Arc::new(Mutex::new(
|
||||
arrival
|
||||
- Duration::from_nanos(
|
||||
config.timeframe.nano_gap() * (config.burst_capacity.get() - 1),
|
||||
),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn check_media_download(&self, target: Option<Target>, size: u64) -> Result<()> {
|
||||
// All targets besides servers use the client-server API
|
||||
let (target_limitation, global_limitation, global_bucket) =
|
||||
if let Some(Target::Server(_)) = &target {
|
||||
(
|
||||
services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.target
|
||||
.federation
|
||||
.media
|
||||
.download,
|
||||
services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.global
|
||||
.federation
|
||||
.media
|
||||
.download,
|
||||
&self.global_media_download_federation,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.target
|
||||
.client
|
||||
.media
|
||||
.download,
|
||||
services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.global
|
||||
.client
|
||||
.media
|
||||
.download,
|
||||
&self.global_media_download_client,
|
||||
)
|
||||
};
|
||||
|
||||
check_media(
|
||||
target,
|
||||
size,
|
||||
target_limitation,
|
||||
global_limitation,
|
||||
&self.media_download,
|
||||
global_bucket,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn check_media_upload(&self, target: Target, size: u64) -> Result<()> {
|
||||
let target_limitation = services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.target
|
||||
// Media can only be uploaded on the client-server API
|
||||
.client
|
||||
.media
|
||||
.upload;
|
||||
|
||||
let global_limitation = services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.global
|
||||
// Media can only be uploaded on the client-server API
|
||||
.client
|
||||
.media
|
||||
.upload;
|
||||
|
||||
check_media(
|
||||
Some(target),
|
||||
size,
|
||||
target_limitation,
|
||||
global_limitation,
|
||||
&self.media_upload,
|
||||
&self.global_media_upload,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn check_media_pre_fetch(&self, target: &Target) -> Result<()> {
|
||||
let arrival = Instant::now();
|
||||
|
||||
let global_bucket = self.global_media_fetch;
|
||||
|
||||
let map = self.media_fetch.lock().await;
|
||||
if let Some(mutex) = map.get(target) {
|
||||
let mutex = mutex.lock().await;
|
||||
|
||||
if arrival.checked_duration_since(*mutex).is_none() {
|
||||
return instant_to_err(&mutex);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_media_post_fetch(&self, target: Target, size: u64) {
|
||||
if !target.rate_limited() {
|
||||
return;
|
||||
}
|
||||
|
||||
let limitation = services()
|
||||
.globals
|
||||
.config
|
||||
.rate_limiting
|
||||
.target
|
||||
// Media can only be "fetched" (causing our server to download media from another server) by the client-server API
|
||||
.client
|
||||
.media
|
||||
.fetch;
|
||||
|
||||
let arrival = Instant::now();
|
||||
|
||||
let mut map = self.media_fetch.lock().await;
|
||||
let entry = map.entry(target);
|
||||
|
||||
match entry {
|
||||
Entry::Occupied(occupied_entry) => {
|
||||
let entry = Arc::clone(occupied_entry.get());
|
||||
|
||||
update_media_entry(size, &limitation, &arrival, entry, false).await;
|
||||
}
|
||||
Entry::Vacant(vacant_entry) => {
|
||||
vacant_entry.insert(Arc::new(Mutex::new(
|
||||
arrival
|
||||
- Duration::from_nanos(
|
||||
limitation.burst_capacity.as_u64()
|
||||
/ limitation.timeframe.bytes_per_sec(),
|
||||
),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_media_entry(
|
||||
size: u64,
|
||||
limitation: &MediaLimitation,
|
||||
arrival: &Instant,
|
||||
entry: Arc<Mutex<Instant>>,
|
||||
and_check: bool,
|
||||
) -> Result<()> {
|
||||
let mut entry = entry.lock().await;
|
||||
|
||||
//TODO: use more precise conversion than secs
|
||||
let proposed_entry = get_proposed_entry(size, limitation, arrival, &entry, and_check)?;
|
||||
|
||||
*entry = proposed_entry;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_proposed_entry(
|
||||
size: u64,
|
||||
limitation: &MediaLimitation,
|
||||
arrival: &Instant,
|
||||
entry: &MutexGuard<'_, Instant>,
|
||||
and_check: bool,
|
||||
) -> Result<Instant> {
|
||||
let min_instant = *arrival
|
||||
- Duration::from_secs(
|
||||
limitation.burst_capacity.as_u64() / limitation.timeframe.bytes_per_sec(),
|
||||
);
|
||||
|
||||
let proposed_entry =
|
||||
entry.max(min_instant) + Duration::from_secs(size / limitation.timeframe.bytes_per_sec());
|
||||
|
||||
if and_check && arrival.checked_duration_since(proposed_entry).is_none() {
|
||||
return instant_to_err(&proposed_entry).map(|_| proposed_entry);
|
||||
}
|
||||
|
||||
Ok(proposed_entry)
|
||||
}
|
||||
|
||||
async fn check_media(
|
||||
target: Option<Target>,
|
||||
size: u64,
|
||||
target_limitation: MediaLimitation,
|
||||
global_limitation: MediaLimitation,
|
||||
target_map: &MediaBucket,
|
||||
global_bucket: &GlobalMediaBucket,
|
||||
) -> Result<()> {
|
||||
if !target.as_ref().is_some_and(Target::rate_limited) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let arrival = Instant::now();
|
||||
|
||||
let mut global_bucket = global_bucket.lock().await;
|
||||
let proposed = get_proposed_entry(size, &global_limitation, &arrival, &global_bucket, true)?;
|
||||
|
||||
if let Some(target) = target {
|
||||
let mut map = target_map.lock().await;
|
||||
let entry = map.entry(target);
|
||||
|
||||
match entry {
|
||||
Entry::Occupied(occupied_entry) => {
|
||||
let entry = Arc::clone(occupied_entry.get());
|
||||
|
||||
update_media_entry(size, &target_limitation, &arrival, entry, true).await;
|
||||
}
|
||||
Entry::Vacant(vacant_entry) => {
|
||||
vacant_entry.insert(default_media_entry(target_limitation, arrival));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*global_bucket = proposed;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn default_media_entry(
|
||||
target_limitation: MediaLimitation,
|
||||
arrival: Instant,
|
||||
) -> Arc<Mutex<Instant>> {
|
||||
Arc::new(Mutex::new(
|
||||
arrival
|
||||
- Duration::from_nanos(
|
||||
target_limitation.burst_capacity.as_u64()
|
||||
/ target_limitation.timeframe.bytes_per_sec(),
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn instant_to_err(instant: &Instant) -> Result<()> {
|
||||
let now = Instant::now();
|
||||
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::LimitExceeded {
|
||||
// Not using ::DateTime because conversion from Instant to SystemTime is convoluted
|
||||
retry_after: Some(RetryAfter::Delay(instant.duration_since(now))),
|
||||
},
|
||||
"Rate limit exceeded",
|
||||
))
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue