diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 93660f9f..6086ba20 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -1,6 +1,7 @@ use super::Config; use crate::Result; +use async_trait::async_trait; use std::{future::Future, pin::Pin, sync::Arc}; #[cfg(feature = "sled")] @@ -26,16 +27,17 @@ pub mod persy; ))] pub mod watchers; +#[async_trait] pub trait KeyValueDatabaseEngine: Send + Sync { - fn open(config: &Config) -> Result + async fn open(config: &Config) -> Result where Self: Sized; - fn open_tree(&self, name: &'static str) -> Result>; - fn flush(&self) -> Result<()>; - fn cleanup(&self) -> Result<()> { + async fn open_tree(&self, name: &'static str) -> Result>; + async fn flush(&self) -> Result<()>; + async fn cleanup(&self) -> Result<()> { Ok(()) } - fn memory_usage(&self) -> Result { + async fn memory_usage(&self) -> Result { Ok("Current database engine does not support memory usage reporting.".to_owned()) } } diff --git a/src/database/abstraction/persy.rs b/src/database/abstraction/persy.rs index da7d4cf0..31197809 100644 --- a/src/database/abstraction/persy.rs +++ b/src/database/abstraction/persy.rs @@ -5,6 +5,7 @@ use crate::{ }, Result, }; +use async_trait::async_trait; use persy::{ByteVec, OpenOptions, Persy, Transaction, TransactionConfig, ValueMode}; use std::{future::Future, pin::Pin, sync::Arc}; @@ -15,8 +16,9 @@ pub struct Engine { persy: Persy, } +#[async_trait] impl KeyValueDatabaseEngine for Arc { - fn open(config: &Config) -> Result { + async fn open(config: &Config) -> Result { let mut cfg = persy::Config::new(); cfg.change_cache_size((config.db_cache_capacity_mb * 1024.0 * 1024.0) as u64); @@ -27,7 +29,7 @@ impl KeyValueDatabaseEngine for Arc { Ok(Arc::new(Engine { persy })) } - fn open_tree(&self, name: &'static str) -> Result> { + async fn open_tree(&self, name: &'static str) -> Result> { // Create if it doesn't exist if !self.persy.exists_index(name)? { let mut tx = self.persy.begin()?; @@ -42,7 +44,7 @@ impl KeyValueDatabaseEngine for Arc { })) } - fn flush(&self) -> Result<()> { + async fn flush(&self) -> Result<()> { Ok(()) } } diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index cf77e3dd..0dddfea0 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -1,5 +1,6 @@ use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{utils, Result}; +use async_trait::async_trait; use std::{ future::Future, pin::Pin, @@ -56,8 +57,9 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O db_opts } +#[async_trait] impl KeyValueDatabaseEngine for Arc { - fn open(config: &Config) -> Result { + async fn open(config: &Config) -> Result { let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes); @@ -88,7 +90,7 @@ impl KeyValueDatabaseEngine for Arc { })) } - fn open_tree(&self, name: &'static str) -> Result> { + async fn open_tree(&self, name: &'static str) -> Result> { if !self.old_cfs.contains(&name.to_owned()) { // Create if it didn't exist let _ = self @@ -104,12 +106,12 @@ impl KeyValueDatabaseEngine for Arc { })) } - fn flush(&self) -> Result<()> { + async fn flush(&self) -> Result<()> { // TODO? Ok(()) } - fn memory_usage(&self) -> Result { + async fn memory_usage(&self) -> Result { let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?; Ok(format!( diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index b448c3b6..a94b78d2 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,5 +1,6 @@ use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{database::Config, Result}; +use async_trait::async_trait; use parking_lot::{Mutex, MutexGuard}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use std::{ @@ -80,8 +81,9 @@ impl Engine { } } +#[async_trait] impl KeyValueDatabaseEngine for Arc { - fn open(config: &Config) -> Result { + async fn open(config: &Config) -> Result { let path = Path::new(&config.database_path).join("conduit.db"); // calculates cache-size per permanent connection @@ -105,7 +107,7 @@ impl KeyValueDatabaseEngine for Arc { Ok(arc) } - fn open_tree(&self, name: &str) -> Result> { + async fn open_tree(&self, name: &'static str) -> Result> { self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?; Ok(Arc::new(SqliteTable { @@ -115,12 +117,12 @@ impl KeyValueDatabaseEngine for Arc { })) } - fn flush(&self) -> Result<()> { + async fn flush(&self) -> Result<()> { // we enabled PRAGMA synchronous=normal, so this should not be necessary Ok(()) } - fn cleanup(&self) -> Result<()> { + async fn cleanup(&self) -> Result<()> { self.flush_wal() } } diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index bd47cb42..968d6420 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -139,11 +139,11 @@ impl service::globals::Data for KeyValueDatabase { Ok(()) } - fn cleanup(&self) -> Result<()> { - self._db.cleanup() + async fn cleanup(&self) -> Result<()> { + self._db.cleanup().await } - fn memory_usage(&self) -> String { + async fn memory_usage(&self) -> String { let pdu_cache = self.pdu_cache.lock().unwrap().len(); let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len(); let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); @@ -164,7 +164,7 @@ our_real_users_cache: {our_real_users_cache} appservice_in_room_cache: {appservice_in_room_cache} lasttimelinecount_cache: {lasttimelinecount_cache}\n" ); - if let Ok(db_stats) = self._db.memory_usage() { + if let Ok(db_stats) = self._db.memory_usage().await { response += &db_stats; } diff --git a/src/database/mod.rs b/src/database/mod.rs index 5171d4bb..1d549f19 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -249,19 +249,19 @@ impl KeyValueDatabase { #[cfg(not(feature = "sqlite"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "sqlite")] - Arc::new(Arc::::open(&config)?) + Arc::new(Arc::::open(&config).await?) } "rocksdb" => { #[cfg(not(feature = "rocksdb"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "rocksdb")] - Arc::new(Arc::::open(&config)?) + Arc::new(Arc::::open(&config).await?) } "persy" => { #[cfg(not(feature = "persy"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "persy")] - Arc::new(Arc::::open(&config)?) + Arc::new(Arc::::open(&config).await?) } _ => { return Err(Error::BadConfig("Database backend not found.")); @@ -278,101 +278,102 @@ impl KeyValueDatabase { let db_raw = Box::new(Self { _db: builder.clone(), - userid_password: builder.open_tree("userid_password")?, - userid_displayname: builder.open_tree("userid_displayname")?, - userid_avatarurl: builder.open_tree("userid_avatarurl")?, - userid_blurhash: builder.open_tree("userid_blurhash")?, - userdeviceid_token: builder.open_tree("userdeviceid_token")?, - userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, - userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, - token_userdeviceid: builder.open_tree("token_userdeviceid")?, - onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, - userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, - keychangeid_userid: builder.open_tree("keychangeid_userid")?, - keyid_key: builder.open_tree("keyid_key")?, - userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, - userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, - userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, - openidtoken_expiresatuserid: builder.open_tree("openidtoken_expiresatuserid")?, - userfilterid_filter: builder.open_tree("userfilterid_filter")?, - todeviceid_events: builder.open_tree("todeviceid_events")?, + userid_password: builder.open_tree("userid_password").await?, + userid_displayname: builder.open_tree("userid_displayname").await?, + userid_avatarurl: builder.open_tree("userid_avatarurl").await?, + userid_blurhash: builder.open_tree("userid_blurhash").await?, + userdeviceid_token: builder.open_tree("userdeviceid_token").await?, + userdeviceid_metadata: builder.open_tree("userdeviceid_metadata").await?, + userid_devicelistversion: builder.open_tree("userid_devicelistversion").await?, + token_userdeviceid: builder.open_tree("token_userdeviceid").await?, + onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys").await?, + userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate").await?, + keychangeid_userid: builder.open_tree("keychangeid_userid").await?, + keyid_key: builder.open_tree("keyid_key").await?, + userid_masterkeyid: builder.open_tree("userid_masterkeyid").await?, + userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid").await?, + userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid").await?, + openidtoken_expiresatuserid: builder.open_tree("openidtoken_expiresatuserid").await?, + userfilterid_filter: builder.open_tree("userfilterid_filter").await?, + todeviceid_events: builder.open_tree("todeviceid_events").await?, - userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo").await?, userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, - roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt + readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt").await?, + roomuserid_privateread: builder.open_tree("roomuserid_privateread").await?, // "Private" read receipt roomuserid_lastprivatereadupdate: builder - .open_tree("roomuserid_lastprivatereadupdate")?, - presenceid_presence: builder.open_tree("presenceid_presence")?, - userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, - pduid_pdu: builder.open_tree("pduid_pdu")?, - eventid_pduid: builder.open_tree("eventid_pduid")?, - roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, + .open_tree("roomuserid_lastprivatereadupdate") + .await?, + presenceid_presence: builder.open_tree("presenceid_presence").await?, + userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate").await?, + pduid_pdu: builder.open_tree("pduid_pdu").await?, + eventid_pduid: builder.open_tree("eventid_pduid").await?, + roomid_pduleaves: builder.open_tree("roomid_pduleaves").await?, - alias_roomid: builder.open_tree("alias_roomid")?, - aliasid_alias: builder.open_tree("aliasid_alias")?, - publicroomids: builder.open_tree("publicroomids")?, + alias_roomid: builder.open_tree("alias_roomid").await?, + aliasid_alias: builder.open_tree("aliasid_alias").await?, + publicroomids: builder.open_tree("publicroomids").await?, - threadid_userids: builder.open_tree("threadid_userids")?, + threadid_userids: builder.open_tree("threadid_userids").await?, - tokenids: builder.open_tree("tokenids")?, + tokenids: builder.open_tree("tokenids").await?, - roomserverids: builder.open_tree("roomserverids")?, - serverroomids: builder.open_tree("serverroomids")?, - userroomid_joined: builder.open_tree("userroomid_joined")?, - roomuserid_joined: builder.open_tree("roomuserid_joined")?, - roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, - roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, - roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, - userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, - roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, - userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, - roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + roomserverids: builder.open_tree("roomserverids").await?, + serverroomids: builder.open_tree("serverroomids").await?, + userroomid_joined: builder.open_tree("userroomid_joined").await?, + roomuserid_joined: builder.open_tree("roomuserid_joined").await?, + roomid_joinedcount: builder.open_tree("roomid_joinedcount").await?, + roomid_invitedcount: builder.open_tree("roomid_invitedcount").await?, + roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids").await?, + userroomid_invitestate: builder.open_tree("userroomid_invitestate").await?, + roomuserid_invitecount: builder.open_tree("roomuserid_invitecount").await?, + userroomid_leftstate: builder.open_tree("userroomid_leftstate").await?, + roomuserid_leftcount: builder.open_tree("roomuserid_leftcount").await?, - alias_userid: builder.open_tree("alias_userid")?, + alias_userid: builder.open_tree("alias_userid").await?, - disabledroomids: builder.open_tree("disabledroomids")?, + disabledroomids: builder.open_tree("disabledroomids").await?, - lazyloadedids: builder.open_tree("lazyloadedids")?, + lazyloadedids: builder.open_tree("lazyloadedids").await?, - userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, - userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, - roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, + userroomid_notificationcount: builder.open_tree("userroomid_notificationcount").await?, + userroomid_highlightcount: builder.open_tree("userroomid_highlightcount").await?, + roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount").await?, - statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, - shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, + statekey_shortstatekey: builder.open_tree("statekey_shortstatekey").await?, + shortstatekey_statekey: builder.open_tree("shortstatekey_statekey").await?, - shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, + shorteventid_authchain: builder.open_tree("shorteventid_authchain").await?, - roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, + roomid_shortroomid: builder.open_tree("roomid_shortroomid").await?, - shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, - eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, - shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, - shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, - roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, - roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, - statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, + shortstatehash_statediff: builder.open_tree("shortstatehash_statediff").await?, + eventid_shorteventid: builder.open_tree("eventid_shorteventid").await?, + shorteventid_eventid: builder.open_tree("shorteventid_eventid").await?, + shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash").await?, + roomid_shortstatehash: builder.open_tree("roomid_shortstatehash").await?, + roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash").await?, + statehash_shortstatehash: builder.open_tree("statehash_shortstatehash").await?, - eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, - softfailedeventids: builder.open_tree("softfailedeventids")?, + eventid_outlierpdu: builder.open_tree("eventid_outlierpdu").await?, + softfailedeventids: builder.open_tree("softfailedeventids").await?, - tofrom_relation: builder.open_tree("tofrom_relation")?, - referencedevents: builder.open_tree("referencedevents")?, - roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, - roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, - mediaid_file: builder.open_tree("mediaid_file")?, - backupid_algorithm: builder.open_tree("backupid_algorithm")?, - backupid_etag: builder.open_tree("backupid_etag")?, - backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, - userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, - servername_educount: builder.open_tree("servername_educount")?, - servernameevent_data: builder.open_tree("servernameevent_data")?, - servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, - id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, - senderkey_pusher: builder.open_tree("senderkey_pusher")?, - global: builder.open_tree("global")?, - server_signingkeys: builder.open_tree("server_signingkeys")?, + tofrom_relation: builder.open_tree("tofrom_relation").await?, + referencedevents: builder.open_tree("referencedevents").await?, + roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata").await?, + roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid").await?, + mediaid_file: builder.open_tree("mediaid_file").await?, + backupid_algorithm: builder.open_tree("backupid_algorithm").await?, + backupid_etag: builder.open_tree("backupid_etag").await?, + backupkeyid_backup: builder.open_tree("backupkeyid_backup").await?, + userdevicetxnid_response: builder.open_tree("userdevicetxnid_response").await?, + servername_educount: builder.open_tree("servername_educount").await?, + servernameevent_data: builder.open_tree("servernameevent_data").await?, + servercurrentevent_data: builder.open_tree("servercurrentevent_data").await?, + id_appserviceregistrations: builder.open_tree("id_appserviceregistrations").await?, + senderkey_pusher: builder.open_tree("senderkey_pusher").await?, + global: builder.open_tree("global").await?, + server_signingkeys: builder.open_tree("server_signingkeys").await?, pdu_cache: Mutex::new(LruCache::new( config @@ -624,7 +625,7 @@ impl KeyValueDatabase { Ok::<_, Error>(()) }; - for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { + for (k, seventid) in db._db.open_tree("stateid_shorteventid").await?.iter() { let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]) .expect("number of bytes is correct"); let sstatekey = k[size_of::()..].to_vec(); @@ -808,7 +809,8 @@ impl KeyValueDatabase { if services().globals.database_version()? < 11 { db._db - .open_tree("userdevicesessionid_uiaarequest")? + .open_tree("userdevicesessionid_uiaarequest") + .await? .clear()?; services().globals.bump_database_version(11)?; @@ -998,10 +1000,10 @@ impl KeyValueDatabase { } #[tracing::instrument(skip(self))] - pub fn flush(&self) -> Result<()> { + pub async fn flush(&self) -> Result<()> { let start = std::time::Instant::now(); - let res = self._db.flush(); + let res = self._db.flush().await; debug!("flush: took {:?}", start.elapsed()); @@ -1094,7 +1096,7 @@ impl KeyValueDatabase { } let start = Instant::now(); - if let Err(e) = services().globals.cleanup() { + if let Err(e) = services().globals.cleanup().await { error!("cleanup: Errored: {}", e); } else { debug!("cleanup: Finished in {:?}", start.elapsed()); diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 70c63381..718219c4 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -522,7 +522,7 @@ impl Service { } AdminCommand::MemoryUsage => { let response1 = services().memory_usage().await; - let response2 = services().globals.db.memory_usage(); + let response2 = services().globals.db.memory_usage().await; RoomMessageEventContent::text_plain(format!( "Services:\n{response1}\n\nDatabase:\n{response2}" diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 167e823c..5fd84539 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -74,8 +74,8 @@ pub trait Data: Send + Sync { fn last_check_for_updates_id(&self) -> Result; fn update_check_for_updates_id(&self, id: u64) -> Result<()>; async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; - fn cleanup(&self) -> Result<()>; - fn memory_usage(&self) -> String; + async fn cleanup(&self) -> Result<()>; + async fn memory_usage(&self) -> String; fn clear_caches(&self, amount: u32); fn load_keypair(&self) -> Result; fn remove_keypair(&self) -> Result<()>; diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index c22ffef3..c0c0bd44 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -281,8 +281,8 @@ impl Service { self.db.watch(user_id, device_id).await } - pub fn cleanup(&self) -> Result<()> { - self.db.cleanup() + pub async fn cleanup(&self) -> Result<()> { + self.db.cleanup().await } pub fn server_name(&self) -> &ServerName {