diff --git a/Cargo.lock b/Cargo.lock index 8453335a..abfbda74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2441,6 +2441,7 @@ version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ + "indexmap 2.2.5", "itoa", "ryu", "serde", diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 3ffd3a5e..37cf28af 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -14,6 +14,7 @@ channel = "1.78.0" components = [ # For rust-analyzer "rust-src", + "rust-analyzer", ] targets = [ "aarch64-unknown-linux-musl", diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index 8e193e6b..c4e4c732 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -1,7 +1,8 @@ -use crate::{services, Error, Result, Ruma}; +use crate::{services, utils::filter, Error, Result, Ruma}; use ruma::{ api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, events::StateEventType, + serde::Raw, }; use std::collections::HashSet; use tracing::error; @@ -78,7 +79,6 @@ pub async fn get_context_route( .rooms .timeline .pdus_until(sender_user, &room_id, base_token)? - .take(limit / 2) .filter_map(|r| r.ok()) // Remove buggy events .filter(|(_, pdu)| { services() @@ -109,13 +109,33 @@ pub async fn get_context_route( let events_before: Vec<_> = events_before .into_iter() .map(|(_, pdu)| pdu.to_room_event()) + .filter(|v| { + filter::senders( + v, + body.filter.senders.as_ref(), + body.filter.not_senders.as_ref(), + ) + }) + .filter(|v| { + filter::types( + v, + body.filter.types.as_ref(), + body.filter.not_types.as_ref(), + ) + }) + .filter(|v| { + body.filter + .url_filter + .map(|f| filter::url(v, &room_id, f)) + .unwrap_or(true) + }) + .take(limit / 2) .collect(); let events_after: Vec<_> = services() .rooms .timeline .pdus_after(sender_user, &room_id, base_token)? - .take(limit / 2) .filter_map(|r| r.ok()) // Remove buggy events .filter(|(_, pdu)| { services() @@ -165,6 +185,27 @@ pub async fn get_context_route( let events_after: Vec<_> = events_after .into_iter() .map(|(_, pdu)| pdu.to_room_event()) + .filter(|v| { + filter::senders( + v, + body.filter.senders.as_ref(), + body.filter.not_senders.as_ref(), + ) + }) + .filter(|v| { + filter::types( + v, + body.filter.types.as_ref(), + body.filter.not_types.as_ref(), + ) + }) + .filter(|v| { + body.filter + .url_filter + .map(|f| filter::url(v, &room_id, f)) + .unwrap_or(true) + }) + .take(limit / 2) .collect(); let mut state = Vec::new(); @@ -175,24 +216,34 @@ pub async fn get_context_route( .short .get_statekey_from_short(shortstatekey)?; - if event_type != StateEventType::RoomMember { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, + if !filter::types( + &Raw::new(&serde_json::json!({"type": event_type})).expect("json can be serialized"), + body.filter.types.as_ref(), + body.filter.not_types.as_ref(), + ) { + continue; + } + + if event_type != StateEventType::RoomMember + || (!lazy_load_enabled || lazy_loaded.contains(&state_key)) + { + let event = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu.to_state_event(), None => { error!("Pdu in state not found: {}", id); continue; } }; - state.push(pdu.to_state_event()); - } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; - state.push(pdu.to_state_event()); + + if !filter::senders( + &event, + body.filter.senders.as_ref(), + body.filter.not_senders.as_ref(), + ) { + continue; + } + + state.push(event); } } diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 8c95a3ef..59026a70 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -7,7 +7,7 @@ use crate::{ use ruma::{ api::client::{ - filter::{EventFormat, FilterDefinition, LazyLoadOptions, RoomFilter, UrlFilter}, + filter::{EventFormat, FilterDefinition, LazyLoadOptions, RoomFilter}, sync::sync_events::{ self, v3::{ @@ -1111,9 +1111,12 @@ async fn load_joined_room( .expect("json can be serialized"), _ => pdu.to_sync_room_event(), }) - .filter(|v| match filter.timeline.url_filter.unwrap_or(true) { - UrlFilter::EventsWithUrl => todo!(), - UrlFilter::EventsWithoutUrl => todo!(), + .filter(|v| { + filter + .timeline + .url_filter + .map(|f| filter::url(v, room_id, f)) + .unwrap_or(true) }) .filter(|v| { filter::senders( diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index c2bdfa8c..1c347d76 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -34,7 +34,7 @@ impl service::rooms::search::Data for KeyValueDatabase { key.push(0xff); key.extend_from_slice(pdu_id); - self.urltokenids.insert(&key, <&[u8]>::default())?; + self.urltokenids.insert(&key, Default::default())?; } Ok(()) @@ -85,15 +85,17 @@ impl service::rooms::search::Data for KeyValueDatabase { Ok(Some((Box::new(common_elements), words))) } - fn contains_url<'a>(&'a self, room_id: &RoomId, pdu_id: &[u8]) -> Result { + fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result { let prefix = services() .rooms .short .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); + .expect("room exists"); - todo!() + let mut key = prefix.to_be_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(pdu_id); + + self.urltokenids.get(&key).map(|v| v.is_some()) } } diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 7d046fa5..0ec480b1 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -11,5 +11,5 @@ pub trait Data: Send + Sync { search_string: &str, ) -> Result> + 'a>, Vec)>>; - fn contains_url<'a>(&'a self, room_id: &RoomId, pdu_id: &[u8]) -> Result; + fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result; } diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index b6f35e79..39a97e43 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -23,4 +23,9 @@ impl Service { ) -> Result> + 'a, Vec)>> { self.db.search_pdus(room_id, search_string) } + + #[tracing::instrument(skip(self))] + pub fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result { + self.db.contains_url(room_id, pdu_id) + } } diff --git a/src/utils/filter.rs b/src/utils/filter.rs index 22740c80..64350aff 100644 --- a/src/utils/filter.rs +++ b/src/utils/filter.rs @@ -1,6 +1,10 @@ -use ruma::{serde::Raw, OwnedRoomId, OwnedUserId}; +use ruma::{ + api::client::filter::UrlFilter, serde::Raw, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, +}; use serde::Deserialize; +use crate::services; + fn inclusion Deserialize<'a> + PartialEq>( event: &Raw, field: &str, @@ -37,3 +41,25 @@ pub fn senders( pub fn types(event: &Raw, types: Option<&Vec>, not_types: &[String]) -> bool { inclusion(event, "event_type", types, not_types) } + +pub fn url(event: &Raw, room_id: &RoomId, filter: UrlFilter) -> bool { + let Ok(Some(pdu_id)) = services().rooms.timeline.get_pdu_id( + &event + .get_field::("event_id") + .expect("event_id can be deserialized") + .expect("event should have event_id"), + ) else { + return filter == UrlFilter::EventsWithoutUrl; + }; + + let contains_url = services() + .rooms + .search + .contains_url(room_id, &pdu_id) + .unwrap_or(false); + + match filter { + UrlFilter::EventsWithUrl => contains_url, + UrlFilter::EventsWithoutUrl => !contains_url, + } +}