From d1138204a68e1afc501b5378ec597e489c404e14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 7 Aug 2022 19:42:22 +0200 Subject: [PATCH] Refactor appservices, pusher, timeline, transactionids, users --- .../key_value}/appservice.rs | 24 +- src/database/key_value/pusher.rs | 56 +++ src/database/key_value/rooms/timeline.rs | 282 ++++++++++++++ .../key_value}/transaction_ids.rs | 13 +- src/{service => database/key_value}/users.rs | 148 +------- src/service/appservice/data.rs | 17 + src/service/appservice/mod.rs | 36 ++ src/service/globals.rs | 14 +- src/service/pusher.rs | 348 ----------------- src/service/pusher/data.rs | 12 + src/service/pusher/mod.rs | 287 ++++++++++++++ src/service/rooms/short/mod.rs | 11 +- src/service/rooms/timeline/data.rs | 66 ++++ src/service/rooms/timeline/mod.rs | 232 ++---------- src/service/transaction_ids/data.rs | 16 + src/service/transaction_ids/mod.rs | 44 +++ src/service/users/data.rs | 228 +++++++++++ src/service/users/mod.rs | 354 ++++++++++++++++++ 18 files changed, 1452 insertions(+), 736 deletions(-) rename src/{service => database/key_value}/appservice.rs (77%) create mode 100644 src/database/key_value/pusher.rs create mode 100644 src/database/key_value/rooms/timeline.rs rename src/{service => database/key_value}/transaction_ids.rs (77%) rename src/{service => database/key_value}/users.rs (84%) create mode 100644 src/service/appservice/data.rs create mode 100644 src/service/appservice/mod.rs delete mode 100644 src/service/pusher.rs create mode 100644 src/service/pusher/data.rs create mode 100644 src/service/pusher/mod.rs create mode 100644 src/service/rooms/timeline/data.rs create mode 100644 src/service/transaction_ids/data.rs create mode 100644 src/service/transaction_ids/mod.rs create mode 100644 src/service/users/data.rs create mode 100644 src/service/users/mod.rs diff --git a/src/service/appservice.rs b/src/database/key_value/appservice.rs similarity index 77% rename from src/service/appservice.rs rename to src/database/key_value/appservice.rs index edd5009b..66a2a5c8 100644 --- a/src/service/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,19 +1,5 @@ -use crate::{utils, Error, Result}; -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; - -use super::abstraction::Tree; - -pub struct Appservice { - pub(super) cached_registrations: Arc>>, - pub(super) id_appserviceregistrations: Arc, -} - -impl Appservice { +impl service::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller - /// pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { // TODO: Rumaify let id = yaml.get("id").unwrap().as_str().unwrap(); @@ -34,7 +20,7 @@ impl Appservice { /// # Arguments /// /// * `service_name` - the name you send to register the service previously - pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations .remove(service_name.as_bytes())?; self.cached_registrations @@ -44,7 +30,7 @@ impl Appservice { Ok(()) } - pub fn get_registration(&self, id: &str) -> Result> { + fn get_registration(&self, id: &str) -> Result> { self.cached_registrations .read() .unwrap() @@ -66,14 +52,14 @@ impl Appservice { ) } - pub fn iter_ids(&self) -> Result> + '_> { + fn iter_ids(&self) -> Result> + '_> { Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { utils::string_from_bytes(&id) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) })) } - pub fn all(&self) -> Result> { + fn all(&self) -> Result> { self.iter_ids()? .filter_map(|id| id.ok()) .map(move |id| { diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs new file mode 100644 index 00000000..94374ab2 --- /dev/null +++ b/src/database/key_value/pusher.rs @@ -0,0 +1,56 @@ +impl service::pusher::Data for KeyValueDatabase { + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { + let mut key = sender.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(pusher.pushkey.as_bytes()); + + // There are 2 kinds of pushers but the spec says: null deletes the pusher. + if pusher.kind.is_none() { + return self + .senderkey_pusher + .remove(&key) + .map(|_| ()) + .map_err(Into::into); + } + + self.senderkey_pusher.insert( + &key, + &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), + )?; + + Ok(()) + } + + fn get_pusher(&self, senderkey: &[u8]) -> Result> { + self.senderkey_pusher + .get(senderkey)? + .map(|push| { + serde_json::from_slice(&*push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .transpose() + } + + fn get_pushers(&self, sender: &UserId) -> Result> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher + .scan_prefix(prefix) + .map(|(_, push)| { + serde_json::from_slice(&*push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .collect() + } + + fn get_pusher_senderkeys<'a>( + &'a self, + sender: &UserId, + ) -> impl Iterator> + 'a { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) + } +} diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs new file mode 100644 index 00000000..58884ec3 --- /dev/null +++ b/src/database/key_value/rooms/timeline.rs @@ -0,0 +1,282 @@ +impl service::room::timeline::Data for KeyValueDatabase { + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + match self + .lasttimelinecount_cache + .lock() + .unwrap() + .entry(room_id.to_owned()) + { + hash_map::Entry::Vacant(v) => { + if let Some(last_count) = self + .pdus_until(&sender_user, &room_id, u64::MAX)? + .filter_map(|r| { + // Filter out buggy events + if r.is_err() { + error!("Bad pdu in pdus_since: {:?}", r); + } + r.ok() + }) + .map(|(pduid, _)| self.pdu_count(&pduid)) + .next() + { + Ok(*v.insert(last_count?)) + } else { + Ok(0) + } + } + hash_map::Entry::Occupied(o) => Ok(*o.get()), + } + } + + /// Returns the `count` of this pdu's id. + fn get_pdu_count(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pdu_id| self.pdu_count(&pdu_id)) + .transpose() + } + + /// Returns the json of a pdu. + pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the json of a pdu. + pub fn get_non_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the pdu's id. + pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { + self.eventid_pduid.get(event_id.as_bytes()) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_pdu(&self, event_id: &EventId) -> Result>> { + if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { + return Ok(Some(Arc::clone(p))); + } + + if let Some(pdu) = self + .eventid_pduid + .get(event_id.as_bytes())? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + .map(Arc::new) + }) + .transpose()? + { + self.pdu_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), Arc::clone(&pdu)); + Ok(Some(pdu)) + } else { + Ok(None) + } + } + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the pdu as a `BTreeMap`. + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the `count` of this pdu's id. + pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) + } + + /// Removes a pdu and creates a new one with the same id. + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { + if self.pduid_pdu.get(pdu_id)?.is_some() { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"), + )?; + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::NotFound, + "PDU does not exist.", + )) + } + } + + /// Returns an iterator over all events in a room that happened after the event with id `since` + /// in chronological order. + pub fn pdus_since<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + since: u64, + ) -> Result, PduEvent)>> + 'a> { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Skip the first pdu if it's exactly at since, because we sent that last time + let mut first_pdu_id = prefix.clone(); + first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(&first_pdu_id, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } + + /// Returns an iterator over all events and their tokens in a room that happened before the + /// event with id `until` in reverse-chronological order. + pub fn pdus_until<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + until: u64, + ) -> Result, PduEvent)>> + 'a> { + // Create the first part of the full pdu id + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` + + let current: &[u8] = ¤t; + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(current, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } + + pub fn pdus_after<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + from: u64, + ) -> Result, PduEvent)>> + 'a> { + // Create the first part of the full pdu id + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event + + let current: &[u8] = ¤t; + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(current, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } +} diff --git a/src/service/transaction_ids.rs b/src/database/key_value/transaction_ids.rs similarity index 77% rename from src/service/transaction_ids.rs rename to src/database/key_value/transaction_ids.rs index ed0970d1..81c1197d 100644 --- a/src/service/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,15 +1,4 @@ -use std::sync::Arc; - -use crate::Result; -use ruma::{DeviceId, TransactionId, UserId}; - -use super::abstraction::Tree; - -pub struct TransactionIds { - pub(super) userdevicetxnid_response: Arc, // Response can be empty (/sendToDevice) or the event id (/send) -} - -impl TransactionIds { +impl service::pusher::Data for KeyValueDatabase { pub fn add_txnid( &self, user_id: &UserId, diff --git a/src/service/users.rs b/src/database/key_value/users.rs similarity index 84% rename from src/service/users.rs rename to src/database/key_value/users.rs index 7c15f1d8..5ef058f3 100644 --- a/src/service/users.rs +++ b/src/database/key_value/users.rs @@ -1,49 +1,10 @@ -use crate::{utils, Error, Result}; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId, - UInt, UserId, -}; -use std::{collections::BTreeMap, mem, sync::Arc}; -use tracing::warn; - -use super::abstraction::Tree; - -pub struct Users { - pub(super) userid_password: Arc, - pub(super) userid_displayname: Arc, - pub(super) userid_avatarurl: Arc, - pub(super) userid_blurhash: Arc, - pub(super) userdeviceid_token: Arc, - pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists - pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 - pub(super) token_userdeviceid: Arc, - - pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId - pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count - pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count - pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) - pub(super) userid_masterkeyid: Arc, - pub(super) userid_selfsigningkeyid: Arc, - pub(super) userid_usersigningkeyid: Arc, - - pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId - - pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count -} - -impl Users { +impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. - #[tracing::instrument(skip(self, user_id))] pub fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } /// Check if account is deactivated - #[tracing::instrument(skip(self, user_id))] pub fn is_deactivated(&self, user_id: &UserId) -> Result { Ok(self .userid_password @@ -56,7 +17,6 @@ impl Users { } /// Check if a user is an admin - #[tracing::instrument(skip(self, user_id, rooms, globals))] pub fn is_admin( &self, user_id: &UserId, @@ -71,20 +31,17 @@ impl Users { } /// Create a new user account on this homeserver. - #[tracing::instrument(skip(self, user_id, password))] pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { self.set_password(user_id, password)?; Ok(()) } /// Returns the number of users registered on this server. - #[tracing::instrument(skip(self))] pub fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } /// Find out which user an access token belongs to. - #[tracing::instrument(skip(self, token))] pub fn find_from_token(&self, token: &str) -> Result, String)>> { self.token_userdeviceid .get(token.as_bytes())? @@ -112,7 +69,6 @@ impl Users { } /// Returns an iterator over all users on this homeserver. - #[tracing::instrument(skip(self))] pub fn iter(&self) -> impl Iterator>> + '_ { self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { @@ -125,7 +81,6 @@ impl Users { /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is greater then zero. - #[tracing::instrument(skip(self))] pub fn list_local_users(&self) -> Result> { let users: Vec = self .userid_password @@ -139,7 +94,6 @@ impl Users { /// username could be successfully parsed. /// If utils::string_from_bytes(...) returns an error that username will be skipped /// and the error will be logged. - #[tracing::instrument(skip(self))] fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { // A valid password is not empty if password.is_empty() { @@ -159,7 +113,6 @@ impl Users { } /// Returns the password hash for the given user. - #[tracing::instrument(skip(self, user_id))] pub fn password_hash(&self, user_id: &UserId) -> Result> { self.userid_password .get(user_id.as_bytes())? @@ -171,7 +124,6 @@ impl Users { } /// Hash and set the user's password to the Argon2 hash - #[tracing::instrument(skip(self, user_id, password))] pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { if let Ok(hash) = utils::calculate_hash(password) { @@ -191,7 +143,6 @@ impl Users { } /// Returns the displayname of a user on this homeserver. - #[tracing::instrument(skip(self, user_id))] pub fn displayname(&self, user_id: &UserId) -> Result> { self.userid_displayname .get(user_id.as_bytes())? @@ -203,7 +154,6 @@ impl Users { } /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - #[tracing::instrument(skip(self, user_id, displayname))] pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { if let Some(displayname) = displayname { self.userid_displayname @@ -216,7 +166,6 @@ impl Users { } /// Get the avatar_url of a user. - #[tracing::instrument(skip(self, user_id))] pub fn avatar_url(&self, user_id: &UserId) -> Result>> { self.userid_avatarurl .get(user_id.as_bytes())? @@ -230,7 +179,6 @@ impl Users { } /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, avatar_url))] pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { if let Some(avatar_url) = avatar_url { self.userid_avatarurl @@ -243,7 +191,6 @@ impl Users { } /// Get the blurhash of a user. - #[tracing::instrument(skip(self, user_id))] pub fn blurhash(&self, user_id: &UserId) -> Result> { self.userid_blurhash .get(user_id.as_bytes())? @@ -257,7 +204,6 @@ impl Users { } /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, blurhash))] pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { if let Some(blurhash) = blurhash { self.userid_blurhash @@ -270,7 +216,6 @@ impl Users { } /// Adds a new device to a user. - #[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))] pub fn create_device( &self, user_id: &UserId, @@ -305,7 +250,6 @@ impl Users { } /// Removes a device from a user. - #[tracing::instrument(skip(self, user_id, device_id))] pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); @@ -336,7 +280,6 @@ impl Users { } /// Returns an iterator over all device ids of this user. - #[tracing::instrument(skip(self, user_id))] pub fn all_device_ids<'a>( &'a self, user_id: &UserId, @@ -359,7 +302,6 @@ impl Users { } /// Replaces the access token of one device. - #[tracing::instrument(skip(self, user_id, device_id, token))] pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); @@ -383,14 +325,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip( - self, - user_id, - device_id, - one_time_key_key, - one_time_key_value, - globals - ))] pub fn add_one_time_key( &self, user_id: &UserId, @@ -427,7 +361,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id))] pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { self.userid_lastonetimekeyupdate .get(user_id.as_bytes())? @@ -439,7 +372,6 @@ impl Users { .unwrap_or(Ok(0)) } - #[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))] pub fn take_one_time_key( &self, user_id: &UserId, @@ -479,7 +411,6 @@ impl Users { .transpose() } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn count_one_time_keys( &self, user_id: &UserId, @@ -512,7 +443,6 @@ impl Users { Ok(counts) } - #[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))] pub fn add_device_keys( &self, user_id: &UserId, @@ -535,14 +465,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip( - self, - master_key, - self_signing_key, - user_signing_key, - rooms, - globals - ))] pub fn add_cross_signing_keys( &self, user_id: &UserId, @@ -658,7 +580,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))] pub fn sign_key( &self, target_id: &UserId, @@ -703,7 +624,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_or_room_id, from, to))] pub fn keys_changed<'a>( &'a self, user_or_room_id: &str, @@ -742,7 +662,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id, rooms, globals))] pub fn mark_device_key_update( &self, user_id: &UserId, @@ -774,7 +693,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_device_keys( &self, user_id: &UserId, @@ -791,7 +709,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id, allowed_signatures))] pub fn get_master_key bool>( &self, user_id: &UserId, @@ -813,7 +730,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id, allowed_signatures))] pub fn get_self_signing_key bool>( &self, user_id: &UserId, @@ -835,7 +751,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { self.userid_usersigningkeyid .get(user_id.as_bytes())? @@ -848,15 +763,6 @@ impl Users { }) } - #[tracing::instrument(skip( - self, - sender, - target_user_id, - target_device_id, - event_type, - content, - globals - ))] pub fn add_to_device_event( &self, sender: &UserId, @@ -884,7 +790,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_to_device_events( &self, user_id: &UserId, @@ -907,7 +812,6 @@ impl Users { Ok(events) } - #[tracing::instrument(skip(self, user_id, device_id, until))] pub fn remove_to_device_events( &self, user_id: &UserId, @@ -942,7 +846,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id, device))] pub fn update_device_metadata( &self, user_id: &UserId, @@ -968,7 +871,6 @@ impl Users { } /// Get device metadata. - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_device_metadata( &self, user_id: &UserId, @@ -987,7 +889,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { self.userid_devicelistversion .get(user_id.as_bytes())? @@ -998,7 +899,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] pub fn all_devices_metadata<'a>( &'a self, user_id: &UserId, @@ -1014,25 +914,7 @@ impl Users { }) } - /// Deactivate account - #[tracing::instrument(skip(self, user_id))] - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { - // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } - - // Set the password to "" to indicate a deactivated account. Hashes will never result in an - // empty string, so the user will not be able to log in again. Systems like changing the - // password without logging in should check if the account is deactivated. - self.userid_password.insert(user_id.as_bytes(), &[])?; - - // TODO: Unhook 3PID - Ok(()) - } - /// Creates a new sync filter. Returns the filter id. - #[tracing::instrument(skip(self))] pub fn create_filter( &self, user_id: &UserId, @@ -1052,7 +934,6 @@ impl Users { Ok(filter_id) } - #[tracing::instrument(skip(self))] pub fn get_filter( &self, user_id: &UserId, @@ -1072,30 +953,3 @@ impl Users { } } } - -/// Ensure that a user only sees signatures from themselves and the target user -fn clean_signatures bool>( - cross_signing_key: &mut serde_json::Value, - user_id: &UserId, - allowed_signatures: F, -) -> Result<(), Error> { - if let Some(signatures) = cross_signing_key - .get_mut("signatures") - .and_then(|v| v.as_object_mut()) - { - // Don't allocate for the full size of the current signatures, but require - // at most one resize if nothing is dropped - let new_capacity = signatures.len() / 2; - for (user, signature) in - mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) - { - let id = <&UserId>::try_from(user.as_str()) - .map_err(|_| Error::bad_database("Invalid user ID in database."))?; - if id == user_id || allowed_signatures(id) { - signatures.insert(user, signature); - } - } - } - - Ok(()) -} diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs new file mode 100644 index 00000000..fe57451f --- /dev/null +++ b/src/service/appservice/data.rs @@ -0,0 +1,17 @@ +pub trait Data { + /// Registers an appservice and returns the ID to the caller + pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result; + + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + pub fn unregister_appservice(&self, service_name: &str) -> Result<()>; + + pub fn get_registration(&self, id: &str) -> Result>; + + pub fn iter_ids(&self) -> Result> + '_>; + + pub fn all(&self) -> Result>; +} diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs new file mode 100644 index 00000000..ec4ffc56 --- /dev/null +++ b/src/service/appservice/mod.rs @@ -0,0 +1,36 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + /// Registers an appservice and returns the ID to the caller + pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { + self.db.register_appservice(yaml) + } + + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + self.db.unregister_appservice(service_name) + } + + pub fn get_registration(&self, id: &str) -> Result> { + self.db.get_registration(id) + } + + pub fn iter_ids(&self) -> Result> + '_> { + self.db.iter_ids() + } + + pub fn all(&self) -> Result> { + self.db.all() + } +} diff --git a/src/service/globals.rs b/src/service/globals.rs index 7e09128e..2b47e5b1 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -1,3 +1,8 @@ +mod data; +pub use data::Data; + +use crate::service::*; + use crate::{database::Config, server_server::FedDest, utils, Error, Result}; use ruma::{ api::{ @@ -32,10 +37,11 @@ type SyncHandle = ( Receiver>>, // rx ); -pub struct Globals { +pub struct Service { + db: D, + pub actual_destination_cache: Arc>, // actual_destination, host pub tls_name_override: Arc>, - pub(super) globals: Arc, pub config: Config, keypair: Arc, dns_resolver: TokioAsyncResolver, @@ -44,7 +50,6 @@ pub struct Globals { default_client: reqwest::Client, pub stable_room_versions: Vec, pub unstable_room_versions: Vec, - pub(super) server_signingkeys: Arc, pub bad_event_ratelimiter: Arc, RateLimitState>>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, @@ -87,7 +92,8 @@ impl Default for RotationHandler { } } -impl Globals { + +impl Service<_> { pub fn load( globals: Arc, server_signingkeys: Arc, diff --git a/src/service/pusher.rs b/src/service/pusher.rs deleted file mode 100644 index 6b906c24..00000000 --- a/src/service/pusher.rs +++ /dev/null @@ -1,348 +0,0 @@ -use crate::{Database, Error, PduEvent, Result}; -use bytes::BytesMut; -use ruma::{ - api::{ - client::push::{get_pushers, set_pusher, PusherKind}, - push_gateway::send_event_notification::{ - self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, - }, - IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, - }, - events::{ - room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, - AnySyncRoomEvent, RoomEventType, StateEventType, - }, - push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, - serde::Raw, - uint, RoomId, UInt, UserId, -}; -use tracing::{error, info, warn}; - -use std::{fmt::Debug, mem, sync::Arc}; - -use super::abstraction::Tree; - -pub struct PushData { - /// UserId + pushkey -> Pusher - pub(super) senderkey_pusher: Arc, -} - -impl PushData { - #[tracing::instrument(skip(self, sender, pusher))] - pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { - let mut key = sender.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(pusher.pushkey.as_bytes()); - - // There are 2 kinds of pushers but the spec says: null deletes the pusher. - if pusher.kind.is_none() { - return self - .senderkey_pusher - .remove(&key) - .map(|_| ()) - .map_err(Into::into); - } - - self.senderkey_pusher.insert( - &key, - &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), - )?; - - Ok(()) - } - - #[tracing::instrument(skip(self, senderkey))] - pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { - self.senderkey_pusher - .get(senderkey)? - .map(|push| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .transpose() - } - - #[tracing::instrument(skip(self, sender))] - pub fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .collect() - } - - #[tracing::instrument(skip(self, sender))] - pub fn get_pusher_senderkeys<'a>( - &'a self, - sender: &UserId, - ) -> impl Iterator> + 'a { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) - } -} - -#[tracing::instrument(skip(globals, destination, request))] -pub async fn send_request( - globals: &crate::database::globals::Globals, - destination: &str, - request: T, -) -> Result -where - T: Debug, -{ - let destination = destination.replace("/_matrix/push/v1/notify", ""); - - let http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - })? - .map(|body| body.freeze()); - - let reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); - - // TODO: we could keep this very short and let expo backoff do it's thing... - //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); - - let url = reqwest_request.url().clone(); - let response = globals.default_client().execute(reqwest_request).await; - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { - info!( - "Push gateway returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - crate::utils::string_from_bytes(&body) - ); - } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - info!( - "Push gateway returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Push gateway returned bad response.") - }) - } - Err(e) => Err(e.into()), - } -} - -#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] -pub async fn send_push_notice( - user: &UserId, - unread: UInt, - pusher: &get_pushers::v3::Pusher, - ruleset: Ruleset, - pdu: &PduEvent, - db: &Database, -) -> Result<()> { - let mut notify = None; - let mut tweaks = Vec::new(); - - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - for action in get_actions( - user, - &ruleset, - &power_levels, - &pdu.to_sync_room_event(), - &pdu.room_id, - db, - )? { - let n = match action { - Action::DontNotify => false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => true, - Action::SetTweak(tweak) => { - tweaks.push(tweak.clone()); - continue; - } - }; - - if notify.is_some() { - return Err(Error::bad_database( - r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, - )); - } - - notify = Some(n); - } - - if notify == Some(true) { - send_notice(unread, pusher, tweaks, pdu, db).await?; - } - // Else the event triggered no actions - - Ok(()) -} - -#[tracing::instrument(skip(user, ruleset, pdu, db))] -pub fn get_actions<'a>( - user: &UserId, - ruleset: &'a Ruleset, - power_levels: &RoomPowerLevelsEventContent, - pdu: &Raw, - room_id: &RoomId, - db: &Database, -) -> Result<&'a [Action]> { - let ctx = PushConditionRoomCtx { - room_id: room_id.to_owned(), - member_count: 10_u32.into(), // TODO: get member count efficiently - user_display_name: db - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), - users_power_levels: power_levels.users.clone(), - default_power_level: power_levels.users_default, - notification_power_levels: power_levels.notifications.clone(), - }; - - Ok(ruleset.get_actions(pdu, &ctx)) -} - -#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] -async fn send_notice( - unread: UInt, - pusher: &get_pushers::v3::Pusher, - tweaks: Vec, - event: &PduEvent, - db: &Database, -) -> Result<()> { - // TODO: email - if pusher.kind == PusherKind::Email { - return Ok(()); - } - - // TODO: - // Two problems with this - // 1. if "event_id_only" is the only format kind it seems we should never add more info - // 2. can pusher/devices have conflicting formats - let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); - let url = if let Some(url) = &pusher.data.url { - url - } else { - error!("Http Pusher must have URL specified."); - return Ok(()); - }; - - let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); - let mut data_minus_url = pusher.data.clone(); - // The url must be stripped off according to spec - data_minus_url.url = None; - device.data = data_minus_url; - - // Tweaks are only added if the format is NOT event_id_only - if !event_id_only { - device.tweaks = tweaks.clone(); - } - - let d = &[device]; - let mut notifi = Notification::new(d); - - notifi.prio = NotificationPriority::Low; - notifi.event_id = Some(&event.event_id); - notifi.room_id = Some(&event.room_id); - // TODO: missed calls - notifi.counts = NotificationCounts::new(unread, uint!(0)); - - if event.kind == RoomEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) - { - notifi.prio = NotificationPriority::High - } - - if event_id_only { - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } else { - notifi.sender = Some(&event.sender); - notifi.event_type = Some(&event.kind); - let content = serde_json::value::to_raw_value(&event.content).ok(); - notifi.content = content.as_deref(); - - if event.kind == RoomEventType::RoomMember { - notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); - } - - let user_name = db.users.displayname(&event.sender)?; - notifi.sender_display_name = user_name.as_deref(); - - let room_name = if let Some(room_name_pdu) = - db.rooms - .room_state_get(&event.room_id, &StateEventType::RoomName, "")? - { - serde_json::from_str::(room_name_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid room name event in database."))? - .name - } else { - None - }; - - notifi.room_name = room_name.as_deref(); - - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } - - // TODO: email - - Ok(()) -} diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs new file mode 100644 index 00000000..468ad8b4 --- /dev/null +++ b/src/service/pusher/data.rs @@ -0,0 +1,12 @@ +pub trait Data { + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>; + + pub fn get_pusher(&self, senderkey: &[u8]) -> Result>; + + pub fn get_pushers(&self, sender: &UserId) -> Result>; + + pub fn get_pusher_senderkeys<'a>( + &'a self, + sender: &UserId, + ) -> impl Iterator> + 'a; +} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs new file mode 100644 index 00000000..342763e8 --- /dev/null +++ b/src/service/pusher/mod.rs @@ -0,0 +1,287 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { + self.db.set_pusher(sender, pusher) + } + + pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { + self.db.get_pusher(senderkey) + } + + pub fn get_pushers(&self, sender: &UserId) -> Result> { + self.db.get_pushers(sender) + } + + pub fn get_pusher_senderkeys<'a>( + &'a self, + sender: &UserId, + ) -> impl Iterator> + 'a { + self.db.get_pusher_senderkeys(sender) + } + + #[tracing::instrument(skip(globals, destination, request))] + pub async fn send_request( + globals: &crate::database::globals::Globals, + destination: &str, + request: T, + ) -> Result + where + T: Debug, + { + let destination = destination.replace("/_matrix/push/v1/notify", ""); + + let http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + })? + .map(|body| body.freeze()); + + let reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); + + // TODO: we could keep this very short and let expo backoff do it's thing... + //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); + + let url = reqwest_request.url().clone(); + let response = globals.default_client().execute(reqwest_request).await; + + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if status != 200 { + info!( + "Push gateway returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + crate::utils::string_from_bytes(&body) + ); + } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + response.map_err(|_| { + info!( + "Push gateway returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Push gateway returned bad response.") + }) + } + Err(e) => Err(e.into()), + } + } + + #[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] + pub async fn send_push_notice( + user: &UserId, + unread: UInt, + pusher: &get_pushers::v3::Pusher, + ruleset: Ruleset, + pdu: &PduEvent, + db: &Database, + ) -> Result<()> { + let mut notify = None; + let mut tweaks = Vec::new(); + + let power_levels: RoomPowerLevelsEventContent = db + .rooms + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + for action in get_actions( + user, + &ruleset, + &power_levels, + &pdu.to_sync_room_event(), + &pdu.room_id, + db, + )? { + let n = match action { + Action::DontNotify => false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => true, + Action::SetTweak(tweak) => { + tweaks.push(tweak.clone()); + continue; + } + }; + + if notify.is_some() { + return Err(Error::bad_database( + r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, + )); + } + + notify = Some(n); + } + + if notify == Some(true) { + send_notice(unread, pusher, tweaks, pdu, db).await?; + } + // Else the event triggered no actions + + Ok(()) + } + + #[tracing::instrument(skip(user, ruleset, pdu, db))] + pub fn get_actions<'a>( + user: &UserId, + ruleset: &'a Ruleset, + power_levels: &RoomPowerLevelsEventContent, + pdu: &Raw, + room_id: &RoomId, + db: &Database, + ) -> Result<&'a [Action]> { + let ctx = PushConditionRoomCtx { + room_id: room_id.to_owned(), + member_count: 10_u32.into(), // TODO: get member count efficiently + user_display_name: db + .users + .displayname(user)? + .unwrap_or_else(|| user.localpart().to_owned()), + users_power_levels: power_levels.users.clone(), + default_power_level: power_levels.users_default, + notification_power_levels: power_levels.notifications.clone(), + }; + + Ok(ruleset.get_actions(pdu, &ctx)) + } + + #[tracing::instrument(skip(unread, pusher, tweaks, event, db))] + async fn send_notice( + unread: UInt, + pusher: &get_pushers::v3::Pusher, + tweaks: Vec, + event: &PduEvent, + db: &Database, + ) -> Result<()> { + // TODO: email + if pusher.kind == PusherKind::Email { + return Ok(()); + } + + // TODO: + // Two problems with this + // 1. if "event_id_only" is the only format kind it seems we should never add more info + // 2. can pusher/devices have conflicting formats + let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); + let url = if let Some(url) = &pusher.data.url { + url + } else { + error!("Http Pusher must have URL specified."); + return Ok(()); + }; + + let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); + let mut data_minus_url = pusher.data.clone(); + // The url must be stripped off according to spec + data_minus_url.url = None; + device.data = data_minus_url; + + // Tweaks are only added if the format is NOT event_id_only + if !event_id_only { + device.tweaks = tweaks.clone(); + } + + let d = &[device]; + let mut notifi = Notification::new(d); + + notifi.prio = NotificationPriority::Low; + notifi.event_id = Some(&event.event_id); + notifi.room_id = Some(&event.room_id); + // TODO: missed calls + notifi.counts = NotificationCounts::new(unread, uint!(0)); + + if event.kind == RoomEventType::RoomEncrypted + || tweaks + .iter() + .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + { + notifi.prio = NotificationPriority::High + } + + if event_id_only { + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } else { + notifi.sender = Some(&event.sender); + notifi.event_type = Some(&event.kind); + let content = serde_json::value::to_raw_value(&event.content).ok(); + notifi.content = content.as_deref(); + + if event.kind == RoomEventType::RoomMember { + notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); + } + + let user_name = db.users.displayname(&event.sender)?; + notifi.sender_display_name = user_name.as_deref(); + + let room_name = if let Some(room_name_pdu) = + db.rooms + .room_state_get(&event.room_id, &StateEventType::RoomName, "")? + { + serde_json::from_str::(room_name_pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid room name event in database."))? + .name + } else { + None + }; + + notifi.room_name = room_name.as_deref(); + + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } + + // TODO: email + + Ok(()) + } +} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index c44d357c..a8e87b91 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,4 +1,13 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { pub fn get_or_create_shorteventid( &self, event_id: &EventId, @@ -222,4 +231,4 @@ } }) } - +} diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs new file mode 100644 index 00000000..4e5c3796 --- /dev/null +++ b/src/service/rooms/timeline/data.rs @@ -0,0 +1,66 @@ +pub trait Data { + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; + + /// Returns the `count` of this pdu's id. + fn get_pdu_count(&self, event_id: &EventId) -> Result>; + + /// Returns the json of a pdu. + pub fn get_pdu_json(&self, event_id: &EventId) -> Result>; + + /// Returns the json of a pdu. + pub fn get_non_outlier_pdu_json( + + /// Returns the pdu's id. + pub fn get_pdu_id(&self, event_id: &EventId) -> Result>>; + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result>; + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_pdu(&self, event_id: &EventId) -> Result>>; + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; + + /// Returns the pdu as a `BTreeMap`. + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; + + /// Returns the `count` of this pdu's id. + pub fn pdu_count(&self, pdu_id: &[u8]) -> Result; + + /// Removes a pdu and creates a new one with the same id. + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()>; + + /// Returns an iterator over all events in a room that happened after the event with id `since` + /// in chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_since<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + since: u64, + ) -> Result, PduEvent)>> + 'a>; + + /// Returns an iterator over all events and their tokens in a room that happened before the + /// event with id `until` in reverse-chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_until<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + until: u64, + ) -> Result, PduEvent)>> + 'a>; + + pub fn pdus_after<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + from: u64, + ) -> Result, PduEvent)>> + 'a>; +} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 5b423d2d..c6393c68 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,4 +1,14 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + /* /// Checks if a room exists. #[tracing::instrument(skip(self))] pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { @@ -20,38 +30,15 @@ .next() .transpose() } + */ #[tracing::instrument(skip(self))] pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - match self - .lasttimelinecount_cache - .lock() - .unwrap() - .entry(room_id.to_owned()) - { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(&sender_user, &room_id, u64::MAX)? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .map(|(pduid, _)| self.pdu_count(&pduid)) - .next() - { - Ok(*v.insert(last_count?)) - } else { - Ok(0) - } - } - hash_map::Entry::Occupied(o) => Ok(*o.get()), - } + self.db.last_timeline_count(sender_user: &UserId, room_id: &RoomId) } // TODO Is this the same as the function above? + /* #[tracing::instrument(skip(self))] pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { let prefix = self @@ -71,33 +58,16 @@ .transpose() .map(|op| op.unwrap_or_default()) } - - + */ /// Returns the `count` of this pdu's id. pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pdu_id| self.pdu_count(&pdu_id)) - .transpose() + self.db.get_pdu_count(event_id) } /// Returns the json of a pdu. pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map_or_else( - || self.eventid_outlierpdu.get(event_id.as_bytes()), - |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { - Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) - }, - )? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() + self.db.get_pdu_json(event_id) } /// Returns the json of a pdu. @@ -105,122 +75,49 @@ &self, event_id: &EventId, ) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() + self.db.get_non_outlier_pdu(event_id) } /// Returns the pdu's id. pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.eventid_pduid.get(event_id.as_bytes()) + self.db.get_pdu_id(event_id) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() + self.db.get_non_outlier_pdu(event_id) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. pub fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { - return Ok(Some(Arc::clone(p))); - } - - if let Some(pdu) = self - .eventid_pduid - .get(event_id.as_bytes())? - .map_or_else( - || self.eventid_outlierpdu.get(event_id.as_bytes()), - |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { - Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) - }, - )? - .map(|pdu| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db.")) - .map(Arc::new) - }) - .transpose()? - { - self.pdu_cache - .lock() - .unwrap() - .insert(event_id.to_owned(), Arc::clone(&pdu)); - Ok(Some(pdu)) - } else { - Ok(None) - } + self.db.get_pdu(event_id) } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + self.db.get_pdu_from_id(pdu_id) } /// Returns the pdu as a `BTreeMap`. pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + self.db.get_pdu_json_from_id(pdu_id) } /// Returns the `count` of this pdu's id. pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { - utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) + self.db.pdu_count(pdu_id) } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self))] fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"), - )?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::NotFound, - "PDU does not exist.", - )) - } + self.db.pdu_count(pdu_id, pdu: &PduEvent) } /// Creates a new persisted data unit and adds it to a room. @@ -803,7 +700,6 @@ } /// Returns an iterator over all PDUs in a room. - #[tracing::instrument(skip(self))] pub fn all_pdus<'a>( &'a self, user_id: &UserId, @@ -814,37 +710,13 @@ /// Returns an iterator over all events in a room that happened after the event with id `since` /// in chronological order. - #[tracing::instrument(skip(self))] pub fn pdus_since<'a>( &'a self, user_id: &UserId, room_id: &RoomId, since: u64, ) -> Result, PduEvent)>> + 'a> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Skip the first pdu if it's exactly at since, because we sent that last time - let mut first_pdu_id = prefix.clone(); - first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(&first_pdu_id, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) + self.db.pdus_since(user_id, room_id, since) } /// Returns an iterator over all events and their tokens in a room that happened before the @@ -856,32 +728,7 @@ room_id: &RoomId, until: u64, ) -> Result, PduEvent)>> + 'a> { - // Create the first part of the full pdu id - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` - - let current: &[u8] = ¤t; - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(current, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) + self.db.pdus_until(user_id, room_id, until) } /// Returns an iterator over all events and their token in a room that happened after the event @@ -893,32 +740,7 @@ room_id: &RoomId, from: u64, ) -> Result, PduEvent)>> + 'a> { - // Create the first part of the full pdu id - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event - - let current: &[u8] = ¤t; - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(current, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) + self.db.pdus_after(user_id, room_id, from) } /// Replace a PDU with the redacted form. diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs new file mode 100644 index 00000000..f1ff5f88 --- /dev/null +++ b/src/service/transaction_ids/data.rs @@ -0,0 +1,16 @@ +pub trait Data { + pub fn add_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + data: &[u8], + ) -> Result<()>; + + pub fn existing_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + ) -> Result>>; +} diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs new file mode 100644 index 00000000..d944847e --- /dev/null +++ b/src/service/transaction_ids/mod.rs @@ -0,0 +1,44 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + pub fn add_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + data: &[u8], + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default()); + key.push(0xff); + key.extend_from_slice(txn_id.as_bytes()); + + self.userdevicetxnid_response.insert(&key, data)?; + + Ok(()) + } + + pub fn existing_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default()); + key.push(0xff); + key.extend_from_slice(txn_id.as_bytes()); + + // If there's no entry, this is a new transaction + self.userdevicetxnid_response.get(&key) + } +} diff --git a/src/service/users/data.rs b/src/service/users/data.rs new file mode 100644 index 00000000..d99d0328 --- /dev/null +++ b/src/service/users/data.rs @@ -0,0 +1,228 @@ +pub trait Data { + /// Check if a user has an account on this homeserver. + pub fn exists(&self, user_id: &UserId) -> Result; + + /// Check if account is deactivated + pub fn is_deactivated(&self, user_id: &UserId) -> Result; + + /// Check if a user is an admin + pub fn is_admin( + &self, + user_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result; + + /// Create a new user account on this homeserver. + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; + + /// Returns the number of users registered on this server. + pub fn count(&self) -> Result; + + /// Find out which user an access token belongs to. + pub fn find_from_token(&self, token: &str) -> Result, String)>>; + + /// Returns an iterator over all users on this homeserver. + pub fn iter(&self) -> impl Iterator>> + '_; + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is greater then zero. + pub fn list_local_users(&self) -> Result>; + + /// Will only return with Some(username) if the password was not empty and the + /// username could be successfully parsed. + /// If utils::string_from_bytes(...) returns an error that username will be skipped + /// and the error will be logged. + fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option; + + /// Returns the password hash for the given user. + pub fn password_hash(&self, user_id: &UserId) -> Result>; + + /// Hash and set the user's password to the Argon2 hash + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; + + /// Returns the displayname of a user on this homeserver. + pub fn displayname(&self, user_id: &UserId) -> Result>; + + /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()>; + + /// Get the avatar_url of a user. + pub fn avatar_url(&self, user_id: &UserId) -> Result>>; + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()>; + + /// Get the blurhash of a user. + pub fn blurhash(&self, user_id: &UserId) -> Result>; + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; + + /// Adds a new device to a user. + pub fn create_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + initial_device_display_name: Option, + ) -> Result<()>; + + /// Removes a device from a user. + pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + + /// Returns an iterator over all device ids of this user. + pub fn all_device_ids<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator>> + 'a; + + /// Replaces the access token of one device. + pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; + + pub fn add_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + globals: &super::globals::Globals, + ) -> Result<()>; + + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; + + pub fn take_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + key_algorithm: &DeviceKeyAlgorithm, + globals: &super::globals::Globals, + ) -> Result, Raw)>>; + + pub fn count_one_time_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>; + + pub fn add_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + device_keys: &Raw, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()>; + + pub fn add_cross_signing_keys( + &self, + user_id: &UserId, + master_key: &Raw, + self_signing_key: &Option>, + user_signing_key: &Option>, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()>; + + pub fn sign_key( + &self, + target_id: &UserId, + key_id: &str, + signature: (String, String), + sender_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()>; + + pub fn keys_changed<'a>( + &'a self, + user_or_room_id: &str, + from: u64, + to: Option, + ) -> impl Iterator>> + 'a; + + pub fn mark_device_key_update( + &self, + user_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()>; + + pub fn get_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>>; + + pub fn get_master_key bool>( + &self, + user_id: &UserId, + allowed_signatures: F, + ) -> Result>>; + + pub fn get_self_signing_key bool>( + &self, + user_id: &UserId, + allowed_signatures: F, + ) -> Result>>; + + pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; + + pub fn add_to_device_event( + &self, + sender: &UserId, + target_user_id: &UserId, + target_device_id: &DeviceId, + event_type: &str, + content: serde_json::Value, + globals: &super::globals::Globals, + ) -> Result<()>; + + pub fn get_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>>; + + pub fn remove_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + until: u64, + ) -> Result<()>; + + pub fn update_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + device: &Device, + ) -> Result<()>; + + /// Get device metadata. + pub fn get_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>; + + pub fn get_devicelist_version(&self, user_id: &UserId) -> Result>; + + pub fn all_devices_metadata<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator> + 'a; + + /// Creates a new sync filter. Returns the filter id. + pub fn create_filter( + &self, + user_id: &UserId, + filter: &IncomingFilterDefinition, + ) -> Result; + + pub fn get_filter( + &self, + user_id: &UserId, + filter_id: &str, + ) -> Result>; +} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs new file mode 100644 index 00000000..93d6ea52 --- /dev/null +++ b/src/service/users/mod.rs @@ -0,0 +1,354 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + /// Check if a user has an account on this homeserver. + pub fn exists(&self, user_id: &UserId) -> Result { + self.db.exists(user_id) + } + + /// Check if account is deactivated + pub fn is_deactivated(&self, user_id: &UserId) -> Result { + self.db.is_deactivated(user_id) + } + + /// Check if a user is an admin + pub fn is_admin( + &self, + user_id: &UserId, + ) -> Result { + self.db.is_admin(user_id) + } + + /// Create a new user account on this homeserver. + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.db.set_password(user_id, password) + } + + /// Returns the number of users registered on this server. + pub fn count(&self) -> Result { + self.db.count() + } + + /// Find out which user an access token belongs to. + pub fn find_from_token(&self, token: &str) -> Result, String)>> { + self.db.find_from_token(token) + } + + /// Returns an iterator over all users on this homeserver. + pub fn iter(&self) -> impl Iterator>> + '_ { + self.db.iter() + } + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is greater then zero. + pub fn list_local_users(&self) -> Result> { + self.db.list_local_users() + } + + /// Will only return with Some(username) if the password was not empty and the + /// username could be successfully parsed. + /// If utils::string_from_bytes(...) returns an error that username will be skipped + /// and the error will be logged. + fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { + self.db.get_username_with_valid_password(username, password) + } + + /// Returns the password hash for the given user. + pub fn password_hash(&self, user_id: &UserId) -> Result> { + self.db.password_hash(user_id) + } + + /// Hash and set the user's password to the Argon2 hash + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.db.set_password(user_id, password) + } + + /// Returns the displayname of a user on this homeserver. + pub fn displayname(&self, user_id: &UserId) -> Result> { + self.db.displayname(user_id) + } + + /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + self.db.set_displayname(user_id, displayname) + } + + /// Get the avatar_url of a user. + pub fn avatar_url(&self, user_id: &UserId) -> Result>> { + self.db.avatar_url(user_id) + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { + self.db.set_avatar_url(user_id, avatar_url) + } + + /// Get the blurhash of a user. + pub fn blurhash(&self, user_id: &UserId) -> Result> { + self.db.blurhash(user_id) + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + self.db.set_blurhash(user_id, blurhash) + } + + /// Adds a new device to a user. + pub fn create_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + initial_device_display_name: Option, + ) -> Result<()> { + self.db.create_device(user_id, device_id, token, initial_device_display_name) + } + + /// Removes a device from a user. + pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + self.db.remove_device(user_id, device_id) + } + + /// Returns an iterator over all device ids of this user. + pub fn all_device_ids<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator>> + 'a { + self.db.all_device_ids(user_id) + } + + /// Replaces the access token of one device. + pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + self.db.set_token(user_id, device_id, token) + } + + pub fn add_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + globals: &super::globals::Globals, + ) -> Result<()> { + self.db.add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) + } + + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + self.db.last_one_time_keys_update(user_id) + } + + pub fn take_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + key_algorithm: &DeviceKeyAlgorithm, + ) -> Result, Raw)>> { + self.db.take_one_time_key(user_id, device_id, key_algorithm) + } + + pub fn count_one_time_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + self.db.count_one_time_keys(user_id, device_id) + } + + pub fn add_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + device_keys: &Raw, + ) -> Result<()> { + self.db.add_device_keys(user_id, device_id, device_keys) + } + + pub fn add_cross_signing_keys( + &self, + user_id: &UserId, + master_key: &Raw, + self_signing_key: &Option>, + user_signing_key: &Option>, + ) -> Result<()> { + self.db.add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key) + } + + pub fn sign_key( + &self, + target_id: &UserId, + key_id: &str, + signature: (String, String), + sender_id: &UserId, + ) -> Result<()> { + self.db.sign_key(target_id, key_id, signature, sender_id) + } + + pub fn keys_changed<'a>( + &'a self, + user_or_room_id: &str, + from: u64, + to: Option, + ) -> impl Iterator>> + 'a { + self.db.keys_changed(user_or_room_id, from, to) + } + + pub fn mark_device_key_update( + &self, + user_id: &UserId, + ) -> Result<()> { + self.db.mark_device_key_update(user_id) + } + + pub fn get_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>> { + self.db.get_device_keys(user_id, device_id) + } + + pub fn get_master_key bool>( + &self, + user_id: &UserId, + allowed_signatures: F, + ) -> Result>> { + self.db.get_master_key(user_id, allow_signatures) + } + + pub fn get_self_signing_key bool>( + &self, + user_id: &UserId, + allowed_signatures: F, + ) -> Result>> { + self.db.get_self_signing_key(user_id, allowed_signatures) + } + + pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + self.db.get_user_signing_key(user_id) + } + + pub fn add_to_device_event( + &self, + sender: &UserId, + target_user_id: &UserId, + target_device_id: &DeviceId, + event_type: &str, + content: serde_json::Value, + ) -> Result<()> { + self.db.add_to_device_event(sender, target_user_id, target_device_id, event_type, content) + } + + pub fn get_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>> { + self.get_to_device_events(user_id, device_id) + } + + pub fn remove_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + until: u64, + ) -> Result<()> { + self.db.remove_to_device_events(user_id, device_id, until) + } + + pub fn update_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + device: &Device, + ) -> Result<()> { + self.db.update_device_metadata(user_id, device_id, device) + } + + /// Get device metadata. + pub fn get_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + self.get_device_metadata(user_id, device_id) + } + + pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + self.db.devicelist_version(user_id) + } + + pub fn all_devices_metadata<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator> + 'a { + self.db.all_devices_metadata(user_id) + } + + /// Deactivate account + pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + // Remove all associated devices + for device_id in self.all_device_ids(user_id) { + self.remove_device(user_id, &device_id?)?; + } + + // Set the password to "" to indicate a deactivated account. Hashes will never result in an + // empty string, so the user will not be able to log in again. Systems like changing the + // password without logging in should check if the account is deactivated. + self.userid_password.insert(user_id.as_bytes(), &[])?; + + // TODO: Unhook 3PID + Ok(()) + } + + /// Creates a new sync filter. Returns the filter id. + pub fn create_filter( + &self, + user_id: &UserId, + filter: &IncomingFilterDefinition, + ) -> Result { + self.db.create_filter(user_id, filter) + } + + pub fn get_filter( + &self, + user_id: &UserId, + filter_id: &str, + ) -> Result> { + self.db.get_filter(user_id, filter_id) + } +} + +/// Ensure that a user only sees signatures from themselves and the target user +fn clean_signatures bool>( + cross_signing_key: &mut serde_json::Value, + user_id: &UserId, + allowed_signatures: F, +) -> Result<(), Error> { + if let Some(signatures) = cross_signing_key + .get_mut("signatures") + .and_then(|v| v.as_object_mut()) + { + // Don't allocate for the full size of the current signatures, but require + // at most one resize if nothing is dropped + let new_capacity = signatures.len() / 2; + for (user, signature) in + mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) + { + let id = <&UserId>::try_from(user.as_str()) + .map_err(|_| Error::bad_database("Invalid user ID in database."))?; + if id == user_id || allowed_signatures(id) { + signatures.insert(user, signature); + } + } + } + + Ok(()) +}