1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-07-02 16:38:36 +00:00

Change URL preview setting from bool to a mode, and add support for an allowlist

This commit is contained in:
Reiner Herrmann 2023-07-29 00:38:58 +02:00
parent bb4cade9fd
commit 61fd9166f6
5 changed files with 90 additions and 26 deletions

View file

@ -11,9 +11,11 @@ use ruma::api::client::{
#[cfg(feature = "url_preview")]
use {
crate::config::UrlPreviewMode,
crate::service::media::UrlPreviewData,
webpage::HTML,
std::{io::Cursor, net::IpAddr, sync::Arc, time::Duration},
reqwest::Url,
std::{io::Cursor, net::IpAddr, sync::Arc},
tokio::sync::Notify,
image::io::Reader as ImgReader,
};
@ -123,9 +125,9 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
}
#[cfg(feature = "url_preview")]
async fn request_url_preview(url: String) -> Result<UrlPreviewData> {
async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
let client = services().globals.default_client();
let response = client.head(&url).send().await?;
let response = client.head(url).send().await?;
if !response
.remote_addr()
@ -151,8 +153,8 @@ async fn request_url_preview(url: String) -> Result<UrlPreviewData> {
}
};
let data = match content_type {
html if html.starts_with("text/html") => download_html(&client, &url).await?,
img if img.starts_with("image/") => download_image(&client, &url).await?,
html if html.starts_with("text/html") => download_html(&client, url).await?,
img if img.starts_with("image/") => download_image(&client, url).await?,
_ => {
return Err(Error::BadRequest(
ErrorKind::Unknown,
@ -161,14 +163,14 @@ async fn request_url_preview(url: String) -> Result<UrlPreviewData> {
}
};
services().media.set_url_preview(&url, &data).await?;
services().media.set_url_preview(url, &data).await?;
Ok(data)
}
#[cfg(feature = "url_preview")]
async fn get_url_preview(url: String) -> Result<UrlPreviewData> {
if let Some(preview) = services().media.get_url_preview(&url).await {
async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
if let Some(preview) = services().media.get_url_preview(url).await {
return Ok(preview);
}
@ -177,7 +179,7 @@ async fn get_url_preview(url: String) -> Result<UrlPreviewData> {
.url_preview_requests
.read()
.unwrap()
.get(&url)
.get(url)
.cloned();
match notif_opt {
@ -188,15 +190,15 @@ async fn get_url_preview(url: String) -> Result<UrlPreviewData> {
.url_preview_requests
.write()
.unwrap()
.insert(url.clone(), notifier.clone());
.insert(url.to_string(), notifier.clone());
}
let data = request_url_preview(url.clone()).await;
let data = request_url_preview(url).await;
notifier.notify_waiters();
{
services().media.url_preview_requests.write().unwrap().remove(&url);
services().media.url_preview_requests.write().unwrap().remove(url);
}
data
@ -208,7 +210,7 @@ async fn get_url_preview(url: String) -> Result<UrlPreviewData> {
notifier.await;
services().media
.get_url_preview(&url)
.get_url_preview(url)
.await
.ok_or(Error::BadRequest(
ErrorKind::Unknown,
@ -218,6 +220,29 @@ async fn get_url_preview(url: String) -> Result<UrlPreviewData> {
}
}
#[cfg(feature = "url_preview")]
fn url_preview_allowed(url_str: &str) -> bool {
let url = match Url::parse(url_str) {
Ok(u) => u,
Err(_) => return false,
};
if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) {
return false;
}
match services().globals.url_preview_mode() {
UrlPreviewMode::All => true,
UrlPreviewMode::None => false,
UrlPreviewMode::Allowlist => {
match url.host_str() {
None => false,
Some(host) => {
services().globals.url_preview_allowlist().contains(&host.to_string())
}
}
}
}
}
/// # `GET /_matrix/media/r0/preview_url`
///
/// Returns URL preview.
@ -225,14 +250,15 @@ async fn get_url_preview(url: String) -> Result<UrlPreviewData> {
pub async fn get_media_preview_route(
body: Ruma<get_media_preview::v3::Request>,
) -> Result<get_media_preview::v3::Response> {
if !services().globals.allow_url_preview() {
let url = &body.url;
if !url_preview_allowed(url) {
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Previewing URL not allowed",
));
}
if let Ok(preview) = get_url_preview(body.url.clone()).await {
if let Ok(preview) = get_url_preview(url).await {
let res = serde_json::value::to_raw_value(&preview).expect("Converting to JSON failed");
return Ok(get_media_preview::v3::Response::from_raw_value(res));
}