From a7e34eb0b3851954bb40925ffd306457f7b8fb4f Mon Sep 17 00:00:00 2001 From: chayleaf Date: Sat, 22 Jun 2024 21:08:17 +0700 Subject: [PATCH 1/3] asyncify KeyValueDatabaseEngine --- src/database/abstraction.rs | 12 +- src/database/abstraction/persy.rs | 8 +- src/database/abstraction/rocksdb.rs | 10 +- src/database/abstraction/sqlite.rs | 10 +- src/database/key_value/globals.rs | 8 +- src/database/mod.rs | 174 ++++++++++++++-------------- src/service/admin/mod.rs | 2 +- src/service/globals/data.rs | 4 +- src/service/globals/mod.rs | 4 +- 9 files changed, 121 insertions(+), 111 deletions(-) diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 93660f9f..6086ba20 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -1,6 +1,7 @@ use super::Config; use crate::Result; +use async_trait::async_trait; use std::{future::Future, pin::Pin, sync::Arc}; #[cfg(feature = "sled")] @@ -26,16 +27,17 @@ pub mod persy; ))] pub mod watchers; +#[async_trait] pub trait KeyValueDatabaseEngine: Send + Sync { - fn open(config: &Config) -> Result + async fn open(config: &Config) -> Result where Self: Sized; - fn open_tree(&self, name: &'static str) -> Result>; - fn flush(&self) -> Result<()>; - fn cleanup(&self) -> Result<()> { + async fn open_tree(&self, name: &'static str) -> Result>; + async fn flush(&self) -> Result<()>; + async fn cleanup(&self) -> Result<()> { Ok(()) } - fn memory_usage(&self) -> Result { + async fn memory_usage(&self) -> Result { Ok("Current database engine does not support memory usage reporting.".to_owned()) } } diff --git a/src/database/abstraction/persy.rs b/src/database/abstraction/persy.rs index da7d4cf0..31197809 100644 --- a/src/database/abstraction/persy.rs +++ b/src/database/abstraction/persy.rs @@ -5,6 +5,7 @@ use crate::{ }, Result, }; +use async_trait::async_trait; use persy::{ByteVec, OpenOptions, Persy, Transaction, TransactionConfig, ValueMode}; use std::{future::Future, pin::Pin, sync::Arc}; @@ -15,8 +16,9 @@ pub struct Engine { persy: Persy, } +#[async_trait] impl KeyValueDatabaseEngine for Arc { - fn open(config: &Config) -> Result { + async fn open(config: &Config) -> Result { let mut cfg = persy::Config::new(); cfg.change_cache_size((config.db_cache_capacity_mb * 1024.0 * 1024.0) as u64); @@ -27,7 +29,7 @@ impl KeyValueDatabaseEngine for Arc { Ok(Arc::new(Engine { persy })) } - fn open_tree(&self, name: &'static str) -> Result> { + async fn open_tree(&self, name: &'static str) -> Result> { // Create if it doesn't exist if !self.persy.exists_index(name)? { let mut tx = self.persy.begin()?; @@ -42,7 +44,7 @@ impl KeyValueDatabaseEngine for Arc { })) } - fn flush(&self) -> Result<()> { + async fn flush(&self) -> Result<()> { Ok(()) } } diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index cf77e3dd..0dddfea0 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -1,5 +1,6 @@ use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{utils, Result}; +use async_trait::async_trait; use std::{ future::Future, pin::Pin, @@ -56,8 +57,9 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O db_opts } +#[async_trait] impl KeyValueDatabaseEngine for Arc { - fn open(config: &Config) -> Result { + async fn open(config: &Config) -> Result { let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes); @@ -88,7 +90,7 @@ impl KeyValueDatabaseEngine for Arc { })) } - fn open_tree(&self, name: &'static str) -> Result> { + async fn open_tree(&self, name: &'static str) -> Result> { if !self.old_cfs.contains(&name.to_owned()) { // Create if it didn't exist let _ = self @@ -104,12 +106,12 @@ impl KeyValueDatabaseEngine for Arc { })) } - fn flush(&self) -> Result<()> { + async fn flush(&self) -> Result<()> { // TODO? Ok(()) } - fn memory_usage(&self) -> Result { + async fn memory_usage(&self) -> Result { let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?; Ok(format!( diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index b448c3b6..a94b78d2 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,5 +1,6 @@ use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{database::Config, Result}; +use async_trait::async_trait; use parking_lot::{Mutex, MutexGuard}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use std::{ @@ -80,8 +81,9 @@ impl Engine { } } +#[async_trait] impl KeyValueDatabaseEngine for Arc { - fn open(config: &Config) -> Result { + async fn open(config: &Config) -> Result { let path = Path::new(&config.database_path).join("conduit.db"); // calculates cache-size per permanent connection @@ -105,7 +107,7 @@ impl KeyValueDatabaseEngine for Arc { Ok(arc) } - fn open_tree(&self, name: &str) -> Result> { + async fn open_tree(&self, name: &'static str) -> Result> { self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?; Ok(Arc::new(SqliteTable { @@ -115,12 +117,12 @@ impl KeyValueDatabaseEngine for Arc { })) } - fn flush(&self) -> Result<()> { + async fn flush(&self) -> Result<()> { // we enabled PRAGMA synchronous=normal, so this should not be necessary Ok(()) } - fn cleanup(&self) -> Result<()> { + async fn cleanup(&self) -> Result<()> { self.flush_wal() } } diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index bd47cb42..968d6420 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -139,11 +139,11 @@ impl service::globals::Data for KeyValueDatabase { Ok(()) } - fn cleanup(&self) -> Result<()> { - self._db.cleanup() + async fn cleanup(&self) -> Result<()> { + self._db.cleanup().await } - fn memory_usage(&self) -> String { + async fn memory_usage(&self) -> String { let pdu_cache = self.pdu_cache.lock().unwrap().len(); let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len(); let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); @@ -164,7 +164,7 @@ our_real_users_cache: {our_real_users_cache} appservice_in_room_cache: {appservice_in_room_cache} lasttimelinecount_cache: {lasttimelinecount_cache}\n" ); - if let Ok(db_stats) = self._db.memory_usage() { + if let Ok(db_stats) = self._db.memory_usage().await { response += &db_stats; } diff --git a/src/database/mod.rs b/src/database/mod.rs index 5171d4bb..1d549f19 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -249,19 +249,19 @@ impl KeyValueDatabase { #[cfg(not(feature = "sqlite"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "sqlite")] - Arc::new(Arc::::open(&config)?) + Arc::new(Arc::::open(&config).await?) } "rocksdb" => { #[cfg(not(feature = "rocksdb"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "rocksdb")] - Arc::new(Arc::::open(&config)?) + Arc::new(Arc::::open(&config).await?) } "persy" => { #[cfg(not(feature = "persy"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "persy")] - Arc::new(Arc::::open(&config)?) + Arc::new(Arc::::open(&config).await?) } _ => { return Err(Error::BadConfig("Database backend not found.")); @@ -278,101 +278,102 @@ impl KeyValueDatabase { let db_raw = Box::new(Self { _db: builder.clone(), - userid_password: builder.open_tree("userid_password")?, - userid_displayname: builder.open_tree("userid_displayname")?, - userid_avatarurl: builder.open_tree("userid_avatarurl")?, - userid_blurhash: builder.open_tree("userid_blurhash")?, - userdeviceid_token: builder.open_tree("userdeviceid_token")?, - userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, - userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, - token_userdeviceid: builder.open_tree("token_userdeviceid")?, - onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, - userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, - keychangeid_userid: builder.open_tree("keychangeid_userid")?, - keyid_key: builder.open_tree("keyid_key")?, - userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, - userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, - userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, - openidtoken_expiresatuserid: builder.open_tree("openidtoken_expiresatuserid")?, - userfilterid_filter: builder.open_tree("userfilterid_filter")?, - todeviceid_events: builder.open_tree("todeviceid_events")?, + userid_password: builder.open_tree("userid_password").await?, + userid_displayname: builder.open_tree("userid_displayname").await?, + userid_avatarurl: builder.open_tree("userid_avatarurl").await?, + userid_blurhash: builder.open_tree("userid_blurhash").await?, + userdeviceid_token: builder.open_tree("userdeviceid_token").await?, + userdeviceid_metadata: builder.open_tree("userdeviceid_metadata").await?, + userid_devicelistversion: builder.open_tree("userid_devicelistversion").await?, + token_userdeviceid: builder.open_tree("token_userdeviceid").await?, + onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys").await?, + userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate").await?, + keychangeid_userid: builder.open_tree("keychangeid_userid").await?, + keyid_key: builder.open_tree("keyid_key").await?, + userid_masterkeyid: builder.open_tree("userid_masterkeyid").await?, + userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid").await?, + userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid").await?, + openidtoken_expiresatuserid: builder.open_tree("openidtoken_expiresatuserid").await?, + userfilterid_filter: builder.open_tree("userfilterid_filter").await?, + todeviceid_events: builder.open_tree("todeviceid_events").await?, - userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo").await?, userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, - roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt + readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt").await?, + roomuserid_privateread: builder.open_tree("roomuserid_privateread").await?, // "Private" read receipt roomuserid_lastprivatereadupdate: builder - .open_tree("roomuserid_lastprivatereadupdate")?, - presenceid_presence: builder.open_tree("presenceid_presence")?, - userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, - pduid_pdu: builder.open_tree("pduid_pdu")?, - eventid_pduid: builder.open_tree("eventid_pduid")?, - roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, + .open_tree("roomuserid_lastprivatereadupdate") + .await?, + presenceid_presence: builder.open_tree("presenceid_presence").await?, + userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate").await?, + pduid_pdu: builder.open_tree("pduid_pdu").await?, + eventid_pduid: builder.open_tree("eventid_pduid").await?, + roomid_pduleaves: builder.open_tree("roomid_pduleaves").await?, - alias_roomid: builder.open_tree("alias_roomid")?, - aliasid_alias: builder.open_tree("aliasid_alias")?, - publicroomids: builder.open_tree("publicroomids")?, + alias_roomid: builder.open_tree("alias_roomid").await?, + aliasid_alias: builder.open_tree("aliasid_alias").await?, + publicroomids: builder.open_tree("publicroomids").await?, - threadid_userids: builder.open_tree("threadid_userids")?, + threadid_userids: builder.open_tree("threadid_userids").await?, - tokenids: builder.open_tree("tokenids")?, + tokenids: builder.open_tree("tokenids").await?, - roomserverids: builder.open_tree("roomserverids")?, - serverroomids: builder.open_tree("serverroomids")?, - userroomid_joined: builder.open_tree("userroomid_joined")?, - roomuserid_joined: builder.open_tree("roomuserid_joined")?, - roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, - roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, - roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, - userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, - roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, - userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, - roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + roomserverids: builder.open_tree("roomserverids").await?, + serverroomids: builder.open_tree("serverroomids").await?, + userroomid_joined: builder.open_tree("userroomid_joined").await?, + roomuserid_joined: builder.open_tree("roomuserid_joined").await?, + roomid_joinedcount: builder.open_tree("roomid_joinedcount").await?, + roomid_invitedcount: builder.open_tree("roomid_invitedcount").await?, + roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids").await?, + userroomid_invitestate: builder.open_tree("userroomid_invitestate").await?, + roomuserid_invitecount: builder.open_tree("roomuserid_invitecount").await?, + userroomid_leftstate: builder.open_tree("userroomid_leftstate").await?, + roomuserid_leftcount: builder.open_tree("roomuserid_leftcount").await?, - alias_userid: builder.open_tree("alias_userid")?, + alias_userid: builder.open_tree("alias_userid").await?, - disabledroomids: builder.open_tree("disabledroomids")?, + disabledroomids: builder.open_tree("disabledroomids").await?, - lazyloadedids: builder.open_tree("lazyloadedids")?, + lazyloadedids: builder.open_tree("lazyloadedids").await?, - userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, - userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, - roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, + userroomid_notificationcount: builder.open_tree("userroomid_notificationcount").await?, + userroomid_highlightcount: builder.open_tree("userroomid_highlightcount").await?, + roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount").await?, - statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, - shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, + statekey_shortstatekey: builder.open_tree("statekey_shortstatekey").await?, + shortstatekey_statekey: builder.open_tree("shortstatekey_statekey").await?, - shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, + shorteventid_authchain: builder.open_tree("shorteventid_authchain").await?, - roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, + roomid_shortroomid: builder.open_tree("roomid_shortroomid").await?, - shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, - eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, - shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, - shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, - roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, - roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, - statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, + shortstatehash_statediff: builder.open_tree("shortstatehash_statediff").await?, + eventid_shorteventid: builder.open_tree("eventid_shorteventid").await?, + shorteventid_eventid: builder.open_tree("shorteventid_eventid").await?, + shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash").await?, + roomid_shortstatehash: builder.open_tree("roomid_shortstatehash").await?, + roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash").await?, + statehash_shortstatehash: builder.open_tree("statehash_shortstatehash").await?, - eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, - softfailedeventids: builder.open_tree("softfailedeventids")?, + eventid_outlierpdu: builder.open_tree("eventid_outlierpdu").await?, + softfailedeventids: builder.open_tree("softfailedeventids").await?, - tofrom_relation: builder.open_tree("tofrom_relation")?, - referencedevents: builder.open_tree("referencedevents")?, - roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, - roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, - mediaid_file: builder.open_tree("mediaid_file")?, - backupid_algorithm: builder.open_tree("backupid_algorithm")?, - backupid_etag: builder.open_tree("backupid_etag")?, - backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, - userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, - servername_educount: builder.open_tree("servername_educount")?, - servernameevent_data: builder.open_tree("servernameevent_data")?, - servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, - id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, - senderkey_pusher: builder.open_tree("senderkey_pusher")?, - global: builder.open_tree("global")?, - server_signingkeys: builder.open_tree("server_signingkeys")?, + tofrom_relation: builder.open_tree("tofrom_relation").await?, + referencedevents: builder.open_tree("referencedevents").await?, + roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata").await?, + roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid").await?, + mediaid_file: builder.open_tree("mediaid_file").await?, + backupid_algorithm: builder.open_tree("backupid_algorithm").await?, + backupid_etag: builder.open_tree("backupid_etag").await?, + backupkeyid_backup: builder.open_tree("backupkeyid_backup").await?, + userdevicetxnid_response: builder.open_tree("userdevicetxnid_response").await?, + servername_educount: builder.open_tree("servername_educount").await?, + servernameevent_data: builder.open_tree("servernameevent_data").await?, + servercurrentevent_data: builder.open_tree("servercurrentevent_data").await?, + id_appserviceregistrations: builder.open_tree("id_appserviceregistrations").await?, + senderkey_pusher: builder.open_tree("senderkey_pusher").await?, + global: builder.open_tree("global").await?, + server_signingkeys: builder.open_tree("server_signingkeys").await?, pdu_cache: Mutex::new(LruCache::new( config @@ -624,7 +625,7 @@ impl KeyValueDatabase { Ok::<_, Error>(()) }; - for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { + for (k, seventid) in db._db.open_tree("stateid_shorteventid").await?.iter() { let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]) .expect("number of bytes is correct"); let sstatekey = k[size_of::()..].to_vec(); @@ -808,7 +809,8 @@ impl KeyValueDatabase { if services().globals.database_version()? < 11 { db._db - .open_tree("userdevicesessionid_uiaarequest")? + .open_tree("userdevicesessionid_uiaarequest") + .await? .clear()?; services().globals.bump_database_version(11)?; @@ -998,10 +1000,10 @@ impl KeyValueDatabase { } #[tracing::instrument(skip(self))] - pub fn flush(&self) -> Result<()> { + pub async fn flush(&self) -> Result<()> { let start = std::time::Instant::now(); - let res = self._db.flush(); + let res = self._db.flush().await; debug!("flush: took {:?}", start.elapsed()); @@ -1094,7 +1096,7 @@ impl KeyValueDatabase { } let start = Instant::now(); - if let Err(e) = services().globals.cleanup() { + if let Err(e) = services().globals.cleanup().await { error!("cleanup: Errored: {}", e); } else { debug!("cleanup: Finished in {:?}", start.elapsed()); diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 70c63381..718219c4 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -522,7 +522,7 @@ impl Service { } AdminCommand::MemoryUsage => { let response1 = services().memory_usage().await; - let response2 = services().globals.db.memory_usage(); + let response2 = services().globals.db.memory_usage().await; RoomMessageEventContent::text_plain(format!( "Services:\n{response1}\n\nDatabase:\n{response2}" diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 167e823c..5fd84539 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -74,8 +74,8 @@ pub trait Data: Send + Sync { fn last_check_for_updates_id(&self) -> Result; fn update_check_for_updates_id(&self, id: u64) -> Result<()>; async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; - fn cleanup(&self) -> Result<()>; - fn memory_usage(&self) -> String; + async fn cleanup(&self) -> Result<()>; + async fn memory_usage(&self) -> String; fn clear_caches(&self, amount: u32); fn load_keypair(&self) -> Result; fn remove_keypair(&self) -> Result<()>; diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index c22ffef3..c0c0bd44 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -281,8 +281,8 @@ impl Service { self.db.watch(user_id, device_id).await } - pub fn cleanup(&self) -> Result<()> { - self.db.cleanup() + pub async fn cleanup(&self) -> Result<()> { + self.db.cleanup().await } pub fn server_name(&self) -> &ServerName { From a8c9e3eebec0e07c1f9c48e14377bac0b3a1836f Mon Sep 17 00:00:00 2001 From: chayleaf Date: Sun, 23 Jun 2024 00:31:27 +0700 Subject: [PATCH 2/3] switch Iterator to Send + Iterator in return types This changes the SQLite implementation quite a bit. It may potentially reduce performance, but allows using the new async API and removes some usage of unsafe --- src/database/abstraction.rs | 6 +- src/database/abstraction/heed.rs | 8 +- src/database/abstraction/persy.rs | 6 +- src/database/abstraction/rocksdb.rs | 6 +- src/database/abstraction/sled.rs | 6 +- src/database/abstraction/sqlite.rs | 150 +++++++----------- src/database/key_value/appservice.rs | 2 +- src/database/key_value/pusher.rs | 2 +- src/database/key_value/rooms/alias.rs | 2 +- src/database/key_value/rooms/directory.rs | 2 +- .../key_value/rooms/edus/read_receipt.rs | 3 +- src/database/key_value/rooms/metadata.rs | 2 +- src/database/key_value/rooms/pdu_metadata.rs | 2 +- src/database/key_value/rooms/search.rs | 2 +- src/database/key_value/rooms/state_cache.rs | 18 ++- src/database/key_value/rooms/threads.rs | 2 +- src/database/key_value/rooms/timeline.rs | 4 +- src/database/key_value/rooms/user.rs | 2 +- src/database/key_value/sending.rs | 7 +- src/database/key_value/users.rs | 8 +- src/service/appservice/data.rs | 2 +- src/service/pusher/data.rs | 6 +- src/service/pusher/mod.rs | 2 +- src/service/rooms/alias/data.rs | 2 +- src/service/rooms/alias/mod.rs | 2 +- src/service/rooms/directory/data.rs | 2 +- src/service/rooms/edus/read_receipt/data.rs | 3 +- src/service/rooms/metadata/data.rs | 2 +- src/service/rooms/metadata/mod.rs | 2 +- src/service/rooms/pdu_metadata/data.rs | 2 +- src/service/rooms/search/data.rs | 2 +- src/service/rooms/state_cache/data.rs | 16 +- src/service/rooms/threads/data.rs | 2 +- src/service/rooms/timeline/data.rs | 4 +- src/service/rooms/user/data.rs | 2 +- src/service/sending/data.rs | 6 +- src/service/users/data.rs | 8 +- 37 files changed, 139 insertions(+), 166 deletions(-) diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 6086ba20..bc9f09dc 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -50,13 +50,13 @@ pub trait KvTree: Send + Sync { fn remove(&self, key: &[u8]) -> Result<()>; - fn iter<'a>(&'a self) -> Box, Vec)> + 'a>; + fn iter<'a>(&'a self) -> Box, Vec)> + 'a>; fn iter_from<'a>( &'a self, from: &[u8], backwards: bool, - ) -> Box, Vec)> + 'a>; + ) -> Box, Vec)> + 'a>; fn increment(&self, key: &[u8]) -> Result>; fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()>; @@ -64,7 +64,7 @@ pub trait KvTree: Send + Sync { fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + 'a>; + ) -> Box, Vec)> + 'a>; fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; diff --git a/src/database/abstraction/heed.rs b/src/database/abstraction/heed.rs index 9cca0975..63ea5819 100644 --- a/src/database/abstraction/heed.rs +++ b/src/database/abstraction/heed.rs @@ -74,7 +74,7 @@ impl EngineTree { tree: Arc, from: Vec, backwards: bool, - ) -> Box + Send + Sync> { + ) -> Box + Send + Sync> { let (s, r) = bounded::(100); let engine = Arc::clone(&self.engine); @@ -150,7 +150,7 @@ impl Tree for EngineTree { Ok(()) } - fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a> { + fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a> { self.iter_from(&[], false) } @@ -158,7 +158,7 @@ impl Tree for EngineTree { &self, from: &[u8], backwards: bool, - ) -> Box, Vec)> + Send> { + ) -> Box, Vec)> + Send> { self.iter_from_thread(Arc::clone(&self.tree), from.to_vec(), backwards) } @@ -181,7 +181,7 @@ impl Tree for EngineTree { fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + Send + 'a> { + ) -> Box, Vec)> + Send + 'a> { Box::new( self.iter_from(&prefix, false) .take_while(move |(key, _)| key.starts_with(&prefix)), diff --git a/src/database/abstraction/persy.rs b/src/database/abstraction/persy.rs index 31197809..5c146eb0 100644 --- a/src/database/abstraction/persy.rs +++ b/src/database/abstraction/persy.rs @@ -113,7 +113,7 @@ impl KvTree for PersyTree { Ok(()) } - fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { + fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { let iter = self.persy.range::(&self.name, ..); match iter { Ok(iter) => Box::new(iter.filter_map(|(k, v)| { @@ -132,7 +132,7 @@ impl KvTree for PersyTree { &'a self, from: &[u8], backwards: bool, - ) -> Box, Vec)> + 'a> { + ) -> Box, Vec)> + 'a> { let range = if backwards { self.persy .range::(&self.name, ..=ByteVec::from(from)) @@ -168,7 +168,7 @@ impl KvTree for PersyTree { fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + 'a> { + ) -> Box, Vec)> + 'a> { let range_prefix = ByteVec::from(prefix.clone()); let range = self .persy diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 0dddfea0..72af45ed 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -175,7 +175,7 @@ impl KvTree for RocksDbEngineTree<'_> { .delete_cf_opt(&self.cf(), key, &writeoptions)?) } - fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { + fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { let readoptions = rocksdb::ReadOptions::default(); Box::new( @@ -191,7 +191,7 @@ impl KvTree for RocksDbEngineTree<'_> { &'a self, from: &[u8], backwards: bool, - ) -> Box, Vec)> + 'a> { + ) -> Box, Vec)> + 'a> { let readoptions = rocksdb::ReadOptions::default(); Box::new( @@ -252,7 +252,7 @@ impl KvTree for RocksDbEngineTree<'_> { fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + 'a> { + ) -> Box, Vec)> + 'a> { let readoptions = rocksdb::ReadOptions::default(); Box::new( diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs index 87defc57..6a4bc614 100644 --- a/src/database/abstraction/sled.rs +++ b/src/database/abstraction/sled.rs @@ -52,7 +52,7 @@ impl Tree for SledEngineTree { Ok(()) } - fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { + fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { Box::new( self.0 .iter() @@ -70,7 +70,7 @@ impl Tree for SledEngineTree { &self, from: &[u8], backwards: bool, - ) -> Box, Vec)>> { + ) -> Box, Vec)>> { let iter = if backwards { self.0.range(..=from) } else { @@ -103,7 +103,7 @@ impl Tree for SledEngineTree { fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + 'a> { + ) -> Box, Vec)> + 'a> { let iter = self .0 .scan_prefix(prefix) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index a94b78d2..12ed9361 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -8,7 +8,7 @@ use std::{ future::Future, path::{Path, PathBuf}, pin::Pin, - sync::Arc, + sync::{mpsc, Arc}, }; use thread_local::ThreadLocal; use tracing::debug; @@ -18,26 +18,6 @@ thread_local! { static READ_CONNECTION_ITERATOR: RefCell> = const { RefCell::new(None) }; } -struct PreparedStatementIterator<'a> { - pub iterator: Box + 'a>, - pub _statement_ref: NonAliasingBox>, -} - -impl Iterator for PreparedStatementIterator<'_> { - type Item = TupleOfBytes; - - fn next(&mut self) -> Option { - self.iterator.next() - } -} - -struct NonAliasingBox(*mut T); -impl Drop for NonAliasingBox { - fn drop(&mut self) { - drop(unsafe { Box::from_raw(self.0) }); - } -} - pub struct Engine { writer: Mutex, read_conn_tls: ThreadLocal, @@ -135,6 +115,22 @@ pub struct SqliteTable { type TupleOfBytes = (Vec, Vec); +fn iter_stmt( + from: impl rusqlite::Params, + mut statement: rusqlite::Statement, + out: mpsc::SyncSender, +) { + for item in statement + .query_map(from, |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()) + { + if out.send(item).is_err() { + break; + } + } +} + impl SqliteTable { fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result>> { Ok(guard @@ -155,34 +151,18 @@ impl SqliteTable { Ok(()) } - pub fn iter_with_guard<'a>( - &'a self, - guard: &'a Connection, - ) -> Box + 'a> { - let statement = Box::leak(Box::new( - guard - .prepare(&format!( - "SELECT key, value FROM {} ORDER BY key ASC", - &self.name - )) - .unwrap(), - )); - - let statement_ref = NonAliasingBox(statement); + pub fn iter_with_guard( + name: &str, + guard: &Connection, + out: mpsc::SyncSender, + ) { + let statement = guard + .prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", name)) + .unwrap(); //let name = self.name.clone(); - let iterator = Box::new( - statement - .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(move |r| r.unwrap()), - ); - - Box::new(PreparedStatementIterator { - iterator, - _statement_ref: statement_ref, - }) + iter_stmt([], statement, out) } } @@ -241,68 +221,51 @@ impl KvTree for SqliteTable { Ok(()) } - fn iter<'a>(&'a self) -> Box + 'a> { - let guard = self.engine.read_lock_iterator(); + fn iter<'a>(&'a self) -> Box + 'a> { + let (tx, rx) = mpsc::sync_channel(1); + let engine = self.engine.clone(); + let name = self.name.clone(); + tokio::task::spawn_blocking(move || { + let guard = engine.read_lock_iterator(); - self.iter_with_guard(guard) + Self::iter_with_guard(&name, guard, tx) + }); + Box::new(rx.into_iter()) } fn iter_from<'a>( &'a self, from: &[u8], backwards: bool, - ) -> Box + 'a> { - let guard = self.engine.read_lock_iterator(); + ) -> Box + 'a> { + let (tx, rx) = mpsc::sync_channel(1); + let engine = self.engine.clone(); + let name = self.name.clone(); let from = from.to_vec(); // TODO change interface? + tokio::task::spawn_blocking(move || { + let guard = engine.read_lock_iterator(); - //let name = self.name.clone(); - - if backwards { - let statement = Box::leak(Box::new( - guard + if backwards { + let statement = guard .prepare(&format!( "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", - &self.name + &name )) - .unwrap(), - )); + .unwrap(); - let statement_ref = NonAliasingBox(statement); - - let iterator = Box::new( - statement - .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(move |r| r.unwrap()), - ); - Box::new(PreparedStatementIterator { - iterator, - _statement_ref: statement_ref, - }) - } else { - let statement = Box::leak(Box::new( - guard + iter_stmt([from], statement, tx) + } else { + let statement = guard .prepare(&format!( "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", - &self.name + &name )) - .unwrap(), - )); + .unwrap(); - let statement_ref = NonAliasingBox(statement); - - let iterator = Box::new( - statement - .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(move |r| r.unwrap()), - ); - - Box::new(PreparedStatementIterator { - iterator, - _statement_ref: statement_ref, - }) - } + iter_stmt([from], statement, tx) + } + }); + Box::new(rx.into_iter()) } fn increment(&self, key: &[u8]) -> Result> { @@ -318,7 +281,10 @@ impl KvTree for SqliteTable { Ok(new) } - fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { + fn scan_prefix<'a>( + &'a self, + prefix: Vec, + ) -> Box + 'a> { Box::new( self.iter_from(&prefix, false) .take_while(move |(key, _)| key.starts_with(&prefix)), diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index b547e66a..91c29ef6 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -36,7 +36,7 @@ impl service::appservice::Data for KeyValueDatabase { .transpose() } - fn iter_ids<'a>(&'a self) -> Result> + 'a>> { + fn iter_ids<'a>(&'a self) -> Result> + 'a>> { Ok(Box::new(self.id_appserviceregistrations.iter().map( |(id, _)| { utils::string_from_bytes(&id).map_err(|_| { diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 50a6faca..b956bc7a 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -60,7 +60,7 @@ impl service::pusher::Data for KeyValueDatabase { fn get_pushkeys<'a>( &'a self, sender: &UserId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 2f7df781..1a27fbac 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -52,7 +52,7 @@ impl service::rooms::alias::Data for KeyValueDatabase { fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index e05dee82..0eefc4d1 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -15,7 +15,7 @@ impl service::rooms::directory::Data for KeyValueDatabase { Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) } - fn public_rooms<'a>(&'a self) -> Box> + 'a> { + fn public_rooms<'a>(&'a self) -> Box> + 'a> { Box::new(self.publicroomids.iter().map(|(bytes, _)| { RoomId::parse( utils::string_from_bytes(&bytes).map_err(|_| { diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index fa97ea34..fb7c9e99 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -53,7 +53,8 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { room_id: &RoomId, since: u64, ) -> Box< - dyn Iterator< + dyn Send + + Iterator< Item = Result<( OwnedUserId, u64, diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index 57540c40..3959334e 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -18,7 +18,7 @@ impl service::rooms::metadata::Data for KeyValueDatabase { .is_some()) } - fn iter_ids<'a>(&'a self) -> Box> + 'a> { + fn iter_ids<'a>(&'a self) -> Box> + 'a> { Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { RoomId::parse( utils::string_from_bytes(&bytes).map_err(|_| { diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index 0641f9d8..488ee96d 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -22,7 +22,7 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { shortroomid: u64, target: u64, until: PduCount, - ) -> Result> + 'a>> { + ) -> Result> + 'a>> { let prefix = target.to_be_bytes().to_vec(); let mut current = prefix.clone(); diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 8a2769bd..d57d6145 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -46,7 +46,7 @@ impl service::rooms::search::Data for KeyValueDatabase { &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a>, Vec)>> { + ) -> Result> + 'a>, Vec)>> { let prefix = services() .rooms .short diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 49e3842b..29be2608 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -244,7 +244,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn room_servers<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -277,7 +277,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn server_rooms<'a>( &'a self, server: &ServerName, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = server.as_bytes().to_vec(); prefix.push(0xff); @@ -299,7 +299,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn room_members<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -345,7 +345,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn room_useroncejoined<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -375,7 +375,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn room_members_invited<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -433,7 +433,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn rooms_joined<'a>( &'a self, user_id: &UserId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { Box::new( self.userroomid_joined .scan_prefix(user_id.as_bytes().to_vec()) @@ -459,7 +459,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn rooms_invited<'a>( &'a self, user_id: &UserId, - ) -> Box>)>> + 'a> { + ) -> Box>)>> + 'a> + { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); @@ -538,7 +539,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn rooms_left<'a>( &'a self, user_id: &UserId, - ) -> Box>)>> + 'a> { + ) -> Box>)>> + 'a> + { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/threads.rs b/src/database/key_value/rooms/threads.rs index 5e3dc970..74a65071 100644 --- a/src/database/key_value/rooms/threads.rs +++ b/src/database/key_value/rooms/threads.rs @@ -11,7 +11,7 @@ impl service::rooms::threads::Data for KeyValueDatabase { room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, - ) -> Result> + 'a>> { + ) -> Result> + 'a>> { let prefix = services() .rooms .short diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 0331a624..ba862471 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -228,7 +228,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result> + 'a>> { + ) -> Result> + 'a>> { let (prefix, current) = count_to_id(room_id, until, 1, true)?; let user_id = user_id.to_owned(); @@ -255,7 +255,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result> + 'a>> { + ) -> Result> + 'a>> { let (prefix, current) = count_to_id(room_id, from, 1, false)?; let user_id = user_id.to_owned(); diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 4c435720..3d2d4a8f 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -111,7 +111,7 @@ impl service::rooms::user::Data for KeyValueDatabase { fn get_shared_rooms<'a>( &'a self, users: Vec, - ) -> Result> + 'a>> { + ) -> Result> + 'a>> { let iterators = users.into_iter().map(move |user_id| { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs index 3fc3e042..58380a05 100644 --- a/src/database/key_value/sending.rs +++ b/src/database/key_value/sending.rs @@ -12,7 +12,8 @@ use crate::{ impl service::sending::Data for KeyValueDatabase { fn active_requests<'a>( &'a self, - ) -> Box, OutgoingKind, SendingEventType)>> + 'a> { + ) -> Box, OutgoingKind, SendingEventType)>> + 'a> + { Box::new( self.servercurrentevent_data .iter() @@ -23,7 +24,7 @@ impl service::sending::Data for KeyValueDatabase { fn active_requests_for<'a>( &'a self, outgoing_kind: &OutgoingKind, - ) -> Box, SendingEventType)>> + 'a> { + ) -> Box, SendingEventType)>> + 'a> { let prefix = outgoing_kind.get_prefix(); Box::new( self.servercurrentevent_data @@ -87,7 +88,7 @@ impl service::sending::Data for KeyValueDatabase { fn queued_requests<'a>( &'a self, outgoing_kind: &OutgoingKind, - ) -> Box)>> + 'a> { + ) -> Box)>> + 'a> { let prefix = outgoing_kind.get_prefix(); return Box::new( self.servernameevent_data diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 63321a40..7fa24a12 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -68,7 +68,7 @@ impl service::users::Data for KeyValueDatabase { } /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a> { + fn iter<'a>(&'a self) -> Box> + 'a> { Box::new(self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("User ID in userid_password is invalid unicode.") @@ -259,7 +259,7 @@ impl service::users::Data for KeyValueDatabase { fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); // All devices have metadata @@ -584,7 +584,7 @@ impl service::users::Data for KeyValueDatabase { user_or_room_id: &str, from: u64, to: Option, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut prefix = user_or_room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -899,7 +899,7 @@ impl service::users::Data for KeyValueDatabase { fn all_devices_metadata<'a>( &'a self, user_id: &UserId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index ab19a50c..5611edf9 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -15,7 +15,7 @@ pub trait Data: Send + Sync { fn get_registration(&self, id: &str) -> Result>; - fn iter_ids<'a>(&'a self) -> Result> + 'a>>; + fn iter_ids<'a>(&'a self) -> Result> + 'a>>; fn all(&self) -> Result>; } diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index 2062f567..50f95111 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -11,6 +11,8 @@ pub trait Data: Send + Sync { fn get_pushers(&self, sender: &UserId) -> Result>; - fn get_pushkeys<'a>(&'a self, sender: &UserId) - -> Box> + 'a>; + fn get_pushkeys<'a>( + &'a self, + sender: &UserId, + ) -> Box> + 'a>; } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 83127e63..37973821 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -39,7 +39,7 @@ impl Service { self.db.get_pushers(sender) } - pub fn get_pushkeys(&self, sender: &UserId) -> Box>> { + pub fn get_pushkeys(&self, sender: &UserId) -> Box>> { self.db.get_pushkeys(sender) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index dd514072..c73799e4 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -18,5 +18,5 @@ pub trait Data: Send + Sync { fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a>; + ) -> Box> + 'a>; } diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 95d52ad3..bd5693f1 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -95,7 +95,7 @@ impl Service { pub fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a> { + ) -> Box> + 'a> { self.db.local_aliases_for_room(room_id) } } diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index aca731ce..9bd8ef4a 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -12,5 +12,5 @@ pub trait Data: Send + Sync { fn is_public_room(&self, room_id: &RoomId) -> Result; /// Returns the unsorted public room directory - fn public_rooms<'a>(&'a self) -> Box> + 'a>; + fn public_rooms<'a>(&'a self) -> Box> + 'a>; } diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index 044dad82..41a33eed 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -17,7 +17,8 @@ pub trait Data: Send + Sync { room_id: &RoomId, since: u64, ) -> Box< - dyn Iterator< + dyn Send + + Iterator< Item = Result<( OwnedUserId, u64, diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 339db573..4abdb3b1 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -3,7 +3,7 @@ use ruma::{OwnedRoomId, RoomId}; pub trait Data: Send + Sync { fn exists(&self, room_id: &RoomId) -> Result; - fn iter_ids<'a>(&'a self) -> Box> + 'a>; + fn iter_ids<'a>(&'a self) -> Box> + 'a>; fn is_disabled(&self, room_id: &RoomId) -> Result; fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; } diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index d1884691..cfcc77fa 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -16,7 +16,7 @@ impl Service { self.db.exists(room_id) } - pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { + pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { self.db.iter_ids() } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index a4df34cc..8a9e1e5d 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -12,7 +12,7 @@ pub trait Data: Send + Sync { room_id: u64, target: u64, until: PduCount, - ) -> Result> + 'a>>; + ) -> Result> + 'a>>; fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result; fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 7dbfd56a..18f92b2f 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -11,5 +11,5 @@ pub trait Data: Send + Sync { &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a>, Vec)>>; + ) -> Result> + 'a>, Vec)>>; } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index b511919a..76dcc6cc 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -31,7 +31,7 @@ pub trait Data: Send + Sync { fn room_servers<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a>; + ) -> Box> + 'a>; fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result; @@ -39,13 +39,13 @@ pub trait Data: Send + Sync { fn server_rooms<'a>( &'a self, server: &ServerName, - ) -> Box> + 'a>; + ) -> Box> + 'a>; /// Returns an iterator over all joined members of a room. fn room_members<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a>; + ) -> Box> + 'a>; fn room_joined_count(&self, room_id: &RoomId) -> Result>; @@ -55,13 +55,13 @@ pub trait Data: Send + Sync { fn room_useroncejoined<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a>; + ) -> Box> + 'a>; /// Returns an iterator over all invited members of a room. fn room_members_invited<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a>; + ) -> Box> + 'a>; fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; @@ -71,14 +71,14 @@ pub trait Data: Send + Sync { fn rooms_joined<'a>( &'a self, user_id: &UserId, - ) -> Box> + 'a>; + ) -> Box> + 'a>; /// Returns an iterator over all rooms a user was invited to. #[allow(clippy::type_complexity)] fn rooms_invited<'a>( &'a self, user_id: &UserId, - ) -> Box>)>> + 'a>; + ) -> Box>)>> + 'a>; fn invite_state( &self, @@ -97,7 +97,7 @@ pub trait Data: Send + Sync { fn rooms_left<'a>( &'a self, user_id: &UserId, - ) -> Box>)>> + 'a>; + ) -> Box>)>> + 'a>; fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result; diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index e7159de0..9612a162 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -9,7 +9,7 @@ pub trait Data: Send + Sync { room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, - ) -> Result> + 'a>>; + ) -> Result> + 'a>>; fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()>; fn get_participants(&self, root_id: &[u8]) -> Result>>; diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 6290b8cc..afec6f78 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -72,7 +72,7 @@ pub trait Data: Send + Sync { user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result> + 'a>>; + ) -> Result> + 'a>>; /// Returns an iterator over all events in a room that happened after the event with id `from` /// in chronological order. @@ -82,7 +82,7 @@ pub trait Data: Send + Sync { user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result> + 'a>>; + ) -> Result> + 'a>>; fn increment_notification_counts( &self, diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 4b8a4eca..5544af2c 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -23,5 +23,5 @@ pub trait Data: Send + Sync { fn get_shared_rooms<'a>( &'a self, users: Vec, - ) -> Result> + 'a>>; + ) -> Result> + 'a>>; } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 8b4d236f..78d3f1e1 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -8,11 +8,11 @@ pub trait Data: Send + Sync { #[allow(clippy::type_complexity)] fn active_requests<'a>( &'a self, - ) -> Box, OutgoingKind, SendingEventType)>> + 'a>; + ) -> Box, OutgoingKind, SendingEventType)>> + 'a>; fn active_requests_for<'a>( &'a self, outgoing_kind: &OutgoingKind, - ) -> Box, SendingEventType)>> + 'a>; + ) -> Box, SendingEventType)>> + 'a>; fn delete_active_request(&self, key: Vec) -> Result<()>; fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; @@ -23,7 +23,7 @@ pub trait Data: Send + Sync { fn queued_requests<'a>( &'a self, outgoing_kind: &OutgoingKind, - ) -> Box)>> + 'a>; + ) -> Box)>> + 'a>; fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()>; fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>; fn get_latest_educount(&self, server_name: &ServerName) -> Result; diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 4566c36d..75d7eb2c 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -23,7 +23,7 @@ pub trait Data: Send + Sync { fn find_from_token(&self, token: &str) -> Result>; /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a>; + fn iter<'a>(&'a self) -> Box> + 'a>; /// Returns a list of local users as list of usernames. /// @@ -70,7 +70,7 @@ pub trait Data: Send + Sync { fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> Box> + 'a>; + ) -> Box> + 'a>; /// Replaces the access token of one device. fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; @@ -127,7 +127,7 @@ pub trait Data: Send + Sync { user_or_room_id: &str, from: u64, to: Option, - ) -> Box> + 'a>; + ) -> Box> + 'a>; fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; @@ -205,7 +205,7 @@ pub trait Data: Send + Sync { fn all_devices_metadata<'a>( &'a self, user_id: &UserId, - ) -> Box> + 'a>; + ) -> Box> + 'a>; /// Creates a new sync filter. Returns the filter id. fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result; From 7658414fc4d21e7dda2fa0fb161d337f910eb40d Mon Sep 17 00:00:00 2001 From: chayleaf Date: Sat, 22 Jun 2024 21:22:43 +0700 Subject: [PATCH 3/3] KvTree: asyncify clear and increment --- src/api/client_server/account.rs | 44 +++--- src/api/client_server/alias.rs | 3 +- src/api/client_server/backup.rs | 51 ++++--- src/api/client_server/config.rs | 42 +++--- src/api/client_server/device.rs | 11 +- src/api/client_server/keys.rs | 45 +++--- src/api/client_server/membership.rs | 112 +++++++++------ src/api/client_server/push.rs | 60 ++++---- src/api/client_server/read_marker.rs | 92 +++++++----- src/api/client_server/relations.rs | 7 +- src/api/client_server/room.rs | 15 +- src/api/client_server/session.rs | 25 ++-- src/api/client_server/sync.rs | 5 +- src/api/client_server/tag.rs | 30 ++-- src/api/client_server/to_device.rs | 69 +++++---- src/api/server_server.rs | 128 ++++++++++------- src/database/abstraction.rs | 5 +- src/database/abstraction/persy.rs | 3 +- src/database/abstraction/rocksdb.rs | 3 +- src/database/abstraction/sqlite.rs | 5 +- src/database/key_value/account_data.rs | 6 +- src/database/key_value/globals.rs | 4 +- src/database/key_value/key_backups.rs | 16 ++- src/database/key_value/rooms/alias.rs | 11 +- src/database/key_value/rooms/edus/presence.rs | 6 +- .../key_value/rooms/edus/read_receipt.rs | 10 +- src/database/key_value/rooms/short.rs | 18 +-- src/database/key_value/rooms/state_cache.rs | 10 +- src/database/key_value/rooms/user.rs | 6 +- src/database/key_value/sending.rs | 6 +- src/database/key_value/users.rs | 45 +++--- src/database/mod.rs | 66 +++++---- src/service/account_data/data.rs | 4 +- src/service/account_data/mod.rs | 4 +- src/service/admin/mod.rs | 42 +++--- src/service/globals/data.rs | 2 +- src/service/globals/mod.rs | 4 +- src/service/key_backups/data.rs | 8 +- src/service/key_backups/mod.rs | 13 +- src/service/rooms/alias/data.rs | 9 +- src/service/rooms/alias/mod.rs | 9 +- src/service/rooms/auth_chain/mod.rs | 17 ++- src/service/rooms/edus/presence/data.rs | 4 +- src/service/rooms/edus/read_receipt/data.rs | 6 +- src/service/rooms/edus/read_receipt/mod.rs | 13 +- src/service/rooms/edus/typing/mod.rs | 6 +- src/service/rooms/event_handler/mod.rs | 135 ++++++++++-------- src/service/rooms/pdu_metadata/mod.rs | 34 ++--- src/service/rooms/short/data.rs | 10 +- src/service/rooms/short/mod.rs | 18 +-- src/service/rooms/state/mod.rs | 34 ++--- src/service/rooms/state_accessor/mod.rs | 1 + src/service/rooms/state_cache/data.rs | 6 +- src/service/rooms/state_cache/mod.rs | 26 ++-- src/service/rooms/state_compressor/mod.rs | 10 +- src/service/rooms/timeline/mod.rs | 78 ++++++---- src/service/rooms/user/data.rs | 4 +- src/service/rooms/user/mod.rs | 8 +- src/service/sending/data.rs | 4 +- src/service/sending/mod.rs | 38 +++-- src/service/users/data.rs | 22 +-- src/service/users/mod.rs | 80 ++++++----- 62 files changed, 958 insertions(+), 650 deletions(-) diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 36640b54..4e033d18 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -218,17 +218,20 @@ pub async fn register_route(body: Ruma) -> Result) -> Result bool>( } let json = serde_json::to_value(master_key).expect("to_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works"); - services().users.add_cross_signing_keys( - &user, &raw, &None, &None, - false, // Dont notify. A notification would trigger another key request resulting in an endless loop - )?; + services() + .users + .add_cross_signing_keys( + &user, &raw, &None, &None, + false, // Dont notify. A notification would trigger another key request resulting in an endless loop + ) + .await?; master_keys.insert(user, raw); } @@ -481,10 +490,10 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = - services() - .users - .take_one_time_key(user_id, device_id, key_algorithm)? + if let Some(one_time_keys) = services() + .users + .take_one_time_key(user_id, device_id, key_algorithm) + .await? { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 1ca711e2..9650da3a 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -703,7 +703,11 @@ async fn join_room_by_id_helper( } } - services().rooms.short.get_or_create_shortroomid(room_id)?; + services() + .rooms + .short + .get_or_create_shortroomid(room_id) + .await?; info!("Parsing join event"); let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) @@ -744,7 +748,8 @@ async fn join_room_by_id_helper( let shortstatekey = services() .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await?; state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -781,8 +786,8 @@ async fn join_room_by_id_helper( &services() .rooms .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, + .get_shortstatekey(&k.to_string().into(), s) + .ok()??, )?, ) .ok()? @@ -801,20 +806,23 @@ async fn join_room_by_id_helper( } info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| { + let (statehash_before_join, new, removed) = services() + .rooms + .state_compressor + .save_state(room_id, { + let mut new_state = HashSet::new(); + for (k, id) in state { + new_state.insert( services() .rooms .state_compressor .compress_state_event(k, &id) - }) - .collect::>()?, - ), - )?; + .await?, + ); + } + Arc::new(new_state) + }) + .await?; services() .rooms @@ -827,7 +835,11 @@ async fn join_room_by_id_helper( // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + let statehash_after_join = services() + .rooms + .state + .append_to_state(&parsed_join_pdu) + .await?; info!("Appending new room join event"); services() @@ -1253,18 +1265,22 @@ pub(crate) async fn invite_helper<'a>( }) .expect("member event is valid value"); - let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - sender_user, - room_id, - &state_lock, - )?; + let (pdu, pdu_json) = services() + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await?; let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; @@ -1335,7 +1351,7 @@ pub(crate) async fn invite_helper<'a>( .filter_map(|r| r.ok()) .filter(|server| &**server != services().globals.server_name()); - services().sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu(servers, &pdu_id).await?; } else { if !services() .rooms @@ -1442,14 +1458,18 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option { error!("Trying to leave a room you are not a member of."); - services().rooms.state_cache.update_membership( - room_id, - user_id, - MembershipState::Leave, - user_id, - None, - true, - )?; + services() + .rooms + .state_cache + .update_membership( + room_id, + user_id, + MembershipState::Leave, + user_id, + None, + true, + ) + .await?; return Ok(()); } Some(e) => e, diff --git a/src/api/client_server/push.rs b/src/api/client_server/push.rs index 72768662..41e5fd36 100644 --- a/src/api/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -143,12 +143,15 @@ pub async fn set_pushrule_route( return Err(err); } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule::v3::Response {}) } @@ -238,12 +241,15 @@ pub async fn set_pushrule_actions_route( )); } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_actions::v3::Response {}) } @@ -332,12 +338,15 @@ pub async fn set_pushrule_enabled_route( )); } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_enabled::v3::Response {}) } @@ -391,12 +400,15 @@ pub async fn delete_pushrule_route( return Err(err); } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(delete_pushrule::v3::Response {}) } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index a5553d25..9624dd8a 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -26,19 +26,23 @@ pub async fn set_read_marker_route( event_id: fully_read.clone(), }, }; - services().account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services() + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; } if body.private_read_receipt.is_some() || body.read_receipt.is_some() { services() .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id) + .await?; } if let Some(event) = &body.private_read_receipt { @@ -63,7 +67,8 @@ pub async fn set_read_marker_route( .rooms .edus .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count) + .await?; } if let Some(event) = &body.read_receipt { @@ -82,14 +87,19 @@ pub async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - services().rooms.edus.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services() + .rooms + .edus + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await?; } Ok(set_read_marker::v3::Response {}) @@ -110,7 +120,8 @@ pub async fn create_receipt_route( services() .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id) + .await?; } match body.receipt_type { @@ -120,12 +131,15 @@ pub async fn create_receipt_route( event_id: body.event_id.clone(), }, }; - services().account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services() + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; } create_receipt::v3::ReceiptType::Read => { let mut user_receipts = BTreeMap::new(); @@ -142,14 +156,19 @@ pub async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.to_owned(), receipts); - services().rooms.edus.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services() + .rooms + .edus + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await?; } create_receipt::v3::ReceiptType::ReadPrivate => { let count = services() @@ -169,11 +188,12 @@ pub async fn create_receipt_route( } PduCount::Normal(c) => c, }; - services().rooms.edus.read_receipt.private_read_set( - &body.room_id, - sender_user, - count, - )?; + services() + .rooms + .edus + .read_receipt + .private_read_set(&body.room_id, sender_user, count) + .await?; } _ => return Err(Error::bad_database("Unsupported receipt type")), } diff --git a/src/api/client_server/relations.rs b/src/api/client_server/relations.rs index 27c00729..6c4dd098 100644 --- a/src/api/client_server/relations.rs +++ b/src/api/client_server/relations.rs @@ -25,7 +25,8 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route( body.limit, body.recurse, &body.dir, - )?; + ) + .await?; Ok( get_relating_events_with_rel_type_and_event_type::v1::Response { @@ -57,7 +58,8 @@ pub async fn get_relating_events_with_rel_type_route( body.limit, body.recurse, &body.dir, - )?; + ) + .await?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, @@ -88,4 +90,5 @@ pub async fn get_relating_events_route( body.recurse, &body.dir, ) + .await } diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index 890ff9cb..cd5ea8bb 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -54,7 +54,11 @@ pub async fn create_room_route( let room_id = RoomId::new(services().globals.server_name()); - services().rooms.short.get_or_create_shortroomid(&room_id)?; + services() + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await?; let mutex_state = Arc::clone( services() @@ -488,7 +492,8 @@ pub async fn create_room_route( services() .rooms .alias - .set_alias(&alias, &room_id, sender_user)?; + .set_alias(&alias, &room_id, sender_user) + .await?; } if body.visibility == room::Visibility::Public { @@ -600,7 +605,8 @@ pub async fn upgrade_room_route( services() .rooms .short - .get_or_create_shortroomid(&replacement_room)?; + .get_or_create_shortroomid(&replacement_room) + .await?; let mutex_state = Arc::clone( services() @@ -818,7 +824,8 @@ pub async fn upgrade_room_route( services() .rooms .alias - .set_alias(&alias, &replacement_room, sender_user)?; + .set_alias(&alias, &replacement_room, sender_user) + .await?; } // Get the old room power levels diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 07078328..39fc7e6a 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -192,12 +192,15 @@ pub async fn login_route(body: Ruma) -> Result) -> Result { - services().users.add_to_device_event( - sender_user, - target_user_id, - target_device_id, - &body.event_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") - })?, - )? - } - - DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services().users.all_device_ids(target_user_id) { - services().users.add_to_device_event( + services() + .users + .add_to_device_event( sender_user, target_user_id, - &target_device_id?, + target_device_id, &body.event_type.to_string(), event.deserialize_as().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") })?, - )?; + ) + .await? + } + + DeviceIdOrAllDevices::AllDevices => { + for target_device_id in services().users.all_device_ids(target_user_id) { + services() + .users + .add_to_device_event( + sender_user, + target_user_id, + &target_device_id?, + &body.event_type.to_string(), + event.deserialize_as().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") + })?, + ) + .await?; } } } diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 605a4672..5038de2e 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -814,7 +814,8 @@ pub async fn send_transaction_message_route( .rooms .edus .read_receipt - .readreceipt_update(&user_id, &room_id, event)?; + .readreceipt_update(&user_id, &room_id, event) + .await?; } else { // TODO fetch missing events debug!("No known event ids in read receipt: {:?}", user_updates); @@ -853,7 +854,7 @@ pub async fn send_transaction_message_route( } Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { if user_id.server_name() == sender_servername { - services().users.mark_device_key_update(&user_id)?; + services().users.mark_device_key_update(&user_id).await?; } } Edu::DirectToDevice(DirectDeviceContent { @@ -873,37 +874,43 @@ pub async fn send_transaction_message_route( for (target_device_id_maybe, event) in map { match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services().users.add_to_device_event( - &sender, - target_user_id, - target_device_id, - &ev_type.to_string(), - event.deserialize_as().map_err(|e| { - warn!("To-Device event is invalid: {event:?} {e}"); - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, - )? + services() + .users + .add_to_device_event( + &sender, + target_user_id, + target_device_id, + &ev_type.to_string(), + event.deserialize_as().map_err(|e| { + warn!("To-Device event is invalid: {event:?} {e}"); + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) + })?, + ) + .await? } DeviceIdOrAllDevices::AllDevices => { for target_device_id in services().users.all_device_ids(target_user_id) { - services().users.add_to_device_event( - &sender, - target_user_id, - &target_device_id?, - &ev_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, - )?; + services() + .users + .add_to_device_event( + &sender, + target_user_id, + &target_device_id?, + &ev_type.to_string(), + event.deserialize_as().map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) + })?, + ) + .await?; } } } @@ -923,13 +930,16 @@ pub async fn send_transaction_message_route( }) => { if user_id.server_name() == sender_servername { if let Some(master_key) = master_key { - services().users.add_cross_signing_keys( - &user_id, - &master_key, - &self_signing_key, - &None, - true, - )?; + services() + .users + .add_cross_signing_keys( + &user_id, + &master_key, + &self_signing_key, + &None, + true, + ) + .await?; } } } @@ -1438,18 +1448,22 @@ pub async fn create_join_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services() + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); @@ -1581,7 +1595,7 @@ async fn create_join_event( .filter_map(|r| r.ok()) .filter(|server| &**server != services().globals.server_name()); - services().sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu(servers, &pdu_id).await?; Ok(create_join_event::v1::RoomState { auth_chain: auth_chain_ids @@ -1738,14 +1752,18 @@ pub async fn create_invite_route( .state_cache .server_in_room(services().globals.server_name(), &body.room_id)? { - services().rooms.state_cache.update_membership( - &body.room_id, - &invited_user, - MembershipState::Invite, - &sender, - Some(invite_state), - true, - )?; + services() + .rooms + .state_cache + .update_membership( + &body.room_id, + &invited_user, + MembershipState::Invite, + &sender, + Some(invite_state), + true, + ) + .await?; } Ok(create_invite::v2::Response { diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index bc9f09dc..a4336928 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -42,6 +42,7 @@ pub trait KeyValueDatabaseEngine: Send + Sync { } } +#[async_trait] pub trait KvTree: Send + Sync { fn get(&self, key: &[u8]) -> Result>>; @@ -58,7 +59,7 @@ pub trait KvTree: Send + Sync { backwards: bool, ) -> Box, Vec)> + 'a>; - fn increment(&self, key: &[u8]) -> Result>; + async fn increment(&self, key: &[u8]) -> Result>; fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()>; fn scan_prefix<'a>( @@ -68,7 +69,7 @@ pub trait KvTree: Send + Sync { fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; - fn clear(&self) -> Result<()> { + async fn clear(&self) -> Result<()> { for (key, _) in self.iter() { self.remove(&key)?; } diff --git a/src/database/abstraction/persy.rs b/src/database/abstraction/persy.rs index 5c146eb0..4eaaa8e8 100644 --- a/src/database/abstraction/persy.rs +++ b/src/database/abstraction/persy.rs @@ -63,6 +63,7 @@ impl PersyTree { } } +#[async_trait] impl KvTree for PersyTree { fn get(&self, key: &[u8]) -> Result>> { let result = self @@ -160,7 +161,7 @@ impl KvTree for PersyTree { } } - fn increment(&self, key: &[u8]) -> Result> { + async fn increment(&self, key: &[u8]) -> Result> { self.increment_batch(&mut Some(key.to_owned()).into_iter())?; Ok(self.get(key)?.unwrap()) } diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 72af45ed..3013ebe8 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -136,6 +136,7 @@ impl RocksDbEngineTree<'_> { } } +#[async_trait] impl KvTree for RocksDbEngineTree<'_> { fn get(&self, key: &[u8]) -> Result>> { let readoptions = rocksdb::ReadOptions::default(); @@ -214,7 +215,7 @@ impl KvTree for RocksDbEngineTree<'_> { ) } - fn increment(&self, key: &[u8]) -> Result> { + async fn increment(&self, key: &[u8]) -> Result> { let readoptions = rocksdb::ReadOptions::default(); let writeoptions = rocksdb::WriteOptions::default(); diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 12ed9361..f0990f1b 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -166,6 +166,7 @@ impl SqliteTable { } } +#[async_trait] impl KvTree for SqliteTable { fn get(&self, key: &[u8]) -> Result>> { self.get_with_guard(self.engine.read_lock(), key) @@ -268,7 +269,7 @@ impl KvTree for SqliteTable { Box::new(rx.into_iter()) } - fn increment(&self, key: &[u8]) -> Result> { + async fn increment(&self, key: &[u8]) -> Result> { let guard = self.engine.write_lock(); let old = self.get_with_guard(&guard, key)?; @@ -295,7 +296,7 @@ impl KvTree for SqliteTable { self.watchers.watch(prefix) } - fn clear(&self) -> Result<()> { + async fn clear(&self) -> Result<()> { debug!("clear: running"); self.engine .write_lock() diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index 970b36b5..53af9e86 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use async_trait::async_trait; use ruma::{ api::client::error::ErrorKind, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, @@ -9,10 +10,11 @@ use ruma::{ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::account_data::Data for KeyValueDatabase { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - fn update( + async fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, @@ -29,7 +31,7 @@ impl service::account_data::Data for KeyValueDatabase { prefix.push(0xff); let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + roomuserdataid.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()); roomuserdataid.push(0xff); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 968d6420..1c0c57f9 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -20,8 +20,8 @@ pub const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; #[async_trait] impl service::globals::Data for KeyValueDatabase { - fn next_count(&self) -> Result { - utils::u64_from_bytes(&self.global.increment(COUNTER)?) + async fn next_count(&self) -> Result { + utils::u64_from_bytes(&self.global.increment(COUNTER).await?) .map_err(|_| Error::bad_database("Count has invalid bytes.")) } diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index 900b700b..43ded671 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use async_trait::async_trait; use ruma::{ api::client::{ backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, @@ -11,13 +12,14 @@ use ruma::{ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::key_backups::Data for KeyValueDatabase { - fn create_backup( + async fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw, ) -> Result { - let version = services().globals.next_count()?.to_string(); + let version = services().globals.next_count().await?.to_string(); let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -28,7 +30,7 @@ impl service::key_backups::Data for KeyValueDatabase { &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), )?; self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count().await?.to_be_bytes())?; Ok(version) } @@ -49,7 +51,7 @@ impl service::key_backups::Data for KeyValueDatabase { Ok(()) } - fn update_backup( + async fn update_backup( &self, user_id: &UserId, version: &str, @@ -69,7 +71,7 @@ impl service::key_backups::Data for KeyValueDatabase { self.backupid_algorithm .insert(&key, backup_metadata.json().get().as_bytes())?; self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count().await?.to_be_bytes())?; Ok(version.to_owned()) } @@ -138,7 +140,7 @@ impl service::key_backups::Data for KeyValueDatabase { }) } - fn add_key( + async fn add_key( &self, user_id: &UserId, version: &str, @@ -158,7 +160,7 @@ impl service::key_backups::Data for KeyValueDatabase { } self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count().await?.to_be_bytes())?; key.push(0xff); key.extend_from_slice(room_id.as_bytes()); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 1a27fbac..66f84441 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use ruma::{ api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId, @@ -5,8 +6,14 @@ use ruma::{ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::alias::Data for KeyValueDatabase { - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { + async fn set_alias( + &self, + alias: &RoomAliasId, + room_id: &RoomId, + user_id: &UserId, + ) -> Result<()> { // Comes first as we don't want a stuck alias self.alias_userid .insert(alias.alias().as_bytes(), user_id.as_bytes())?; @@ -14,7 +21,7 @@ impl service::rooms::alias::Data for KeyValueDatabase { .insert(alias.alias().as_bytes(), room_id.as_bytes())?; let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xff); - aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + aliasid.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()); self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; Ok(()) } diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 904b1c44..ba29be8f 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -1,13 +1,15 @@ use std::collections::HashMap; +use async_trait::async_trait; use ruma::{ events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId, }; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::edus::presence::Data for KeyValueDatabase { - fn update_presence( + async fn update_presence( &self, user_id: &UserId, room_id: &RoomId, @@ -15,7 +17,7 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase { ) -> Result<()> { // TODO: Remove old entry? Or maybe just wipe completely from time to time? - let count = services().globals.next_count()?.to_be_bytes(); + let count = services().globals.next_count().await?.to_be_bytes(); let mut presence_id = room_id.as_bytes().to_vec(); presence_id.push(0xff); diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index fb7c9e99..f3f0fa27 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,13 +1,15 @@ use std::mem; +use async_trait::async_trait; use ruma::{ events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, }; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { - fn readreceipt_update( + async fn readreceipt_update( &self, user_id: &UserId, room_id: &RoomId, @@ -36,7 +38,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { } let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + room_latest_id.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()); room_latest_id.push(0xff); room_latest_id.extend_from_slice(user_id.as_bytes()); @@ -106,7 +108,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { ) } - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + async fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { let mut key = room_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(user_id.as_bytes()); @@ -115,7 +117,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { .insert(&key, &count.to_be_bytes())?; self.roomuserid_lastprivatereadupdate - .insert(&key, &services().globals.next_count()?.to_be_bytes()) + .insert(&key, &services().globals.next_count().await?.to_be_bytes()) } fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index 98cfa48a..8d79a339 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -1,11 +1,13 @@ use std::sync::Arc; +use async_trait::async_trait; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::short::Data for KeyValueDatabase { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { + async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { return Ok(*short); } @@ -14,7 +16,7 @@ impl service::rooms::short::Data for KeyValueDatabase { Some(shorteventid) => utils::u64_from_bytes(&shorteventid) .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, None => { - let shorteventid = services().globals.next_count()?; + let shorteventid = services().globals.next_count().await?; self.eventid_shorteventid .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; self.shorteventid_eventid @@ -68,7 +70,7 @@ impl service::rooms::short::Data for KeyValueDatabase { Ok(short) } - fn get_or_create_shortstatekey( + async fn get_or_create_shortstatekey( &self, event_type: &StateEventType, state_key: &str, @@ -90,7 +92,7 @@ impl service::rooms::short::Data for KeyValueDatabase { Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, None => { - let shortstatekey = services().globals.next_count()?; + let shortstatekey = services().globals.next_count().await?; self.statekey_shortstatekey .insert(&statekey, &shortstatekey.to_be_bytes())?; self.shortstatekey_statekey @@ -176,7 +178,7 @@ impl service::rooms::short::Data for KeyValueDatabase { } /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { Ok(match self.statehash_shortstatehash.get(state_hash)? { Some(shortstatehash) => ( utils::u64_from_bytes(&shortstatehash) @@ -184,7 +186,7 @@ impl service::rooms::short::Data for KeyValueDatabase { true, ), None => { - let shortstatehash = services().globals.next_count()?; + let shortstatehash = services().globals.next_count().await?; self.statehash_shortstatehash .insert(state_hash, &shortstatehash.to_be_bytes())?; (shortstatehash, false) @@ -202,12 +204,12 @@ impl service::rooms::short::Data for KeyValueDatabase { .transpose() } - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { Some(short) => utils::u64_from_bytes(&short) .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, None => { - let short = services().globals.next_count()?; + let short = services().globals.next_count().await?; self.roomid_shortroomid .insert(room_id.as_bytes(), &short.to_be_bytes())?; short diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 29be2608..7cb281f6 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,5 +1,6 @@ use std::{collections::HashSet, sync::Arc}; +use async_trait::async_trait; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, @@ -12,6 +13,7 @@ use crate::{ services, utils, Error, Result, }; +#[async_trait] impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); @@ -39,7 +41,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Ok(()) } - fn mark_as_invited( + async fn mark_as_invited( &self, user_id: &UserId, room_id: &RoomId, @@ -60,7 +62,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { )?; self.roomuserid_invitecount.insert( &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; @@ -70,7 +72,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Ok(()) } - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + async fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut roomuser_id = room_id.as_bytes().to_vec(); roomuser_id.push(0xff); roomuser_id.extend_from_slice(user_id.as_bytes()); @@ -85,7 +87,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { )?; // TODO self.roomuserid_leftcount.insert( &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 3d2d4a8f..4ef93a43 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,9 +1,11 @@ +use async_trait::async_trait; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::user::Data for KeyValueDatabase { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + async fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -18,7 +20,7 @@ impl service::rooms::user::Data for KeyValueDatabase { self.roomuserid_lastnotificationread.insert( &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; Ok(()) diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs index 58380a05..8a0fa2f4 100644 --- a/src/database/key_value/sending.rs +++ b/src/database/key_value/sending.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use ruma::{ServerName, UserId}; use crate::{ @@ -9,6 +10,7 @@ use crate::{ services, utils, Error, Result, }; +#[async_trait] impl service::sending::Data for KeyValueDatabase { fn active_requests<'a>( &'a self, @@ -59,7 +61,7 @@ impl service::sending::Data for KeyValueDatabase { Ok(()) } - fn queue_requests( + async fn queue_requests( &self, requests: &[(&OutgoingKind, SendingEventType)], ) -> Result>> { @@ -70,7 +72,7 @@ impl service::sending::Data for KeyValueDatabase { if let SendingEventType::Pdu(value) = &event { key.extend_from_slice(value) } else { - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()) + key.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()) } let value = if let SendingEventType::Edu(value) = &event { &**value diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 7fa24a12..6e872d65 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,5 +1,6 @@ use std::{collections::BTreeMap, mem::size_of}; +use async_trait::async_trait; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, @@ -17,6 +18,7 @@ use crate::{ services, utils, Error, Result, }; +#[async_trait] impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result { @@ -192,7 +194,7 @@ impl service::users::Data for KeyValueDatabase { } /// Adds a new device to a user. - fn create_device( + async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -207,7 +209,8 @@ impl service::users::Data for KeyValueDatabase { userdeviceid.extend_from_slice(device_id.as_bytes()); self.userid_devicelistversion - .increment(user_id.as_bytes())?; + .increment(user_id.as_bytes()) + .await?; self.userdeviceid_metadata.insert( &userdeviceid, @@ -226,7 +229,7 @@ impl service::users::Data for KeyValueDatabase { } /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -248,7 +251,8 @@ impl service::users::Data for KeyValueDatabase { // TODO: Remove onetimekeys self.userid_devicelistversion - .increment(user_id.as_bytes())?; + .increment(user_id.as_bytes()) + .await?; self.userdeviceid_metadata.remove(&userdeviceid)?; @@ -304,7 +308,7 @@ impl service::users::Data for KeyValueDatabase { Ok(()) } - fn add_one_time_key( + async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -335,7 +339,7 @@ impl service::users::Data for KeyValueDatabase { self.userid_lastonetimekeyupdate.insert( user_id.as_bytes(), - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; Ok(()) @@ -352,7 +356,7 @@ impl service::users::Data for KeyValueDatabase { .unwrap_or(Ok(0)) } - fn take_one_time_key( + async fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -368,7 +372,7 @@ impl service::users::Data for KeyValueDatabase { self.userid_lastonetimekeyupdate.insert( user_id.as_bytes(), - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; self.onetimekeyid_onetimekeys @@ -423,7 +427,7 @@ impl service::users::Data for KeyValueDatabase { Ok(counts) } - fn add_device_keys( + async fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, @@ -438,12 +442,12 @@ impl service::users::Data for KeyValueDatabase { &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), )?; - self.mark_device_key_update(user_id)?; + self.mark_device_key_update(user_id).await?; Ok(()) } - fn add_cross_signing_keys( + async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, @@ -532,13 +536,13 @@ impl service::users::Data for KeyValueDatabase { } if notify { - self.mark_device_key_update(user_id)?; + self.mark_device_key_update(user_id).await?; } Ok(()) } - fn sign_key( + async fn sign_key( &self, target_id: &UserId, key_id: &str, @@ -574,7 +578,7 @@ impl service::users::Data for KeyValueDatabase { &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), )?; - self.mark_device_key_update(target_id)?; + self.mark_device_key_update(target_id).await?; Ok(()) } @@ -623,8 +627,8 @@ impl service::users::Data for KeyValueDatabase { ) } - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = services().globals.next_count()?.to_be_bytes(); + async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + let count = services().globals.next_count().await?.to_be_bytes(); for room_id in services() .rooms .state_cache @@ -761,7 +765,7 @@ impl service::users::Data for KeyValueDatabase { }) } - fn add_to_device_event( + async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, @@ -773,7 +777,7 @@ impl service::users::Data for KeyValueDatabase { key.push(0xff); key.extend_from_slice(target_device_id.as_bytes()); key.push(0xff); - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()); let mut json = serde_json::Map::new(); json.insert("type".to_owned(), event_type.to_owned().into()); @@ -843,7 +847,7 @@ impl service::users::Data for KeyValueDatabase { Ok(()) } - fn update_device_metadata( + async fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, @@ -857,7 +861,8 @@ impl service::users::Data for KeyValueDatabase { assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); self.userid_devicelistversion - .increment(user_id.as_bytes())?; + .increment(user_id.as_bytes()) + .await?; self.userdeviceid_metadata.insert( &userdeviceid, diff --git a/src/database/mod.rs b/src/database/mod.rs index 1d549f19..a50d4920 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -680,7 +680,7 @@ impl KeyValueDatabase { if services().globals.database_version()? < 8 { // Generate short room ids for all rooms for (room_id, _) in db.roomid_shortstatehash.iter() { - let shortroomid = services().globals.next_count()?.to_be_bytes(); + let shortroomid = services().globals.next_count().await?.to_be_bytes(); db.roomid_shortroomid.insert(&room_id, &shortroomid)?; info!("Migration: 8"); } @@ -799,7 +799,7 @@ impl KeyValueDatabase { // Force E2EE device list updates so we can send them over federation for user_id in services().users.iter().filter_map(|r| r.ok()) { - services().users.mark_device_key_update(&user_id)?; + services().users.mark_device_key_update(&user_id).await?; } services().globals.bump_database_version(10)?; @@ -811,7 +811,8 @@ impl KeyValueDatabase { db._db .open_tree("userdevicesessionid_uiaarequest") .await? - .clear()?; + .clear() + .await?; services().globals.bump_database_version(11)?; warn!("Migration: 10 -> 11 finished"); @@ -884,12 +885,16 @@ impl KeyValueDatabase { } } - services().account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data) + .expect("to json value always works"), + ) + .await?; } services().globals.bump_database_version(12)?; @@ -930,12 +935,16 @@ impl KeyValueDatabase { .global .update_with_server_default(user_default_rules); - services().account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data) + .expect("to json value always works"), + ) + .await?; } services().globals.bump_database_version(13)?; @@ -969,12 +978,12 @@ impl KeyValueDatabase { } // This data is probably outdated - db.presenceid_presence.clear()?; + db.presenceid_presence.clear().await?; services().admin.start_handler(); // Set emergency access for the conduit user - match set_emergency_access() { + match set_emergency_access().await { Ok(pwd_set) => { if pwd_set { warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"); @@ -1107,7 +1116,7 @@ impl KeyValueDatabase { } /// Sets the emergency password and push rules for the @conduit account in case emergency password is set -fn set_emergency_access() -> Result { +async fn set_emergency_access() -> Result { let conduit_user = services().globals.server_user(); services().users.set_password( @@ -1120,15 +1129,18 @@ fn set_emergency_access() -> Result { None => (Ruleset::new(), Ok(false)), }; - services().account_data.update( - None, - conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { global: ruleset }, - }) - .expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { global: ruleset }, + }) + .expect("to json value always works"), + ) + .await?; res } diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index c7c92981..1b08a820 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,15 +1,17 @@ use std::collections::HashMap; use crate::Result; +use async_trait::async_trait; use ruma::{ events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; +#[async_trait] pub trait Data: Send + Sync { /// Places one event in the account data of the user and removes the previous entry. - fn update( + async fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index f9c49b1a..ee2f7498 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -19,14 +19,14 @@ pub struct Service { impl Service { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - pub fn update( + pub async fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, ) -> Result<()> { - self.db.update(room_id, user_id, event_type, data) + self.db.update(room_id, user_id, event_type, data).await } /// Searches the account data for a specific kind. diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 718219c4..9b4b3cfd 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -639,19 +639,22 @@ impl Service { .set_displayname(&user_id, Some(displayname))?; // Initial account data - services().account_data.update( - None, - &user_id, - ruma::events::GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: ruma::push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + &user_id, + ruma::events::GlobalAccountDataEventType::PushRules + .to_string() + .into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: ruma::push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json value always works"), + ) + .await?; // we dont add a device since we're not the user, just the creator @@ -704,7 +707,7 @@ impl Service { "Making {user_id} leave all rooms before deactivation..." )); - services().users.deactivate_account(&user_id)?; + services().users.deactivate_account(&user_id).await?; if leave_rooms { leave_all_rooms(&user_id).await?; @@ -800,7 +803,7 @@ impl Service { } for &user_id in &user_ids { - if services().users.deactivate_account(user_id).is_ok() { + if services().users.deactivate_account(user_id).await.is_ok() { deactivation_count += 1 } } @@ -1057,7 +1060,11 @@ impl Service { pub(crate) async fn create_admin_room(&self) -> Result<()> { let room_id = RoomId::new(services().globals.server_name()); - services().rooms.short.get_or_create_shortroomid(&room_id)?; + services() + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await?; let mutex_state = Arc::clone( services() @@ -1293,7 +1300,8 @@ impl Service { services() .rooms .alias - .set_alias(&alias, &room_id, conduit_user)?; + .set_alias(&alias, &room_id, conduit_user) + .await?; Ok(()) } diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5fd84539..eb0b53fb 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -69,7 +69,7 @@ impl From for SigningKeys { #[async_trait] pub trait Data: Send + Sync { - fn next_count(&self) -> Result; + async fn next_count(&self) -> Result; fn current_count(&self) -> Result; fn last_check_for_updates_id(&self) -> Result; fn update_check_for_updates_id(&self, id: u64) -> Result<()>; diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index c0c0bd44..6592b7e4 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -258,8 +258,8 @@ impl Service { } #[tracing::instrument(skip(self))] - pub fn next_count(&self) -> Result { - self.db.next_count() + pub async fn next_count(&self) -> Result { + self.db.next_count().await } #[tracing::instrument(skip(self))] diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index bf640015..6624ffb2 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -1,14 +1,16 @@ use std::collections::BTreeMap; use crate::Result; +use async_trait::async_trait; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, OwnedRoomId, RoomId, UserId, }; +#[async_trait] pub trait Data: Send + Sync { - fn create_backup( + async fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw, @@ -16,7 +18,7 @@ pub trait Data: Send + Sync { fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; - fn update_backup( + async fn update_backup( &self, user_id: &UserId, version: &str, @@ -30,7 +32,7 @@ pub trait Data: Send + Sync { fn get_backup(&self, user_id: &UserId, version: &str) -> Result>>; - fn add_key( + async fn add_key( &self, user_id: &UserId, version: &str, diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 5fc52ced..1bda7d86 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -14,25 +14,27 @@ pub struct Service { } impl Service { - pub fn create_backup( + pub async fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw, ) -> Result { - self.db.create_backup(user_id, backup_metadata) + self.db.create_backup(user_id, backup_metadata).await } pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { self.db.delete_backup(user_id, version) } - pub fn update_backup( + pub async fn update_backup( &self, user_id: &UserId, version: &str, backup_metadata: &Raw, ) -> Result { - self.db.update_backup(user_id, version, backup_metadata) + self.db + .update_backup(user_id, version, backup_metadata) + .await } pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { @@ -54,7 +56,7 @@ impl Service { self.db.get_backup(user_id, version) } - pub fn add_key( + pub async fn add_key( &self, user_id: &UserId, version: &str, @@ -64,6 +66,7 @@ impl Service { ) -> Result<()> { self.db .add_key(user_id, version, room_id, session_id, key_data) + .await } pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index c73799e4..30979d2d 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,9 +1,16 @@ use crate::Result; +use async_trait::async_trait; use ruma::{OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; +#[async_trait] pub trait Data: Send + Sync { /// Creates or updates the alias to the given room id. - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()>; + async fn set_alias( + &self, + alias: &RoomAliasId, + room_id: &RoomId, + user_id: &UserId, + ) -> Result<()>; /// Finds the user who assigned the given alias to a room fn who_created_alias(&self, alias: &RoomAliasId) -> Result>; diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index bd5693f1..af50ec48 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -19,7 +19,12 @@ pub struct Service { impl Service { #[tracing::instrument(skip(self))] - pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { + pub async fn set_alias( + &self, + alias: &RoomAliasId, + room_id: &RoomId, + user_id: &UserId, + ) -> Result<()> { if alias == services().globals.admin_alias() && user_id != services().globals.server_user() { Err(Error::BadRequest( @@ -27,7 +32,7 @@ impl Service { "Only the server user can set this alias", )) } else { - self.db.set_alias(alias, room_id, user_id) + self.db.set_alias(alias, room_id, user_id).await } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 1a8a3ad7..e947df29 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -36,7 +36,11 @@ impl Service { let mut i = 0; for id in starting_events { - let short = services().rooms.short.get_or_create_shorteventid(&id)?; + let short = services() + .rooms + .short + .get_or_create_shorteventid(&id) + .await?; let bucket_id = (short % NUM_BUCKETS as u64) as usize; buckets[bucket_id].insert((short, id.clone())); i += 1; @@ -80,7 +84,7 @@ impl Service { chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; - let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); + let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id).await?); services() .rooms .auth_chain @@ -125,7 +129,11 @@ impl Service { } #[tracing::instrument(skip(self, event_id))] - fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + async fn get_auth_chain_inner( + &self, + room_id: &RoomId, + event_id: &EventId, + ) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); @@ -142,7 +150,8 @@ impl Service { let sauthevent = services() .rooms .short - .get_or_create_shorteventid(auth_event)?; + .get_or_create_shorteventid(auth_event) + .await?; if !found.contains(&sauthevent) { found.insert(sauthevent); diff --git a/src/service/rooms/edus/presence/data.rs b/src/service/rooms/edus/presence/data.rs index 53329e08..a9e54558 100644 --- a/src/service/rooms/edus/presence/data.rs +++ b/src/service/rooms/edus/presence/data.rs @@ -1,14 +1,16 @@ use std::collections::HashMap; use crate::Result; +use async_trait::async_trait; use ruma::{events::presence::PresenceEvent, OwnedUserId, RoomId, UserId}; +#[async_trait] pub trait Data: Send + Sync { /// Adds a presence event which will be saved until a new event replaces it. /// /// Note: This method takes a RoomId because presence updates are always bound to rooms to /// make sure users outside these rooms can't see them. - fn update_presence( + async fn update_presence( &self, user_id: &UserId, room_id: &RoomId, diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index 41a33eed..d0761ad7 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -1,9 +1,11 @@ use crate::Result; +use async_trait::async_trait; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; +#[async_trait] pub trait Data: Send + Sync { /// Replaces the previous read receipt. - fn readreceipt_update( + async fn readreceipt_update( &self, user_id: &UserId, room_id: &RoomId, @@ -28,7 +30,7 @@ pub trait Data: Send + Sync { >; /// Sets a private read marker at `count`. - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; + async fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; /// Returns the private read marker. fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result>; diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs index c6035280..89f07fa7 100644 --- a/src/service/rooms/edus/read_receipt/mod.rs +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -11,13 +11,13 @@ pub struct Service { impl Service { /// Replaces the previous read receipt. - pub fn readreceipt_update( + pub async fn readreceipt_update( &self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent, ) -> Result<()> { - self.db.readreceipt_update(user_id, room_id, event) + self.db.readreceipt_update(user_id, room_id, event).await } /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. @@ -38,8 +38,13 @@ impl Service { /// Sets a private read marker at `count`. #[tracing::instrument(skip(self))] - pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - self.db.private_read_set(room_id, user_id, count) + pub async fn private_read_set( + &self, + room_id: &RoomId, + user_id: &UserId, + count: u64, + ) -> Result<()> { + self.db.private_read_set(room_id, user_id, count).await } /// Returns the private read marker. diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index 7546aa84..f86162d7 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -23,7 +23,7 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), services().globals.next_count().await?); let _ = self.typing_update_sender.send(room_id.to_owned()); Ok(()) } @@ -39,7 +39,7 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), services().globals.next_count().await?); let _ = self.typing_update_sender.send(room_id.to_owned()); Ok(()) } @@ -80,7 +80,7 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), services().globals.next_count().await?); let _ = self.typing_update_sender.send(room_id.to_owned()); } Ok(()) diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 0bdfd4ae..21cfdd48 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -589,10 +589,11 @@ impl Service { })?; if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &prev_pdu.kind.to_string().into(), - state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) + .await?; state.insert(shortstatekey, Arc::from(prev_event)); // Now it's the state after the pdu @@ -640,10 +641,14 @@ impl Service { .await?; if let Some(state_key) = &prev_event.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &prev_event.kind.to_string().into(), - state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey( + &prev_event.kind.to_string().into(), + state_key, + ) + .await?; leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); // Now it's the state after the pdu } @@ -677,34 +682,38 @@ impl Service { let lock = services().globals.stateres_mutex.lock(); - let result = + let new_state = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { let res = services().rooms.timeline.get_pdu(id); if let Err(e) = &res { error!("LOOK AT ME Failed to fetch event: {}", e); } res.ok().flatten() - }); + }) + .map_err(|e| { + warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); + e + }) + .ok(); drop(lock); - state_at_incoming_event = match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = - services().rooms.short.get_or_create_shortstatekey( - &event_type.to_string().into(), - &state_key, - )?; - Ok((shortstatekey, event_id)) - }) - .collect::>()?, - ), - Err(e) => { - warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); - None + state_at_incoming_event = match new_state { + Some(new_state) => { + let mut state_at_incoming_event = HashMap::with_capacity(new_state.len()); + for ((event_type, state_key), event_id) in new_state { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey( + &event_type.to_string().into(), + &state_key, + ) + .await?; + state_at_incoming_event.insert(shortstatekey, event_id); + } + Some(state_at_incoming_event) } + None => None, } } } @@ -748,10 +757,11 @@ impl Service { Error::bad_database("Found non-state pdu in state events.") })?; - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &pdu.kind.to_string().into(), - &state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) + .await?; match state.entry(shortstatekey) { hash_map::Entry::Vacant(v) => { @@ -915,17 +925,17 @@ impl Service { }); debug!("Compressing state at event"); - let state_ids_compressed = Arc::new( - state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - services() - .rooms - .state_compressor - .compress_state_event(*shortstatekey, id) - }) - .collect::>()?, - ); + let mut state_ids_compressed = HashSet::new(); + for (shortstatekey, id) in &state_at_incoming_event { + state_ids_compressed.insert( + services() + .rooms + .state_compressor + .compress_state_event(*shortstatekey, id) + .await?, + ); + } + let state_ids_compressed = Arc::new(state_ids_compressed); if incoming_pdu.state_key.is_some() { debug!("Preparing for stateres to derive new room state"); @@ -933,10 +943,11 @@ impl Service { // We also add state after incoming event to the fork states let mut state_after = state_at_incoming_event.clone(); if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &incoming_pdu.kind.to_string().into(), - state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) + .await?; state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); } @@ -951,7 +962,8 @@ impl Service { let (sstatehash, new, removed) = services() .rooms .state_compressor - .save_state(room_id, new_room_state)?; + .save_state(room_id, new_room_state) + .await?; services() .rooms @@ -1078,35 +1090,32 @@ impl Service { }; let lock = services().globals.stateres_mutex.lock(); - let state = match state_res::resolve( + let state = state_res::resolve( room_version_id, &fork_states, auth_chain_sets, fetch_event, - ) { - Ok(new_state) => new_state, - Err(_) => { - return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization")); - } - }; + ).map_err(|_| Error::bad_database("State resolution failed, either an event could not be found or deserialization"))?; drop(lock); debug!("State resolution done. Compressing state"); - let new_room_state = state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; + let mut new_room_state = HashSet::new(); + for ((event_type, state_key), event_id) in state { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) + .await?; + new_room_state.insert( services() .rooms .state_compressor .compress_state_event(shortstatekey, &event_id) - }) - .collect::>()?; + .await?, + ); + } Ok(Arc::new(new_room_state)) } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 5ffe8846..df55df33 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -41,7 +41,7 @@ impl Service { } #[allow(clippy::too_many_arguments)] - pub fn paginate_relations_with_filter( + pub async fn paginate_relations_with_filter( &self, sender_user: &UserId, room_id: &RoomId, @@ -77,13 +77,11 @@ impl Service { match dir { Direction::Forward => { - let relations_until = &services().rooms.pdu_metadata.relations_until( - sender_user, - room_id, - target, - from, - depth, - )?; + let relations_until = &services() + .rooms + .pdu_metadata + .relations_until(sender_user, room_id, target, from, depth) + .await?; let events_after: Vec<_> = relations_until // TODO: should be relations_after .iter() .filter(|(_, pdu)| { @@ -125,13 +123,11 @@ impl Service { }) } Direction::Backward => { - let relations_until = &services().rooms.pdu_metadata.relations_until( - sender_user, - room_id, - target, - from, - depth, - )?; + let relations_until = &services() + .rooms + .pdu_metadata + .relations_until(sender_user, room_id, target, from, depth) + .await?; let events_before: Vec<_> = relations_until .iter() .filter(|(_, pdu)| { @@ -174,7 +170,7 @@ impl Service { } } - pub fn relations_until<'a>( + pub async fn relations_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, @@ -182,7 +178,11 @@ impl Service { until: PduCount, max_depth: u8, ) -> Result> { - let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; + let room_id = services() + .rooms + .short + .get_or_create_shortroomid(room_id) + .await?; let target = match services().rooms.timeline.get_pdu_count(target)? { Some(PduCount::Normal(c)) => c, // TODO: Support backfilled relations diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 652c525b..b22cdc6b 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,10 +1,12 @@ use std::sync::Arc; use crate::Result; +use async_trait::async_trait; use ruma::{events::StateEventType, EventId, RoomId}; +#[async_trait] pub trait Data: Send + Sync { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; + async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; fn get_shortstatekey( &self, @@ -12,7 +14,7 @@ pub trait Data: Send + Sync { state_key: &str, ) -> Result>; - fn get_or_create_shortstatekey( + async fn get_or_create_shortstatekey( &self, event_type: &StateEventType, state_key: &str, @@ -23,9 +25,9 @@ pub trait Data: Send + Sync { fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; + async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; fn get_shortroomid(&self, room_id: &RoomId) -> Result>; - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result; + async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 45fadd74..6ec5a030 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -11,8 +11,8 @@ pub struct Service { } impl Service { - pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - self.db.get_or_create_shorteventid(event_id) + pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { + self.db.get_or_create_shorteventid(event_id).await } pub fn get_shortstatekey( @@ -23,12 +23,14 @@ impl Service { self.db.get_shortstatekey(event_type, state_key) } - pub fn get_or_create_shortstatekey( + pub async fn get_or_create_shortstatekey( &self, event_type: &StateEventType, state_key: &str, ) -> Result { - self.db.get_or_create_shortstatekey(event_type, state_key) + self.db + .get_or_create_shortstatekey(event_type, state_key) + .await } pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { @@ -40,15 +42,15 @@ impl Service { } /// Returns (shortstatehash, already_existed) - pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - self.db.get_or_create_shortstatehash(state_hash) + pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + self.db.get_or_create_shortstatehash(state_hash).await } pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.db.get_shortroomid(room_id) } - pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - self.db.get_or_create_shortroomid(room_id) + pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + self.db.get_or_create_shortroomid(room_id).await } } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index f6581bb5..e76d88f9 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -80,14 +80,11 @@ impl Service { Err(_) => continue, }; - services().rooms.state_cache.update_membership( - room_id, - &user_id, - membership, - &pdu.sender, - None, - false, - )?; + services() + .rooms + .state_cache + .update_membership(room_id, &user_id, membership, &pdu.sender, None, false) + .await?; } TimelineEventType::SpaceChild => { services() @@ -115,7 +112,7 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed))] - pub fn set_event_state( + pub async fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, @@ -124,7 +121,8 @@ impl Service { let shorteventid = services() .rooms .short - .get_or_create_shorteventid(event_id)?; + .get_or_create_shorteventid(event_id) + .await?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; @@ -138,7 +136,8 @@ impl Service { let (shortstatehash, already_existed) = services() .rooms .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await?; if !already_existed { let states_parents = previous_shortstatehash.map_or_else( @@ -187,11 +186,12 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu))] - pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result { let shorteventid = services() .rooms .short - .get_or_create_shorteventid(&new_pdu.event_id)?; + .get_or_create_shorteventid(&new_pdu.event_id) + .await?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; @@ -213,12 +213,14 @@ impl Service { let shortstatekey = services() .rooms .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key) + .await?; let new = services() .rooms .state_compressor - .compress_state_event(shortstatekey, &new_pdu.event_id)?; + .compress_state_event(shortstatekey, &new_pdu.event_id) + .await?; let replaces = states_parents .last() @@ -234,7 +236,7 @@ impl Service { } // TODO: statehash with deterministic inputs - let shortstatehash = services().globals.next_count()?; + let shortstatehash = services().globals.next_count().await?; let mut statediffnew = HashSet::new(); statediffnew.insert(new); diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 53e3176f..2a678830 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -327,6 +327,7 @@ impl Service { .rooms .timeline .create_hash_and_sign_event(new_event, sender, room_id, state_lock) + .await .is_ok()) } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 76dcc6cc..c1134992 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,22 +1,24 @@ use std::{collections::HashSet, sync::Arc}; use crate::{service::appservice::RegistrationInfo, Result}; +use async_trait::async_trait; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; +#[async_trait] pub trait Data: Send + Sync { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - fn mark_as_invited( + async fn mark_as_invited( &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, ) -> Result<()>; - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + async fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index c108695d..500dc9eb 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -25,7 +25,7 @@ pub struct Service { impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] - pub fn update_membership( + pub async fn update_membership( &self, room_id: &RoomId, user_id: &UserId, @@ -103,6 +103,7 @@ impl Service { RoomAccountDataEventType::Tag, &tag_event?, ) + .await .ok(); }; @@ -132,13 +133,16 @@ impl Service { } if room_ids_updated { - services().account_data.update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &serde_json::to_value(&direct_event) - .expect("to json always works"), - )?; + services() + .account_data + .update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event) + .expect("to json always works"), + ) + .await?; } }; } @@ -176,10 +180,12 @@ impl Service { return Ok(()); } - self.db.mark_as_invited(user_id, room_id, last_state)?; + self.db + .mark_as_invited(user_id, room_id, last_state) + .await?; } MembershipState::Leave | MembershipState::Ban => { - self.db.mark_as_left(user_id, room_id)?; + self.db.mark_as_left(user_id, room_id).await?; } _ => {} } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 6118e06b..f3bb6816 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -89,7 +89,7 @@ impl Service { } } - pub fn compress_state_event( + pub async fn compress_state_event( &self, shortstatekey: u64, event_id: &EventId, @@ -99,7 +99,8 @@ impl Service { &services() .rooms .short - .get_or_create_shorteventid(event_id)? + .get_or_create_shorteventid(event_id) + .await? .to_be_bytes(), ); Ok(v.try_into().expect("we checked the size above")) @@ -257,7 +258,7 @@ impl Service { /// Returns the new shortstatehash, and the state diff from the previous room state #[allow(clippy::type_complexity)] - pub fn save_state( + pub async fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, @@ -278,7 +279,8 @@ impl Service { let (new_shortstatehash, already_existed) = services() .rooms .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await?; if Some(new_shortstatehash) == previous_shortstatehash { return Ok(( diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 29d8339d..416aa6d0 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -267,20 +267,22 @@ impl Service { ); let insert_lock = mutex_insert.lock().await; - let count1 = services().globals.next_count()?; + let count1 = services().globals.next_count().await?; // Mark as read first so the sending client doesn't get a notification even if appending // fails services() .rooms .edus .read_receipt - .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + .private_read_set(&pdu.room_id, &pdu.sender, count1) + .await?; services() .rooms .user - .reset_notification_counts(&pdu.sender, &pdu.room_id)?; + .reset_notification_counts(&pdu.sender, &pdu.room_id) + .await?; - let count2 = services().globals.next_count()?; + let count2 = services().globals.next_count().await?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); @@ -373,7 +375,10 @@ impl Service { } for push_key in services().pusher.get_pushkeys(user) { - services().sending.send_push_pdu(&pdu_id, user, push_key?)?; + services() + .sending + .send_push_pdu(&pdu_id, user, push_key?) + .await?; } } @@ -460,14 +465,18 @@ impl Service { // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - services().rooms.state_cache.update_membership( - &pdu.room_id, - &target_user_id, - content.membership, - &pdu.sender, - invite_state, - true, - )?; + services() + .rooms + .state_cache + .update_membership( + &pdu.room_id, + &target_user_id, + content.membership, + &pdu.sender, + invite_state, + true, + ) + .await?; } } TimelineEventType::RoomMessage => { @@ -578,7 +587,8 @@ impl Service { { services() .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone()) + .await?; continue; } @@ -592,10 +602,10 @@ impl Service { { let appservice_uid = appservice.registration.sender_localpart.as_str(); if state_key_uid == appservice_uid { - services().sending.send_pdu_appservice( - appservice.registration.id.clone(), - pdu_id.clone(), - )?; + services() + .sending + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone()) + .await?; continue; } } @@ -645,14 +655,15 @@ impl Service { { services() .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone()) + .await?; } } Ok(pdu_id) } - pub fn create_hash_and_sign_event( + pub async fn create_hash_and_sign_event( &self, pdu_builder: PduBuilder, sender: &UserId, @@ -827,7 +838,8 @@ impl Service { let _shorteventid = services() .rooms .short - .get_or_create_shorteventid(&pdu.event_id)?; + .get_or_create_shorteventid(&pdu.event_id) + .await?; Ok((pdu, pdu_json)) } @@ -842,8 +854,9 @@ impl Service { room_id: &RoomId, state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result> { - let (pdu, pdu_json) = - self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; + let (pdu, pdu_json) = self + .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock) + .await?; if let Some(admin_room) = services().admin.get_admin_room()? { if admin_room == room_id { @@ -986,7 +999,7 @@ impl Service { // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = services().rooms.state.append_to_state(&pdu)?; + let statehashid = services().rooms.state.append_to_state(&pdu).await?; let pdu_id = self .append_pdu( @@ -1027,7 +1040,10 @@ impl Service { // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above servers.remove(services().globals.server_name()); - services().sending.send_pdu(servers.into_iter(), &pdu_id)?; + services() + .sending + .send_pdu(servers.into_iter(), &pdu_id) + .await?; Ok(pdu.event_id) } @@ -1046,11 +1062,11 @@ impl Service { ) -> Result>> { // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - services().rooms.state.set_event_state( - &pdu.event_id, - &pdu.room_id, - state_ids_compressed, - )?; + services() + .rooms + .state + .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed) + .await?; if soft_fail { services() @@ -1264,7 +1280,7 @@ impl Service { ); let insert_lock = mutex_insert.lock().await; - let count = services().globals.next_count()?; + let count = services().globals.next_count().await?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 5544af2c..4e37374c 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,8 +1,10 @@ use crate::Result; +use async_trait::async_trait; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +#[async_trait] pub trait Data: Send + Sync { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + async fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 672e502d..7385325e 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -10,8 +10,12 @@ pub struct Service { } impl Service { - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.reset_notification_counts(user_id, room_id) + pub async fn reset_notification_counts( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<()> { + self.db.reset_notification_counts(user_id, room_id).await } pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 78d3f1e1..060200e3 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,9 +1,11 @@ +use async_trait::async_trait; use ruma::ServerName; use crate::Result; use super::{OutgoingKind, SendingEventType}; +#[async_trait] pub trait Data: Send + Sync { #[allow(clippy::type_complexity)] fn active_requests<'a>( @@ -16,7 +18,7 @@ pub trait Data: Send + Sync { fn delete_active_request(&self, key: Vec) -> Result<()>; fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; - fn queue_requests( + async fn queue_requests( &self, requests: &[(&OutgoingKind, SendingEventType)], ) -> Result>>; diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index fa14f123..052554e6 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -370,10 +370,13 @@ impl Service { } #[tracing::instrument(skip(self, pdu_id, user, pushkey))] - pub fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + pub async fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); let event = SendingEventType::Pdu(pdu_id.to_owned()); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = self + .db + .queue_requests(&[(&outgoing_kind, event.clone())]) + .await?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); @@ -382,7 +385,7 @@ impl Service { } #[tracing::instrument(skip(self, servers, pdu_id))] - pub fn send_pdu>( + pub async fn send_pdu>( &self, servers: I, pdu_id: &[u8], @@ -396,12 +399,15 @@ impl Service { ) }) .collect::>(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; + let keys = self + .db + .queue_requests( + &requests + .iter() + .map(|(o, e)| (o, e.clone())) + .collect::>(), + ) + .await?; for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { self.sender .send((outgoing_kind.to_owned(), event, key)) @@ -412,7 +418,7 @@ impl Service { } #[tracing::instrument(skip(self, server, serialized))] - pub fn send_reliable_edu( + pub async fn send_reliable_edu( &self, server: &ServerName, serialized: Vec, @@ -420,7 +426,10 @@ impl Service { ) -> Result<()> { let outgoing_kind = OutgoingKind::Normal(server.to_owned()); let event = SendingEventType::Edu(serialized); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = self + .db + .queue_requests(&[(&outgoing_kind, event.clone())]) + .await?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); @@ -429,10 +438,13 @@ impl Service { } #[tracing::instrument(skip(self))] - pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + pub async fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { let outgoing_kind = OutgoingKind::Appservice(appservice_id); let event = SendingEventType::Pdu(pdu_id); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = self + .db + .queue_requests(&[(&outgoing_kind, event.clone())]) + .await?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 75d7eb2c..9d4c8eb9 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,4 +1,5 @@ use crate::Result; +use async_trait::async_trait; use ruma::{ api::client::{device::Device, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, @@ -9,6 +10,7 @@ use ruma::{ }; use std::collections::BTreeMap; +#[async_trait] pub trait Data: Send + Sync { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result; @@ -55,7 +57,7 @@ pub trait Data: Send + Sync { fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; /// Adds a new device to a user. - fn create_device( + async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -64,7 +66,7 @@ pub trait Data: Send + Sync { ) -> Result<()>; /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; /// Returns an iterator over all device ids of this user. fn all_device_ids<'a>( @@ -75,7 +77,7 @@ pub trait Data: Send + Sync { /// Replaces the access token of one device. fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; - fn add_one_time_key( + async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -85,7 +87,7 @@ pub trait Data: Send + Sync { fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; - fn take_one_time_key( + async fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -98,14 +100,14 @@ pub trait Data: Send + Sync { device_id: &DeviceId, ) -> Result>; - fn add_device_keys( + async fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, ) -> Result<()>; - fn add_cross_signing_keys( + async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, @@ -114,7 +116,7 @@ pub trait Data: Send + Sync { notify: bool, ) -> Result<()>; - fn sign_key( + async fn sign_key( &self, target_id: &UserId, key_id: &str, @@ -129,7 +131,7 @@ pub trait Data: Send + Sync { to: Option, ) -> Box> + 'a>; - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; + async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; fn get_device_keys( &self, @@ -167,7 +169,7 @@ pub trait Data: Send + Sync { fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; - fn add_to_device_event( + async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, @@ -189,7 +191,7 @@ pub trait Data: Send + Sync { until: u64, ) -> Result<()>; - fn update_device_metadata( + async fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index a5694a10..1da8d5a9 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -340,7 +340,7 @@ impl Service { } /// Adds a new device to a user. - pub fn create_device( + pub async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -349,18 +349,19 @@ impl Service { ) -> Result<()> { self.db .create_device(user_id, device_id, token, initial_device_display_name) + .await } /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.remove_device(user_id, device_id) + pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + self.db.remove_device(user_id, device_id).await } /// Returns an iterator over all device ids of this user. pub fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator> + 'a { + ) -> impl Send + Iterator> + 'a { self.db.all_device_ids(user_id) } @@ -369,7 +370,7 @@ impl Service { self.db.set_token(user_id, device_id, token) } - pub fn add_one_time_key( + pub async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -378,19 +379,22 @@ impl Service { ) -> Result<()> { self.db .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) + .await } pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { self.db.last_one_time_keys_update(user_id) } - pub fn take_one_time_key( + pub async fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, ) -> Result)>> { - self.db.take_one_time_key(user_id, device_id, key_algorithm) + self.db + .take_one_time_key(user_id, device_id, key_algorithm) + .await } pub fn count_one_time_keys( @@ -401,16 +405,18 @@ impl Service { self.db.count_one_time_keys(user_id, device_id) } - pub fn add_device_keys( + pub async fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, ) -> Result<()> { - self.db.add_device_keys(user_id, device_id, device_keys) + self.db + .add_device_keys(user_id, device_id, device_keys) + .await } - pub fn add_cross_signing_keys( + pub async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, @@ -418,23 +424,27 @@ impl Service { user_signing_key: &Option>, notify: bool, ) -> Result<()> { - self.db.add_cross_signing_keys( - user_id, - master_key, - self_signing_key, - user_signing_key, - notify, - ) + self.db + .add_cross_signing_keys( + user_id, + master_key, + self_signing_key, + user_signing_key, + notify, + ) + .await } - pub fn sign_key( + pub async fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, ) -> Result<()> { - self.db.sign_key(target_id, key_id, signature, sender_id) + self.db + .sign_key(target_id, key_id, signature, sender_id) + .await } pub fn keys_changed<'a>( @@ -446,8 +456,8 @@ impl Service { self.db.keys_changed(user_or_room_id, from, to) } - pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - self.db.mark_device_key_update(user_id) + pub async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + self.db.mark_device_key_update(user_id).await } pub fn get_device_keys( @@ -501,7 +511,7 @@ impl Service { self.db.get_user_signing_key(user_id) } - pub fn add_to_device_event( + pub async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, @@ -509,13 +519,15 @@ impl Service { event_type: &str, content: serde_json::Value, ) -> Result<()> { - self.db.add_to_device_event( - sender, - target_user_id, - target_device_id, - event_type, - content, - ) + self.db + .add_to_device_event( + sender, + target_user_id, + target_device_id, + event_type, + content, + ) + .await } pub fn get_to_device_events( @@ -535,13 +547,15 @@ impl Service { self.db.remove_to_device_events(user_id, device_id, until) } - pub fn update_device_metadata( + pub async fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, device: &Device, ) -> Result<()> { - self.db.update_device_metadata(user_id, device_id, device) + self.db + .update_device_metadata(user_id, device_id, device) + .await } /// Get device metadata. @@ -565,10 +579,10 @@ impl Service { } /// Deactivate account - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> { // Remove all associated devices for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; + self.remove_device(user_id, &device_id?).await?; } // Set the password to "" to indicate a deactivated account. Hashes will never result in an