diff --git a/Cargo.lock b/Cargo.lock index 20013bd5..39b50aa8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2687,6 +2687,7 @@ version = "1.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" dependencies = [ + "indexmap 2.2.5", "itoa", "ryu", "serde", diff --git a/Cargo.toml b/Cargo.toml index 66f6adbc..d5ce3e71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,7 @@ http = "1" # Used to find data directory for default db path directories = "5" # Used for ruma wrapper -serde_json = { version = "1.0.96", features = ["raw_value"] } +serde_json = { version = "1.0.96", features = ["raw_value", "preserve_order"] } # Used for appservice registration files serde_yaml = "0.9.21" # Used for pdu definition diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 3ffd3a5e..37cf28af 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -14,6 +14,7 @@ channel = "1.78.0" components = [ # For rust-analyzer "rust-src", + "rust-analyzer", ] targets = [ "aarch64-unknown-linux-musl", diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index a5edb5eb..c2c94111 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -1,7 +1,8 @@ -use crate::{services, Error, Result, Ruma}; +use crate::{services, utils::filter, Error, Result, Ruma}; use ruma::{ api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, events::StateEventType, + serde::Raw, }; use std::collections::HashSet; use tracing::error; @@ -78,7 +79,6 @@ pub async fn get_context_route( .rooms .timeline .pdus_until(sender_user, &room_id, base_token)? - .take(limit / 2) .filter_map(|r| r.ok()) // Remove buggy events .filter(|(_, pdu)| { services() @@ -109,13 +109,33 @@ pub async fn get_context_route( let events_before: Vec<_> = events_before .into_iter() .map(|(_, pdu)| pdu.to_room_event()) + .filter(|v| { + filter::senders( + v, + body.filter.senders.as_ref(), + body.filter.not_senders.as_ref(), + ) + }) + .filter(|v| { + filter::types( + v, + body.filter.types.as_ref(), + body.filter.not_types.as_ref(), + ) + }) + .filter(|v| { + body.filter + .url_filter + .map(|f| filter::url(v, &room_id, f)) + .unwrap_or(true) + }) + .take(limit / 2) .collect(); let events_after: Vec<_> = services() .rooms .timeline .pdus_after(sender_user, &room_id, base_token)? - .take(limit / 2) .filter_map(|r| r.ok()) // Remove buggy events .filter(|(_, pdu)| { services() @@ -165,6 +185,27 @@ pub async fn get_context_route( let events_after: Vec<_> = events_after .into_iter() .map(|(_, pdu)| pdu.to_room_event()) + .filter(|v| { + filter::senders( + v, + body.filter.senders.as_ref(), + body.filter.not_senders.as_ref(), + ) + }) + .filter(|v| { + filter::types( + v, + body.filter.types.as_ref(), + body.filter.not_types.as_ref(), + ) + }) + .filter(|v| { + body.filter + .url_filter + .map(|f| filter::url(v, &room_id, f)) + .unwrap_or(true) + }) + .take(limit / 2) .collect(); let mut state = Vec::new(); @@ -175,24 +216,34 @@ pub async fn get_context_route( .short .get_statekey_from_short(shortstatekey)?; - if event_type != StateEventType::RoomMember { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, + if !filter::types( + &Raw::new(&serde_json::json!({"type": event_type})).expect("json can be serialized"), + body.filter.types.as_ref(), + body.filter.not_types.as_ref(), + ) { + continue; + } + + if event_type != StateEventType::RoomMember + || (!lazy_load_enabled || lazy_loaded.contains(&state_key)) + { + let event = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu.to_state_event(), None => { error!("Pdu in state not found: {}", id); continue; } }; - state.push(pdu.to_state_event()); - } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; - state.push(pdu.to_state_event()); + + if !filter::senders( + &event, + body.filter.senders.as_ref(), + body.filter.not_senders.as_ref(), + ) { + continue; + } + + state.push(event); } } diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 57ceec3b..564f50a9 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -1,11 +1,13 @@ use crate::{ service::{pdu::EventHash, rooms::timeline::PduCount}, - services, utils, Error, PduEvent, Result, Ruma, RumaResponse, + services, + utils::{self, filter}, + Error, PduEvent, Result, Ruma, RumaResponse, }; use ruma::{ api::client::{ - filter::{FilterDefinition, LazyLoadOptions}, + filter::{EventFormat, FilterDefinition, LazyLoadOptions, RoomFilter}, sync::sync_events::{ self, v3::{ @@ -194,6 +196,8 @@ async fn sync_helper( .unwrap_or_default(), }; + let event_fields = filter.event_fields.as_ref(); + let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { LazyLoadOptions::Enabled { include_redundant_members: redundant, @@ -231,6 +235,15 @@ async fn sync_helper( .collect::>(); for room_id in all_joined_rooms { let room_id = room_id?; + + if filter::rooms( + &Raw::new(&serde_json::json!({"room_id": room_id})).expect("json can be serialized"), + filter.room.rooms.as_ref(), + filter.room.not_rooms.as_ref(), + ) { + continue; + } + if let Ok(joined_room) = load_joined_room( &sender_user, &sender_device, @@ -241,6 +254,9 @@ async fn sync_helper( next_batchcount, lazy_load_enabled, lazy_load_send_redundant, + &filter.room, + event_fields, + &filter.event_format, full_state, &mut device_list_updates, &mut left_encrypted_users, @@ -289,11 +305,14 @@ async fn sync_helper( } let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = services() - .rooms - .state_cache - .rooms_left(&sender_user) - .collect(); + let all_left_rooms = match filter.room.include_leave { + false => Vec::with_capacity(0), + true => services() + .rooms + .state_cache + .rooms_left(&sender_user) + .collect(), + }; for result in all_left_rooms { let (room_id, _) = result?; @@ -541,6 +560,7 @@ async fn sync_helper( knock: BTreeMap::new(), // TODO }, presence: Presence { + // HashMap events: presence_updates .into_values() .map(|v| Raw::new(&v).expect("PresenceEvent always serializes successfully")) @@ -549,6 +569,7 @@ async fn sync_helper( account_data: GlobalAccountData { events: services() .account_data + // HashMap> .changes_since(None, &sender_user, since)? .into_iter() .filter_map(|(_, v)| { @@ -606,6 +627,9 @@ async fn load_joined_room( next_batchcount: PduCount, lazy_load_enabled: bool, lazy_load_send_redundant: bool, + filter: &RoomFilter, + _event_fields: Option<&Vec>, + event_format: &EventFormat, full_state: bool, device_list_updates: &mut HashSet, left_encrypted_users: &mut HashSet, @@ -1081,7 +1105,33 @@ async fn load_joined_room( let room_events: Vec<_> = timeline_pdus .iter() - .map(|(_, pdu)| pdu.to_sync_room_event()) + .map(|(_, pdu)| match event_format { + EventFormat::Federation => Raw::new(pdu) + .map(Raw::cast) + .expect("json can be serialized"), + _ => pdu.to_sync_room_event(), + }) + .filter(|v| { + filter + .timeline + .url_filter + .map(|f| filter::url(v, room_id, f)) + .unwrap_or(true) + }) + .filter(|v| { + filter::senders( + v, + filter.timeline.senders.as_ref(), + filter.timeline.not_senders.as_ref(), + ) + }) + .filter(|v| { + filter::types( + v, + filter.timeline.types.as_ref(), + filter.timeline.not_types.as_ref(), + ) + }) .collect(); let mut edus: Vec<_> = services() @@ -1090,7 +1140,23 @@ async fn load_joined_room( .read_receipt .readreceipts_since(room_id, since) .filter_map(|r| r.ok()) // Filter out buggy events + // TODO .map(|(_, _, v)| v) + .filter(|v| { + filter::senders( + v, + filter.ephemeral.senders.as_ref(), + filter.ephemeral.not_senders.as_ref(), + ) + }) + .filter(|v| { + filter::types( + v, + filter.ephemeral.types.as_ref(), + filter.ephemeral.not_types.as_ref(), + ) + }) + .take(filter.ephemeral.limit.map_or(10, u64::from).min(100) as usize) .collect(); if services() @@ -1123,11 +1189,15 @@ async fn load_joined_room( .account_data .changes_since(Some(room_id), sender_user, since)? .into_iter() - .filter_map(|(_, v)| { - serde_json::from_str(v.json().get()) - .map_err(|_| Error::bad_database("Invalid account event in database.")) - .ok() + .filter_map(|(_, v)| serde_json::from_str(v.json().get()).ok()) + .filter(|v| { + filter::types( + v, + filter.account_data.types.as_ref(), + filter.account_data.not_types.as_ref(), + ) }) + .take(filter.account_data.limit.map_or(10, u64::from).min(100) as usize) .collect(), }, summary: RoomSummary { @@ -1142,12 +1212,33 @@ async fn load_joined_room( timeline: Timeline { limited: limited || joined_since_last_sync, prev_batch, + // Vec> events: room_events, }, state: State { events: state_events .iter() - .map(|pdu| pdu.to_sync_state_event()) + .map(|pdu| match event_format { + EventFormat::Federation => Raw::new(pdu) + .map(Raw::cast) + .expect("json can be serialized"), + _ => pdu.to_sync_state_event(), + }) + .filter(|v| { + filter::senders( + v, + filter.state.senders.as_ref(), + filter.state.not_senders.as_ref(), + ) + }) + .filter(|v| { + filter::types( + v, + filter.state.types.as_ref(), + filter.state.not_types.as_ref(), + ) + }) + .take(filter.state.limit.map_or(10, u64::from).min(100) as usize) .collect(), }, ephemeral: Ephemeral { events: edus }, diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index 970b36b5..b6d1fc26 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use ruma::{ api::client::error::ErrorKind, - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + events::{AnyRoomAccountDataEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; @@ -101,7 +101,7 @@ impl service::account_data::Data for KeyValueDatabase { room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result>> { + ) -> Result>> { let mut userdata = HashMap::new(); let mut prefix = room_id @@ -129,7 +129,7 @@ impl service::account_data::Data for KeyValueDatabase { )?) .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?, ), - serde_json::from_slice::>(&v).map_err(|_| { + serde_json::from_slice::>(&v).map_err(|_| { Error::bad_database("Database contains invalid account data.") })?, )) diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index ad573f06..1c347d76 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -1,13 +1,23 @@ +use std::str::FromStr; + use ruma::RoomId; +use url::Url; use crate::{database::KeyValueDatabase, service, services, utils, Result}; impl service::rooms::search::Data for KeyValueDatabase { fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - let mut batch = message_body + let mut contains_url = false; + + let mut token_batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) + .filter(|word| { + contains_url = + contains_url || (word.starts_with("http") && Url::from_str(word).is_ok()); + + word.len() <= 50 + }) .map(str::to_lowercase) .map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); @@ -17,7 +27,17 @@ impl service::rooms::search::Data for KeyValueDatabase { (key, Vec::new()) }); - self.tokenids.insert_batch(&mut batch) + self.tokenids.insert_batch(&mut token_batch)?; + + if contains_url { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(pdu_id); + + self.urltokenids.insert(&key, Default::default())?; + } + + Ok(()) } fn search_pdus<'a>( @@ -64,4 +84,18 @@ impl service::rooms::search::Data for KeyValueDatabase { Ok(Some((Box::new(common_elements), words))) } + + fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result { + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists"); + + let mut key = prefix.to_be_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(pdu_id); + + self.urltokenids.get(&key).map(|v| v.is_some()) + } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 1b178bd5..b2a21ff7 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -86,6 +86,7 @@ pub struct KeyValueDatabase { pub(super) threadid_userids: Arc, // ThreadId = RoomId + Count pub(super) tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount + pub(super) urltokenids: Arc, // useful for `RoomEventFilter::contains_url` /// Participating servers in a room. pub(super) roomserverids: Arc, // RoomServerId = RoomId + ServerName @@ -314,6 +315,7 @@ impl KeyValueDatabase { threadid_userids: builder.open_tree("threadid_userids")?, tokenids: builder.open_tree("tokenids")?, + urltokenids: builder.open_tree("urltokenids")?, roomserverids: builder.open_tree("roomserverids")?, serverroomids: builder.open_tree("serverroomids")?, diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index c7c92981..c6eb23c0 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use crate::Result; use ruma::{ - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + events::{AnyRoomAccountDataEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; @@ -31,5 +31,5 @@ pub trait Data: Send + Sync { room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result>>; + ) -> Result>>; } diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index f9c49b1a..7c3a9549 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -3,7 +3,7 @@ mod data; pub use data::Data; use ruma::{ - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + events::{AnyRoomAccountDataEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; @@ -47,7 +47,7 @@ impl Service { room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result>> { + ) -> Result>> { self.db.changes_since(room_id, user_id, since) } } diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 7ea7e3d1..0ec480b1 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -10,4 +10,6 @@ pub trait Data: Send + Sync { room_id: &RoomId, search_string: &str, ) -> Result> + 'a>, Vec)>>; + + fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result; } diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index b6f35e79..39a97e43 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -23,4 +23,9 @@ impl Service { ) -> Result> + 'a, Vec)>> { self.db.search_pdus(room_id, search_string) } + + #[tracing::instrument(skip(self))] + pub fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result { + self.db.contains_url(room_id, pdu_id) + } } diff --git a/src/utils/filter.rs b/src/utils/filter.rs new file mode 100644 index 00000000..64350aff --- /dev/null +++ b/src/utils/filter.rs @@ -0,0 +1,65 @@ +use ruma::{ + api::client::filter::UrlFilter, serde::Raw, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, +}; +use serde::Deserialize; + +use crate::services; + +fn inclusion Deserialize<'a> + PartialEq>( + event: &Raw, + field: &str, + include: Option<&Vec>, + exclude: &[F], +) -> bool { + let value = event + .get_field::(field) + .expect("room events should deserialize") + .expect("field should exist"); + + include + .map(|v| v.iter().any(|item| *item == value)) + .unwrap_or(true) + && exclude.iter().all(|item| *item != value) +} + +pub fn rooms( + event: &Raw, + rooms: Option<&Vec>, + not_rooms: &[OwnedRoomId], +) -> bool { + inclusion(event, "room_id", rooms, not_rooms) +} + +pub fn senders( + event: &Raw, + senders: Option<&Vec>, + not_senders: &[OwnedUserId], +) -> bool { + inclusion(event, "sender", senders, not_senders) +} + +pub fn types(event: &Raw, types: Option<&Vec>, not_types: &[String]) -> bool { + inclusion(event, "event_type", types, not_types) +} + +pub fn url(event: &Raw, room_id: &RoomId, filter: UrlFilter) -> bool { + let Ok(Some(pdu_id)) = services().rooms.timeline.get_pdu_id( + &event + .get_field::("event_id") + .expect("event_id can be deserialized") + .expect("event should have event_id"), + ) else { + return filter == UrlFilter::EventsWithoutUrl; + }; + + let contains_url = services() + .rooms + .search + .contains_url(room_id, &pdu_id) + .unwrap_or(false); + + match filter { + UrlFilter::EventsWithUrl => contains_url, + UrlFilter::EventsWithoutUrl => !contains_url, + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d09a1033..6ca6a2e3 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,4 +1,5 @@ pub mod error; +pub mod filter; use argon2::{Config, Variant}; use cmp::Ordering;