From 839498ada74bacba73347c109321531874eb488f Mon Sep 17 00:00:00 2001 From: Steven Vergenz <1882376+stevenvergenz@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:48:38 -0700 Subject: [PATCH] Parse URL only once --- src/api/client_server/media.rs | 74 ++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 77db7f83..fcbbdd15 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -174,20 +174,22 @@ fn url_request_allowed(addr: &IpAddr) -> bool { } } -async fn request_url_preview(url: &str) -> Result { - let client = services().globals.default_client(); - let response = client.head(url).send().await?; - - if !response - .remote_addr() - .map_or(false, |a| url_request_allowed(&a.ip())) - { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Requesting from this address forbidden", - )); +async fn request_url_preview(url: &Url) -> Result { + // resolve host to IP to ensure it's not a local IP (host guaranteed to not be None) + let dns_resolver = services().globals.dns_resolver(); + match dns_resolver.lookup_ip(url.host_str().unwrap()).await { + Err(_) => { + return Err(Error::BadServerResponse("Failed to resolve media preview host")); + }, + Ok(lookup) if lookup.iter().any(|ip| !url_request_allowed(&ip)) => { + return Err(Error::BadRequest(ErrorKind::Unknown, "Requesting from this address forbidden")); + }, + Ok(_) => { }, } + let client = services().globals.default_client(); + let response = client.head(url.as_str()).send().await?; + let content_type = match response .headers() .get(reqwest::header::CONTENT_TYPE) @@ -202,8 +204,8 @@ async fn request_url_preview(url: &str) -> Result { } }; let data = match content_type { - html if html.starts_with("text/html") => download_html(&client, url).await?, - img if img.starts_with("image/") => download_image(&client, url).await?, + html if html.starts_with("text/html") => download_html(&client, url.as_str()).await?, + img if img.starts_with("image/") => download_image(&client, url.as_str()).await?, _ => { return Err(Error::BadRequest( ErrorKind::Unknown, @@ -212,13 +214,13 @@ async fn request_url_preview(url: &str) -> Result { } }; - services().media.set_url_preview(url, &data).await?; + services().media.set_url_preview(url.as_str(), &data).await?; Ok(data) } -async fn get_url_preview(url: &str) -> Result { - if let Some(preview) = services().media.get_url_preview(url).await { +async fn get_url_preview(url: &Url) -> Result { + if let Some(preview) = services().media.get_url_preview(url.as_str()).await { return Ok(preview); } @@ -229,18 +231,18 @@ async fn get_url_preview(url: &str) -> Result { .url_preview_mutex .write() .unwrap() - .entry(url.to_owned()) + .entry(url.as_str().to_owned()) .or_default(), ); let _request_lock = mutex_request.lock().await; - match services().media.get_url_preview(url).await { + match services().media.get_url_preview(url.as_str()).await { Some(preview) => Ok(preview), None => request_url_preview(url).await } } -fn url_preview_allowed(url_str: &str) -> bool { +fn url_preview_allowed(url: &Url) -> bool { const DEFAULT_ALLOWLIST: &[&str] = &[ "matrix.org", "mastodon.social", @@ -248,13 +250,6 @@ fn url_preview_allowed(url_str: &str) -> bool { "wikipedia.org", ]; - let url = match Url::parse(url_str) { - Ok(u) => u, - Err(_) => return false, - }; - if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) { - return false; - } let mut host = match url.host_str() { None => return false, Some(h) => h.to_lowercase(), @@ -286,15 +281,34 @@ fn url_preview_allowed(url_str: &str) -> bool { pub async fn get_media_preview_route( body: Ruma, ) -> Result { - let url = &body.url; - if !url_preview_allowed(url) { + let url = match Url::parse(&body.url) { + Err(_) => { + return Err(Error::BadRequest( + ErrorKind::Unknown, + "Not a valid URL", + )); + }, + Ok(u) + if u.scheme() != "http" + && u.scheme() != "https" + || u.host().is_none() + => { + return Err(Error::BadRequest( + ErrorKind::Unknown, + "Not a valid HTTP URL", + )); + }, + Ok(url) => url, + }; + + if !url_preview_allowed(&url) { return Err(Error::BadRequest( ErrorKind::Unknown, "Previewing URL not allowed", )); } - if let Ok(preview) = get_url_preview(url).await { + if let Ok(preview) = get_url_preview(&url).await { let res = serde_json::value::to_raw_value(&preview).expect("Converting to JSON failed"); return Ok(get_media_preview::v3::Response::from_raw_value(res)); }