1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00

More flexible preview config

This commit is contained in:
Steven Vergenz 2024-10-31 10:28:51 -07:00
parent 6789ed336e
commit b76357c80e
5 changed files with 125 additions and 75 deletions

View file

@ -3,7 +3,10 @@
use std::time::Duration; use std::time::Duration;
use crate::{service::media::{FileMeta, UrlPreviewData}, services, utils, Error, Result, Ruma}; use crate::{
service::media::{FileMeta, UrlPreviewData},
config::UrlPreviewPermission,
services, utils, Error, Result, Ruma};
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE}; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE};
use ruma::{ use ruma::{
api::{ api::{
@ -11,7 +14,7 @@ use ruma::{
authenticated_media::{ authenticated_media::{
get_content, get_content_as_filename, get_content_thumbnail, get_media_config, get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
}, },
error::{ErrorKind, RetryAfter}, error::ErrorKind,
media::{ media::{
self, create_content, get_media_preview, self, create_content, get_media_preview,
}, },
@ -174,6 +177,7 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
} }
} }
/// Generate URL preview data from the given URL
async fn request_url_preview(url: &Url) -> Result<UrlPreviewData> { async fn request_url_preview(url: &Url) -> Result<UrlPreviewData> {
// resolve host to IP to ensure it's not a local IP (host guaranteed to not be None) // resolve host to IP to ensure it's not a local IP (host guaranteed to not be None)
let dns_resolver = services().globals.dns_resolver(); let dns_resolver = services().globals.dns_resolver();
@ -219,6 +223,7 @@ async fn request_url_preview(url: &Url) -> Result<UrlPreviewData> {
Ok(data) Ok(data)
} }
/// Retrieve URL preview data from database if available, or generate it
async fn get_url_preview(url: &Url) -> Result<UrlPreviewData> { async fn get_url_preview(url: &Url) -> Result<UrlPreviewData> {
if let Some(preview) = services().media.get_url_preview(url.as_str()).await { if let Some(preview) = services().media.get_url_preview(url.as_str()).await {
return Ok(preview); return Ok(preview);
@ -246,25 +251,15 @@ async fn get_url_preview(url: &Url) -> Result<UrlPreviewData> {
fn url_preview_allowed(url: &Url) -> bool { fn url_preview_allowed(url: &Url) -> bool {
// host's existence is already verified in get_media_preview_route, unwrap is safe // host's existence is already verified in get_media_preview_route, unwrap is safe
let host = url.host_str().unwrap().to_lowercase(); let host = url.host_str().unwrap().to_lowercase();
let host_parts_iter = host let preview_config = services().globals.url_previews();
.char_indices() match preview_config.default {
.filter_map(|(i, c)| { UrlPreviewPermission::Forbid => {
if i == 0 { preview_config.exceptions.iter().any(|ex| ex.matches(&host))
Some(host.as_str()) },
} UrlPreviewPermission::Allow => {
else if c == '.' { !preview_config.exceptions.iter().any(|ex| ex.matches(&host))
Some(&host[i+1..]) },
} }
else {
None
}
})
.rev().skip(1); // don't match TLDs
let ret = ["*"].into_iter().chain(host_parts_iter).any(|nld| {
services().globals.url_preview_allowlist().any(|a| a == nld)
});
ret // temp variable to avoid returning from the closure
} }
/// # `GET /_matrix/media/r0/preview_url` /// # `GET /_matrix/media/r0/preview_url`

View file

@ -4,12 +4,14 @@ use std::{
net::{IpAddr, Ipv4Addr}, net::{IpAddr, Ipv4Addr},
}; };
use wild_carded_domain::WildCardedDomain;
use ruma::{OwnedServerName, RoomVersionId}; use ruma::{OwnedServerName, RoomVersionId};
use serde::{de::IgnoredAny, Deserialize}; use serde::{de::IgnoredAny, Deserialize};
use tracing::warn; use tracing::warn;
use url::Url; use url::Url;
mod proxy; mod proxy;
mod wild_carded_domain;
use self::proxy::ProxyConfig; use self::proxy::ProxyConfig;
@ -85,8 +87,8 @@ pub struct Config {
pub emergency_password: Option<String>, pub emergency_password: Option<String>,
#[serde(default = "Vec::new")] #[serde(default)]
pub url_preview_allowlist: Vec<String>, pub url_previews: UrlPreviewConfig,
#[serde(flatten)] #[serde(flatten)]
pub catchall: BTreeMap<String, IgnoredAny>, pub catchall: BTreeMap<String, IgnoredAny>,
@ -104,6 +106,35 @@ pub struct WellKnownConfig {
pub server: Option<OwnedServerName>, pub server: Option<OwnedServerName>,
} }
#[derive(Clone, Debug, Deserialize, Default)]
pub struct UrlPreviewConfig {
pub default: UrlPreviewPermission,
pub exceptions: Vec<WildCardedDomain>,
}
#[derive(Clone, Debug, Deserialize, Default)]
pub enum UrlPreviewPermission {
Allow,
#[default]
Forbid,
}
impl UrlPreviewPermission {
pub fn invert(&self) -> Self {
match self {
UrlPreviewPermission::Allow => UrlPreviewPermission::Forbid,
UrlPreviewPermission::Forbid => UrlPreviewPermission::Allow,
}
}
}
impl fmt::Display for UrlPreviewPermission {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UrlPreviewPermission::Allow => write!(f, "ALLOW"),
UrlPreviewPermission::Forbid => write!(f, "FORBID"),
}
}
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config { impl Config {
@ -235,7 +266,14 @@ impl fmt::Display for Config {
}), }),
("Well-known server name", well_known_server.as_str()), ("Well-known server name", well_known_server.as_str()),
("Well-known client URL", &self.well_known_client()), ("Well-known client URL", &self.well_known_client()),
("URL preview allowlist", &self.url_preview_allowlist.join(", ")), ("URL preview", {
let mut lst = vec![];
for exc in &self.url_previews.exceptions {
lst.push(format!("{} {}", self.url_previews.default.invert(), exc));
}
lst.push(format!("{} {}", self.url_previews.default, "*"));
&lst.join(", ")
}),
]; ];
let mut msg: String = "Active config values:\n\n".to_owned(); let mut msg: String = "Active config values:\n\n".to_owned();

View file

@ -2,6 +2,7 @@ use reqwest::{Proxy, Url};
use serde::Deserialize; use serde::Deserialize;
use crate::Result; use crate::Result;
use super::wild_carded_domain::WildCardedDomain;
/// ## Examples: /// ## Examples:
/// - No proxy (default): /// - No proxy (default):
@ -92,52 +93,3 @@ impl PartialProxyConfig {
} }
} }
} }
/// A domain name, that optionally allows a * as its first subdomain.
#[derive(Clone, Debug)]
pub enum WildCardedDomain {
WildCard,
WildCarded(String),
Exact(String),
}
impl WildCardedDomain {
pub fn matches(&self, domain: &str) -> bool {
match self {
WildCardedDomain::WildCard => true,
WildCardedDomain::WildCarded(d) => domain.ends_with(d),
WildCardedDomain::Exact(d) => domain == d,
}
}
pub fn more_specific_than(&self, other: &Self) -> bool {
match (self, other) {
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
(_, WildCardedDomain::WildCard) => true,
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
a != b && a.ends_with(b)
}
_ => false,
}
}
}
impl std::str::FromStr for WildCardedDomain {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// maybe do some domain validation?
Ok(if s.starts_with("*.") {
WildCardedDomain::WildCarded(s[1..].to_owned())
} else if s == "*" {
WildCardedDomain::WildCarded("".to_owned())
} else {
WildCardedDomain::Exact(s.to_owned())
})
}
}
impl<'de> Deserialize<'de> for WildCardedDomain {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
crate::utils::deserialize_from_str(deserializer)
}
}

