1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00
This commit is contained in:
Timo Kösters 2022-04-03 11:48:25 +02:00
parent 0066f20bdd
commit d4ccfa16dc
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
5 changed files with 129 additions and 57 deletions

22
Cargo.lock generated
View file

@ -96,6 +96,27 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "async-stream"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
dependencies = [
"async-stream-impl",
"futures-core",
]
[[package]]
name = "async-stream-impl"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.52" version = "0.1.52"
@ -389,6 +410,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
name = "conduit" name = "conduit"
version = "0.3.0-next" version = "0.3.0-next"
dependencies = [ dependencies = [
"async-stream",
"axum", "axum",
"axum-server", "axum-server",
"base64 0.13.0", "base64 0.13.0",

View file

@ -27,6 +27,7 @@ ruma = { git = "https://github.com/ruma/ruma", rev = "fa2e3662a456bd8957b3e1293c
# Async runtime and utilities # Async runtime and utilities
tokio = { version = "1.11.0", features = ["fs", "macros", "signal", "sync"] } tokio = { version = "1.11.0", features = ["fs", "macros", "signal", "sync"] }
async-stream = "0.3.2"
# Used for storing data permanently # Used for storing data permanently
sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true }
#sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] }
@ -76,7 +77,7 @@ crossbeam = { version = "0.8.1", optional = true }
num_cpus = "1.13.0" num_cpus = "1.13.0"
threadpool = "1.8.1" threadpool = "1.8.1"
heed = { git = "https://github.com/timokoesters/heed.git", rev = "f6f825da7fb2c758867e05ad973ef800a6fe1d5d", optional = true } heed = { git = "https://github.com/timokoesters/heed.git", rev = "f6f825da7fb2c758867e05ad973ef800a6fe1d5d", optional = true }
rocksdb = { version = "0.17.0", default-features = false, features = ["multi-threaded-cf", "zstd"], optional = true } rocksdb = { version = "0.17.0", default-features = true, features = ["multi-threaded-cf", "zstd"], optional = true }
thread_local = "1.1.3" thread_local = "1.1.3"
# used for TURN server authentication # used for TURN server authentication

View file

@ -36,8 +36,8 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O
db_opts.set_level_compaction_dynamic_level_bytes(true); db_opts.set_level_compaction_dynamic_level_bytes(true);
db_opts.set_target_file_size_base(256 * 1024 * 1024); db_opts.set_target_file_size_base(256 * 1024 * 1024);
//db_opts.set_compaction_readahead_size(2 * 1024 * 1024); //db_opts.set_compaction_readahead_size(2 * 1024 * 1024);
//db_opts.set_use_direct_reads(true); db_opts.set_use_direct_reads(true);
//db_opts.set_use_direct_io_for_flush_and_compaction(true); db_opts.set_use_direct_io_for_flush_and_compaction(true);
db_opts.create_if_missing(true); db_opts.create_if_missing(true);
db_opts.increase_parallelism(num_cpus::get() as i32); db_opts.increase_parallelism(num_cpus::get() as i32);
db_opts.set_max_open_files(max_open_files); db_opts.set_max_open_files(max_open_files);

View file

@ -1,6 +1,7 @@
mod edus; mod edus;
pub use edus::RoomEdus; pub use edus::RoomEdus;
use futures_util::Stream;
use crate::{ use crate::{
pdu::{EventHash, PduBuilder}, pdu::{EventHash, PduBuilder},
@ -39,6 +40,7 @@ use std::{
sync::{Arc, Mutex, RwLock}, sync::{Arc, Mutex, RwLock},
}; };
use tokio::sync::MutexGuard; use tokio::sync::MutexGuard;
use async_stream::try_stream;
use tracing::{error, warn}; use tracing::{error, warn};
use super::{abstraction::Tree, pusher}; use super::{abstraction::Tree, pusher};
@ -1083,6 +1085,38 @@ impl Rooms {
.transpose() .transpose()
} }
pub async fn get_pdu_async(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
return Ok(Some(Arc::clone(p)));
}
let eventid_pduid = Arc::clone(&self.eventid_pduid);
let event_id_bytes = event_id.as_bytes().to_vec();
if let Some(pdu) = tokio::task::spawn_blocking(move || { eventid_pduid .get(&event_id_bytes)}).await.unwrap()?
.map_or_else(
|| self.eventid_outlierpdu.get(event_id.as_bytes()),
|pduid| {
Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| {
Error::bad_database("Invalid pduid in eventid_pduid.")
})?))
},
)?
.map(|pdu| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))
.map(Arc::new)
})
.transpose()?
{
self.pdu_cache
.lock()
.unwrap()
.insert(event_id.to_owned(), Arc::clone(&pdu));
Ok(Some(pdu))
} else {
Ok(None)
}
}
/// Returns the pdu. /// Returns the pdu.
/// ///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline. /// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
@ -2109,7 +2143,7 @@ impl Rooms {
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
from: u64, from: u64,
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { ) -> Result<impl Stream<Item = Result<(Vec<u8>, PduEvent)>>> {
// Create the first part of the full pdu id // Create the first part of the full pdu id
let prefix = self let prefix = self
.get_shortroomid(room_id)? .get_shortroomid(room_id)?
@ -2124,18 +2158,23 @@ impl Rooms {
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
Ok(self let iter = self
.pduid_pdu .pduid_pdu
.iter_from(current, false) .iter_from(current, false);
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| { Ok(try_stream! {
while let Some((k, v)) = tokio::task::spawn_blocking(|| { iter.next() }).await.unwrap() {
if !k.starts_with(&prefix) {
return;
}
let mut pdu = serde_json::from_slice::<PduEvent>(&v) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?; .map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id { if pdu.sender != user_id {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
Ok((pdu_id, pdu)) yield (k, pdu)
})) }
})
} }
/// Replace a PDU with the redacted form. /// Replace a PDU with the redacted form.

View file

@ -49,6 +49,7 @@ use ruma::{
}, },
int, int,
receipt::ReceiptType, receipt::ReceiptType,
room_id,
serde::{Base64, JsonObject, Raw}, serde::{Base64, JsonObject, Raw},
signatures::{CanonicalJsonObject, CanonicalJsonValue}, signatures::{CanonicalJsonObject, CanonicalJsonValue},
state_res::{self, RoomVersion, StateMap}, state_res::{self, RoomVersion, StateMap},
@ -681,7 +682,7 @@ pub async fn send_transaction_message_route(
.roomid_mutex_federation .roomid_mutex_federation
.write() .write()
.unwrap() .unwrap()
.entry(room_id.clone()) .entry(room_id!("!somewhere:example.org").to_owned()) // only allow one room at a time
.or_default(), .or_default(),
); );
let mutex_lock = mutex.lock().await; let mutex_lock = mutex.lock().await;
@ -1141,7 +1142,7 @@ fn handle_outlier_pdu<'a>(
// Build map of auth events // Build map of auth events
let mut auth_events = HashMap::new(); let mut auth_events = HashMap::new();
for id in &incoming_pdu.auth_events { for id in &incoming_pdu.auth_events {
let auth_event = match db.rooms.get_pdu(id).map_err(|e| e.to_string())? { let auth_event = match db.rooms.get_pdu_async(id).await.map_err(|e| e.to_string())? {
Some(e) => e, Some(e) => e,
None => { None => {
warn!("Could not find auth event {}", id); warn!("Could not find auth event {}", id);
@ -1182,7 +1183,7 @@ fn handle_outlier_pdu<'a>(
&& incoming_pdu.prev_events == incoming_pdu.auth_events && incoming_pdu.prev_events == incoming_pdu.auth_events
{ {
db.rooms db.rooms
.get_pdu(&incoming_pdu.auth_events[0]) .get_pdu_async(&incoming_pdu.auth_events[0]).await
.map_err(|e| e.to_string())? .map_err(|e| e.to_string())?
.filter(|maybe_create| **maybe_create == *create_event) .filter(|maybe_create| **maybe_create == *create_event)
} else { } else {
@ -1265,10 +1266,13 @@ async fn upgrade_outlier_to_timeline_pdu(
if let Some(Ok(mut state)) = state { if let Some(Ok(mut state)) = state {
warn!("Using cached state"); warn!("Using cached state");
let prev_pdu = let prev_pdu = db
db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { .rooms
"Could not find prev event, but we know the state.".to_owned() .get_pdu_async(prev_event)
})?; .await
.ok()
.flatten()
.ok_or_else(|| "Could not find prev event, but we know the state.".to_owned())?;
if let Some(state_key) = &prev_pdu.state_key { if let Some(state_key) = &prev_pdu.state_key {
let shortstatekey = db let shortstatekey = db
@ -1288,7 +1292,7 @@ async fn upgrade_outlier_to_timeline_pdu(
let mut okay = true; let mut okay = true;
for prev_eventid in &incoming_pdu.prev_events { for prev_eventid in &incoming_pdu.prev_events {
let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu(prev_eventid) { let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu_async(prev_eventid).await {
pdu pdu
} else { } else {
okay = false; okay = false;
@ -1337,7 +1341,7 @@ async fn upgrade_outlier_to_timeline_pdu(
} }
auth_chain_sets.push( auth_chain_sets.push(
get_auth_chain(room_id, starting_events, db) get_auth_chain(room_id, starting_events, db).await
.map_err(|_| "Failed to load auth chain.".to_owned())? .map_err(|_| "Failed to load auth chain.".to_owned())?
.collect(), .collect(),
); );
@ -1350,7 +1354,7 @@ async fn upgrade_outlier_to_timeline_pdu(
&fork_states, &fork_states,
auth_chain_sets, auth_chain_sets,
|id| { |id| {
let res = db.rooms.get_pdu(id); let res = db.rooms.get_pdu_async(id).await;
if let Err(e) = &res { if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e); error!("LOOK AT ME Failed to fetch event: {}", e);
} }
@ -1462,28 +1466,33 @@ async fn upgrade_outlier_to_timeline_pdu(
&& incoming_pdu.prev_events == incoming_pdu.auth_events && incoming_pdu.prev_events == incoming_pdu.auth_events
{ {
db.rooms db.rooms
.get_pdu(&incoming_pdu.auth_events[0]) .get_pdu_async(&incoming_pdu.auth_events[0])
.await
.map_err(|e| e.to_string())? .map_err(|e| e.to_string())?
.filter(|maybe_create| **maybe_create == *create_event) .filter(|maybe_create| **maybe_create == *create_event)
} else { } else {
None None
}; };
let check_result = state_res::event_auth::auth_check( let check_result = tokio::task::spawn_blocking(move || {
&room_version, state_res::event_auth::auth_check(
&incoming_pdu, &room_version,
previous_create.as_ref(), &incoming_pdu,
None::<PduEvent>, // TODO: third party invite previous_create.as_ref(),
|k, s| { None::<PduEvent>, // TODO: third party invite
db.rooms |k, s| {
.get_shortstatekey(k, s) db.rooms
.ok() .get_shortstatekey(k, s)
.flatten() .ok()
.and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) .flatten()
.and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten()) .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey))
}, .and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten())
) },
.map_err(|_e| "Auth check failed.".to_owned())?; )
.map_err(|_e| "Auth check failed.".to_owned())
})
.await
.unwrap()?;
if !check_result { if !check_result {
return Err("Event has failed auth check with state at the event.".into()); return Err("Event has failed auth check with state at the event.".into());
@ -1591,7 +1600,8 @@ async fn upgrade_outlier_to_timeline_pdu(
for id in dbg!(&extremities) { for id in dbg!(&extremities) {
match db match db
.rooms .rooms
.get_pdu(id) .get_pdu_async(id)
.await
.map_err(|_| "Failed to ask db for pdu.".to_owned())? .map_err(|_| "Failed to ask db for pdu.".to_owned())?
{ {
Some(leaf_pdu) => { Some(leaf_pdu) => {
@ -1664,7 +1674,7 @@ async fn upgrade_outlier_to_timeline_pdu(
room_id, room_id,
state.iter().map(|(_, id)| id.clone()).collect(), state.iter().map(|(_, id)| id.clone()).collect(),
db, db,
) ).await
.map_err(|_| "Failed to load auth chain.".to_owned())? .map_err(|_| "Failed to load auth chain.".to_owned())?
.collect(), .collect(),
); );
@ -1685,20 +1695,20 @@ async fn upgrade_outlier_to_timeline_pdu(
}) })
.collect(); .collect();
let state = match state_res::resolve( let state = match tokio::task::spawn_blocking(move || {
room_version_id, state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
&fork_states,
auth_chain_sets,
|id| {
let res = db.rooms.get_pdu(id); let res = db.rooms.get_pdu(id);
if let Err(e) = &res { if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e); error!("LOOK AT ME Failed to fetch event: {}", e);
} }
res.ok().flatten() res.ok().flatten()
}, }).ok()
) { })
Ok(new_state) => new_state, .await
Err(_) => { .unwrap()
{
Some(new_state) => new_state,
None => {
return Err("State resolution failed, either an event could not be found or deserialization".into()); return Err("State resolution failed, either an event could not be found or deserialization".into());
} }
}; };
@ -1798,7 +1808,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>(
// a. Look in the main timeline (pduid_pdu tree) // a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree // b. Look at outlier pdu tree
// (get_pdu_json checks both) // (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = db.rooms.get_pdu(id) { if let Ok(Some(local_pdu)) = db.rooms.get_pdu_async(id).await {
trace!("Found {} in db", id); trace!("Found {} in db", id);
pdus.push((local_pdu, None)); pdus.push((local_pdu, None));
continue; continue;
@ -1815,7 +1825,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>(
continue; continue;
} }
if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { if let Ok(Some(_)) = db.rooms.get_pdu_async(&next_id).await {
trace!("Found {} in db", id); trace!("Found {} in db", id);
continue; continue;
} }
@ -2153,7 +2163,7 @@ fn append_incoming_pdu<'a>(
} }
#[tracing::instrument(skip(starting_events, db))] #[tracing::instrument(skip(starting_events, db))]
pub(crate) fn get_auth_chain<'a>( pub(crate) async fn get_auth_chain<'a>(
room_id: &RoomId, room_id: &RoomId,
starting_events: Vec<Arc<EventId>>, starting_events: Vec<Arc<EventId>>,
db: &'a Database, db: &'a Database,
@ -2194,7 +2204,7 @@ pub(crate) fn get_auth_chain<'a>(
chunk_cache.extend(cached.iter().copied()); chunk_cache.extend(cached.iter().copied());
} else { } else {
misses2 += 1; misses2 += 1;
let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db)?); let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db).await?);
db.rooms db.rooms
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
println!( println!(
@ -2230,7 +2240,7 @@ pub(crate) fn get_auth_chain<'a>(
} }
#[tracing::instrument(skip(event_id, db))] #[tracing::instrument(skip(event_id, db))]
fn get_auth_chain_inner( async fn get_auth_chain_inner(
room_id: &RoomId, room_id: &RoomId,
event_id: &EventId, event_id: &EventId,
db: &Database, db: &Database,
@ -2239,7 +2249,7 @@ fn get_auth_chain_inner(
let mut found = HashSet::new(); let mut found = HashSet::new();
while let Some(event_id) = todo.pop() { while let Some(event_id) = todo.pop() {
match db.rooms.get_pdu(&event_id) { match db.rooms.get_pdu_async(&event_id).await {
Ok(Some(pdu)) => { Ok(Some(pdu)) => {
if pdu.room_id != room_id { if pdu.room_id != room_id {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
@ -2423,7 +2433,7 @@ pub async fn get_event_authorization_route(
let room_id = <&RoomId>::try_from(room_id_str) let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db)?; let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db).await?;
Ok(get_event_authorization::v1::Response { Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
@ -2477,7 +2487,7 @@ pub async fn get_room_state_route(
}) })
.collect(); .collect();
let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?; let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?;
Ok(get_room_state::v1::Response { Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
@ -2532,7 +2542,7 @@ pub async fn get_room_state_ids_route(
.map(|(_, id)| (*id).to_owned()) .map(|(_, id)| (*id).to_owned())
.collect(); .collect();
let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?; let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?;
Ok(get_room_state_ids::v1::Response { Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(),
@ -2795,7 +2805,7 @@ async fn create_join_event(
room_id, room_id,
state_ids.iter().map(|(_, id)| id.clone()).collect(), state_ids.iter().map(|(_, id)| id.clone()).collect(),
db, db,
)?; ).await?;
let servers = db let servers = db
.rooms .rooms