From 7658414fc4d21e7dda2fa0fb161d337f910eb40d Mon Sep 17 00:00:00 2001 From: chayleaf Date: Sat, 22 Jun 2024 21:22:43 +0700 Subject: [PATCH] KvTree: asyncify clear and increment --- src/api/client_server/account.rs | 44 +++--- src/api/client_server/alias.rs | 3 +- src/api/client_server/backup.rs | 51 ++++--- src/api/client_server/config.rs | 42 +++--- src/api/client_server/device.rs | 11 +- src/api/client_server/keys.rs | 45 +++--- src/api/client_server/membership.rs | 112 +++++++++------ src/api/client_server/push.rs | 60 ++++---- src/api/client_server/read_marker.rs | 92 +++++++----- src/api/client_server/relations.rs | 7 +- src/api/client_server/room.rs | 15 +- src/api/client_server/session.rs | 25 ++-- src/api/client_server/sync.rs | 5 +- src/api/client_server/tag.rs | 30 ++-- src/api/client_server/to_device.rs | 69 +++++---- src/api/server_server.rs | 128 ++++++++++------- src/database/abstraction.rs | 5 +- src/database/abstraction/persy.rs | 3 +- src/database/abstraction/rocksdb.rs | 3 +- src/database/abstraction/sqlite.rs | 5 +- src/database/key_value/account_data.rs | 6 +- src/database/key_value/globals.rs | 4 +- src/database/key_value/key_backups.rs | 16 ++- src/database/key_value/rooms/alias.rs | 11 +- src/database/key_value/rooms/edus/presence.rs | 6 +- .../key_value/rooms/edus/read_receipt.rs | 10 +- src/database/key_value/rooms/short.rs | 18 +-- src/database/key_value/rooms/state_cache.rs | 10 +- src/database/key_value/rooms/user.rs | 6 +- src/database/key_value/sending.rs | 6 +- src/database/key_value/users.rs | 45 +++--- src/database/mod.rs | 66 +++++---- src/service/account_data/data.rs | 4 +- src/service/account_data/mod.rs | 4 +- src/service/admin/mod.rs | 42 +++--- src/service/globals/data.rs | 2 +- src/service/globals/mod.rs | 4 +- src/service/key_backups/data.rs | 8 +- src/service/key_backups/mod.rs | 13 +- src/service/rooms/alias/data.rs | 9 +- src/service/rooms/alias/mod.rs | 9 +- src/service/rooms/auth_chain/mod.rs | 17 ++- src/service/rooms/edus/presence/data.rs | 4 +- src/service/rooms/edus/read_receipt/data.rs | 6 +- src/service/rooms/edus/read_receipt/mod.rs | 13 +- src/service/rooms/edus/typing/mod.rs | 6 +- src/service/rooms/event_handler/mod.rs | 135 ++++++++++-------- src/service/rooms/pdu_metadata/mod.rs | 34 ++--- src/service/rooms/short/data.rs | 10 +- src/service/rooms/short/mod.rs | 18 +-- src/service/rooms/state/mod.rs | 34 ++--- src/service/rooms/state_accessor/mod.rs | 1 + src/service/rooms/state_cache/data.rs | 6 +- src/service/rooms/state_cache/mod.rs | 26 ++-- src/service/rooms/state_compressor/mod.rs | 10 +- src/service/rooms/timeline/mod.rs | 78 ++++++---- src/service/rooms/user/data.rs | 4 +- src/service/rooms/user/mod.rs | 8 +- src/service/sending/data.rs | 4 +- src/service/sending/mod.rs | 38 +++-- src/service/users/data.rs | 22 +-- src/service/users/mod.rs | 80 ++++++----- 62 files changed, 958 insertions(+), 650 deletions(-) diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 36640b54..4e033d18 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -218,17 +218,20 @@ pub async fn register_route(body: Ruma) -> Result) -> Result bool>( } let json = serde_json::to_value(master_key).expect("to_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works"); - services().users.add_cross_signing_keys( - &user, &raw, &None, &None, - false, // Dont notify. A notification would trigger another key request resulting in an endless loop - )?; + services() + .users + .add_cross_signing_keys( + &user, &raw, &None, &None, + false, // Dont notify. A notification would trigger another key request resulting in an endless loop + ) + .await?; master_keys.insert(user, raw); } @@ -481,10 +490,10 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = - services() - .users - .take_one_time_key(user_id, device_id, key_algorithm)? + if let Some(one_time_keys) = services() + .users + .take_one_time_key(user_id, device_id, key_algorithm) + .await? { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 1ca711e2..9650da3a 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -703,7 +703,11 @@ async fn join_room_by_id_helper( } } - services().rooms.short.get_or_create_shortroomid(room_id)?; + services() + .rooms + .short + .get_or_create_shortroomid(room_id) + .await?; info!("Parsing join event"); let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) @@ -744,7 +748,8 @@ async fn join_room_by_id_helper( let shortstatekey = services() .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await?; state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -781,8 +786,8 @@ async fn join_room_by_id_helper( &services() .rooms .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, + .get_shortstatekey(&k.to_string().into(), s) + .ok()??, )?, ) .ok()? @@ -801,20 +806,23 @@ async fn join_room_by_id_helper( } info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| { + let (statehash_before_join, new, removed) = services() + .rooms + .state_compressor + .save_state(room_id, { + let mut new_state = HashSet::new(); + for (k, id) in state { + new_state.insert( services() .rooms .state_compressor .compress_state_event(k, &id) - }) - .collect::>()?, - ), - )?; + .await?, + ); + } + Arc::new(new_state) + }) + .await?; services() .rooms @@ -827,7 +835,11 @@ async fn join_room_by_id_helper( // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + let statehash_after_join = services() + .rooms + .state + .append_to_state(&parsed_join_pdu) + .await?; info!("Appending new room join event"); services() @@ -1253,18 +1265,22 @@ pub(crate) async fn invite_helper<'a>( }) .expect("member event is valid value"); - let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - sender_user, - room_id, - &state_lock, - )?; + let (pdu, pdu_json) = services() + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await?; let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; @@ -1335,7 +1351,7 @@ pub(crate) async fn invite_helper<'a>( .filter_map(|r| r.ok()) .filter(|server| &**server != services().globals.server_name()); - services().sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu(servers, &pdu_id).await?; } else { if !services() .rooms @@ -1442,14 +1458,18 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option { error!("Trying to leave a room you are not a member of."); - services().rooms.state_cache.update_membership( - room_id, - user_id, - MembershipState::Leave, - user_id, - None, - true, - )?; + services() + .rooms + .state_cache + .update_membership( + room_id, + user_id, + MembershipState::Leave, + user_id, + None, + true, + ) + .await?; return Ok(()); } Some(e) => e, diff --git a/src/api/client_server/push.rs b/src/api/client_server/push.rs index 72768662..41e5fd36 100644 --- a/src/api/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -143,12 +143,15 @@ pub async fn set_pushrule_route( return Err(err); } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule::v3::Response {}) } @@ -238,12 +241,15 @@ pub async fn set_pushrule_actions_route( )); } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_actions::v3::Response {}) } @@ -332,12 +338,15 @@ pub async fn set_pushrule_enabled_route( )); } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_enabled::v3::Response {}) } @@ -391,12 +400,15 @@ pub async fn delete_pushrule_route( return Err(err); } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(delete_pushrule::v3::Response {}) } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index a5553d25..9624dd8a 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -26,19 +26,23 @@ pub async fn set_read_marker_route( event_id: fully_read.clone(), }, }; - services().account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services() + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; } if body.private_read_receipt.is_some() || body.read_receipt.is_some() { services() .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id) + .await?; } if let Some(event) = &body.private_read_receipt { @@ -63,7 +67,8 @@ pub async fn set_read_marker_route( .rooms .edus .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count) + .await?; } if let Some(event) = &body.read_receipt { @@ -82,14 +87,19 @@ pub async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - services().rooms.edus.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services() + .rooms + .edus + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await?; } Ok(set_read_marker::v3::Response {}) @@ -110,7 +120,8 @@ pub async fn create_receipt_route( services() .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id) + .await?; } match body.receipt_type { @@ -120,12 +131,15 @@ pub async fn create_receipt_route( event_id: body.event_id.clone(), }, }; - services().account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services() + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; } create_receipt::v3::ReceiptType::Read => { let mut user_receipts = BTreeMap::new(); @@ -142,14 +156,19 @@ pub async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.to_owned(), receipts); - services().rooms.edus.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services() + .rooms + .edus + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await?; } create_receipt::v3::ReceiptType::ReadPrivate => { let count = services() @@ -169,11 +188,12 @@ pub async fn create_receipt_route( } PduCount::Normal(c) => c, }; - services().rooms.edus.read_receipt.private_read_set( - &body.room_id, - sender_user, - count, - )?; + services() + .rooms + .edus + .read_receipt + .private_read_set(&body.room_id, sender_user, count) + .await?; } _ => return Err(Error::bad_database("Unsupported receipt type")), } diff --git a/src/api/client_server/relations.rs b/src/api/client_server/relations.rs index 27c00729..6c4dd098 100644 --- a/src/api/client_server/relations.rs +++ b/src/api/client_server/relations.rs @@ -25,7 +25,8 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route( body.limit, body.recurse, &body.dir, - )?; + ) + .await?; Ok( get_relating_events_with_rel_type_and_event_type::v1::Response { @@ -57,7 +58,8 @@ pub async fn get_relating_events_with_rel_type_route( body.limit, body.recurse, &body.dir, - )?; + ) + .await?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, @@ -88,4 +90,5 @@ pub async fn get_relating_events_route( body.recurse, &body.dir, ) + .await } diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index 890ff9cb..cd5ea8bb 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -54,7 +54,11 @@ pub async fn create_room_route( let room_id = RoomId::new(services().globals.server_name()); - services().rooms.short.get_or_create_shortroomid(&room_id)?; + services() + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await?; let mutex_state = Arc::clone( services() @@ -488,7 +492,8 @@ pub async fn create_room_route( services() .rooms .alias - .set_alias(&alias, &room_id, sender_user)?; + .set_alias(&alias, &room_id, sender_user) + .await?; } if body.visibility == room::Visibility::Public { @@ -600,7 +605,8 @@ pub async fn upgrade_room_route( services() .rooms .short - .get_or_create_shortroomid(&replacement_room)?; + .get_or_create_shortroomid(&replacement_room) + .await?; let mutex_state = Arc::clone( services() @@ -818,7 +824,8 @@ pub async fn upgrade_room_route( services() .rooms .alias - .set_alias(&alias, &replacement_room, sender_user)?; + .set_alias(&alias, &replacement_room, sender_user) + .await?; } // Get the old room power levels diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 07078328..39fc7e6a 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -192,12 +192,15 @@ pub async fn login_route(body: Ruma) -> Result) -> Result { - services().users.add_to_device_event( - sender_user, - target_user_id, - target_device_id, - &body.event_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") - })?, - )? - } - - DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services().users.all_device_ids(target_user_id) { - services().users.add_to_device_event( + services() + .users + .add_to_device_event( sender_user, target_user_id, - &target_device_id?, + target_device_id, &body.event_type.to_string(), event.deserialize_as().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") })?, - )?; + ) + .await? + } + + DeviceIdOrAllDevices::AllDevices => { + for target_device_id in services().users.all_device_ids(target_user_id) { + services() + .users + .add_to_device_event( + sender_user, + target_user_id, + &target_device_id?, + &body.event_type.to_string(), + event.deserialize_as().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") + })?, + ) + .await?; } } } diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 605a4672..5038de2e 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -814,7 +814,8 @@ pub async fn send_transaction_message_route( .rooms .edus .read_receipt - .readreceipt_update(&user_id, &room_id, event)?; + .readreceipt_update(&user_id, &room_id, event) + .await?; } else { // TODO fetch missing events debug!("No known event ids in read receipt: {:?}", user_updates); @@ -853,7 +854,7 @@ pub async fn send_transaction_message_route( } Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { if user_id.server_name() == sender_servername { - services().users.mark_device_key_update(&user_id)?; + services().users.mark_device_key_update(&user_id).await?; } } Edu::DirectToDevice(DirectDeviceContent { @@ -873,37 +874,43 @@ pub async fn send_transaction_message_route( for (target_device_id_maybe, event) in map { match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services().users.add_to_device_event( - &sender, - target_user_id, - target_device_id, - &ev_type.to_string(), - event.deserialize_as().map_err(|e| { - warn!("To-Device event is invalid: {event:?} {e}"); - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, - )? + services() + .users + .add_to_device_event( + &sender, + target_user_id, + target_device_id, + &ev_type.to_string(), + event.deserialize_as().map_err(|e| { + warn!("To-Device event is invalid: {event:?} {e}"); + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) + })?, + ) + .await? } DeviceIdOrAllDevices::AllDevices => { for target_device_id in services().users.all_device_ids(target_user_id) { - services().users.add_to_device_event( - &sender, - target_user_id, - &target_device_id?, - &ev_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, - )?; + services() + .users + .add_to_device_event( + &sender, + target_user_id, + &target_device_id?, + &ev_type.to_string(), + event.deserialize_as().map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) + })?, + ) + .await?; } } } @@ -923,13 +930,16 @@ pub async fn send_transaction_message_route( }) => { if user_id.server_name() == sender_servername { if let Some(master_key) = master_key { - services().users.add_cross_signing_keys( - &user_id, - &master_key, - &self_signing_key, - &None, - true, - )?; + services() + .users + .add_cross_signing_keys( + &user_id, + &master_key, + &self_signing_key, + &None, + true, + ) + .await?; } } } @@ -1438,18 +1448,22 @@ pub async fn create_join_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services() + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); @@ -1581,7 +1595,7 @@ async fn create_join_event( .filter_map(|r| r.ok()) .filter(|server| &**server != services().globals.server_name()); - services().sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu(servers, &pdu_id).await?; Ok(create_join_event::v1::RoomState { auth_chain: auth_chain_ids @@ -1738,14 +1752,18 @@ pub async fn create_invite_route( .state_cache .server_in_room(services().globals.server_name(), &body.room_id)? { - services().rooms.state_cache.update_membership( - &body.room_id, - &invited_user, - MembershipState::Invite, - &sender, - Some(invite_state), - true, - )?; + services() + .rooms + .state_cache + .update_membership( + &body.room_id, + &invited_user, + MembershipState::Invite, + &sender, + Some(invite_state), + true, + ) + .await?; } Ok(create_invite::v2::Response { diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index bc9f09dc..a4336928 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -42,6 +42,7 @@ pub trait KeyValueDatabaseEngine: Send + Sync { } } +#[async_trait] pub trait KvTree: Send + Sync { fn get(&self, key: &[u8]) -> Result>>; @@ -58,7 +59,7 @@ pub trait KvTree: Send + Sync { backwards: bool, ) -> Box, Vec)> + 'a>; - fn increment(&self, key: &[u8]) -> Result>; + async fn increment(&self, key: &[u8]) -> Result>; fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()>; fn scan_prefix<'a>( @@ -68,7 +69,7 @@ pub trait KvTree: Send + Sync { fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; - fn clear(&self) -> Result<()> { + async fn clear(&self) -> Result<()> { for (key, _) in self.iter() { self.remove(&key)?; } diff --git a/src/database/abstraction/persy.rs b/src/database/abstraction/persy.rs index 5c146eb0..4eaaa8e8 100644 --- a/src/database/abstraction/persy.rs +++ b/src/database/abstraction/persy.rs @@ -63,6 +63,7 @@ impl PersyTree { } } +#[async_trait] impl KvTree for PersyTree { fn get(&self, key: &[u8]) -> Result>> { let result = self @@ -160,7 +161,7 @@ impl KvTree for PersyTree { } } - fn increment(&self, key: &[u8]) -> Result> { + async fn increment(&self, key: &[u8]) -> Result> { self.increment_batch(&mut Some(key.to_owned()).into_iter())?; Ok(self.get(key)?.unwrap()) } diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 72af45ed..3013ebe8 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -136,6 +136,7 @@ impl RocksDbEngineTree<'_> { } } +#[async_trait] impl KvTree for RocksDbEngineTree<'_> { fn get(&self, key: &[u8]) -> Result>> { let readoptions = rocksdb::ReadOptions::default(); @@ -214,7 +215,7 @@ impl KvTree for RocksDbEngineTree<'_> { ) } - fn increment(&self, key: &[u8]) -> Result> { + async fn increment(&self, key: &[u8]) -> Result> { let readoptions = rocksdb::ReadOptions::default(); let writeoptions = rocksdb::WriteOptions::default(); diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 12ed9361..f0990f1b 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -166,6 +166,7 @@ impl SqliteTable { } } +#[async_trait] impl KvTree for SqliteTable { fn get(&self, key: &[u8]) -> Result>> { self.get_with_guard(self.engine.read_lock(), key) @@ -268,7 +269,7 @@ impl KvTree for SqliteTable { Box::new(rx.into_iter()) } - fn increment(&self, key: &[u8]) -> Result> { + async fn increment(&self, key: &[u8]) -> Result> { let guard = self.engine.write_lock(); let old = self.get_with_guard(&guard, key)?; @@ -295,7 +296,7 @@ impl KvTree for SqliteTable { self.watchers.watch(prefix) } - fn clear(&self) -> Result<()> { + async fn clear(&self) -> Result<()> { debug!("clear: running"); self.engine .write_lock() diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index 970b36b5..53af9e86 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use async_trait::async_trait; use ruma::{ api::client::error::ErrorKind, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, @@ -9,10 +10,11 @@ use ruma::{ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::account_data::Data for KeyValueDatabase { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - fn update( + async fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, @@ -29,7 +31,7 @@ impl service::account_data::Data for KeyValueDatabase { prefix.push(0xff); let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + roomuserdataid.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()); roomuserdataid.push(0xff); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 968d6420..1c0c57f9 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -20,8 +20,8 @@ pub const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; #[async_trait] impl service::globals::Data for KeyValueDatabase { - fn next_count(&self) -> Result { - utils::u64_from_bytes(&self.global.increment(COUNTER)?) + async fn next_count(&self) -> Result { + utils::u64_from_bytes(&self.global.increment(COUNTER).await?) .map_err(|_| Error::bad_database("Count has invalid bytes.")) } diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index 900b700b..43ded671 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use async_trait::async_trait; use ruma::{ api::client::{ backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, @@ -11,13 +12,14 @@ use ruma::{ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::key_backups::Data for KeyValueDatabase { - fn create_backup( + async fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw, ) -> Result { - let version = services().globals.next_count()?.to_string(); + let version = services().globals.next_count().await?.to_string(); let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -28,7 +30,7 @@ impl service::key_backups::Data for KeyValueDatabase { &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), )?; self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count().await?.to_be_bytes())?; Ok(version) } @@ -49,7 +51,7 @@ impl service::key_backups::Data for KeyValueDatabase { Ok(()) } - fn update_backup( + async fn update_backup( &self, user_id: &UserId, version: &str, @@ -69,7 +71,7 @@ impl service::key_backups::Data for KeyValueDatabase { self.backupid_algorithm .insert(&key, backup_metadata.json().get().as_bytes())?; self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count().await?.to_be_bytes())?; Ok(version.to_owned()) } @@ -138,7 +140,7 @@ impl service::key_backups::Data for KeyValueDatabase { }) } - fn add_key( + async fn add_key( &self, user_id: &UserId, version: &str, @@ -158,7 +160,7 @@ impl service::key_backups::Data for KeyValueDatabase { } self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count().await?.to_be_bytes())?; key.push(0xff); key.extend_from_slice(room_id.as_bytes()); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 1a27fbac..66f84441 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use ruma::{ api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId, @@ -5,8 +6,14 @@ use ruma::{ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::alias::Data for KeyValueDatabase { - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { + async fn set_alias( + &self, + alias: &RoomAliasId, + room_id: &RoomId, + user_id: &UserId, + ) -> Result<()> { // Comes first as we don't want a stuck alias self.alias_userid .insert(alias.alias().as_bytes(), user_id.as_bytes())?; @@ -14,7 +21,7 @@ impl service::rooms::alias::Data for KeyValueDatabase { .insert(alias.alias().as_bytes(), room_id.as_bytes())?; let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xff); - aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + aliasid.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()); self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; Ok(()) } diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 904b1c44..ba29be8f 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -1,13 +1,15 @@ use std::collections::HashMap; +use async_trait::async_trait; use ruma::{ events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId, }; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::edus::presence::Data for KeyValueDatabase { - fn update_presence( + async fn update_presence( &self, user_id: &UserId, room_id: &RoomId, @@ -15,7 +17,7 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase { ) -> Result<()> { // TODO: Remove old entry? Or maybe just wipe completely from time to time? - let count = services().globals.next_count()?.to_be_bytes(); + let count = services().globals.next_count().await?.to_be_bytes(); let mut presence_id = room_id.as_bytes().to_vec(); presence_id.push(0xff); diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index fb7c9e99..f3f0fa27 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,13 +1,15 @@ use std::mem; +use async_trait::async_trait; use ruma::{ events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, }; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { - fn readreceipt_update( + async fn readreceipt_update( &self, user_id: &UserId, room_id: &RoomId, @@ -36,7 +38,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { } let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + room_latest_id.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()); room_latest_id.push(0xff); room_latest_id.extend_from_slice(user_id.as_bytes()); @@ -106,7 +108,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { ) } - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + async fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { let mut key = room_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(user_id.as_bytes()); @@ -115,7 +117,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { .insert(&key, &count.to_be_bytes())?; self.roomuserid_lastprivatereadupdate - .insert(&key, &services().globals.next_count()?.to_be_bytes()) + .insert(&key, &services().globals.next_count().await?.to_be_bytes()) } fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index 98cfa48a..8d79a339 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -1,11 +1,13 @@ use std::sync::Arc; +use async_trait::async_trait; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::short::Data for KeyValueDatabase { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { + async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { return Ok(*short); } @@ -14,7 +16,7 @@ impl service::rooms::short::Data for KeyValueDatabase { Some(shorteventid) => utils::u64_from_bytes(&shorteventid) .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, None => { - let shorteventid = services().globals.next_count()?; + let shorteventid = services().globals.next_count().await?; self.eventid_shorteventid .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; self.shorteventid_eventid @@ -68,7 +70,7 @@ impl service::rooms::short::Data for KeyValueDatabase { Ok(short) } - fn get_or_create_shortstatekey( + async fn get_or_create_shortstatekey( &self, event_type: &StateEventType, state_key: &str, @@ -90,7 +92,7 @@ impl service::rooms::short::Data for KeyValueDatabase { Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, None => { - let shortstatekey = services().globals.next_count()?; + let shortstatekey = services().globals.next_count().await?; self.statekey_shortstatekey .insert(&statekey, &shortstatekey.to_be_bytes())?; self.shortstatekey_statekey @@ -176,7 +178,7 @@ impl service::rooms::short::Data for KeyValueDatabase { } /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { Ok(match self.statehash_shortstatehash.get(state_hash)? { Some(shortstatehash) => ( utils::u64_from_bytes(&shortstatehash) @@ -184,7 +186,7 @@ impl service::rooms::short::Data for KeyValueDatabase { true, ), None => { - let shortstatehash = services().globals.next_count()?; + let shortstatehash = services().globals.next_count().await?; self.statehash_shortstatehash .insert(state_hash, &shortstatehash.to_be_bytes())?; (shortstatehash, false) @@ -202,12 +204,12 @@ impl service::rooms::short::Data for KeyValueDatabase { .transpose() } - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { Some(short) => utils::u64_from_bytes(&short) .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, None => { - let short = services().globals.next_count()?; + let short = services().globals.next_count().await?; self.roomid_shortroomid .insert(room_id.as_bytes(), &short.to_be_bytes())?; short diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 29be2608..7cb281f6 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,5 +1,6 @@ use std::{collections::HashSet, sync::Arc}; +use async_trait::async_trait; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, @@ -12,6 +13,7 @@ use crate::{ services, utils, Error, Result, }; +#[async_trait] impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); @@ -39,7 +41,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Ok(()) } - fn mark_as_invited( + async fn mark_as_invited( &self, user_id: &UserId, room_id: &RoomId, @@ -60,7 +62,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { )?; self.roomuserid_invitecount.insert( &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; @@ -70,7 +72,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Ok(()) } - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + async fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut roomuser_id = room_id.as_bytes().to_vec(); roomuser_id.push(0xff); roomuser_id.extend_from_slice(user_id.as_bytes()); @@ -85,7 +87,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { )?; // TODO self.roomuserid_leftcount.insert( &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 3d2d4a8f..4ef93a43 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,9 +1,11 @@ +use async_trait::async_trait; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +#[async_trait] impl service::rooms::user::Data for KeyValueDatabase { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + async fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -18,7 +20,7 @@ impl service::rooms::user::Data for KeyValueDatabase { self.roomuserid_lastnotificationread.insert( &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; Ok(()) diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs index 58380a05..8a0fa2f4 100644 --- a/src/database/key_value/sending.rs +++ b/src/database/key_value/sending.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use ruma::{ServerName, UserId}; use crate::{ @@ -9,6 +10,7 @@ use crate::{ services, utils, Error, Result, }; +#[async_trait] impl service::sending::Data for KeyValueDatabase { fn active_requests<'a>( &'a self, @@ -59,7 +61,7 @@ impl service::sending::Data for KeyValueDatabase { Ok(()) } - fn queue_requests( + async fn queue_requests( &self, requests: &[(&OutgoingKind, SendingEventType)], ) -> Result>> { @@ -70,7 +72,7 @@ impl service::sending::Data for KeyValueDatabase { if let SendingEventType::Pdu(value) = &event { key.extend_from_slice(value) } else { - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()) + key.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()) } let value = if let SendingEventType::Edu(value) = &event { &**value diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 7fa24a12..6e872d65 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,5 +1,6 @@ use std::{collections::BTreeMap, mem::size_of}; +use async_trait::async_trait; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, @@ -17,6 +18,7 @@ use crate::{ services, utils, Error, Result, }; +#[async_trait] impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result { @@ -192,7 +194,7 @@ impl service::users::Data for KeyValueDatabase { } /// Adds a new device to a user. - fn create_device( + async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -207,7 +209,8 @@ impl service::users::Data for KeyValueDatabase { userdeviceid.extend_from_slice(device_id.as_bytes()); self.userid_devicelistversion - .increment(user_id.as_bytes())?; + .increment(user_id.as_bytes()) + .await?; self.userdeviceid_metadata.insert( &userdeviceid, @@ -226,7 +229,7 @@ impl service::users::Data for KeyValueDatabase { } /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -248,7 +251,8 @@ impl service::users::Data for KeyValueDatabase { // TODO: Remove onetimekeys self.userid_devicelistversion - .increment(user_id.as_bytes())?; + .increment(user_id.as_bytes()) + .await?; self.userdeviceid_metadata.remove(&userdeviceid)?; @@ -304,7 +308,7 @@ impl service::users::Data for KeyValueDatabase { Ok(()) } - fn add_one_time_key( + async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -335,7 +339,7 @@ impl service::users::Data for KeyValueDatabase { self.userid_lastonetimekeyupdate.insert( user_id.as_bytes(), - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; Ok(()) @@ -352,7 +356,7 @@ impl service::users::Data for KeyValueDatabase { .unwrap_or(Ok(0)) } - fn take_one_time_key( + async fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -368,7 +372,7 @@ impl service::users::Data for KeyValueDatabase { self.userid_lastonetimekeyupdate.insert( user_id.as_bytes(), - &services().globals.next_count()?.to_be_bytes(), + &services().globals.next_count().await?.to_be_bytes(), )?; self.onetimekeyid_onetimekeys @@ -423,7 +427,7 @@ impl service::users::Data for KeyValueDatabase { Ok(counts) } - fn add_device_keys( + async fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, @@ -438,12 +442,12 @@ impl service::users::Data for KeyValueDatabase { &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), )?; - self.mark_device_key_update(user_id)?; + self.mark_device_key_update(user_id).await?; Ok(()) } - fn add_cross_signing_keys( + async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, @@ -532,13 +536,13 @@ impl service::users::Data for KeyValueDatabase { } if notify { - self.mark_device_key_update(user_id)?; + self.mark_device_key_update(user_id).await?; } Ok(()) } - fn sign_key( + async fn sign_key( &self, target_id: &UserId, key_id: &str, @@ -574,7 +578,7 @@ impl service::users::Data for KeyValueDatabase { &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), )?; - self.mark_device_key_update(target_id)?; + self.mark_device_key_update(target_id).await?; Ok(()) } @@ -623,8 +627,8 @@ impl service::users::Data for KeyValueDatabase { ) } - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = services().globals.next_count()?.to_be_bytes(); + async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + let count = services().globals.next_count().await?.to_be_bytes(); for room_id in services() .rooms .state_cache @@ -761,7 +765,7 @@ impl service::users::Data for KeyValueDatabase { }) } - fn add_to_device_event( + async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, @@ -773,7 +777,7 @@ impl service::users::Data for KeyValueDatabase { key.push(0xff); key.extend_from_slice(target_device_id.as_bytes()); key.push(0xff); - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&services().globals.next_count().await?.to_be_bytes()); let mut json = serde_json::Map::new(); json.insert("type".to_owned(), event_type.to_owned().into()); @@ -843,7 +847,7 @@ impl service::users::Data for KeyValueDatabase { Ok(()) } - fn update_device_metadata( + async fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, @@ -857,7 +861,8 @@ impl service::users::Data for KeyValueDatabase { assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); self.userid_devicelistversion - .increment(user_id.as_bytes())?; + .increment(user_id.as_bytes()) + .await?; self.userdeviceid_metadata.insert( &userdeviceid, diff --git a/src/database/mod.rs b/src/database/mod.rs index 1d549f19..a50d4920 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -680,7 +680,7 @@ impl KeyValueDatabase { if services().globals.database_version()? < 8 { // Generate short room ids for all rooms for (room_id, _) in db.roomid_shortstatehash.iter() { - let shortroomid = services().globals.next_count()?.to_be_bytes(); + let shortroomid = services().globals.next_count().await?.to_be_bytes(); db.roomid_shortroomid.insert(&room_id, &shortroomid)?; info!("Migration: 8"); } @@ -799,7 +799,7 @@ impl KeyValueDatabase { // Force E2EE device list updates so we can send them over federation for user_id in services().users.iter().filter_map(|r| r.ok()) { - services().users.mark_device_key_update(&user_id)?; + services().users.mark_device_key_update(&user_id).await?; } services().globals.bump_database_version(10)?; @@ -811,7 +811,8 @@ impl KeyValueDatabase { db._db .open_tree("userdevicesessionid_uiaarequest") .await? - .clear()?; + .clear() + .await?; services().globals.bump_database_version(11)?; warn!("Migration: 10 -> 11 finished"); @@ -884,12 +885,16 @@ impl KeyValueDatabase { } } - services().account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data) + .expect("to json value always works"), + ) + .await?; } services().globals.bump_database_version(12)?; @@ -930,12 +935,16 @@ impl KeyValueDatabase { .global .update_with_server_default(user_default_rules); - services().account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data) + .expect("to json value always works"), + ) + .await?; } services().globals.bump_database_version(13)?; @@ -969,12 +978,12 @@ impl KeyValueDatabase { } // This data is probably outdated - db.presenceid_presence.clear()?; + db.presenceid_presence.clear().await?; services().admin.start_handler(); // Set emergency access for the conduit user - match set_emergency_access() { + match set_emergency_access().await { Ok(pwd_set) => { if pwd_set { warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"); @@ -1107,7 +1116,7 @@ impl KeyValueDatabase { } /// Sets the emergency password and push rules for the @conduit account in case emergency password is set -fn set_emergency_access() -> Result { +async fn set_emergency_access() -> Result { let conduit_user = services().globals.server_user(); services().users.set_password( @@ -1120,15 +1129,18 @@ fn set_emergency_access() -> Result { None => (Ruleset::new(), Ok(false)), }; - services().account_data.update( - None, - conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { global: ruleset }, - }) - .expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { global: ruleset }, + }) + .expect("to json value always works"), + ) + .await?; res } diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index c7c92981..1b08a820 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,15 +1,17 @@ use std::collections::HashMap; use crate::Result; +use async_trait::async_trait; use ruma::{ events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; +#[async_trait] pub trait Data: Send + Sync { /// Places one event in the account data of the user and removes the previous entry. - fn update( + async fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index f9c49b1a..ee2f7498 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -19,14 +19,14 @@ pub struct Service { impl Service { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - pub fn update( + pub async fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, ) -> Result<()> { - self.db.update(room_id, user_id, event_type, data) + self.db.update(room_id, user_id, event_type, data).await } /// Searches the account data for a specific kind. diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 718219c4..9b4b3cfd 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -639,19 +639,22 @@ impl Service { .set_displayname(&user_id, Some(displayname))?; // Initial account data - services().account_data.update( - None, - &user_id, - ruma::events::GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: ruma::push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json value always works"), - )?; + services() + .account_data + .update( + None, + &user_id, + ruma::events::GlobalAccountDataEventType::PushRules + .to_string() + .into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: ruma::push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json value always works"), + ) + .await?; // we dont add a device since we're not the user, just the creator @@ -704,7 +707,7 @@ impl Service { "Making {user_id} leave all rooms before deactivation..." )); - services().users.deactivate_account(&user_id)?; + services().users.deactivate_account(&user_id).await?; if leave_rooms { leave_all_rooms(&user_id).await?; @@ -800,7 +803,7 @@ impl Service { } for &user_id in &user_ids { - if services().users.deactivate_account(user_id).is_ok() { + if services().users.deactivate_account(user_id).await.is_ok() { deactivation_count += 1 } } @@ -1057,7 +1060,11 @@ impl Service { pub(crate) async fn create_admin_room(&self) -> Result<()> { let room_id = RoomId::new(services().globals.server_name()); - services().rooms.short.get_or_create_shortroomid(&room_id)?; + services() + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await?; let mutex_state = Arc::clone( services() @@ -1293,7 +1300,8 @@ impl Service { services() .rooms .alias - .set_alias(&alias, &room_id, conduit_user)?; + .set_alias(&alias, &room_id, conduit_user) + .await?; Ok(()) } diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5fd84539..eb0b53fb 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -69,7 +69,7 @@ impl From for SigningKeys { #[async_trait] pub trait Data: Send + Sync { - fn next_count(&self) -> Result; + async fn next_count(&self) -> Result; fn current_count(&self) -> Result; fn last_check_for_updates_id(&self) -> Result; fn update_check_for_updates_id(&self, id: u64) -> Result<()>; diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index c0c0bd44..6592b7e4 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -258,8 +258,8 @@ impl Service { } #[tracing::instrument(skip(self))] - pub fn next_count(&self) -> Result { - self.db.next_count() + pub async fn next_count(&self) -> Result { + self.db.next_count().await } #[tracing::instrument(skip(self))] diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index bf640015..6624ffb2 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -1,14 +1,16 @@ use std::collections::BTreeMap; use crate::Result; +use async_trait::async_trait; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, OwnedRoomId, RoomId, UserId, }; +#[async_trait] pub trait Data: Send + Sync { - fn create_backup( + async fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw, @@ -16,7 +18,7 @@ pub trait Data: Send + Sync { fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; - fn update_backup( + async fn update_backup( &self, user_id: &UserId, version: &str, @@ -30,7 +32,7 @@ pub trait Data: Send + Sync { fn get_backup(&self, user_id: &UserId, version: &str) -> Result>>; - fn add_key( + async fn add_key( &self, user_id: &UserId, version: &str, diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 5fc52ced..1bda7d86 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -14,25 +14,27 @@ pub struct Service { } impl Service { - pub fn create_backup( + pub async fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw, ) -> Result { - self.db.create_backup(user_id, backup_metadata) + self.db.create_backup(user_id, backup_metadata).await } pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { self.db.delete_backup(user_id, version) } - pub fn update_backup( + pub async fn update_backup( &self, user_id: &UserId, version: &str, backup_metadata: &Raw, ) -> Result { - self.db.update_backup(user_id, version, backup_metadata) + self.db + .update_backup(user_id, version, backup_metadata) + .await } pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { @@ -54,7 +56,7 @@ impl Service { self.db.get_backup(user_id, version) } - pub fn add_key( + pub async fn add_key( &self, user_id: &UserId, version: &str, @@ -64,6 +66,7 @@ impl Service { ) -> Result<()> { self.db .add_key(user_id, version, room_id, session_id, key_data) + .await } pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index c73799e4..30979d2d 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,9 +1,16 @@ use crate::Result; +use async_trait::async_trait; use ruma::{OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; +#[async_trait] pub trait Data: Send + Sync { /// Creates or updates the alias to the given room id. - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()>; + async fn set_alias( + &self, + alias: &RoomAliasId, + room_id: &RoomId, + user_id: &UserId, + ) -> Result<()>; /// Finds the user who assigned the given alias to a room fn who_created_alias(&self, alias: &RoomAliasId) -> Result>; diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index bd5693f1..af50ec48 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -19,7 +19,12 @@ pub struct Service { impl Service { #[tracing::instrument(skip(self))] - pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { + pub async fn set_alias( + &self, + alias: &RoomAliasId, + room_id: &RoomId, + user_id: &UserId, + ) -> Result<()> { if alias == services().globals.admin_alias() && user_id != services().globals.server_user() { Err(Error::BadRequest( @@ -27,7 +32,7 @@ impl Service { "Only the server user can set this alias", )) } else { - self.db.set_alias(alias, room_id, user_id) + self.db.set_alias(alias, room_id, user_id).await } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 1a8a3ad7..e947df29 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -36,7 +36,11 @@ impl Service { let mut i = 0; for id in starting_events { - let short = services().rooms.short.get_or_create_shorteventid(&id)?; + let short = services() + .rooms + .short + .get_or_create_shorteventid(&id) + .await?; let bucket_id = (short % NUM_BUCKETS as u64) as usize; buckets[bucket_id].insert((short, id.clone())); i += 1; @@ -80,7 +84,7 @@ impl Service { chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; - let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); + let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id).await?); services() .rooms .auth_chain @@ -125,7 +129,11 @@ impl Service { } #[tracing::instrument(skip(self, event_id))] - fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + async fn get_auth_chain_inner( + &self, + room_id: &RoomId, + event_id: &EventId, + ) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); @@ -142,7 +150,8 @@ impl Service { let sauthevent = services() .rooms .short - .get_or_create_shorteventid(auth_event)?; + .get_or_create_shorteventid(auth_event) + .await?; if !found.contains(&sauthevent) { found.insert(sauthevent); diff --git a/src/service/rooms/edus/presence/data.rs b/src/service/rooms/edus/presence/data.rs index 53329e08..a9e54558 100644 --- a/src/service/rooms/edus/presence/data.rs +++ b/src/service/rooms/edus/presence/data.rs @@ -1,14 +1,16 @@ use std::collections::HashMap; use crate::Result; +use async_trait::async_trait; use ruma::{events::presence::PresenceEvent, OwnedUserId, RoomId, UserId}; +#[async_trait] pub trait Data: Send + Sync { /// Adds a presence event which will be saved until a new event replaces it. /// /// Note: This method takes a RoomId because presence updates are always bound to rooms to /// make sure users outside these rooms can't see them. - fn update_presence( + async fn update_presence( &self, user_id: &UserId, room_id: &RoomId, diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index 41a33eed..d0761ad7 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -1,9 +1,11 @@ use crate::Result; +use async_trait::async_trait; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; +#[async_trait] pub trait Data: Send + Sync { /// Replaces the previous read receipt. - fn readreceipt_update( + async fn readreceipt_update( &self, user_id: &UserId, room_id: &RoomId, @@ -28,7 +30,7 @@ pub trait Data: Send + Sync { >; /// Sets a private read marker at `count`. - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; + async fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; /// Returns the private read marker. fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result>; diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs index c6035280..89f07fa7 100644 --- a/src/service/rooms/edus/read_receipt/mod.rs +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -11,13 +11,13 @@ pub struct Service { impl Service { /// Replaces the previous read receipt. - pub fn readreceipt_update( + pub async fn readreceipt_update( &self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent, ) -> Result<()> { - self.db.readreceipt_update(user_id, room_id, event) + self.db.readreceipt_update(user_id, room_id, event).await } /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. @@ -38,8 +38,13 @@ impl Service { /// Sets a private read marker at `count`. #[tracing::instrument(skip(self))] - pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - self.db.private_read_set(room_id, user_id, count) + pub async fn private_read_set( + &self, + room_id: &RoomId, + user_id: &UserId, + count: u64, + ) -> Result<()> { + self.db.private_read_set(room_id, user_id, count).await } /// Returns the private read marker. diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index 7546aa84..f86162d7 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -23,7 +23,7 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), services().globals.next_count().await?); let _ = self.typing_update_sender.send(room_id.to_owned()); Ok(()) } @@ -39,7 +39,7 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), services().globals.next_count().await?); let _ = self.typing_update_sender.send(room_id.to_owned()); Ok(()) } @@ -80,7 +80,7 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), services().globals.next_count().await?); let _ = self.typing_update_sender.send(room_id.to_owned()); } Ok(()) diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 0bdfd4ae..21cfdd48 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -589,10 +589,11 @@ impl Service { })?; if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &prev_pdu.kind.to_string().into(), - state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) + .await?; state.insert(shortstatekey, Arc::from(prev_event)); // Now it's the state after the pdu @@ -640,10 +641,14 @@ impl Service { .await?; if let Some(state_key) = &prev_event.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &prev_event.kind.to_string().into(), - state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey( + &prev_event.kind.to_string().into(), + state_key, + ) + .await?; leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); // Now it's the state after the pdu } @@ -677,34 +682,38 @@ impl Service { let lock = services().globals.stateres_mutex.lock(); - let result = + let new_state = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { let res = services().rooms.timeline.get_pdu(id); if let Err(e) = &res { error!("LOOK AT ME Failed to fetch event: {}", e); } res.ok().flatten() - }); + }) + .map_err(|e| { + warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); + e + }) + .ok(); drop(lock); - state_at_incoming_event = match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = - services().rooms.short.get_or_create_shortstatekey( - &event_type.to_string().into(), - &state_key, - )?; - Ok((shortstatekey, event_id)) - }) - .collect::>()?, - ), - Err(e) => { - warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); - None + state_at_incoming_event = match new_state { + Some(new_state) => { + let mut state_at_incoming_event = HashMap::with_capacity(new_state.len()); + for ((event_type, state_key), event_id) in new_state { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey( + &event_type.to_string().into(), + &state_key, + ) + .await?; + state_at_incoming_event.insert(shortstatekey, event_id); + } + Some(state_at_incoming_event) } + None => None, } } } @@ -748,10 +757,11 @@ impl Service { Error::bad_database("Found non-state pdu in state events.") })?; - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &pdu.kind.to_string().into(), - &state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) + .await?; match state.entry(shortstatekey) { hash_map::Entry::Vacant(v) => { @@ -915,17 +925,17 @@ impl Service { }); debug!("Compressing state at event"); - let state_ids_compressed = Arc::new( - state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - services() - .rooms - .state_compressor - .compress_state_event(*shortstatekey, id) - }) - .collect::>()?, - ); + let mut state_ids_compressed = HashSet::new(); + for (shortstatekey, id) in &state_at_incoming_event { + state_ids_compressed.insert( + services() + .rooms + .state_compressor + .compress_state_event(*shortstatekey, id) + .await?, + ); + } + let state_ids_compressed = Arc::new(state_ids_compressed); if incoming_pdu.state_key.is_some() { debug!("Preparing for stateres to derive new room state"); @@ -933,10 +943,11 @@ impl Service { // We also add state after incoming event to the fork states let mut state_after = state_at_incoming_event.clone(); if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &incoming_pdu.kind.to_string().into(), - state_key, - )?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) + .await?; state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); } @@ -951,7 +962,8 @@ impl Service { let (sstatehash, new, removed) = services() .rooms .state_compressor - .save_state(room_id, new_room_state)?; + .save_state(room_id, new_room_state) + .await?; services() .rooms @@ -1078,35 +1090,32 @@ impl Service { }; let lock = services().globals.stateres_mutex.lock(); - let state = match state_res::resolve( + let state = state_res::resolve( room_version_id, &fork_states, auth_chain_sets, fetch_event, - ) { - Ok(new_state) => new_state, - Err(_) => { - return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization")); - } - }; + ).map_err(|_| Error::bad_database("State resolution failed, either an event could not be found or deserialization"))?; drop(lock); debug!("State resolution done. Compressing state"); - let new_room_state = state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; + let mut new_room_state = HashSet::new(); + for ((event_type, state_key), event_id) in state { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) + .await?; + new_room_state.insert( services() .rooms .state_compressor .compress_state_event(shortstatekey, &event_id) - }) - .collect::>()?; + .await?, + ); + } Ok(Arc::new(new_room_state)) } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 5ffe8846..df55df33 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -41,7 +41,7 @@ impl Service { } #[allow(clippy::too_many_arguments)] - pub fn paginate_relations_with_filter( + pub async fn paginate_relations_with_filter( &self, sender_user: &UserId, room_id: &RoomId, @@ -77,13 +77,11 @@ impl Service { match dir { Direction::Forward => { - let relations_until = &services().rooms.pdu_metadata.relations_until( - sender_user, - room_id, - target, - from, - depth, - )?; + let relations_until = &services() + .rooms + .pdu_metadata + .relations_until(sender_user, room_id, target, from, depth) + .await?; let events_after: Vec<_> = relations_until // TODO: should be relations_after .iter() .filter(|(_, pdu)| { @@ -125,13 +123,11 @@ impl Service { }) } Direction::Backward => { - let relations_until = &services().rooms.pdu_metadata.relations_until( - sender_user, - room_id, - target, - from, - depth, - )?; + let relations_until = &services() + .rooms + .pdu_metadata + .relations_until(sender_user, room_id, target, from, depth) + .await?; let events_before: Vec<_> = relations_until .iter() .filter(|(_, pdu)| { @@ -174,7 +170,7 @@ impl Service { } } - pub fn relations_until<'a>( + pub async fn relations_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, @@ -182,7 +178,11 @@ impl Service { until: PduCount, max_depth: u8, ) -> Result> { - let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; + let room_id = services() + .rooms + .short + .get_or_create_shortroomid(room_id) + .await?; let target = match services().rooms.timeline.get_pdu_count(target)? { Some(PduCount::Normal(c)) => c, // TODO: Support backfilled relations diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 652c525b..b22cdc6b 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,10 +1,12 @@ use std::sync::Arc; use crate::Result; +use async_trait::async_trait; use ruma::{events::StateEventType, EventId, RoomId}; +#[async_trait] pub trait Data: Send + Sync { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; + async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; fn get_shortstatekey( &self, @@ -12,7 +14,7 @@ pub trait Data: Send + Sync { state_key: &str, ) -> Result>; - fn get_or_create_shortstatekey( + async fn get_or_create_shortstatekey( &self, event_type: &StateEventType, state_key: &str, @@ -23,9 +25,9 @@ pub trait Data: Send + Sync { fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; + async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; fn get_shortroomid(&self, room_id: &RoomId) -> Result>; - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result; + async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 45fadd74..6ec5a030 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -11,8 +11,8 @@ pub struct Service { } impl Service { - pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - self.db.get_or_create_shorteventid(event_id) + pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { + self.db.get_or_create_shorteventid(event_id).await } pub fn get_shortstatekey( @@ -23,12 +23,14 @@ impl Service { self.db.get_shortstatekey(event_type, state_key) } - pub fn get_or_create_shortstatekey( + pub async fn get_or_create_shortstatekey( &self, event_type: &StateEventType, state_key: &str, ) -> Result { - self.db.get_or_create_shortstatekey(event_type, state_key) + self.db + .get_or_create_shortstatekey(event_type, state_key) + .await } pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { @@ -40,15 +42,15 @@ impl Service { } /// Returns (shortstatehash, already_existed) - pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - self.db.get_or_create_shortstatehash(state_hash) + pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + self.db.get_or_create_shortstatehash(state_hash).await } pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.db.get_shortroomid(room_id) } - pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - self.db.get_or_create_shortroomid(room_id) + pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + self.db.get_or_create_shortroomid(room_id).await } } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index f6581bb5..e76d88f9 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -80,14 +80,11 @@ impl Service { Err(_) => continue, }; - services().rooms.state_cache.update_membership( - room_id, - &user_id, - membership, - &pdu.sender, - None, - false, - )?; + services() + .rooms + .state_cache + .update_membership(room_id, &user_id, membership, &pdu.sender, None, false) + .await?; } TimelineEventType::SpaceChild => { services() @@ -115,7 +112,7 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed))] - pub fn set_event_state( + pub async fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, @@ -124,7 +121,8 @@ impl Service { let shorteventid = services() .rooms .short - .get_or_create_shorteventid(event_id)?; + .get_or_create_shorteventid(event_id) + .await?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; @@ -138,7 +136,8 @@ impl Service { let (shortstatehash, already_existed) = services() .rooms .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await?; if !already_existed { let states_parents = previous_shortstatehash.map_or_else( @@ -187,11 +186,12 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu))] - pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result { let shorteventid = services() .rooms .short - .get_or_create_shorteventid(&new_pdu.event_id)?; + .get_or_create_shorteventid(&new_pdu.event_id) + .await?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; @@ -213,12 +213,14 @@ impl Service { let shortstatekey = services() .rooms .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key) + .await?; let new = services() .rooms .state_compressor - .compress_state_event(shortstatekey, &new_pdu.event_id)?; + .compress_state_event(shortstatekey, &new_pdu.event_id) + .await?; let replaces = states_parents .last() @@ -234,7 +236,7 @@ impl Service { } // TODO: statehash with deterministic inputs - let shortstatehash = services().globals.next_count()?; + let shortstatehash = services().globals.next_count().await?; let mut statediffnew = HashSet::new(); statediffnew.insert(new); diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 53e3176f..2a678830 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -327,6 +327,7 @@ impl Service { .rooms .timeline .create_hash_and_sign_event(new_event, sender, room_id, state_lock) + .await .is_ok()) } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 76dcc6cc..c1134992 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,22 +1,24 @@ use std::{collections::HashSet, sync::Arc}; use crate::{service::appservice::RegistrationInfo, Result}; +use async_trait::async_trait; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; +#[async_trait] pub trait Data: Send + Sync { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - fn mark_as_invited( + async fn mark_as_invited( &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, ) -> Result<()>; - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + async fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index c108695d..500dc9eb 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -25,7 +25,7 @@ pub struct Service { impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] - pub fn update_membership( + pub async fn update_membership( &self, room_id: &RoomId, user_id: &UserId, @@ -103,6 +103,7 @@ impl Service { RoomAccountDataEventType::Tag, &tag_event?, ) + .await .ok(); }; @@ -132,13 +133,16 @@ impl Service { } if room_ids_updated { - services().account_data.update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &serde_json::to_value(&direct_event) - .expect("to json always works"), - )?; + services() + .account_data + .update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event) + .expect("to json always works"), + ) + .await?; } }; } @@ -176,10 +180,12 @@ impl Service { return Ok(()); } - self.db.mark_as_invited(user_id, room_id, last_state)?; + self.db + .mark_as_invited(user_id, room_id, last_state) + .await?; } MembershipState::Leave | MembershipState::Ban => { - self.db.mark_as_left(user_id, room_id)?; + self.db.mark_as_left(user_id, room_id).await?; } _ => {} } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 6118e06b..f3bb6816 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -89,7 +89,7 @@ impl Service { } } - pub fn compress_state_event( + pub async fn compress_state_event( &self, shortstatekey: u64, event_id: &EventId, @@ -99,7 +99,8 @@ impl Service { &services() .rooms .short - .get_or_create_shorteventid(event_id)? + .get_or_create_shorteventid(event_id) + .await? .to_be_bytes(), ); Ok(v.try_into().expect("we checked the size above")) @@ -257,7 +258,7 @@ impl Service { /// Returns the new shortstatehash, and the state diff from the previous room state #[allow(clippy::type_complexity)] - pub fn save_state( + pub async fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, @@ -278,7 +279,8 @@ impl Service { let (new_shortstatehash, already_existed) = services() .rooms .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await?; if Some(new_shortstatehash) == previous_shortstatehash { return Ok(( diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 29d8339d..416aa6d0 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -267,20 +267,22 @@ impl Service { ); let insert_lock = mutex_insert.lock().await; - let count1 = services().globals.next_count()?; + let count1 = services().globals.next_count().await?; // Mark as read first so the sending client doesn't get a notification even if appending // fails services() .rooms .edus .read_receipt - .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + .private_read_set(&pdu.room_id, &pdu.sender, count1) + .await?; services() .rooms .user - .reset_notification_counts(&pdu.sender, &pdu.room_id)?; + .reset_notification_counts(&pdu.sender, &pdu.room_id) + .await?; - let count2 = services().globals.next_count()?; + let count2 = services().globals.next_count().await?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); @@ -373,7 +375,10 @@ impl Service { } for push_key in services().pusher.get_pushkeys(user) { - services().sending.send_push_pdu(&pdu_id, user, push_key?)?; + services() + .sending + .send_push_pdu(&pdu_id, user, push_key?) + .await?; } } @@ -460,14 +465,18 @@ impl Service { // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - services().rooms.state_cache.update_membership( - &pdu.room_id, - &target_user_id, - content.membership, - &pdu.sender, - invite_state, - true, - )?; + services() + .rooms + .state_cache + .update_membership( + &pdu.room_id, + &target_user_id, + content.membership, + &pdu.sender, + invite_state, + true, + ) + .await?; } } TimelineEventType::RoomMessage => { @@ -578,7 +587,8 @@ impl Service { { services() .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone()) + .await?; continue; } @@ -592,10 +602,10 @@ impl Service { { let appservice_uid = appservice.registration.sender_localpart.as_str(); if state_key_uid == appservice_uid { - services().sending.send_pdu_appservice( - appservice.registration.id.clone(), - pdu_id.clone(), - )?; + services() + .sending + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone()) + .await?; continue; } } @@ -645,14 +655,15 @@ impl Service { { services() .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone()) + .await?; } } Ok(pdu_id) } - pub fn create_hash_and_sign_event( + pub async fn create_hash_and_sign_event( &self, pdu_builder: PduBuilder, sender: &UserId, @@ -827,7 +838,8 @@ impl Service { let _shorteventid = services() .rooms .short - .get_or_create_shorteventid(&pdu.event_id)?; + .get_or_create_shorteventid(&pdu.event_id) + .await?; Ok((pdu, pdu_json)) } @@ -842,8 +854,9 @@ impl Service { room_id: &RoomId, state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result> { - let (pdu, pdu_json) = - self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; + let (pdu, pdu_json) = self + .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock) + .await?; if let Some(admin_room) = services().admin.get_admin_room()? { if admin_room == room_id { @@ -986,7 +999,7 @@ impl Service { // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = services().rooms.state.append_to_state(&pdu)?; + let statehashid = services().rooms.state.append_to_state(&pdu).await?; let pdu_id = self .append_pdu( @@ -1027,7 +1040,10 @@ impl Service { // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above servers.remove(services().globals.server_name()); - services().sending.send_pdu(servers.into_iter(), &pdu_id)?; + services() + .sending + .send_pdu(servers.into_iter(), &pdu_id) + .await?; Ok(pdu.event_id) } @@ -1046,11 +1062,11 @@ impl Service { ) -> Result>> { // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - services().rooms.state.set_event_state( - &pdu.event_id, - &pdu.room_id, - state_ids_compressed, - )?; + services() + .rooms + .state + .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed) + .await?; if soft_fail { services() @@ -1264,7 +1280,7 @@ impl Service { ); let insert_lock = mutex_insert.lock().await; - let count = services().globals.next_count()?; + let count = services().globals.next_count().await?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 5544af2c..4e37374c 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,8 +1,10 @@ use crate::Result; +use async_trait::async_trait; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +#[async_trait] pub trait Data: Send + Sync { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + async fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 672e502d..7385325e 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -10,8 +10,12 @@ pub struct Service { } impl Service { - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.reset_notification_counts(user_id, room_id) + pub async fn reset_notification_counts( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<()> { + self.db.reset_notification_counts(user_id, room_id).await } pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 78d3f1e1..060200e3 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,9 +1,11 @@ +use async_trait::async_trait; use ruma::ServerName; use crate::Result; use super::{OutgoingKind, SendingEventType}; +#[async_trait] pub trait Data: Send + Sync { #[allow(clippy::type_complexity)] fn active_requests<'a>( @@ -16,7 +18,7 @@ pub trait Data: Send + Sync { fn delete_active_request(&self, key: Vec) -> Result<()>; fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; - fn queue_requests( + async fn queue_requests( &self, requests: &[(&OutgoingKind, SendingEventType)], ) -> Result>>; diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index fa14f123..052554e6 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -370,10 +370,13 @@ impl Service { } #[tracing::instrument(skip(self, pdu_id, user, pushkey))] - pub fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + pub async fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); let event = SendingEventType::Pdu(pdu_id.to_owned()); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = self + .db + .queue_requests(&[(&outgoing_kind, event.clone())]) + .await?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); @@ -382,7 +385,7 @@ impl Service { } #[tracing::instrument(skip(self, servers, pdu_id))] - pub fn send_pdu>( + pub async fn send_pdu>( &self, servers: I, pdu_id: &[u8], @@ -396,12 +399,15 @@ impl Service { ) }) .collect::>(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; + let keys = self + .db + .queue_requests( + &requests + .iter() + .map(|(o, e)| (o, e.clone())) + .collect::>(), + ) + .await?; for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { self.sender .send((outgoing_kind.to_owned(), event, key)) @@ -412,7 +418,7 @@ impl Service { } #[tracing::instrument(skip(self, server, serialized))] - pub fn send_reliable_edu( + pub async fn send_reliable_edu( &self, server: &ServerName, serialized: Vec, @@ -420,7 +426,10 @@ impl Service { ) -> Result<()> { let outgoing_kind = OutgoingKind::Normal(server.to_owned()); let event = SendingEventType::Edu(serialized); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = self + .db + .queue_requests(&[(&outgoing_kind, event.clone())]) + .await?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); @@ -429,10 +438,13 @@ impl Service { } #[tracing::instrument(skip(self))] - pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + pub async fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { let outgoing_kind = OutgoingKind::Appservice(appservice_id); let event = SendingEventType::Pdu(pdu_id); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + let keys = self + .db + .queue_requests(&[(&outgoing_kind, event.clone())]) + .await?; self.sender .send((outgoing_kind, event, keys.into_iter().next().unwrap())) .unwrap(); diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 75d7eb2c..9d4c8eb9 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,4 +1,5 @@ use crate::Result; +use async_trait::async_trait; use ruma::{ api::client::{device::Device, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, @@ -9,6 +10,7 @@ use ruma::{ }; use std::collections::BTreeMap; +#[async_trait] pub trait Data: Send + Sync { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result; @@ -55,7 +57,7 @@ pub trait Data: Send + Sync { fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; /// Adds a new device to a user. - fn create_device( + async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -64,7 +66,7 @@ pub trait Data: Send + Sync { ) -> Result<()>; /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; /// Returns an iterator over all device ids of this user. fn all_device_ids<'a>( @@ -75,7 +77,7 @@ pub trait Data: Send + Sync { /// Replaces the access token of one device. fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; - fn add_one_time_key( + async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -85,7 +87,7 @@ pub trait Data: Send + Sync { fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; - fn take_one_time_key( + async fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -98,14 +100,14 @@ pub trait Data: Send + Sync { device_id: &DeviceId, ) -> Result>; - fn add_device_keys( + async fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, ) -> Result<()>; - fn add_cross_signing_keys( + async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, @@ -114,7 +116,7 @@ pub trait Data: Send + Sync { notify: bool, ) -> Result<()>; - fn sign_key( + async fn sign_key( &self, target_id: &UserId, key_id: &str, @@ -129,7 +131,7 @@ pub trait Data: Send + Sync { to: Option, ) -> Box> + 'a>; - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; + async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; fn get_device_keys( &self, @@ -167,7 +169,7 @@ pub trait Data: Send + Sync { fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; - fn add_to_device_event( + async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, @@ -189,7 +191,7 @@ pub trait Data: Send + Sync { until: u64, ) -> Result<()>; - fn update_device_metadata( + async fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index a5694a10..1da8d5a9 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -340,7 +340,7 @@ impl Service { } /// Adds a new device to a user. - pub fn create_device( + pub async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -349,18 +349,19 @@ impl Service { ) -> Result<()> { self.db .create_device(user_id, device_id, token, initial_device_display_name) + .await } /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.remove_device(user_id, device_id) + pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + self.db.remove_device(user_id, device_id).await } /// Returns an iterator over all device ids of this user. pub fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator> + 'a { + ) -> impl Send + Iterator> + 'a { self.db.all_device_ids(user_id) } @@ -369,7 +370,7 @@ impl Service { self.db.set_token(user_id, device_id, token) } - pub fn add_one_time_key( + pub async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, @@ -378,19 +379,22 @@ impl Service { ) -> Result<()> { self.db .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) + .await } pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { self.db.last_one_time_keys_update(user_id) } - pub fn take_one_time_key( + pub async fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, ) -> Result)>> { - self.db.take_one_time_key(user_id, device_id, key_algorithm) + self.db + .take_one_time_key(user_id, device_id, key_algorithm) + .await } pub fn count_one_time_keys( @@ -401,16 +405,18 @@ impl Service { self.db.count_one_time_keys(user_id, device_id) } - pub fn add_device_keys( + pub async fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, ) -> Result<()> { - self.db.add_device_keys(user_id, device_id, device_keys) + self.db + .add_device_keys(user_id, device_id, device_keys) + .await } - pub fn add_cross_signing_keys( + pub async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, @@ -418,23 +424,27 @@ impl Service { user_signing_key: &Option>, notify: bool, ) -> Result<()> { - self.db.add_cross_signing_keys( - user_id, - master_key, - self_signing_key, - user_signing_key, - notify, - ) + self.db + .add_cross_signing_keys( + user_id, + master_key, + self_signing_key, + user_signing_key, + notify, + ) + .await } - pub fn sign_key( + pub async fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, ) -> Result<()> { - self.db.sign_key(target_id, key_id, signature, sender_id) + self.db + .sign_key(target_id, key_id, signature, sender_id) + .await } pub fn keys_changed<'a>( @@ -446,8 +456,8 @@ impl Service { self.db.keys_changed(user_or_room_id, from, to) } - pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - self.db.mark_device_key_update(user_id) + pub async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + self.db.mark_device_key_update(user_id).await } pub fn get_device_keys( @@ -501,7 +511,7 @@ impl Service { self.db.get_user_signing_key(user_id) } - pub fn add_to_device_event( + pub async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, @@ -509,13 +519,15 @@ impl Service { event_type: &str, content: serde_json::Value, ) -> Result<()> { - self.db.add_to_device_event( - sender, - target_user_id, - target_device_id, - event_type, - content, - ) + self.db + .add_to_device_event( + sender, + target_user_id, + target_device_id, + event_type, + content, + ) + .await } pub fn get_to_device_events( @@ -535,13 +547,15 @@ impl Service { self.db.remove_to_device_events(user_id, device_id, until) } - pub fn update_device_metadata( + pub async fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, device: &Device, ) -> Result<()> { - self.db.update_device_metadata(user_id, device_id, device) + self.db + .update_device_metadata(user_id, device_id, device) + .await } /// Get device metadata. @@ -565,10 +579,10 @@ impl Service { } /// Deactivate account - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> { // Remove all associated devices for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; + self.remove_device(user_id, &device_id?).await?; } // Set the password to "" to indicate a deactivated account. Hashes will never result in an