View file

@ -0,0 +1,64 @@
use serde::Deserialize;
use std::fmt;
/// A domain name, that optionally allows a * as its first subdomain.
#[derive(Clone, Debug)]
pub enum WildCardedDomain {
WildCard,
WildCarded(String),
Exact(String),
}
impl WildCardedDomain {
pub fn matches(&self, domain: &str) -> bool {
match self {
WildCardedDomain::WildCard => true,
WildCardedDomain::WildCarded(d) => domain.ends_with(d),
WildCardedDomain::Exact(d) => domain == d,
}
}
pub fn more_specific_than(&self, other: &Self) -> bool {
match (self, other) {
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
(_, WildCardedDomain::WildCard) => true,
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
a != b && a.ends_with(b)
}
_ => false,
}
}
}
impl std::str::FromStr for WildCardedDomain {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// maybe do some domain validation?
Ok(if s.starts_with("*.") {
WildCardedDomain::WildCarded(s[1..].to_lowercase())
} else if s == "*" {
WildCardedDomain::WildCarded("".to_lowercase())
} else {
WildCardedDomain::Exact(s.to_lowercase())
})
}
}
impl<'de> Deserialize<'de> for WildCardedDomain {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
crate::utils::deserialize_from_str(deserializer)
}
}
impl fmt::Display for WildCardedDomain {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WildCardedDomain::WildCard => write!(f, "*"),
WildCardedDomain::WildCarded(d) => write!(f, "*{d}"),
WildCardedDomain::Exact(d) => write!(f, "{d}"),
}
}
}

View file

@ -7,6 +7,7 @@ use ruma::{
use crate::api::server_server::DestinationResponse; use crate::api::server_server::DestinationResponse;
use crate::config::UrlPreviewConfig;
use crate::{services, Config, Error, Result}; use crate::{services, Config, Error, Result};
use futures_util::FutureExt; use futures_util::FutureExt;
use hickory_resolver::TokioAsyncResolver; use hickory_resolver::TokioAsyncResolver;
@ -324,8 +325,8 @@ impl Service {
self.config.allow_federation self.config.allow_federation
} }
pub fn url_preview_allowlist(&self) -> impl Iterator<Item=&str> { pub fn url_previews(&self) -> &UrlPreviewConfig {
self.config.url_preview_allowlist.iter().map(|x| x.as_str()) &self.config.url_previews
} }
pub fn allow_room_creation(&self) -> bool { pub fn allow_room_creation(&self) -> bool {