1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00

Work on rooms/state and database

This commit is contained in:
Timo Kösters 2022-06-25 16:12:23 +02:00
parent 03b2867a84
commit 7c166aa468
No known key found for this signature in database
GPG key ID: 356E705610F626D5
6 changed files with 144 additions and 136 deletions

View file

@ -26,7 +26,7 @@ pub mod persy;
))] ))]
pub mod watchers; pub mod watchers;
pub trait DatabaseEngine: Send + Sync { pub trait KeyValueDatabaseEngine: Send + Sync {
fn open(config: &Config) -> Result<Self> fn open(config: &Config) -> Result<Self>
where where
Self: Sized; Self: Sized;
@ -40,7 +40,7 @@ pub trait DatabaseEngine: Send + Sync {
} }
} }
pub trait Tree: Send + Sync { pub trait KeyValueTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>; fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;

65
src/database/key_value.rs Normal file
View file

@ -0,0 +1,65 @@
use crate::service;
impl service::room::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
}
fn set_room_state(room_id: &RoomId, new_shortstatehash: u64
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
fn set_event_state() -> Result<()> {
db.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
}
fn get_pdu_leaves(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: impl IntoIterator<Item = &'a EventId> + Debug,
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.to_owned();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
}
Ok(())
}
}

View file

@ -15,7 +15,7 @@ pub mod users;
use self::admin::create_admin_room; use self::admin::create_admin_room;
use crate::{utils, Config, Error, Result}; use crate::{utils, Config, Error, Result};
use abstraction::DatabaseEngine; use abstraction::KeyValueDatabaseEngine;
use directories::ProjectDirs; use directories::ProjectDirs;
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache; use lru_cache::LruCache;
@ -39,8 +39,8 @@ use std::{
use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore};
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
pub struct Database { pub struct KeyValueDatabase {
_db: Arc<dyn DatabaseEngine>, _db: Arc<dyn KeyValueDatabaseEngine>,
pub globals: globals::Globals, pub globals: globals::Globals,
pub users: users::Users, pub users: users::Users,
pub uiaa: uiaa::Uiaa, pub uiaa: uiaa::Uiaa,
@ -55,7 +55,7 @@ pub struct Database {
pub pusher: pusher::PushData, pub pusher: pusher::PushData,
} }
impl Database { impl KeyValueDatabase {
/// Tries to remove the old database but ignores all errors. /// Tries to remove the old database but ignores all errors.
pub fn try_remove(server_name: &str) -> Result<()> { pub fn try_remove(server_name: &str) -> Result<()> {
let mut path = ProjectDirs::from("xyz", "koesters", "conduit") let mut path = ProjectDirs::from("xyz", "koesters", "conduit")
@ -124,7 +124,7 @@ impl Database {
.map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?; .map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?;
} }
let builder: Arc<dyn DatabaseEngine> = match &*config.database_backend { let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
"sqlite" => { "sqlite" => {
#[cfg(not(feature = "sqlite"))] #[cfg(not(feature = "sqlite"))]
return Err(Error::BadConfig("Database backend not found.")); return Err(Error::BadConfig("Database backend not found."));
@ -955,7 +955,7 @@ impl Database {
} }
/// Sets the emergency password and push rules for the @conduit account in case emergency password is set /// Sets the emergency password and push rules for the @conduit account in case emergency password is set
fn set_emergency_access(db: &Database) -> Result<bool> { fn set_emergency_access(db: &KeyValueDatabase) -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name())
.expect("@conduit:server_name is a valid UserId"); .expect("@conduit:server_name is a valid UserId");
@ -979,39 +979,3 @@ fn set_emergency_access(db: &Database) -> Result<bool> {
res res
} }
pub struct DatabaseGuard(OwnedRwLockReadGuard<Database>);
impl Deref for DatabaseGuard {
type Target = OwnedRwLockReadGuard<Database>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(feature = "conduit_bin")]
#[axum::async_trait]
impl<B> axum::extract::FromRequest<B> for DatabaseGuard
where
B: Send,
{
type Rejection = axum::extract::rejection::ExtensionRejection;
async fn from_request(
req: &mut axum::extract::RequestParts<B>,
) -> Result<Self, Self::Rejection> {
use axum::extract::Extension;
let Extension(db): Extension<Arc<TokioRwLock<Database>>> =
Extension::from_request(req).await?;
Ok(DatabaseGuard(db.read_owned().await))
}
}
impl From<OwnedRwLockReadGuard<Database>> for DatabaseGuard {
fn from(val: OwnedRwLockReadGuard<Database>) -> Self {
Self(val)
}
}

View file

@ -46,27 +46,26 @@ use tikv_jemallocator::Jemalloc;
#[global_allocator] #[global_allocator]
static GLOBAL: Jemalloc = Jemalloc; static GLOBAL: Jemalloc = Jemalloc;
#[tokio::main] lazy_static! {
async fn main() { static ref DB: Database = {
let raw_config = let raw_config =
Figment::new() Figment::new()
.merge( .merge(
Toml::file(Env::var("CONDUIT_CONFIG").expect( Toml::file(Env::var("CONDUIT_CONFIG").expect(
"The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml", "The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml",
)) ))
.nested(), .nested(),
) )
.merge(Env::prefixed("CONDUIT_").global()); .merge(Env::prefixed("CONDUIT_").global());
let config = match raw_config.extract::<Config>() { let config = match raw_config.extract::<Config>() {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e); eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e);
std::process::exit(1); std::process::exit(1);
} }
}; };
let start = async {
config.warn_deprecated(); config.warn_deprecated();
let db = match Database::load_or_create(&config).await { let db = match Database::load_or_create(&config).await {
@ -79,8 +78,15 @@ async fn main() {
std::process::exit(1); std::process::exit(1);
} }
}; };
};
}
run_server(&config, db).await.unwrap(); #[tokio::main]
async fn main() {
lazy_static::initialize(&DB);
let start = async {
run_server(&config).await.unwrap();
}; };
if config.allow_jaeger { if config.allow_jaeger {
@ -120,7 +126,8 @@ async fn main() {
} }
} }
async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<()> { async fn run_server() -> io::Result<()> {
let config = DB.globals.config;
let addr = SocketAddr::from((config.address, config.port)); let addr = SocketAddr::from((config.address, config.port));
let x_requested_with = HeaderName::from_static("x-requested-with"); let x_requested_with = HeaderName::from_static("x-requested-with");

View file

@ -1,16 +1,24 @@
pub trait Data { pub trait Data {
/// Returns the last state hash key added to the db for the given room.
fn get_room_shortstatehash(room_id: &RoomId); fn get_room_shortstatehash(room_id: &RoomId);
/// Update the current state of the room.
fn set_room_state(room_id: &RoomId, new_shortstatehash: u64
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex
);
/// Associates a state with an event.
fn set_event_state(shorteventid: u64, shortstatehash: u64) -> Result<()> {
/// Returns all events we would send as the prev_events of the next event.
fn get_forward_extremities(room_id: &RoomId) -> Result<HashSet<Arc<EventId>>>;
/// Replace the forward extremities of the room.
fn set_forward_extremities(
room_id: &RoomId,
event_ids: impl IntoIterator<Item = &'_ EventId> + Debug,
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
} }
/// Returns the last state hash key added to the db for the given room. pub struct StateLock;
#[tracing::instrument(skip(self))]
pub fn current_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
}

View file

@ -1,8 +1,13 @@
mod data;
pub use data::Data;
use crate::service::*;
pub struct Service<D: Data> { pub struct Service<D: Data> {
db: D, db: D,
} }
impl Service { impl Service<_> {
/// Set the room to the given statehash and update caches. /// Set the room to the given statehash and update caches.
#[tracing::instrument(skip(self, new_state_ids_compressed, db))] #[tracing::instrument(skip(self, new_state_ids_compressed, db))]
pub fn force_state( pub fn force_state(
@ -15,11 +20,11 @@ impl Service {
) -> Result<()> { ) -> Result<()> {
for event_id in statediffnew.into_iter().filter_map(|new| { for event_id in statediffnew.into_iter().filter_map(|new| {
self.parse_compressed_state_event(new) state_compressor::parse_compressed_state_event(new)
.ok() .ok()
.map(|(_, id)| id) .map(|(_, id)| id)
}) { }) {
let pdu = match self.get_pdu_json(&event_id)? { let pdu = match timeline::get_pdu_json(&event_id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
None => continue, None => continue,
}; };
@ -55,56 +60,12 @@ impl Service {
Err(_) => continue, Err(_) => continue,
}; };
self.update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?; room::state_cache::update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?;
} }
self.update_joined_count(room_id, db)?; room::state_cache::update_joined_count(room_id, db)?;
self.roomid_shortstatehash db.set_room_state(room_id, new_shortstatehash);
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
/// Returns the leaf pdus of a room.
#[tracing::instrument(skip(self))]
pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
/// Replace the leaves of a room.
///
/// The provided `event_ids` become the new leaves, this allows a room to have multiple
/// `prev_events`.
#[tracing::instrument(skip(self))]
pub fn replace_pdu_leaves<'a>(
&self,
room_id: &RoomId,
event_ids: impl IntoIterator<Item = &'a EventId> + Debug,
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.to_owned();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
}
Ok(()) Ok(())
} }
@ -121,11 +82,11 @@ impl Service {
state_ids_compressed: HashSet<CompressedStateEvent>, state_ids_compressed: HashSet<CompressedStateEvent>,
globals: &super::globals::Globals, globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let shorteventid = self.get_or_create_shorteventid(event_id, globals)?; let shorteventid = short::get_or_create_shorteventid(event_id, globals)?;
let previous_shortstatehash = self.current_shortstatehash(room_id)?; let previous_shortstatehash = db.get_room_shortstatehash(room_id)?;
let state_hash = self.calculate_hash( let state_hash = super::calculate_hash(
&state_ids_compressed &state_ids_compressed
.iter() .iter()
.map(|s| &s[..]) .map(|s| &s[..])
@ -133,11 +94,11 @@ impl Service {
); );
let (shortstatehash, already_existed) = let (shortstatehash, already_existed) =
self.get_or_create_shortstatehash(&state_hash, globals)?; short::get_or_create_shortstatehash(&state_hash, globals)?;
if !already_existed { if !already_existed {
let states_parents = previous_shortstatehash let states_parents = previous_shortstatehash
.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; .map_or_else(|| Ok(Vec::new()), |p| room::state_compressor.load_shortstatehash_info(p))?;
let (statediffnew, statediffremoved) = let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() { if let Some(parent_stateinfo) = states_parents.last() {
@ -156,7 +117,7 @@ impl Service {
} else { } else {
(state_ids_compressed, HashSet::new()) (state_ids_compressed, HashSet::new())
}; };
self.save_state_from_diff( state_compressor::save_state_from_diff(
shortstatehash, shortstatehash,
statediffnew, statediffnew,
statediffremoved, statediffremoved,
@ -165,8 +126,7 @@ impl Service {
)?; )?;
} }
self.shorteventid_shortstatehash db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(()) Ok(())
} }
@ -183,7 +143,7 @@ impl Service {
) -> Result<u64> { ) -> Result<u64> {
let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?;
let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?;
if let Some(p) = previous_shortstatehash { if let Some(p) = previous_shortstatehash {
self.shorteventid_shortstatehash self.shorteventid_shortstatehash
@ -293,4 +253,8 @@ impl Service {
Ok(()) Ok(())
} }
pub fn db(&self) -> D {
&self.db
}
} }