From b76357c80e12f08dcac380459fcbc157a6bb7de1 Mon Sep 17 00:00:00 2001 From: Steven Vergenz <1882376+stevenvergenz@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:28:51 -0700 Subject: [PATCH] More flexible preview config --- src/api/client_server/media.rs | 37 ++++++++---------- src/config/mod.rs | 44 ++++++++++++++++++++-- src/config/proxy.rs | 50 +------------------------ src/config/wild_carded_domain.rs | 64 ++++++++++++++++++++++++++++++++ src/service/globals/mod.rs | 5 ++- 5 files changed, 125 insertions(+), 75 deletions(-) create mode 100644 src/config/wild_carded_domain.rs diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 0b529851..c6d3cb8a 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -3,7 +3,10 @@ use std::time::Duration; -use crate::{service::media::{FileMeta, UrlPreviewData}, services, utils, Error, Result, Ruma}; +use crate::{ + service::media::{FileMeta, UrlPreviewData}, + config::UrlPreviewPermission, + services, utils, Error, Result, Ruma}; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE}; use ruma::{ api::{ @@ -11,7 +14,7 @@ use ruma::{ authenticated_media::{ get_content, get_content_as_filename, get_content_thumbnail, get_media_config, }, - error::{ErrorKind, RetryAfter}, + error::ErrorKind, media::{ self, create_content, get_media_preview, }, @@ -174,6 +177,7 @@ fn url_request_allowed(addr: &IpAddr) -> bool { } } +/// Generate URL preview data from the given URL async fn request_url_preview(url: &Url) -> Result { // resolve host to IP to ensure it's not a local IP (host guaranteed to not be None) let dns_resolver = services().globals.dns_resolver(); @@ -219,6 +223,7 @@ async fn request_url_preview(url: &Url) -> Result { Ok(data) } +/// Retrieve URL preview data from database if available, or generate it async fn get_url_preview(url: &Url) -> Result { if let Some(preview) = services().media.get_url_preview(url.as_str()).await { return Ok(preview); @@ -246,25 +251,15 @@ async fn get_url_preview(url: &Url) -> Result { fn url_preview_allowed(url: &Url) -> bool { // host's existence is already verified in get_media_preview_route, unwrap is safe let host = url.host_str().unwrap().to_lowercase(); - let host_parts_iter = host - .char_indices() - .filter_map(|(i, c)| { - if i == 0 { - Some(host.as_str()) - } - else if c == '.' { - Some(&host[i+1..]) - } - else { - None - } - }) - .rev().skip(1); // don't match TLDs - - let ret = ["*"].into_iter().chain(host_parts_iter).any(|nld| { - services().globals.url_preview_allowlist().any(|a| a == nld) - }); - ret // temp variable to avoid returning from the closure + let preview_config = services().globals.url_previews(); + match preview_config.default { + UrlPreviewPermission::Forbid => { + preview_config.exceptions.iter().any(|ex| ex.matches(&host)) + }, + UrlPreviewPermission::Allow => { + !preview_config.exceptions.iter().any(|ex| ex.matches(&host)) + }, + } } /// # `GET /_matrix/media/r0/preview_url` diff --git a/src/config/mod.rs b/src/config/mod.rs index b8eef648..29d8bc1f 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -4,12 +4,14 @@ use std::{ net::{IpAddr, Ipv4Addr}, }; +use wild_carded_domain::WildCardedDomain; use ruma::{OwnedServerName, RoomVersionId}; use serde::{de::IgnoredAny, Deserialize}; use tracing::warn; use url::Url; mod proxy; +mod wild_carded_domain; use self::proxy::ProxyConfig; @@ -85,8 +87,8 @@ pub struct Config { pub emergency_password: Option, - #[serde(default = "Vec::new")] - pub url_preview_allowlist: Vec, + #[serde(default)] + pub url_previews: UrlPreviewConfig, #[serde(flatten)] pub catchall: BTreeMap, @@ -104,6 +106,35 @@ pub struct WellKnownConfig { pub server: Option, } +#[derive(Clone, Debug, Deserialize, Default)] +pub struct UrlPreviewConfig { + pub default: UrlPreviewPermission, + pub exceptions: Vec, +} + +#[derive(Clone, Debug, Deserialize, Default)] +pub enum UrlPreviewPermission { + Allow, + #[default] + Forbid, +} +impl UrlPreviewPermission { + pub fn invert(&self) -> Self { + match self { + UrlPreviewPermission::Allow => UrlPreviewPermission::Forbid, + UrlPreviewPermission::Forbid => UrlPreviewPermission::Allow, + } + } +} +impl fmt::Display for UrlPreviewPermission { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UrlPreviewPermission::Allow => write!(f, "ALLOW"), + UrlPreviewPermission::Forbid => write!(f, "FORBID"), + } + } +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { @@ -235,7 +266,14 @@ impl fmt::Display for Config { }), ("Well-known server name", well_known_server.as_str()), ("Well-known client URL", &self.well_known_client()), - ("URL preview allowlist", &self.url_preview_allowlist.join(", ")), + ("URL preview", { + let mut lst = vec![]; + for exc in &self.url_previews.exceptions { + lst.push(format!("{} {}", self.url_previews.default.invert(), exc)); + } + lst.push(format!("{} {}", self.url_previews.default, "*")); + &lst.join(", ") + }), ]; let mut msg: String = "Active config values:\n\n".to_owned(); diff --git a/src/config/proxy.rs b/src/config/proxy.rs index c03463e7..05762e40 100644 --- a/src/config/proxy.rs +++ b/src/config/proxy.rs @@ -2,6 +2,7 @@ use reqwest::{Proxy, Url}; use serde::Deserialize; use crate::Result; +use super::wild_carded_domain::WildCardedDomain; /// ## Examples: /// - No proxy (default): @@ -92,52 +93,3 @@ impl PartialProxyConfig { } } } - -/// A domain name, that optionally allows a * as its first subdomain. -#[derive(Clone, Debug)] -pub enum WildCardedDomain { - WildCard, - WildCarded(String), - Exact(String), -} -impl WildCardedDomain { - pub fn matches(&self, domain: &str) -> bool { - match self { - WildCardedDomain::WildCard => true, - WildCardedDomain::WildCarded(d) => domain.ends_with(d), - WildCardedDomain::Exact(d) => domain == d, - } - } - pub fn more_specific_than(&self, other: &Self) -> bool { - match (self, other) { - (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, - (_, WildCardedDomain::WildCard) => true, - (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), - (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { - a != b && a.ends_with(b) - } - _ => false, - } - } -} -impl std::str::FromStr for WildCardedDomain { - type Err = std::convert::Infallible; - fn from_str(s: &str) -> Result { - // maybe do some domain validation? - Ok(if s.starts_with("*.") { - WildCardedDomain::WildCarded(s[1..].to_owned()) - } else if s == "*" { - WildCardedDomain::WildCarded("".to_owned()) - } else { - WildCardedDomain::Exact(s.to_owned()) - }) - } -} -impl<'de> Deserialize<'de> for WildCardedDomain { - fn deserialize(deserializer: D) -> Result - where - D: serde::de::Deserializer<'de>, - { - crate::utils::deserialize_from_str(deserializer) - } -} diff --git a/src/config/wild_carded_domain.rs b/src/config/wild_carded_domain.rs new file mode 100644 index 00000000..9452f230 --- /dev/null +++ b/src/config/wild_carded_domain.rs @@ -0,0 +1,64 @@ +use serde::Deserialize; +use std::fmt; + +/// A domain name, that optionally allows a * as its first subdomain. +#[derive(Clone, Debug)] +pub enum WildCardedDomain { + WildCard, + WildCarded(String), + Exact(String), +} + +impl WildCardedDomain { + pub fn matches(&self, domain: &str) -> bool { + match self { + WildCardedDomain::WildCard => true, + WildCardedDomain::WildCarded(d) => domain.ends_with(d), + WildCardedDomain::Exact(d) => domain == d, + } + } + pub fn more_specific_than(&self, other: &Self) -> bool { + match (self, other) { + (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, + (_, WildCardedDomain::WildCard) => true, + (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), + (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { + a != b && a.ends_with(b) + } + _ => false, + } + } +} + +impl std::str::FromStr for WildCardedDomain { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> Result { + // maybe do some domain validation? + Ok(if s.starts_with("*.") { + WildCardedDomain::WildCarded(s[1..].to_lowercase()) + } else if s == "*" { + WildCardedDomain::WildCarded("".to_lowercase()) + } else { + WildCardedDomain::Exact(s.to_lowercase()) + }) + } +} + +impl<'de> Deserialize<'de> for WildCardedDomain { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + crate::utils::deserialize_from_str(deserializer) + } +} + +impl fmt::Display for WildCardedDomain { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + WildCardedDomain::WildCard => write!(f, "*"), + WildCardedDomain::WildCarded(d) => write!(f, "*{d}"), + WildCardedDomain::Exact(d) => write!(f, "{d}"), + } + } +} diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 98c01902..88359afa 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -7,6 +7,7 @@ use ruma::{ use crate::api::server_server::DestinationResponse; +use crate::config::UrlPreviewConfig; use crate::{services, Config, Error, Result}; use futures_util::FutureExt; use hickory_resolver::TokioAsyncResolver; @@ -324,8 +325,8 @@ impl Service { self.config.allow_federation } - pub fn url_preview_allowlist(&self) -> impl Iterator { - self.config.url_preview_allowlist.iter().map(|x| x.as_str()) + pub fn url_previews(&self) -> &UrlPreviewConfig { + &self.config.url_previews } pub fn allow_room_creation(&self) -> bool {