diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 74f3a45a..29325bd6 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -26,7 +26,7 @@ pub mod persy; ))] pub mod watchers; -pub trait DatabaseEngine: Send + Sync { +pub trait KeyValueDatabaseEngine: Send + Sync { fn open(config: &Config) -> Result where Self: Sized; @@ -40,7 +40,7 @@ pub trait DatabaseEngine: Send + Sync { } } -pub trait Tree: Send + Sync { +pub trait KeyValueTree: Send + Sync { fn get(&self, key: &[u8]) -> Result>>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; diff --git a/src/database/key_value.rs b/src/database/key_value.rs new file mode 100644 index 00000000..8ae51eb8 --- /dev/null +++ b/src/database/key_value.rs @@ -0,0 +1,65 @@ +use crate::service; + +impl service::room::state::Data for KeyValueDatabase { + fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { + self.roomid_shortstatehash + .get(room_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") + })?)) + }) + } + + fn set_room_state(room_id: &RoomId, new_shortstatehash: u64 + _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.roomid_shortstatehash + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; + Ok(()) + } + + fn set_event_state() -> Result<()> { + db.shorteventid_shortstatehash + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + Ok(()) + } + + fn get_pdu_leaves(&self, room_id: &RoomId) -> Result>> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + self.roomid_pduleaves + .scan_prefix(prefix) + .map(|(_, bytes)| { + EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + }) + .collect() + } + + fn set_forward_extremities( + &self, + room_id: &RoomId, + event_ids: impl IntoIterator + Debug, + _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { + self.roomid_pduleaves.remove(&key)?; + } + + for event_id in event_ids { + let mut key = prefix.to_owned(); + key.extend_from_slice(event_id.as_bytes()); + self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; + } + + Ok(()) + } + +} diff --git a/src/database/mod.rs b/src/database/mod.rs index a0937c29..a35228aa 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -15,7 +15,7 @@ pub mod users; use self::admin::create_admin_room; use crate::{utils, Config, Error, Result}; -use abstraction::DatabaseEngine; +use abstraction::KeyValueDatabaseEngine; use directories::ProjectDirs; use futures_util::{stream::FuturesUnordered, StreamExt}; use lru_cache::LruCache; @@ -39,8 +39,8 @@ use std::{ use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; use tracing::{debug, error, info, warn}; -pub struct Database { - _db: Arc, +pub struct KeyValueDatabase { + _db: Arc, pub globals: globals::Globals, pub users: users::Users, pub uiaa: uiaa::Uiaa, @@ -55,7 +55,7 @@ pub struct Database { pub pusher: pusher::PushData, } -impl Database { +impl KeyValueDatabase { /// Tries to remove the old database but ignores all errors. pub fn try_remove(server_name: &str) -> Result<()> { let mut path = ProjectDirs::from("xyz", "koesters", "conduit") @@ -124,7 +124,7 @@ impl Database { .map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?; } - let builder: Arc = match &*config.database_backend { + let builder: Arc = match &*config.database_backend { "sqlite" => { #[cfg(not(feature = "sqlite"))] return Err(Error::BadConfig("Database backend not found.")); @@ -955,7 +955,7 @@ impl Database { } /// Sets the emergency password and push rules for the @conduit account in case emergency password is set -fn set_emergency_access(db: &Database) -> Result { +fn set_emergency_access(db: &KeyValueDatabase) -> Result { let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) .expect("@conduit:server_name is a valid UserId"); @@ -979,39 +979,3 @@ fn set_emergency_access(db: &Database) -> Result { res } - -pub struct DatabaseGuard(OwnedRwLockReadGuard); - -impl Deref for DatabaseGuard { - type Target = OwnedRwLockReadGuard; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[cfg(feature = "conduit_bin")] -#[axum::async_trait] -impl axum::extract::FromRequest for DatabaseGuard -where - B: Send, -{ - type Rejection = axum::extract::rejection::ExtensionRejection; - - async fn from_request( - req: &mut axum::extract::RequestParts, - ) -> Result { - use axum::extract::Extension; - - let Extension(db): Extension>> = - Extension::from_request(req).await?; - - Ok(DatabaseGuard(db.read_owned().await)) - } -} - -impl From> for DatabaseGuard { - fn from(val: OwnedRwLockReadGuard) -> Self { - Self(val) - } -} diff --git a/src/main.rs b/src/main.rs index 9a0928a0..a1af9761 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,27 +46,26 @@ use tikv_jemallocator::Jemalloc; #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; -#[tokio::main] -async fn main() { - let raw_config = - Figment::new() - .merge( - Toml::file(Env::var("CONDUIT_CONFIG").expect( - "The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml", - )) - .nested(), - ) - .merge(Env::prefixed("CONDUIT_").global()); +lazy_static! { + static ref DB: Database = { + let raw_config = + Figment::new() + .merge( + Toml::file(Env::var("CONDUIT_CONFIG").expect( + "The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml", + )) + .nested(), + ) + .merge(Env::prefixed("CONDUIT_").global()); - let config = match raw_config.extract::() { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e); - std::process::exit(1); - } - }; + let config = match raw_config.extract::() { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e); + std::process::exit(1); + } + }; - let start = async { config.warn_deprecated(); let db = match Database::load_or_create(&config).await { @@ -79,8 +78,15 @@ async fn main() { std::process::exit(1); } }; + }; +} - run_server(&config, db).await.unwrap(); +#[tokio::main] +async fn main() { + lazy_static::initialize(&DB); + + let start = async { + run_server(&config).await.unwrap(); }; if config.allow_jaeger { @@ -120,7 +126,8 @@ async fn main() { } } -async fn run_server(config: &Config, db: Arc>) -> io::Result<()> { +async fn run_server() -> io::Result<()> { + let config = DB.globals.config; let addr = SocketAddr::from((config.address, config.port)); let x_requested_with = HeaderName::from_static("x-requested-with"); diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 4b42ca8e..8aa76380 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,16 +1,24 @@ pub trait Data { + /// Returns the last state hash key added to the db for the given room. fn get_room_shortstatehash(room_id: &RoomId); + + /// Update the current state of the room. + fn set_room_state(room_id: &RoomId, new_shortstatehash: u64 + _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + ); + + /// Associates a state with an event. + fn set_event_state(shorteventid: u64, shortstatehash: u64) -> Result<()> { + + /// Returns all events we would send as the prev_events of the next event. + fn get_forward_extremities(room_id: &RoomId) -> Result>>; + + /// Replace the forward extremities of the room. + fn set_forward_extremities( + room_id: &RoomId, + event_ids: impl IntoIterator + Debug, + _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { } - /// Returns the last state hash key added to the db for the given room. - #[tracing::instrument(skip(self))] - pub fn current_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) - } - +pub struct StateLock; diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index da03ad4c..bf926078 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,8 +1,13 @@ +mod data; +pub use data::Data; + +use crate::service::*; + pub struct Service { db: D, } -impl Service { +impl Service<_> { /// Set the room to the given statehash and update caches. #[tracing::instrument(skip(self, new_state_ids_compressed, db))] pub fn force_state( @@ -15,11 +20,11 @@ impl Service { ) -> Result<()> { for event_id in statediffnew.into_iter().filter_map(|new| { - self.parse_compressed_state_event(new) + state_compressor::parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { - let pdu = match self.get_pdu_json(&event_id)? { + let pdu = match timeline::get_pdu_json(&event_id)? { Some(pdu) => pdu, None => continue, }; @@ -55,56 +60,12 @@ impl Service { Err(_) => continue, }; - self.update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?; + room::state_cache::update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?; } - self.update_joined_count(room_id, db)?; + room::state_cache::update_joined_count(room_id, db)?; - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - - Ok(()) - } - - /// Returns the leaf pdus of a room. - #[tracing::instrument(skip(self))] - pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } - - /// Replace the leaves of a room. - /// - /// The provided `event_ids` become the new leaves, this allows a room to have multiple - /// `prev_events`. - #[tracing::instrument(skip(self))] - pub fn replace_pdu_leaves<'a>( - &self, - room_id: &RoomId, - event_ids: impl IntoIterator + Debug, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } - - for event_id in event_ids { - let mut key = prefix.to_owned(); - key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; - } + db.set_room_state(room_id, new_shortstatehash); Ok(()) } @@ -121,11 +82,11 @@ impl Service { state_ids_compressed: HashSet, globals: &super::globals::Globals, ) -> Result<()> { - let shorteventid = self.get_or_create_shorteventid(event_id, globals)?; + let shorteventid = short::get_or_create_shorteventid(event_id, globals)?; - let previous_shortstatehash = self.current_shortstatehash(room_id)?; + let previous_shortstatehash = db.get_room_shortstatehash(room_id)?; - let state_hash = self.calculate_hash( + let state_hash = super::calculate_hash( &state_ids_compressed .iter() .map(|s| &s[..]) @@ -133,11 +94,11 @@ impl Service { ); let (shortstatehash, already_existed) = - self.get_or_create_shortstatehash(&state_hash, globals)?; + short::get_or_create_shortstatehash(&state_hash, globals)?; if !already_existed { let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + .map_or_else(|| Ok(Vec::new()), |p| room::state_compressor.load_shortstatehash_info(p))?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -156,7 +117,7 @@ impl Service { } else { (state_ids_compressed, HashSet::new()) }; - self.save_state_from_diff( + state_compressor::save_state_from_diff( shortstatehash, statediffnew, statediffremoved, @@ -165,8 +126,7 @@ impl Service { )?; } - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } @@ -183,7 +143,7 @@ impl Service { ) -> Result { let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; - let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; + let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; if let Some(p) = previous_shortstatehash { self.shorteventid_shortstatehash @@ -293,4 +253,8 @@ impl Service { Ok(()) } + + pub fn db(&self) -> D { + &self.db + } }