1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-09-15 18:57:03 +00:00

Merge branch 'search-filters' into 'next'

Draft: Filters

See merge request famedly/conduit!674
This commit is contained in:
avdb 2024-06-12 08:18:56 +00:00
commit 9c87941459
14 changed files with 293 additions and 40 deletions

1
Cargo.lock generated
View file

@ -2687,6 +2687,7 @@ version = "1.0.117"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3"
dependencies = [ dependencies = [
"indexmap 2.2.5",
"itoa", "itoa",
"ryu", "ryu",
"serde", "serde",

View file

@ -60,7 +60,7 @@ http = "1"
# Used to find data directory for default db path # Used to find data directory for default db path
directories = "5" directories = "5"
# Used for ruma wrapper # Used for ruma wrapper
serde_json = { version = "1.0.96", features = ["raw_value"] } serde_json = { version = "1.0.96", features = ["raw_value", "preserve_order"] }
# Used for appservice registration files # Used for appservice registration files
serde_yaml = "0.9.21" serde_yaml = "0.9.21"
# Used for pdu definition # Used for pdu definition

View file

@ -14,6 +14,7 @@ channel = "1.78.0"
components = [ components = [
# For rust-analyzer # For rust-analyzer
"rust-src", "rust-src",
"rust-analyzer",
] ]
targets = [ targets = [
"aarch64-unknown-linux-musl", "aarch64-unknown-linux-musl",

View file

@ -1,7 +1,8 @@
use crate::{services, Error, Result, Ruma}; use crate::{services, utils::filter, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
events::StateEventType, events::StateEventType,
serde::Raw,
}; };
use std::collections::HashSet; use std::collections::HashSet;
use tracing::error; use tracing::error;
@ -78,7 +79,6 @@ pub async fn get_context_route(
.rooms .rooms
.timeline .timeline
.pdus_until(sender_user, &room_id, base_token)? .pdus_until(sender_user, &room_id, base_token)?
.take(limit / 2)
.filter_map(|r| r.ok()) // Remove buggy events .filter_map(|r| r.ok()) // Remove buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
@ -109,13 +109,33 @@ pub async fn get_context_route(
let events_before: Vec<_> = events_before let events_before: Vec<_> = events_before
.into_iter() .into_iter()
.map(|(_, pdu)| pdu.to_room_event()) .map(|(_, pdu)| pdu.to_room_event())
.filter(|v| {
filter::senders(
v,
body.filter.senders.as_ref(),
body.filter.not_senders.as_ref(),
)
})
.filter(|v| {
filter::types(
v,
body.filter.types.as_ref(),
body.filter.not_types.as_ref(),
)
})
.filter(|v| {
body.filter
.url_filter
.map(|f| filter::url(v, &room_id, f))
.unwrap_or(true)
})
.take(limit / 2)
.collect(); .collect();
let events_after: Vec<_> = services() let events_after: Vec<_> = services()
.rooms .rooms
.timeline .timeline
.pdus_after(sender_user, &room_id, base_token)? .pdus_after(sender_user, &room_id, base_token)?
.take(limit / 2)
.filter_map(|r| r.ok()) // Remove buggy events .filter_map(|r| r.ok()) // Remove buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
@ -165,6 +185,27 @@ pub async fn get_context_route(
let events_after: Vec<_> = events_after let events_after: Vec<_> = events_after
.into_iter() .into_iter()
.map(|(_, pdu)| pdu.to_room_event()) .map(|(_, pdu)| pdu.to_room_event())
.filter(|v| {
filter::senders(
v,
body.filter.senders.as_ref(),
body.filter.not_senders.as_ref(),
)
})
.filter(|v| {
filter::types(
v,
body.filter.types.as_ref(),
body.filter.not_types.as_ref(),
)
})
.filter(|v| {
body.filter
.url_filter
.map(|f| filter::url(v, &room_id, f))
.unwrap_or(true)
})
.take(limit / 2)
.collect(); .collect();
let mut state = Vec::new(); let mut state = Vec::new();
@ -175,24 +216,34 @@ pub async fn get_context_route(
.short .short
.get_statekey_from_short(shortstatekey)?; .get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if !filter::types(
let pdu = match services().rooms.timeline.get_pdu(&id)? { &Raw::new(&serde_json::json!({"type": event_type})).expect("json can be serialized"),
Some(pdu) => pdu, body.filter.types.as_ref(),
body.filter.not_types.as_ref(),
) {
continue;
}
if event_type != StateEventType::RoomMember
|| (!lazy_load_enabled || lazy_loaded.contains(&state_key))
{
let event = match services().rooms.timeline.get_pdu(&id)? {
Some(pdu) => pdu.to_state_event(),
None => { None => {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
} }
}; };
state.push(pdu.to_state_event());
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { if !filter::senders(
let pdu = match services().rooms.timeline.get_pdu(&id)? { &event,
Some(pdu) => pdu, body.filter.senders.as_ref(),
None => { body.filter.not_senders.as_ref(),
error!("Pdu in state not found: {}", id); ) {
continue; continue;
} }
};
state.push(pdu.to_state_event()); state.push(event);
} }
} }

View file

@ -1,11 +1,13 @@
use crate::{ use crate::{
service::{pdu::EventHash, rooms::timeline::PduCount}, service::{pdu::EventHash, rooms::timeline::PduCount},
services, utils, Error, PduEvent, Result, Ruma, RumaResponse, services,
utils::{self, filter},
Error, PduEvent, Result, Ruma, RumaResponse,
}; };
use ruma::{ use ruma::{
api::client::{ api::client::{
filter::{FilterDefinition, LazyLoadOptions}, filter::{EventFormat, FilterDefinition, LazyLoadOptions, RoomFilter},
sync::sync_events::{ sync::sync_events::{
self, self,
v3::{ v3::{
@ -194,6 +196,8 @@ async fn sync_helper(
.unwrap_or_default(), .unwrap_or_default(),
}; };
let event_fields = filter.event_fields.as_ref();
let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options {
LazyLoadOptions::Enabled { LazyLoadOptions::Enabled {
include_redundant_members: redundant, include_redundant_members: redundant,
@ -231,6 +235,15 @@ async fn sync_helper(
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for room_id in all_joined_rooms { for room_id in all_joined_rooms {
let room_id = room_id?; let room_id = room_id?;
if filter::rooms(
&Raw::new(&serde_json::json!({"room_id": room_id})).expect("json can be serialized"),
filter.room.rooms.as_ref(),
filter.room.not_rooms.as_ref(),
) {
continue;
}
if let Ok(joined_room) = load_joined_room( if let Ok(joined_room) = load_joined_room(
&sender_user, &sender_user,
&sender_device, &sender_device,
@ -241,6 +254,9 @@ async fn sync_helper(
next_batchcount, next_batchcount,
lazy_load_enabled, lazy_load_enabled,
lazy_load_send_redundant, lazy_load_send_redundant,
&filter.room,
event_fields,
&filter.event_format,
full_state, full_state,
&mut device_list_updates, &mut device_list_updates,
&mut left_encrypted_users, &mut left_encrypted_users,
@ -289,11 +305,14 @@ async fn sync_helper(
} }
let mut left_rooms = BTreeMap::new(); let mut left_rooms = BTreeMap::new();
let all_left_rooms: Vec<_> = services() let all_left_rooms = match filter.room.include_leave {
false => Vec::with_capacity(0),
true => services()
.rooms .rooms
.state_cache .state_cache
.rooms_left(&sender_user) .rooms_left(&sender_user)
.collect(); .collect(),
};
for result in all_left_rooms { for result in all_left_rooms {
let (room_id, _) = result?; let (room_id, _) = result?;
@ -541,6 +560,7 @@ async fn sync_helper(
knock: BTreeMap::new(), // TODO knock: BTreeMap::new(), // TODO
}, },
presence: Presence { presence: Presence {
// HashMap<OwnedUserId, PresenceEvent>
events: presence_updates events: presence_updates
.into_values() .into_values()
.map(|v| Raw::new(&v).expect("PresenceEvent always serializes successfully")) .map(|v| Raw::new(&v).expect("PresenceEvent always serializes successfully"))
@ -549,6 +569,7 @@ async fn sync_helper(
account_data: GlobalAccountData { account_data: GlobalAccountData {
events: services() events: services()
.account_data .account_data
// HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>
.changes_since(None, &sender_user, since)? .changes_since(None, &sender_user, since)?
.into_iter() .into_iter()
.filter_map(|(_, v)| { .filter_map(|(_, v)| {
@ -606,6 +627,9 @@ async fn load_joined_room(
next_batchcount: PduCount, next_batchcount: PduCount,
lazy_load_enabled: bool, lazy_load_enabled: bool,
lazy_load_send_redundant: bool, lazy_load_send_redundant: bool,
filter: &RoomFilter,
_event_fields: Option<&Vec<String>>,
event_format: &EventFormat,
full_state: bool, full_state: bool,
device_list_updates: &mut HashSet<OwnedUserId>, device_list_updates: &mut HashSet<OwnedUserId>,
left_encrypted_users: &mut HashSet<OwnedUserId>, left_encrypted_users: &mut HashSet<OwnedUserId>,
@ -1081,7 +1105,33 @@ async fn load_joined_room(
let room_events: Vec<_> = timeline_pdus let room_events: Vec<_> = timeline_pdus
.iter() .iter()
.map(|(_, pdu)| pdu.to_sync_room_event()) .map(|(_, pdu)| match event_format {
EventFormat::Federation => Raw::new(pdu)
.map(Raw::cast)
.expect("json can be serialized"),
_ => pdu.to_sync_room_event(),
})
.filter(|v| {
filter
.timeline
.url_filter
.map(|f| filter::url(v, room_id, f))
.unwrap_or(true)
})
.filter(|v| {
filter::senders(
v,
filter.timeline.senders.as_ref(),
filter.timeline.not_senders.as_ref(),
)
})
.filter(|v| {
filter::types(
v,
filter.timeline.types.as_ref(),
filter.timeline.not_types.as_ref(),
)
})
.collect(); .collect();
let mut edus: Vec<_> = services() let mut edus: Vec<_> = services()
@ -1090,7 +1140,23 @@ async fn load_joined_room(
.read_receipt .read_receipt
.readreceipts_since(room_id, since) .readreceipts_since(room_id, since)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
// TODO
.map(|(_, _, v)| v) .map(|(_, _, v)| v)
.filter(|v| {
filter::senders(
v,
filter.ephemeral.senders.as_ref(),
filter.ephemeral.not_senders.as_ref(),
)
})
.filter(|v| {
filter::types(
v,
filter.ephemeral.types.as_ref(),
filter.ephemeral.not_types.as_ref(),
)
})
.take(filter.ephemeral.limit.map_or(10, u64::from).min(100) as usize)
.collect(); .collect();
if services() if services()
@ -1123,11 +1189,15 @@ async fn load_joined_room(
.account_data .account_data
.changes_since(Some(room_id), sender_user, since)? .changes_since(Some(room_id), sender_user, since)?
.into_iter() .into_iter()
.filter_map(|(_, v)| { .filter_map(|(_, v)| serde_json::from_str(v.json().get()).ok())
serde_json::from_str(v.json().get()) .filter(|v| {
.map_err(|_| Error::bad_database("Invalid account event in database.")) filter::types(
.ok() v,
filter.account_data.types.as_ref(),
filter.account_data.not_types.as_ref(),
)
}) })
.take(filter.account_data.limit.map_or(10, u64::from).min(100) as usize)
.collect(), .collect(),
}, },
summary: RoomSummary { summary: RoomSummary {
@ -1142,12 +1212,33 @@ async fn load_joined_room(
timeline: Timeline { timeline: Timeline {
limited: limited || joined_since_last_sync, limited: limited || joined_since_last_sync,
prev_batch, prev_batch,
// Vec<Raw<AnySyncTimelineEvent>>
events: room_events, events: room_events,
}, },
state: State { state: State {
events: state_events events: state_events
.iter() .iter()
.map(|pdu| pdu.to_sync_state_event()) .map(|pdu| match event_format {
EventFormat::Federation => Raw::new(pdu)
.map(Raw::cast)
.expect("json can be serialized"),
_ => pdu.to_sync_state_event(),
})
.filter(|v| {
filter::senders(
v,
filter.state.senders.as_ref(),
filter.state.not_senders.as_ref(),
)
})
.filter(|v| {
filter::types(
v,
filter.state.types.as_ref(),
filter.state.not_types.as_ref(),
)
})
.take(filter.state.limit.map_or(10, u64::from).min(100) as usize)
.collect(), .collect(),
}, },
ephemeral: Ephemeral { events: edus }, ephemeral: Ephemeral { events: edus },

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use ruma::{ use ruma::{
api::client::error::ErrorKind, api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyRoomAccountDataEvent, RoomAccountDataEventType},
serde::Raw, serde::Raw,
RoomId, UserId, RoomId, UserId,
}; };
@ -101,7 +101,7 @@ impl service::account_data::Data for KeyValueDatabase {
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,
since: u64, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> { ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>> {
let mut userdata = HashMap::new(); let mut userdata = HashMap::new();
let mut prefix = room_id let mut prefix = room_id
@ -129,7 +129,7 @@ impl service::account_data::Data for KeyValueDatabase {
)?) )?)
.map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?, .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?,
), ),
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| { serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(&v).map_err(|_| {
Error::bad_database("Database contains invalid account data.") Error::bad_database("Database contains invalid account data.")
})?, })?,
)) ))

View file

@ -1,13 +1,23 @@
use std::str::FromStr;
use ruma::RoomId; use ruma::RoomId;
use url::Url;
use crate::{database::KeyValueDatabase, service, services, utils, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Result};
impl service::rooms::search::Data for KeyValueDatabase { impl service::rooms::search::Data for KeyValueDatabase {
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let mut batch = message_body let mut contains_url = false;
let mut token_batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50) .filter(|word| {
contains_url =
contains_url || (word.starts_with("http") && Url::from_str(word).is_ok());
word.len() <= 50
})
.map(str::to_lowercase) .map(str::to_lowercase)
.map(|word| { .map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
@ -17,7 +27,17 @@ impl service::rooms::search::Data for KeyValueDatabase {
(key, Vec::new()) (key, Vec::new())
}); });
self.tokenids.insert_batch(&mut batch) self.tokenids.insert_batch(&mut token_batch)?;
if contains_url {
let mut key = shortroomid.to_be_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(pdu_id);
self.urltokenids.insert(&key, Default::default())?;
}
Ok(())
} }
fn search_pdus<'a>( fn search_pdus<'a>(
@ -64,4 +84,18 @@ impl service::rooms::search::Data for KeyValueDatabase {
Ok(Some((Box::new(common_elements), words))) Ok(Some((Box::new(common_elements), words)))
} }
fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<bool> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = prefix.to_be_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(pdu_id);
self.urltokenids.get(&key).map(|v| v.is_some())
}
} }

View file

@ -86,6 +86,7 @@ pub struct KeyValueDatabase {
pub(super) threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count pub(super) threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count
pub(super) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount pub(super) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
pub(super) urltokenids: Arc<dyn KvTree>, // useful for `RoomEventFilter::contains_url`
/// Participating servers in a room. /// Participating servers in a room.
pub(super) roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName pub(super) roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName
@ -314,6 +315,7 @@ impl KeyValueDatabase {
threadid_userids: builder.open_tree("threadid_userids")?, threadid_userids: builder.open_tree("threadid_userids")?,
tokenids: builder.open_tree("tokenids")?, tokenids: builder.open_tree("tokenids")?,
urltokenids: builder.open_tree("urltokenids")?,
roomserverids: builder.open_tree("roomserverids")?, roomserverids: builder.open_tree("roomserverids")?,
serverroomids: builder.open_tree("serverroomids")?, serverroomids: builder.open_tree("serverroomids")?,

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use crate::Result; use crate::Result;
use ruma::{ use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyRoomAccountDataEvent, RoomAccountDataEventType},
serde::Raw, serde::Raw,
RoomId, UserId, RoomId, UserId,
}; };
@ -31,5 +31,5 @@ pub trait Data: Send + Sync {
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,
since: u64, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>; ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>;
} }

View file

@ -3,7 +3,7 @@ mod data;
pub use data::Data; pub use data::Data;
use ruma::{ use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyRoomAccountDataEvent, RoomAccountDataEventType},
serde::Raw, serde::Raw,
RoomId, UserId, RoomId, UserId,
}; };
@ -47,7 +47,7 @@ impl Service {
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,
since: u64, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> { ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>> {
self.db.changes_since(room_id, user_id, since) self.db.changes_since(room_id, user_id, since)
} }
} }

View file

@ -10,4 +10,6 @@ pub trait Data: Send + Sync {
room_id: &RoomId, room_id: &RoomId,
search_string: &str, search_string: &str,
) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>; ) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<bool>;
} }

View file

@ -23,4 +23,9 @@ impl Service {
) -> Result<Option<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)>> { ) -> Result<Option<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)>> {
self.db.search_pdus(room_id, search_string) self.db.search_pdus(room_id, search_string)
} }
#[tracing::instrument(skip(self))]
pub fn contains_url(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<bool> {
self.db.contains_url(room_id, pdu_id)
}
} }

65
src/utils/filter.rs Normal file
View file

@ -0,0 +1,65 @@
use ruma::{
api::client::filter::UrlFilter, serde::Raw, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId,
};
use serde::Deserialize;
use crate::services;
fn inclusion<T, F: for<'a> Deserialize<'a> + PartialEq>(
event: &Raw<T>,
field: &str,
include: Option<&Vec<F>>,
exclude: &[F],
) -> bool {
let value = event
.get_field::<F>(field)
.expect("room events should deserialize")
.expect("field should exist");
include
.map(|v| v.iter().any(|item| *item == value))
.unwrap_or(true)
&& exclude.iter().all(|item| *item != value)
}
pub fn rooms<T>(
event: &Raw<T>,
rooms: Option<&Vec<OwnedRoomId>>,
not_rooms: &[OwnedRoomId],
) -> bool {
inclusion(event, "room_id", rooms, not_rooms)
}
pub fn senders<T>(
event: &Raw<T>,
senders: Option<&Vec<OwnedUserId>>,
not_senders: &[OwnedUserId],
) -> bool {
inclusion(event, "sender", senders, not_senders)
}
pub fn types<T>(event: &Raw<T>, types: Option<&Vec<String>>, not_types: &[String]) -> bool {
inclusion(event, "event_type", types, not_types)
}
pub fn url<T>(event: &Raw<T>, room_id: &RoomId, filter: UrlFilter) -> bool {
let Ok(Some(pdu_id)) = services().rooms.timeline.get_pdu_id(
&event
.get_field::<OwnedEventId>("event_id")
.expect("event_id can be deserialized")
.expect("event should have event_id"),
) else {
return filter == UrlFilter::EventsWithoutUrl;
};
let contains_url = services()
.rooms
.search
.contains_url(room_id, &pdu_id)
.unwrap_or(false);
match filter {
UrlFilter::EventsWithUrl => contains_url,
UrlFilter::EventsWithoutUrl => !contains_url,
}
}

View file

@ -1,4 +1,5 @@
pub mod error; pub mod error;
pub mod filter;
use argon2::{Config, Variant}; use argon2::{Config, Variant};
use cmp::Ordering; use cmp::Ordering;