1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-06-27 16:35:59 +00:00

KvTree: asyncify clear and increment

This commit is contained in:
chayleaf 2024-06-22 21:22:43 +07:00
parent a8c9e3eebe
commit 7658414fc4
No known key found for this signature in database
GPG key ID: 78171AD46227E68E
62 changed files with 958 additions and 650 deletions

View file

@ -218,17 +218,20 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
.set_displayname(&user_id, Some(displayname.clone()))?; .set_displayname(&user_id, Some(displayname.clone()))?;
// Initial account data // Initial account data
services().account_data.update( services()
None, .account_data
&user_id, .update(
GlobalAccountDataEventType::PushRules.to_string().into(), None,
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent { &user_id,
content: ruma::events::push_rules::PushRulesEventContent { GlobalAccountDataEventType::PushRules.to_string().into(),
global: push::Ruleset::server_default(&user_id), &serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
}, content: ruma::events::push_rules::PushRulesEventContent {
}) global: push::Ruleset::server_default(&user_id),
.expect("to json always works"), },
)?; })
.expect("to json always works"),
)
.await?;
// Inhibit login does not work for guests // Inhibit login does not work for guests
if !is_guest && body.inhibit_login { if !is_guest && body.inhibit_login {
@ -253,12 +256,15 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Create device for this account // Create device for this account
services().users.create_device( services()
&user_id, .users
&device_id, .create_device(
&token, &user_id,
body.initial_device_display_name.clone(), &device_id,
)?; &token,
body.initial_device_display_name.clone(),
)
.await?;
info!("New user {} registered on this server.", user_id); info!("New user {} registered on this server.", user_id);
if body.appservice_info.is_none() && !is_guest { if body.appservice_info.is_none() && !is_guest {
@ -359,7 +365,7 @@ pub async fn change_password_route(
.filter_map(|id| id.ok()) .filter_map(|id| id.ok())
.filter(|id| id != sender_device) .filter(|id| id != sender_device)
{ {
services().users.remove_device(sender_user, &id)?; services().users.remove_device(sender_user, &id).await?;
} }
} }
@ -438,7 +444,7 @@ pub async fn deactivate_route(
client_server::leave_all_rooms(sender_user).await?; client_server::leave_all_rooms(sender_user).await?;
// Remove devices and mark account as deactivated // Remove devices and mark account as deactivated
services().users.deactivate_account(sender_user)?; services().users.deactivate_account(sender_user).await?;
info!("User {} deactivated their account.", sender_user); info!("User {} deactivated their account.", sender_user);
services() services()

View file

@ -57,7 +57,8 @@ pub async fn create_alias_route(
services() services()
.rooms .rooms
.alias .alias
.set_alias(&body.room_alias, &body.room_id, sender_user)?; .set_alias(&body.room_alias, &body.room_id, sender_user)
.await?;
Ok(create_alias::v3::Response::new()) Ok(create_alias::v3::Response::new())
} }

View file

@ -19,7 +19,8 @@ pub async fn create_backup_version_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let version = services() let version = services()
.key_backups .key_backups
.create_backup(sender_user, &body.algorithm)?; .create_backup(sender_user, &body.algorithm)
.await?;
Ok(create_backup_version::v3::Response { version }) Ok(create_backup_version::v3::Response { version })
} }
@ -33,7 +34,8 @@ pub async fn update_backup_version_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services()
.key_backups .key_backups
.update_backup(sender_user, &body.version, &body.algorithm)?; .update_backup(sender_user, &body.version, &body.algorithm)
.await?;
Ok(update_backup_version::v3::Response {}) Ok(update_backup_version::v3::Response {})
} }
@ -133,13 +135,10 @@ pub async fn add_backup_keys_route(
for (room_id, room) in &body.rooms { for (room_id, room) in &body.rooms {
for (session_id, key_data) in &room.sessions { for (session_id, key_data) in &room.sessions {
services().key_backups.add_key( services()
sender_user, .key_backups
&body.version, .add_key(sender_user, &body.version, room_id, session_id, key_data)
room_id, .await?
session_id,
key_data,
)?
} }
} }
@ -179,13 +178,16 @@ pub async fn add_backup_keys_for_room_route(
} }
for (session_id, key_data) in &body.sessions { for (session_id, key_data) in &body.sessions {
services().key_backups.add_key( services()
sender_user, .key_backups
&body.version, .add_key(
&body.room_id, sender_user,
session_id, &body.version,
key_data, &body.room_id,
)? session_id,
key_data,
)
.await?
} }
Ok(add_backup_keys_for_room::v3::Response { Ok(add_backup_keys_for_room::v3::Response {
@ -223,13 +225,16 @@ pub async fn add_backup_keys_for_session_route(
)); ));
} }
services().key_backups.add_key( services()
sender_user, .key_backups
&body.version, .add_key(
&body.room_id, sender_user,
&body.session_id, &body.version,
&body.session_data, &body.room_id,
)?; &body.session_id,
&body.session_data,
)
.await?;
Ok(add_backup_keys_for_session::v3::Response { Ok(add_backup_keys_for_session::v3::Response {
count: (services() count: (services()

View file

@ -26,15 +26,18 @@ pub async fn set_global_account_data_route(
let event_type = body.event_type.to_string(); let event_type = body.event_type.to_string();
services().account_data.update( services()
None, .account_data
sender_user, .update(
event_type.clone().into(), None,
&json!({ sender_user,
"type": event_type, event_type.clone().into(),
"content": data, &json!({
}), "type": event_type,
)?; "content": data,
}),
)
.await?;
Ok(set_global_account_data::v3::Response {}) Ok(set_global_account_data::v3::Response {})
} }
@ -52,15 +55,18 @@ pub async fn set_room_account_data_route(
let event_type = body.event_type.to_string(); let event_type = body.event_type.to_string();
services().account_data.update( services()
Some(&body.room_id), .account_data
sender_user, .update(
event_type.clone().into(), Some(&body.room_id),
&json!({ sender_user,
"type": event_type, event_type.clone().into(),
"content": data, &json!({
}), "type": event_type,
)?; "content": data,
}),
)
.await?;
Ok(set_room_account_data::v3::Response {}) Ok(set_room_account_data::v3::Response {})
} }

View file

@ -57,7 +57,8 @@ pub async fn update_device_route(
services() services()
.users .users
.update_device_metadata(sender_user, &body.device_id, &device)?; .update_device_metadata(sender_user, &body.device_id, &device)
.await?;
Ok(update_device::v3::Response {}) Ok(update_device::v3::Response {})
} }
@ -109,7 +110,8 @@ pub async fn delete_device_route(
services() services()
.users .users
.remove_device(sender_user, &body.device_id)?; .remove_device(sender_user, &body.device_id)
.await?;
Ok(delete_device::v3::Response {}) Ok(delete_device::v3::Response {})
} }
@ -162,7 +164,10 @@ pub async fn delete_devices_route(
} }
for device_id in &body.devices { for device_id in &body.devices {
services().users.remove_device(sender_user, device_id)? services()
.users
.remove_device(sender_user, device_id)
.await?
} }
Ok(delete_devices::v3::Response {}) Ok(delete_devices::v3::Response {})

View file

@ -38,7 +38,8 @@ pub async fn upload_keys_route(
for (key_key, key_value) in &body.one_time_keys { for (key_key, key_value) in &body.one_time_keys {
services() services()
.users .users
.add_one_time_key(sender_user, sender_device, key_key, key_value)?; .add_one_time_key(sender_user, sender_device, key_key, key_value)
.await?;
} }
if let Some(device_keys) = &body.device_keys { if let Some(device_keys) = &body.device_keys {
@ -51,7 +52,8 @@ pub async fn upload_keys_route(
{ {
services() services()
.users .users
.add_device_keys(sender_user, sender_device, device_keys)?; .add_device_keys(sender_user, sender_device, device_keys)
.await?;
} }
} }
@ -131,13 +133,16 @@ pub async fn upload_signing_keys_route(
} }
if let Some(master_key) = &body.master_key { if let Some(master_key) = &body.master_key {
services().users.add_cross_signing_keys( services()
sender_user, .users
master_key, .add_cross_signing_keys(
&body.self_signing_key, sender_user,
&body.user_signing_key, master_key,
true, // notify so that other users see the new keys &body.self_signing_key,
)?; &body.user_signing_key,
true, // notify so that other users see the new keys
)
.await?;
} }
Ok(upload_signing_keys::v3::Response {}) Ok(upload_signing_keys::v3::Response {})
@ -189,7 +194,8 @@ pub async fn upload_signatures_route(
); );
services() services()
.users .users
.sign_key(user_id, key_id, signature, sender_user)?; .sign_key(user_id, key_id, signature, sender_user)
.await?;
} }
} }
} }
@ -419,10 +425,13 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
} }
let json = serde_json::to_value(master_key).expect("to_value always works"); let json = serde_json::to_value(master_key).expect("to_value always works");
let raw = serde_json::from_value(json).expect("Raw::from_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works");
services().users.add_cross_signing_keys( services()
&user, &raw, &None, &None, .users
false, // Dont notify. A notification would trigger another key request resulting in an endless loop .add_cross_signing_keys(
)?; &user, &raw, &None, &None,
false, // Dont notify. A notification would trigger another key request resulting in an endless loop
)
.await?;
master_keys.insert(user, raw); master_keys.insert(user, raw);
} }
@ -481,10 +490,10 @@ pub(crate) async fn claim_keys_helper(
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map { for (device_id, key_algorithm) in map {
if let Some(one_time_keys) = if let Some(one_time_keys) = services()
services() .users
.users .take_one_time_key(user_id, device_id, key_algorithm)
.take_one_time_key(user_id, device_id, key_algorithm)? .await?
{ {
let mut c = BTreeMap::new(); let mut c = BTreeMap::new();
c.insert(one_time_keys.0, one_time_keys.1); c.insert(one_time_keys.0, one_time_keys.1);

View file

@ -703,7 +703,11 @@ async fn join_room_by_id_helper(
} }
} }
services().rooms.short.get_or_create_shortroomid(room_id)?; services()
.rooms
.short
.get_or_create_shortroomid(room_id)
.await?;
info!("Parsing join event"); info!("Parsing join event");
let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone())
@ -744,7 +748,8 @@ async fn join_room_by_id_helper(
let shortstatekey = services() let shortstatekey = services()
.rooms .rooms
.short .short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)
.await?;
state.insert(shortstatekey, pdu.event_id.clone()); state.insert(shortstatekey, pdu.event_id.clone());
} }
} }
@ -781,8 +786,8 @@ async fn join_room_by_id_helper(
&services() &services()
.rooms .rooms
.short .short
.get_or_create_shortstatekey(&k.to_string().into(), s) .get_shortstatekey(&k.to_string().into(), s)
.ok()?, .ok()??,
)?, )?,
) )
.ok()? .ok()?
@ -801,20 +806,23 @@ async fn join_room_by_id_helper(
} }
info!("Saving state from send_join"); info!("Saving state from send_join");
let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( let (statehash_before_join, new, removed) = services()
room_id, .rooms
Arc::new( .state_compressor
state .save_state(room_id, {
.into_iter() let mut new_state = HashSet::new();
.map(|(k, id)| { for (k, id) in state {
new_state.insert(
services() services()
.rooms .rooms
.state_compressor .state_compressor
.compress_state_event(k, &id) .compress_state_event(k, &id)
}) .await?,
.collect::<Result<_>>()?, );
), }
)?; Arc::new(new_state)
})
.await?;
services() services()
.rooms .rooms
@ -827,7 +835,11 @@ async fn join_room_by_id_helper(
// We append to state before appending the pdu, so we don't have a moment in time with the // We append to state before appending the pdu, so we don't have a moment in time with the
// pdu without it's state. This is okay because append_pdu can't fail. // pdu without it's state. This is okay because append_pdu can't fail.
let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; let statehash_after_join = services()
.rooms
.state
.append_to_state(&parsed_join_pdu)
.await?;
info!("Appending new room join event"); info!("Appending new room join event");
services() services()
@ -1253,18 +1265,22 @@ pub(crate) async fn invite_helper<'a>(
}) })
.expect("member event is valid value"); .expect("member event is valid value");
let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( let (pdu, pdu_json) = services()
PduBuilder { .rooms
event_type: TimelineEventType::RoomMember, .timeline
content, .create_hash_and_sign_event(
unsigned: None, PduBuilder {
state_key: Some(user_id.to_string()), event_type: TimelineEventType::RoomMember,
redacts: None, content,
}, unsigned: None,
sender_user, state_key: Some(user_id.to_string()),
room_id, redacts: None,
&state_lock, },
)?; sender_user,
room_id,
&state_lock,
)
.await?;
let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?;
@ -1335,7 +1351,7 @@ pub(crate) async fn invite_helper<'a>(
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.filter(|server| &**server != services().globals.server_name()); .filter(|server| &**server != services().globals.server_name());
services().sending.send_pdu(servers, &pdu_id)?; services().sending.send_pdu(servers, &pdu_id).await?;
} else { } else {
if !services() if !services()
.rooms .rooms
@ -1442,14 +1458,18 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin
)?; )?;
// We always drop the invite, we can't rely on other servers // We always drop the invite, we can't rely on other servers
services().rooms.state_cache.update_membership( services()
room_id, .rooms
user_id, .state_cache
MembershipState::Leave, .update_membership(
user_id, room_id,
last_state, user_id,
true, MembershipState::Leave,
)?; user_id,
last_state,
true,
)
.await?;
} else { } else {
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
services() services()
@ -1473,14 +1493,18 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin
None => { None => {
error!("Trying to leave a room you are not a member of."); error!("Trying to leave a room you are not a member of.");
services().rooms.state_cache.update_membership( services()
room_id, .rooms
user_id, .state_cache
MembershipState::Leave, .update_membership(
user_id, room_id,
None, user_id,
true, MembershipState::Leave,
)?; user_id,
None,
true,
)
.await?;
return Ok(()); return Ok(());
} }
Some(e) => e, Some(e) => e,

View file

@ -143,12 +143,15 @@ pub async fn set_pushrule_route(
return Err(err); return Err(err);
} }
services().account_data.update( services()
None, .account_data
sender_user, .update(
GlobalAccountDataEventType::PushRules.to_string().into(), None,
&serde_json::to_value(account_data).expect("to json value always works"), sender_user,
)?; GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
Ok(set_pushrule::v3::Response {}) Ok(set_pushrule::v3::Response {})
} }
@ -238,12 +241,15 @@ pub async fn set_pushrule_actions_route(
)); ));
} }
services().account_data.update( services()
None, .account_data
sender_user, .update(
GlobalAccountDataEventType::PushRules.to_string().into(), None,
&serde_json::to_value(account_data).expect("to json value always works"), sender_user,
)?; GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
Ok(set_pushrule_actions::v3::Response {}) Ok(set_pushrule_actions::v3::Response {})
} }
@ -332,12 +338,15 @@ pub async fn set_pushrule_enabled_route(
)); ));
} }
services().account_data.update( services()
None, .account_data
sender_user, .update(
GlobalAccountDataEventType::PushRules.to_string().into(), None,
&serde_json::to_value(account_data).expect("to json value always works"), sender_user,
)?; GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
Ok(set_pushrule_enabled::v3::Response {}) Ok(set_pushrule_enabled::v3::Response {})
} }
@ -391,12 +400,15 @@ pub async fn delete_pushrule_route(
return Err(err); return Err(err);
} }
services().account_data.update( services()
None, .account_data
sender_user, .update(
GlobalAccountDataEventType::PushRules.to_string().into(), None,
&serde_json::to_value(account_data).expect("to json value always works"), sender_user,
)?; GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
Ok(delete_pushrule::v3::Response {}) Ok(delete_pushrule::v3::Response {})
} }

View file

@ -26,19 +26,23 @@ pub async fn set_read_marker_route(
event_id: fully_read.clone(), event_id: fully_read.clone(),
}, },
}; };
services().account_data.update( services()
Some(&body.room_id), .account_data
sender_user, .update(
RoomAccountDataEventType::FullyRead, Some(&body.room_id),
&serde_json::to_value(fully_read_event).expect("to json value always works"), sender_user,
)?; RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"),
)
.await?;
} }
if body.private_read_receipt.is_some() || body.read_receipt.is_some() { if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
services() services()
.rooms .rooms
.user .user
.reset_notification_counts(sender_user, &body.room_id)?; .reset_notification_counts(sender_user, &body.room_id)
.await?;
} }
if let Some(event) = &body.private_read_receipt { if let Some(event) = &body.private_read_receipt {
@ -63,7 +67,8 @@ pub async fn set_read_marker_route(
.rooms .rooms
.edus .edus
.read_receipt .read_receipt
.private_read_set(&body.room_id, sender_user, count)?; .private_read_set(&body.room_id, sender_user, count)
.await?;
} }
if let Some(event) = &body.read_receipt { if let Some(event) = &body.read_receipt {
@ -82,14 +87,19 @@ pub async fn set_read_marker_route(
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(event.to_owned(), receipts); receipt_content.insert(event.to_owned(), receipts);
services().rooms.edus.read_receipt.readreceipt_update( services()
sender_user, .rooms
&body.room_id, .edus
ruma::events::receipt::ReceiptEvent { .read_receipt
content: ruma::events::receipt::ReceiptEventContent(receipt_content), .readreceipt_update(
room_id: body.room_id.clone(), sender_user,
}, &body.room_id,
)?; ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(),
},
)
.await?;
} }
Ok(set_read_marker::v3::Response {}) Ok(set_read_marker::v3::Response {})
@ -110,7 +120,8 @@ pub async fn create_receipt_route(
services() services()
.rooms .rooms
.user .user
.reset_notification_counts(sender_user, &body.room_id)?; .reset_notification_counts(sender_user, &body.room_id)
.await?;
} }
match body.receipt_type { match body.receipt_type {
@ -120,12 +131,15 @@ pub async fn create_receipt_route(
event_id: body.event_id.clone(), event_id: body.event_id.clone(),
}, },
}; };
services().account_data.update( services()
Some(&body.room_id), .account_data
sender_user, .update(
RoomAccountDataEventType::FullyRead, Some(&body.room_id),
&serde_json::to_value(fully_read_event).expect("to json value always works"), sender_user,
)?; RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"),
)
.await?;
} }
create_receipt::v3::ReceiptType::Read => { create_receipt::v3::ReceiptType::Read => {
let mut user_receipts = BTreeMap::new(); let mut user_receipts = BTreeMap::new();
@ -142,14 +156,19 @@ pub async fn create_receipt_route(
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(body.event_id.to_owned(), receipts); receipt_content.insert(body.event_id.to_owned(), receipts);
services().rooms.edus.read_receipt.readreceipt_update( services()
sender_user, .rooms
&body.room_id, .edus
ruma::events::receipt::ReceiptEvent { .read_receipt
content: ruma::events::receipt::ReceiptEventContent(receipt_content), .readreceipt_update(
room_id: body.room_id.clone(), sender_user,
}, &body.room_id,
)?; ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(),
},
)
.await?;
} }
create_receipt::v3::ReceiptType::ReadPrivate => { create_receipt::v3::ReceiptType::ReadPrivate => {
let count = services() let count = services()
@ -169,11 +188,12 @@ pub async fn create_receipt_route(
} }
PduCount::Normal(c) => c, PduCount::Normal(c) => c,
}; };
services().rooms.edus.read_receipt.private_read_set( services()
&body.room_id, .rooms
sender_user, .edus
count, .read_receipt
)?; .private_read_set(&body.room_id, sender_user, count)
.await?;
} }
_ => return Err(Error::bad_database("Unsupported receipt type")), _ => return Err(Error::bad_database("Unsupported receipt type")),
} }

View file

@ -25,7 +25,8 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
body.limit, body.limit,
body.recurse, body.recurse,
&body.dir, &body.dir,
)?; )
.await?;
Ok( Ok(
get_relating_events_with_rel_type_and_event_type::v1::Response { get_relating_events_with_rel_type_and_event_type::v1::Response {
@ -57,7 +58,8 @@ pub async fn get_relating_events_with_rel_type_route(
body.limit, body.limit,
body.recurse, body.recurse,
&body.dir, &body.dir,
)?; )
.await?;
Ok(get_relating_events_with_rel_type::v1::Response { Ok(get_relating_events_with_rel_type::v1::Response {
chunk: res.chunk, chunk: res.chunk,
@ -88,4 +90,5 @@ pub async fn get_relating_events_route(
body.recurse, body.recurse,
&body.dir, &body.dir,
) )
.await
} }

View file

@ -54,7 +54,11 @@ pub async fn create_room_route(
let room_id = RoomId::new(services().globals.server_name()); let room_id = RoomId::new(services().globals.server_name());
services().rooms.short.get_or_create_shortroomid(&room_id)?; services()
.rooms
.short
.get_or_create_shortroomid(&room_id)
.await?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
services() services()
@ -488,7 +492,8 @@ pub async fn create_room_route(
services() services()
.rooms .rooms
.alias .alias
.set_alias(&alias, &room_id, sender_user)?; .set_alias(&alias, &room_id, sender_user)
.await?;
} }
if body.visibility == room::Visibility::Public { if body.visibility == room::Visibility::Public {
@ -600,7 +605,8 @@ pub async fn upgrade_room_route(
services() services()
.rooms .rooms
.short .short
.get_or_create_shortroomid(&replacement_room)?; .get_or_create_shortroomid(&replacement_room)
.await?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
services() services()
@ -818,7 +824,8 @@ pub async fn upgrade_room_route(
services() services()
.rooms .rooms
.alias .alias
.set_alias(&alias, &replacement_room, sender_user)?; .set_alias(&alias, &replacement_room, sender_user)
.await?;
} }
// Get the old room power levels // Get the old room power levels

View file

@ -192,12 +192,15 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
if device_exists { if device_exists {
services().users.set_token(&user_id, &device_id, &token)?; services().users.set_token(&user_id, &device_id, &token)?;
} else { } else {
services().users.create_device( services()
&user_id, .users
&device_id, .create_device(
&token, &user_id,
body.initial_device_display_name.clone(), &device_id,
)?; &token,
body.initial_device_display_name.clone(),
)
.await?;
} }
info!("{} logged in", user_id); info!("{} logged in", user_id);
@ -236,7 +239,10 @@ pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3:
} }
} }
services().users.remove_device(sender_user, sender_device)?; services()
.users
.remove_device(sender_user, sender_device)
.await?;
Ok(logout::v3::Response::new()) Ok(logout::v3::Response::new())
} }
@ -272,7 +278,10 @@ pub async fn logout_all_route(
} }
for device_id in services().users.all_device_ids(sender_user).flatten() { for device_id in services().users.all_device_ids(sender_user).flatten() {
services().users.remove_device(sender_user, &device_id)?; services()
.users
.remove_device(sender_user, &device_id)
.await?;
} }
Ok(logout_all::v3::Response::new()) Ok(logout_all::v3::Response::new())

View file

@ -409,7 +409,8 @@ async fn sync_helper(
let leave_shortstatekey = services() let leave_shortstatekey = services()
.rooms .rooms
.short .short
.get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())
.await?;
left_state_ids.insert(leave_shortstatekey, left_event_id); left_state_ids.insert(leave_shortstatekey, left_event_id);
@ -1234,7 +1235,7 @@ pub async fn sync_events_v4_route(
// Setup watchers, so if there's no response, we can wait for them // Setup watchers, so if there's no response, we can wait for them
let watcher = services().globals.watch(&sender_user, &sender_device); let watcher = services().globals.watch(&sender_user, &sender_device);
let next_batch = services().globals.next_count()?; let next_batch = services().globals.next_count().await?;
let globalsince = body let globalsince = body
.pos .pos

View file

@ -42,12 +42,15 @@ pub async fn update_tag_route(
.tags .tags
.insert(body.tag.clone().into(), body.tag_info.clone()); .insert(body.tag.clone().into(), body.tag_info.clone());
services().account_data.update( services()
Some(&body.room_id), .account_data
sender_user, .update(
RoomAccountDataEventType::Tag, Some(&body.room_id),
&serde_json::to_value(tags_event).expect("to json value always works"), sender_user,
)?; RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)
.await?;
Ok(create_tag::v3::Response {}) Ok(create_tag::v3::Response {})
} }
@ -83,12 +86,15 @@ pub async fn delete_tag_route(
tags_event.content.tags.remove(&body.tag.clone().into()); tags_event.content.tags.remove(&body.tag.clone().into());
services().account_data.update( services()
Some(&body.room_id), .account_data
sender_user, .update(
RoomAccountDataEventType::Tag, Some(&body.room_id),
&serde_json::to_value(tags_event).expect("to json value always works"), sender_user,
)?; RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)
.await?;
Ok(delete_tag::v3::Response {}) Ok(delete_tag::v3::Response {})
} }

View file

@ -34,49 +34,58 @@ pub async fn send_event_to_device_route(
map.insert(target_device_id_maybe.clone(), event.clone()); map.insert(target_device_id_maybe.clone(), event.clone());
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages.insert(target_user_id.clone(), map); messages.insert(target_user_id.clone(), map);
let count = services().globals.next_count()?; let count = services().globals.next_count().await?;
services().sending.send_reliable_edu( services()
target_user_id.server_name(), .sending
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( .send_reliable_edu(
DirectDeviceContent { target_user_id.server_name(),
sender: sender_user.clone(), serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
ev_type: body.event_type.clone(), DirectDeviceContent {
message_id: count.to_string().into(), sender: sender_user.clone(),
messages, ev_type: body.event_type.clone(),
}, message_id: count.to_string().into(),
)) messages,
.expect("DirectToDevice EDU can be serialized"), },
count, ))
)?; .expect("DirectToDevice EDU can be serialized"),
count,
)
.await?;
continue; continue;
} }
match target_device_id_maybe { match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => { DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services().users.add_to_device_event( services()
sender_user, .users
target_user_id, .add_to_device_event(
target_device_id,
&body.event_type.to_string(),
event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
})?,
)?
}
DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services().users.all_device_ids(target_user_id) {
services().users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
&target_device_id?, target_device_id,
&body.event_type.to_string(), &body.event_type.to_string(),
event.deserialize_as().map_err(|_| { event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
})?, })?,
)?; )
.await?
}
DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services().users.all_device_ids(target_user_id) {
services()
.users
.add_to_device_event(
sender_user,
target_user_id,
&target_device_id?,
&body.event_type.to_string(),
event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
})?,
)
.await?;
} }
} }
} }

View file

@ -814,7 +814,8 @@ pub async fn send_transaction_message_route(
.rooms .rooms
.edus .edus
.read_receipt .read_receipt
.readreceipt_update(&user_id, &room_id, event)?; .readreceipt_update(&user_id, &room_id, event)
.await?;
} else { } else {
// TODO fetch missing events // TODO fetch missing events
debug!("No known event ids in read receipt: {:?}", user_updates); debug!("No known event ids in read receipt: {:?}", user_updates);
@ -853,7 +854,7 @@ pub async fn send_transaction_message_route(
} }
Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => {
if user_id.server_name() == sender_servername { if user_id.server_name() == sender_servername {
services().users.mark_device_key_update(&user_id)?; services().users.mark_device_key_update(&user_id).await?;
} }
} }
Edu::DirectToDevice(DirectDeviceContent { Edu::DirectToDevice(DirectDeviceContent {
@ -873,37 +874,43 @@ pub async fn send_transaction_message_route(
for (target_device_id_maybe, event) in map { for (target_device_id_maybe, event) in map {
match target_device_id_maybe { match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => { DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services().users.add_to_device_event( services()
&sender, .users
target_user_id, .add_to_device_event(
target_device_id, &sender,
&ev_type.to_string(), target_user_id,
event.deserialize_as().map_err(|e| { target_device_id,
warn!("To-Device event is invalid: {event:?} {e}"); &ev_type.to_string(),
Error::BadRequest( event.deserialize_as().map_err(|e| {
ErrorKind::InvalidParam, warn!("To-Device event is invalid: {event:?} {e}");
"Event is invalid", Error::BadRequest(
) ErrorKind::InvalidParam,
})?, "Event is invalid",
)? )
})?,
)
.await?
} }
DeviceIdOrAllDevices::AllDevices => { DeviceIdOrAllDevices::AllDevices => {
for target_device_id in for target_device_id in
services().users.all_device_ids(target_user_id) services().users.all_device_ids(target_user_id)
{ {
services().users.add_to_device_event( services()
&sender, .users
target_user_id, .add_to_device_event(
&target_device_id?, &sender,
&ev_type.to_string(), target_user_id,
event.deserialize_as().map_err(|_| { &target_device_id?,
Error::BadRequest( &ev_type.to_string(),
ErrorKind::InvalidParam, event.deserialize_as().map_err(|_| {
"Event is invalid", Error::BadRequest(
) ErrorKind::InvalidParam,
})?, "Event is invalid",
)?; )
})?,
)
.await?;
} }
} }
} }
@ -923,13 +930,16 @@ pub async fn send_transaction_message_route(
}) => { }) => {
if user_id.server_name() == sender_servername { if user_id.server_name() == sender_servername {
if let Some(master_key) = master_key { if let Some(master_key) = master_key {
services().users.add_cross_signing_keys( services()
&user_id, .users
&master_key, .add_cross_signing_keys(
&self_signing_key, &user_id,
&None, &master_key,
true, &self_signing_key,
)?; &None,
true,
)
.await?;
} }
} }
} }
@ -1438,18 +1448,22 @@ pub async fn create_join_event_template_route(
}) })
.expect("member event is valid value"); .expect("member event is valid value");
let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( let (_pdu, mut pdu_json) = services()
PduBuilder { .rooms
event_type: TimelineEventType::RoomMember, .timeline
content, .create_hash_and_sign_event(
unsigned: None, PduBuilder {
state_key: Some(body.user_id.to_string()), event_type: TimelineEventType::RoomMember,
redacts: None, content,
}, unsigned: None,
&body.user_id, state_key: Some(body.user_id.to_string()),
&body.room_id, redacts: None,
&state_lock, },
)?; &body.user_id,
&body.room_id,
&state_lock,
)
.await?;
drop(state_lock); drop(state_lock);
@ -1581,7 +1595,7 @@ async fn create_join_event(
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.filter(|server| &**server != services().globals.server_name()); .filter(|server| &**server != services().globals.server_name());
services().sending.send_pdu(servers, &pdu_id)?; services().sending.send_pdu(servers, &pdu_id).await?;
Ok(create_join_event::v1::RoomState { Ok(create_join_event::v1::RoomState {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
@ -1738,14 +1752,18 @@ pub async fn create_invite_route(
.state_cache .state_cache
.server_in_room(services().globals.server_name(), &body.room_id)? .server_in_room(services().globals.server_name(), &body.room_id)?
{ {
services().rooms.state_cache.update_membership( services()
&body.room_id, .rooms
&invited_user, .state_cache
MembershipState::Invite, .update_membership(
&sender, &body.room_id,
Some(invite_state), &invited_user,
true, MembershipState::Invite,
)?; &sender,
Some(invite_state),
true,
)
.await?;
} }
Ok(create_invite::v2::Response { Ok(create_invite::v2::Response {

View file

@ -42,6 +42,7 @@ pub trait KeyValueDatabaseEngine: Send + Sync {
} }
} }
#[async_trait]
pub trait KvTree: Send + Sync { pub trait KvTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>; fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
@ -58,7 +59,7 @@ pub trait KvTree: Send + Sync {
backwards: bool, backwards: bool,
) -> Box<dyn Send + Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>; ) -> Box<dyn Send + Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>; async fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>; fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
fn scan_prefix<'a>( fn scan_prefix<'a>(
@ -68,7 +69,7 @@ pub trait KvTree: Send + Sync {
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>; fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn clear(&self) -> Result<()> { async fn clear(&self) -> Result<()> {
for (key, _) in self.iter() { for (key, _) in self.iter() {
self.remove(&key)?; self.remove(&key)?;
} }

View file

@ -63,6 +63,7 @@ impl PersyTree {
} }
} }
#[async_trait]
impl KvTree for PersyTree { impl KvTree for PersyTree {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
let result = self let result = self
@ -160,7 +161,7 @@ impl KvTree for PersyTree {
} }
} }
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { async fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
self.increment_batch(&mut Some(key.to_owned()).into_iter())?; self.increment_batch(&mut Some(key.to_owned()).into_iter())?;
Ok(self.get(key)?.unwrap()) Ok(self.get(key)?.unwrap())
} }

View file

@ -136,6 +136,7 @@ impl RocksDbEngineTree<'_> {
} }
} }
#[async_trait]
impl KvTree for RocksDbEngineTree<'_> { impl KvTree for RocksDbEngineTree<'_> {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
let readoptions = rocksdb::ReadOptions::default(); let readoptions = rocksdb::ReadOptions::default();
@ -214,7 +215,7 @@ impl KvTree for RocksDbEngineTree<'_> {
) )
} }
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { async fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
let readoptions = rocksdb::ReadOptions::default(); let readoptions = rocksdb::ReadOptions::default();
let writeoptions = rocksdb::WriteOptions::default(); let writeoptions = rocksdb::WriteOptions::default();

View file

@ -166,6 +166,7 @@ impl SqliteTable {
} }
} }
#[async_trait]
impl KvTree for SqliteTable { impl KvTree for SqliteTable {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.get_with_guard(self.engine.read_lock(), key) self.get_with_guard(self.engine.read_lock(), key)
@ -268,7 +269,7 @@ impl KvTree for SqliteTable {
Box::new(rx.into_iter()) Box::new(rx.into_iter())
} }
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { async fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
let old = self.get_with_guard(&guard, key)?; let old = self.get_with_guard(&guard, key)?;
@ -295,7 +296,7 @@ impl KvTree for SqliteTable {
self.watchers.watch(prefix) self.watchers.watch(prefix)
} }
fn clear(&self) -> Result<()> { async fn clear(&self) -> Result<()> {
debug!("clear: running"); debug!("clear: running");
self.engine self.engine
.write_lock() .write_lock()

View file

@ -1,5 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use async_trait::async_trait;
use ruma::{ use ruma::{
api::client::error::ErrorKind, api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
@ -9,10 +10,11 @@ use ruma::{
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
#[async_trait]
impl service::account_data::Data for KeyValueDatabase { impl service::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))] #[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update( async fn update(
&self, &self,
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,
@ -29,7 +31,7 @@ impl service::account_data::Data for KeyValueDatabase {
prefix.push(0xff); prefix.push(0xff);
let mut roomuserdataid = prefix.clone(); let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); roomuserdataid.extend_from_slice(&services().globals.next_count().await?.to_be_bytes());
roomuserdataid.push(0xff); roomuserdataid.push(0xff);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());

View file

@ -20,8 +20,8 @@ pub const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
#[async_trait] #[async_trait]
impl service::globals::Data for KeyValueDatabase { impl service::globals::Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> { async fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?) utils::u64_from_bytes(&self.global.increment(COUNTER).await?)
.map_err(|_| Error::bad_database("Count has invalid bytes.")) .map_err(|_| Error::bad_database("Count has invalid bytes."))
} }

View file

@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use async_trait::async_trait;
use ruma::{ use ruma::{
api::client::{ api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
@ -11,13 +12,14 @@ use ruma::{
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
#[async_trait]
impl service::key_backups::Data for KeyValueDatabase { impl service::key_backups::Data for KeyValueDatabase {
fn create_backup( async fn create_backup(
&self, &self,
user_id: &UserId, user_id: &UserId,
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> { ) -> Result<String> {
let version = services().globals.next_count()?.to_string(); let version = services().globals.next_count().await?.to_string();
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
@ -28,7 +30,7 @@ impl service::key_backups::Data for KeyValueDatabase {
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?; )?;
self.backupid_etag self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count().await?.to_be_bytes())?;
Ok(version) Ok(version)
} }
@ -49,7 +51,7 @@ impl service::key_backups::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn update_backup( async fn update_backup(
&self, &self,
user_id: &UserId, user_id: &UserId,
version: &str, version: &str,
@ -69,7 +71,7 @@ impl service::key_backups::Data for KeyValueDatabase {
self.backupid_algorithm self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?; .insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count().await?.to_be_bytes())?;
Ok(version.to_owned()) Ok(version.to_owned())
} }
@ -138,7 +140,7 @@ impl service::key_backups::Data for KeyValueDatabase {
}) })
} }
fn add_key( async fn add_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
version: &str, version: &str,
@ -158,7 +160,7 @@ impl service::key_backups::Data for KeyValueDatabase {
} }
self.backupid_etag self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count().await?.to_be_bytes())?;
key.push(0xff); key.push(0xff);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());

View file

@ -1,3 +1,4 @@
use async_trait::async_trait;
use ruma::{ use ruma::{
api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId,
UserId, UserId,
@ -5,8 +6,14 @@ use ruma::{
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
#[async_trait]
impl service::rooms::alias::Data for KeyValueDatabase { impl service::rooms::alias::Data for KeyValueDatabase {
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { async fn set_alias(
&self,
alias: &RoomAliasId,
room_id: &RoomId,
user_id: &UserId,
) -> Result<()> {
// Comes first as we don't want a stuck alias // Comes first as we don't want a stuck alias
self.alias_userid self.alias_userid
.insert(alias.alias().as_bytes(), user_id.as_bytes())?; .insert(alias.alias().as_bytes(), user_id.as_bytes())?;
@ -14,7 +21,7 @@ impl service::rooms::alias::Data for KeyValueDatabase {
.insert(alias.alias().as_bytes(), room_id.as_bytes())?; .insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec(); let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xff); aliasid.push(0xff);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); aliasid.extend_from_slice(&services().globals.next_count().await?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(()) Ok(())
} }

View file

@ -1,13 +1,15 @@
use std::collections::HashMap; use std::collections::HashMap;
use async_trait::async_trait;
use ruma::{ use ruma::{
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId, events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
}; };
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
#[async_trait]
impl service::rooms::edus::presence::Data for KeyValueDatabase { impl service::rooms::edus::presence::Data for KeyValueDatabase {
fn update_presence( async fn update_presence(
&self, &self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
@ -15,7 +17,7 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
) -> Result<()> { ) -> Result<()> {
// TODO: Remove old entry? Or maybe just wipe completely from time to time? // TODO: Remove old entry? Or maybe just wipe completely from time to time?
let count = services().globals.next_count()?.to_be_bytes(); let count = services().globals.next_count().await?.to_be_bytes();
let mut presence_id = room_id.as_bytes().to_vec(); let mut presence_id = room_id.as_bytes().to_vec();
presence_id.push(0xff); presence_id.push(0xff);

View file

@ -1,13 +1,15 @@
use std::mem; use std::mem;
use async_trait::async_trait;
use ruma::{ use ruma::{
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId,
}; };
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
#[async_trait]
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update( async fn readreceipt_update(
&self, &self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
@ -36,7 +38,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
} }
let mut room_latest_id = prefix; let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); room_latest_id.extend_from_slice(&services().globals.next_count().await?.to_be_bytes());
room_latest_id.push(0xff); room_latest_id.push(0xff);
room_latest_id.extend_from_slice(user_id.as_bytes()); room_latest_id.extend_from_slice(user_id.as_bytes());
@ -106,7 +108,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
) )
} }
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { async fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
@ -115,7 +117,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
.insert(&key, &count.to_be_bytes())?; .insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes()) .insert(&key, &services().globals.next_count().await?.to_be_bytes())
} }
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {

View file

@ -1,11 +1,13 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
#[async_trait]
impl service::rooms::short::Data for KeyValueDatabase { impl service::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
return Ok(*short); return Ok(*short);
} }
@ -14,7 +16,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
Some(shorteventid) => utils::u64_from_bytes(&shorteventid) Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
None => { None => {
let shorteventid = services().globals.next_count()?; let shorteventid = services().globals.next_count().await?;
self.eventid_shorteventid self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid self.shorteventid_eventid
@ -68,7 +70,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
Ok(short) Ok(short)
} }
fn get_or_create_shortstatekey( async fn get_or_create_shortstatekey(
&self, &self,
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
@ -90,7 +92,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
None => { None => {
let shortstatekey = services().globals.next_count()?; let shortstatekey = services().globals.next_count().await?;
self.statekey_shortstatekey self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())?; .insert(&statekey, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey self.shortstatekey_statekey
@ -176,7 +178,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
} }
/// Returns (shortstatehash, already_existed) /// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(match self.statehash_shortstatehash.get(state_hash)? { Ok(match self.statehash_shortstatehash.get(state_hash)? {
Some(shortstatehash) => ( Some(shortstatehash) => (
utils::u64_from_bytes(&shortstatehash) utils::u64_from_bytes(&shortstatehash)
@ -184,7 +186,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
true, true,
), ),
None => { None => {
let shortstatehash = services().globals.next_count()?; let shortstatehash = services().globals.next_count().await?;
self.statehash_shortstatehash self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?; .insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false) (shortstatehash, false)
@ -202,12 +204,12 @@ impl service::rooms::short::Data for KeyValueDatabase {
.transpose() .transpose()
} }
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
Some(short) => utils::u64_from_bytes(&short) Some(short) => utils::u64_from_bytes(&short)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
None => { None => {
let short = services().globals.next_count()?; let short = services().globals.next_count().await?;
self.roomid_shortroomid self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?; .insert(room_id.as_bytes(), &short.to_be_bytes())?;
short short

View file

@ -1,5 +1,6 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use async_trait::async_trait;
use ruma::{ use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent}, events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw, serde::Raw,
@ -12,6 +13,7 @@ use crate::{
services, utils, Error, Result, services, utils, Error, Result,
}; };
#[async_trait]
impl service::rooms::state_cache::Data for KeyValueDatabase { impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
@ -39,7 +41,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn mark_as_invited( async fn mark_as_invited(
&self, &self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
@ -60,7 +62,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
)?; )?;
self.roomuserid_invitecount.insert( self.roomuserid_invitecount.insert(
&roomuser_id, &roomuser_id,
&services().globals.next_count()?.to_be_bytes(), &services().globals.next_count().await?.to_be_bytes(),
)?; )?;
self.userroomid_joined.remove(&userroom_id)?; self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?; self.roomuserid_joined.remove(&roomuser_id)?;
@ -70,7 +72,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { async fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xff);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
@ -85,7 +87,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
)?; // TODO )?; // TODO
self.roomuserid_leftcount.insert( self.roomuserid_leftcount.insert(
&roomuser_id, &roomuser_id,
&services().globals.next_count()?.to_be_bytes(), &services().globals.next_count().await?.to_be_bytes(),
)?; )?;
self.userroomid_joined.remove(&userroom_id)?; self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?; self.roomuserid_joined.remove(&roomuser_id)?;

View file

@ -1,9 +1,11 @@
use async_trait::async_trait;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
#[async_trait]
impl service::rooms::user::Data for KeyValueDatabase { impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { async fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
@ -18,7 +20,7 @@ impl service::rooms::user::Data for KeyValueDatabase {
self.roomuserid_lastnotificationread.insert( self.roomuserid_lastnotificationread.insert(
&roomuser_id, &roomuser_id,
&services().globals.next_count()?.to_be_bytes(), &services().globals.next_count().await?.to_be_bytes(),
)?; )?;
Ok(()) Ok(())

View file

@ -1,3 +1,4 @@
use async_trait::async_trait;
use ruma::{ServerName, UserId}; use ruma::{ServerName, UserId};
use crate::{ use crate::{
@ -9,6 +10,7 @@ use crate::{
services, utils, Error, Result, services, utils, Error, Result,
}; };
#[async_trait]
impl service::sending::Data for KeyValueDatabase { impl service::sending::Data for KeyValueDatabase {
fn active_requests<'a>( fn active_requests<'a>(
&'a self, &'a self,
@ -59,7 +61,7 @@ impl service::sending::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn queue_requests( async fn queue_requests(
&self, &self,
requests: &[(&OutgoingKind, SendingEventType)], requests: &[(&OutgoingKind, SendingEventType)],
) -> Result<Vec<Vec<u8>>> { ) -> Result<Vec<Vec<u8>>> {
@ -70,7 +72,7 @@ impl service::sending::Data for KeyValueDatabase {
if let SendingEventType::Pdu(value) = &event { if let SendingEventType::Pdu(value) = &event {
key.extend_from_slice(value) key.extend_from_slice(value)
} else { } else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()) key.extend_from_slice(&services().globals.next_count().await?.to_be_bytes())
} }
let value = if let SendingEventType::Edu(value) = &event { let value = if let SendingEventType::Edu(value) = &event {
&**value &**value

View file

@ -1,5 +1,6 @@
use std::{collections::BTreeMap, mem::size_of}; use std::{collections::BTreeMap, mem::size_of};
use async_trait::async_trait;
use ruma::{ use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
@ -17,6 +18,7 @@ use crate::{
services, utils, Error, Result, services, utils, Error, Result,
}; };
#[async_trait]
impl service::users::Data for KeyValueDatabase { impl service::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver. /// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> { fn exists(&self, user_id: &UserId) -> Result<bool> {
@ -192,7 +194,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Adds a new device to a user. /// Adds a new device to a user.
fn create_device( async fn create_device(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -207,7 +209,8 @@ impl service::users::Data for KeyValueDatabase {
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion self.userid_devicelistversion
.increment(user_id.as_bytes())?; .increment(user_id.as_bytes())
.await?;
self.userdeviceid_metadata.insert( self.userdeviceid_metadata.insert(
&userdeviceid, &userdeviceid,
@ -226,7 +229,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Removes a device from a user. /// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
@ -248,7 +251,8 @@ impl service::users::Data for KeyValueDatabase {
// TODO: Remove onetimekeys // TODO: Remove onetimekeys
self.userid_devicelistversion self.userid_devicelistversion
.increment(user_id.as_bytes())?; .increment(user_id.as_bytes())
.await?;
self.userdeviceid_metadata.remove(&userdeviceid)?; self.userdeviceid_metadata.remove(&userdeviceid)?;
@ -304,7 +308,7 @@ impl service::users::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn add_one_time_key( async fn add_one_time_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -335,7 +339,7 @@ impl service::users::Data for KeyValueDatabase {
self.userid_lastonetimekeyupdate.insert( self.userid_lastonetimekeyupdate.insert(
user_id.as_bytes(), user_id.as_bytes(),
&services().globals.next_count()?.to_be_bytes(), &services().globals.next_count().await?.to_be_bytes(),
)?; )?;
Ok(()) Ok(())
@ -352,7 +356,7 @@ impl service::users::Data for KeyValueDatabase {
.unwrap_or(Ok(0)) .unwrap_or(Ok(0))
} }
fn take_one_time_key( async fn take_one_time_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -368,7 +372,7 @@ impl service::users::Data for KeyValueDatabase {
self.userid_lastonetimekeyupdate.insert( self.userid_lastonetimekeyupdate.insert(
user_id.as_bytes(), user_id.as_bytes(),
&services().globals.next_count()?.to_be_bytes(), &services().globals.next_count().await?.to_be_bytes(),
)?; )?;
self.onetimekeyid_onetimekeys self.onetimekeyid_onetimekeys
@ -423,7 +427,7 @@ impl service::users::Data for KeyValueDatabase {
Ok(counts) Ok(counts)
} }
fn add_device_keys( async fn add_device_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -438,12 +442,12 @@ impl service::users::Data for KeyValueDatabase {
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
)?; )?;
self.mark_device_key_update(user_id)?; self.mark_device_key_update(user_id).await?;
Ok(()) Ok(())
} }
fn add_cross_signing_keys( async fn add_cross_signing_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
master_key: &Raw<CrossSigningKey>, master_key: &Raw<CrossSigningKey>,
@ -532,13 +536,13 @@ impl service::users::Data for KeyValueDatabase {
} }
if notify { if notify {
self.mark_device_key_update(user_id)?; self.mark_device_key_update(user_id).await?;
} }
Ok(()) Ok(())
} }
fn sign_key( async fn sign_key(
&self, &self,
target_id: &UserId, target_id: &UserId,
key_id: &str, key_id: &str,
@ -574,7 +578,7 @@ impl service::users::Data for KeyValueDatabase {
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
)?; )?;
self.mark_device_key_update(target_id)?; self.mark_device_key_update(target_id).await?;
Ok(()) Ok(())
} }
@ -623,8 +627,8 @@ impl service::users::Data for KeyValueDatabase {
) )
} }
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes(); let count = services().globals.next_count().await?.to_be_bytes();
for room_id in services() for room_id in services()
.rooms .rooms
.state_cache .state_cache
@ -761,7 +765,7 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
fn add_to_device_event( async fn add_to_device_event(
&self, &self,
sender: &UserId, sender: &UserId,
target_user_id: &UserId, target_user_id: &UserId,
@ -773,7 +777,7 @@ impl service::users::Data for KeyValueDatabase {
key.push(0xff); key.push(0xff);
key.extend_from_slice(target_device_id.as_bytes()); key.extend_from_slice(target_device_id.as_bytes());
key.push(0xff); key.push(0xff);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); key.extend_from_slice(&services().globals.next_count().await?.to_be_bytes());
let mut json = serde_json::Map::new(); let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into()); json.insert("type".to_owned(), event_type.to_owned().into());
@ -843,7 +847,7 @@ impl service::users::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
fn update_device_metadata( async fn update_device_metadata(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -857,7 +861,8 @@ impl service::users::Data for KeyValueDatabase {
assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some());
self.userid_devicelistversion self.userid_devicelistversion
.increment(user_id.as_bytes())?; .increment(user_id.as_bytes())
.await?;
self.userdeviceid_metadata.insert( self.userdeviceid_metadata.insert(
&userdeviceid, &userdeviceid,

View file

@ -680,7 +680,7 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 8 { if services().globals.database_version()? < 8 {
// Generate short room ids for all rooms // Generate short room ids for all rooms
for (room_id, _) in db.roomid_shortstatehash.iter() { for (room_id, _) in db.roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes(); let shortroomid = services().globals.next_count().await?.to_be_bytes();
db.roomid_shortroomid.insert(&room_id, &shortroomid)?; db.roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8"); info!("Migration: 8");
} }
@ -799,7 +799,7 @@ impl KeyValueDatabase {
// Force E2EE device list updates so we can send them over federation // Force E2EE device list updates so we can send them over federation
for user_id in services().users.iter().filter_map(|r| r.ok()) { for user_id in services().users.iter().filter_map(|r| r.ok()) {
services().users.mark_device_key_update(&user_id)?; services().users.mark_device_key_update(&user_id).await?;
} }
services().globals.bump_database_version(10)?; services().globals.bump_database_version(10)?;
@ -811,7 +811,8 @@ impl KeyValueDatabase {
db._db db._db
.open_tree("userdevicesessionid_uiaarequest") .open_tree("userdevicesessionid_uiaarequest")
.await? .await?
.clear()?; .clear()
.await?;
services().globals.bump_database_version(11)?; services().globals.bump_database_version(11)?;
warn!("Migration: 10 -> 11 finished"); warn!("Migration: 10 -> 11 finished");
@ -884,12 +885,16 @@ impl KeyValueDatabase {
} }
} }
services().account_data.update( services()
None, .account_data
&user, .update(
GlobalAccountDataEventType::PushRules.to_string().into(), None,
&serde_json::to_value(account_data).expect("to json value always works"), &user,
)?; GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)
.await?;
} }
services().globals.bump_database_version(12)?; services().globals.bump_database_version(12)?;
@ -930,12 +935,16 @@ impl KeyValueDatabase {
.global .global
.update_with_server_default(user_default_rules); .update_with_server_default(user_default_rules);
services().account_data.update( services()
None, .account_data
&user, .update(
GlobalAccountDataEventType::PushRules.to_string().into(), None,
&serde_json::to_value(account_data).expect("to json value always works"), &user,
)?; GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data)
.expect("to json value always works"),
)
.await?;
} }
services().globals.bump_database_version(13)?; services().globals.bump_database_version(13)?;
@ -969,12 +978,12 @@ impl KeyValueDatabase {
} }
// This data is probably outdated // This data is probably outdated
db.presenceid_presence.clear()?; db.presenceid_presence.clear().await?;
services().admin.start_handler(); services().admin.start_handler();
// Set emergency access for the conduit user // Set emergency access for the conduit user
match set_emergency_access() { match set_emergency_access().await {
Ok(pwd_set) => { Ok(pwd_set) => {
if pwd_set { if pwd_set {
warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"); warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!");
@ -1107,7 +1116,7 @@ impl KeyValueDatabase {
} }
/// Sets the emergency password and push rules for the @conduit account in case emergency password is set /// Sets the emergency password and push rules for the @conduit account in case emergency password is set
fn set_emergency_access() -> Result<bool> { async fn set_emergency_access() -> Result<bool> {
let conduit_user = services().globals.server_user(); let conduit_user = services().globals.server_user();
services().users.set_password( services().users.set_password(
@ -1120,15 +1129,18 @@ fn set_emergency_access() -> Result<bool> {
None => (Ruleset::new(), Ok(false)), None => (Ruleset::new(), Ok(false)),
}; };
services().account_data.update( services()
None, .account_data
conduit_user, .update(
GlobalAccountDataEventType::PushRules.to_string().into(), None,
&serde_json::to_value(&GlobalAccountDataEvent { conduit_user,
content: PushRulesEventContent { global: ruleset }, GlobalAccountDataEventType::PushRules.to_string().into(),
}) &serde_json::to_value(&GlobalAccountDataEvent {
.expect("to json value always works"), content: PushRulesEventContent { global: ruleset },
)?; })
.expect("to json value always works"),
)
.await?;
res res
} }

View file

@ -1,15 +1,17 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::Result; use crate::Result;
use async_trait::async_trait;
use ruma::{ use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw, serde::Raw,
RoomId, UserId, RoomId, UserId,
}; };
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the previous entry.
fn update( async fn update(
&self, &self,
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,

View file

@ -19,14 +19,14 @@ pub struct Service {
impl Service { impl Service {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))] #[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
pub fn update( pub async fn update(
&self, &self,
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,
event_type: RoomAccountDataEventType, event_type: RoomAccountDataEventType,
data: &serde_json::Value, data: &serde_json::Value,
) -> Result<()> { ) -> Result<()> {
self.db.update(room_id, user_id, event_type, data) self.db.update(room_id, user_id, event_type, data).await
} }
/// Searches the account data for a specific kind. /// Searches the account data for a specific kind.

View file

@ -639,19 +639,22 @@ impl Service {
.set_displayname(&user_id, Some(displayname))?; .set_displayname(&user_id, Some(displayname))?;
// Initial account data // Initial account data
services().account_data.update( services()
None, .account_data
&user_id, .update(
ruma::events::GlobalAccountDataEventType::PushRules None,
.to_string() &user_id,
.into(), ruma::events::GlobalAccountDataEventType::PushRules
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent { .to_string()
content: ruma::events::push_rules::PushRulesEventContent { .into(),
global: ruma::push::Ruleset::server_default(&user_id), &serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
}, content: ruma::events::push_rules::PushRulesEventContent {
}) global: ruma::push::Ruleset::server_default(&user_id),
.expect("to json value always works"), },
)?; })
.expect("to json value always works"),
)
.await?;
// we dont add a device since we're not the user, just the creator // we dont add a device since we're not the user, just the creator
@ -704,7 +707,7 @@ impl Service {
"Making {user_id} leave all rooms before deactivation..." "Making {user_id} leave all rooms before deactivation..."
)); ));
services().users.deactivate_account(&user_id)?; services().users.deactivate_account(&user_id).await?;
if leave_rooms { if leave_rooms {
leave_all_rooms(&user_id).await?; leave_all_rooms(&user_id).await?;
@ -800,7 +803,7 @@ impl Service {
} }
for &user_id in &user_ids { for &user_id in &user_ids {
if services().users.deactivate_account(user_id).is_ok() { if services().users.deactivate_account(user_id).await.is_ok() {
deactivation_count += 1 deactivation_count += 1
} }
} }
@ -1057,7 +1060,11 @@ impl Service {
pub(crate) async fn create_admin_room(&self) -> Result<()> { pub(crate) async fn create_admin_room(&self) -> Result<()> {
let room_id = RoomId::new(services().globals.server_name()); let room_id = RoomId::new(services().globals.server_name());
services().rooms.short.get_or_create_shortroomid(&room_id)?; services()
.rooms
.short
.get_or_create_shortroomid(&room_id)
.await?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
services() services()
@ -1293,7 +1300,8 @@ impl Service {
services() services()
.rooms .rooms
.alias .alias
.set_alias(&alias, &room_id, conduit_user)?; .set_alias(&alias, &room_id, conduit_user)
.await?;
Ok(()) Ok(())
} }

View file

@ -69,7 +69,7 @@ impl From<ServerSigningKeys> for SigningKeys {
#[async_trait] #[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn next_count(&self) -> Result<u64>; async fn next_count(&self) -> Result<u64>;
fn current_count(&self) -> Result<u64>; fn current_count(&self) -> Result<u64>;
fn last_check_for_updates_id(&self) -> Result<u64>; fn last_check_for_updates_id(&self) -> Result<u64>;
fn update_check_for_updates_id(&self, id: u64) -> Result<()>; fn update_check_for_updates_id(&self, id: u64) -> Result<()>;

View file

@ -258,8 +258,8 @@ impl Service {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn next_count(&self) -> Result<u64> { pub async fn next_count(&self) -> Result<u64> {
self.db.next_count() self.db.next_count().await
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]

View file

@ -1,14 +1,16 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use crate::Result; use crate::Result;
use async_trait::async_trait;
use ruma::{ use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw, serde::Raw,
OwnedRoomId, RoomId, UserId, OwnedRoomId, RoomId, UserId,
}; };
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn create_backup( async fn create_backup(
&self, &self,
user_id: &UserId, user_id: &UserId,
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
@ -16,7 +18,7 @@ pub trait Data: Send + Sync {
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
fn update_backup( async fn update_backup(
&self, &self,
user_id: &UserId, user_id: &UserId,
version: &str, version: &str,
@ -30,7 +32,7 @@ pub trait Data: Send + Sync {
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>; fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
fn add_key( async fn add_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
version: &str, version: &str,

View file

@ -14,25 +14,27 @@ pub struct Service {
} }
impl Service { impl Service {
pub fn create_backup( pub async fn create_backup(
&self, &self,
user_id: &UserId, user_id: &UserId,
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> { ) -> Result<String> {
self.db.create_backup(user_id, backup_metadata) self.db.create_backup(user_id, backup_metadata).await
} }
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
self.db.delete_backup(user_id, version) self.db.delete_backup(user_id, version)
} }
pub fn update_backup( pub async fn update_backup(
&self, &self,
user_id: &UserId, user_id: &UserId,
version: &str, version: &str,
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> { ) -> Result<String> {
self.db.update_backup(user_id, version, backup_metadata) self.db
.update_backup(user_id, version, backup_metadata)
.await
} }
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
@ -54,7 +56,7 @@ impl Service {
self.db.get_backup(user_id, version) self.db.get_backup(user_id, version)
} }
pub fn add_key( pub async fn add_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
version: &str, version: &str,
@ -64,6 +66,7 @@ impl Service {
) -> Result<()> { ) -> Result<()> {
self.db self.db
.add_key(user_id, version, room_id, session_id, key_data) .add_key(user_id, version, room_id, session_id, key_data)
.await
} }
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {

View file

@ -1,9 +1,16 @@
use crate::Result; use crate::Result;
use async_trait::async_trait;
use ruma::{OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; use ruma::{OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId};
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Creates or updates the alias to the given room id. /// Creates or updates the alias to the given room id.
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()>; async fn set_alias(
&self,
alias: &RoomAliasId,
room_id: &RoomId,
user_id: &UserId,
) -> Result<()>;
/// Finds the user who assigned the given alias to a room /// Finds the user who assigned the given alias to a room
fn who_created_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedUserId>>; fn who_created_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedUserId>>;

View file

@ -19,7 +19,12 @@ pub struct Service {
impl Service { impl Service {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { pub async fn set_alias(
&self,
alias: &RoomAliasId,
room_id: &RoomId,
user_id: &UserId,
) -> Result<()> {
if alias == services().globals.admin_alias() && user_id != services().globals.server_user() if alias == services().globals.admin_alias() && user_id != services().globals.server_user()
{ {
Err(Error::BadRequest( Err(Error::BadRequest(
@ -27,7 +32,7 @@ impl Service {
"Only the server user can set this alias", "Only the server user can set this alias",
)) ))
} else { } else {
self.db.set_alias(alias, room_id, user_id) self.db.set_alias(alias, room_id, user_id).await
} }
} }

View file

@ -36,7 +36,11 @@ impl Service {
let mut i = 0; let mut i = 0;
for id in starting_events { for id in starting_events {
let short = services().rooms.short.get_or_create_shorteventid(&id)?; let short = services()
.rooms
.short
.get_or_create_shorteventid(&id)
.await?;
let bucket_id = (short % NUM_BUCKETS as u64) as usize; let bucket_id = (short % NUM_BUCKETS as u64) as usize;
buckets[bucket_id].insert((short, id.clone())); buckets[bucket_id].insert((short, id.clone()));
i += 1; i += 1;
@ -80,7 +84,7 @@ impl Service {
chunk_cache.extend(cached.iter().copied()); chunk_cache.extend(cached.iter().copied());
} else { } else {
misses2 += 1; misses2 += 1;
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id).await?);
services() services()
.rooms .rooms
.auth_chain .auth_chain
@ -125,7 +129,11 @@ impl Service {
} }
#[tracing::instrument(skip(self, event_id))] #[tracing::instrument(skip(self, event_id))]
fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> { async fn get_auth_chain_inner(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)]; let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new(); let mut found = HashSet::new();
@ -142,7 +150,8 @@ impl Service {
let sauthevent = services() let sauthevent = services()
.rooms .rooms
.short .short
.get_or_create_shorteventid(auth_event)?; .get_or_create_shorteventid(auth_event)
.await?;
if !found.contains(&sauthevent) { if !found.contains(&sauthevent) {
found.insert(sauthevent); found.insert(sauthevent);

View file

@ -1,14 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::Result; use crate::Result;
use async_trait::async_trait;
use ruma::{events::presence::PresenceEvent, OwnedUserId, RoomId, UserId}; use ruma::{events::presence::PresenceEvent, OwnedUserId, RoomId, UserId};
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Adds a presence event which will be saved until a new event replaces it. /// Adds a presence event which will be saved until a new event replaces it.
/// ///
/// Note: This method takes a RoomId because presence updates are always bound to rooms to /// Note: This method takes a RoomId because presence updates are always bound to rooms to
/// make sure users outside these rooms can't see them. /// make sure users outside these rooms can't see them.
fn update_presence( async fn update_presence(
&self, &self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,

View file

@ -1,9 +1,11 @@
use crate::Result; use crate::Result;
use async_trait::async_trait;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId};
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Replaces the previous read receipt. /// Replaces the previous read receipt.
fn readreceipt_update( async fn readreceipt_update(
&self, &self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
@ -28,7 +30,7 @@ pub trait Data: Send + Sync {
>; >;
/// Sets a private read marker at `count`. /// Sets a private read marker at `count`.
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; async fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>;
/// Returns the private read marker. /// Returns the private read marker.
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>; fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>;

View file

@ -11,13 +11,13 @@ pub struct Service {
impl Service { impl Service {
/// Replaces the previous read receipt. /// Replaces the previous read receipt.
pub fn readreceipt_update( pub async fn readreceipt_update(
&self, &self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
event: ReceiptEvent, event: ReceiptEvent,
) -> Result<()> { ) -> Result<()> {
self.db.readreceipt_update(user_id, room_id, event) self.db.readreceipt_update(user_id, room_id, event).await
} }
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
@ -38,8 +38,13 @@ impl Service {
/// Sets a private read marker at `count`. /// Sets a private read marker at `count`.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { pub async fn private_read_set(
self.db.private_read_set(room_id, user_id, count) &self,
room_id: &RoomId,
user_id: &UserId,
count: u64,
) -> Result<()> {
self.db.private_read_set(room_id, user_id, count).await
} }
/// Returns the private read marker. /// Returns the private read marker.

View file

@ -23,7 +23,7 @@ impl Service {
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), services().globals.next_count().await?);
let _ = self.typing_update_sender.send(room_id.to_owned()); let _ = self.typing_update_sender.send(room_id.to_owned());
Ok(()) Ok(())
} }
@ -39,7 +39,7 @@ impl Service {
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), services().globals.next_count().await?);
let _ = self.typing_update_sender.send(room_id.to_owned()); let _ = self.typing_update_sender.send(room_id.to_owned());
Ok(()) Ok(())
} }
@ -80,7 +80,7 @@ impl Service {
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), services().globals.next_count().await?);
let _ = self.typing_update_sender.send(room_id.to_owned()); let _ = self.typing_update_sender.send(room_id.to_owned());
} }
Ok(()) Ok(())

View file

@ -589,10 +589,11 @@ impl Service {
})?; })?;
if let Some(state_key) = &prev_pdu.state_key { if let Some(state_key) = &prev_pdu.state_key {
let shortstatekey = services().rooms.short.get_or_create_shortstatekey( let shortstatekey = services()
&prev_pdu.kind.to_string().into(), .rooms
state_key, .short
)?; .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)
.await?;
state.insert(shortstatekey, Arc::from(prev_event)); state.insert(shortstatekey, Arc::from(prev_event));
// Now it's the state after the pdu // Now it's the state after the pdu
@ -640,10 +641,14 @@ impl Service {
.await?; .await?;
if let Some(state_key) = &prev_event.state_key { if let Some(state_key) = &prev_event.state_key {
let shortstatekey = services().rooms.short.get_or_create_shortstatekey( let shortstatekey = services()
&prev_event.kind.to_string().into(), .rooms
state_key, .short
)?; .get_or_create_shortstatekey(
&prev_event.kind.to_string().into(),
state_key,
)
.await?;
leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id));
// Now it's the state after the pdu // Now it's the state after the pdu
} }
@ -677,34 +682,38 @@ impl Service {
let lock = services().globals.stateres_mutex.lock(); let lock = services().globals.stateres_mutex.lock();
let result = let new_state =
state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
let res = services().rooms.timeline.get_pdu(id); let res = services().rooms.timeline.get_pdu(id);
if let Err(e) = &res { if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e); error!("LOOK AT ME Failed to fetch event: {}", e);
} }
res.ok().flatten() res.ok().flatten()
}); })
.map_err(|e| {
warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e);
e
})
.ok();
drop(lock); drop(lock);
state_at_incoming_event = match result { state_at_incoming_event = match new_state {
Ok(new_state) => Some( Some(new_state) => {
new_state let mut state_at_incoming_event = HashMap::with_capacity(new_state.len());
.into_iter() for ((event_type, state_key), event_id) in new_state {
.map(|((event_type, state_key), event_id)| { let shortstatekey = services()
let shortstatekey = .rooms
services().rooms.short.get_or_create_shortstatekey( .short
&event_type.to_string().into(), .get_or_create_shortstatekey(
&state_key, &event_type.to_string().into(),
)?; &state_key,
Ok((shortstatekey, event_id)) )
}) .await?;
.collect::<Result<_>>()?, state_at_incoming_event.insert(shortstatekey, event_id);
), }
Err(e) => { Some(state_at_incoming_event)
warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e);
None
} }
None => None,
} }
} }
} }
@ -748,10 +757,11 @@ impl Service {
Error::bad_database("Found non-state pdu in state events.") Error::bad_database("Found non-state pdu in state events.")
})?; })?;
let shortstatekey = services().rooms.short.get_or_create_shortstatekey( let shortstatekey = services()
&pdu.kind.to_string().into(), .rooms
&state_key, .short
)?; .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)
.await?;
match state.entry(shortstatekey) { match state.entry(shortstatekey) {
hash_map::Entry::Vacant(v) => { hash_map::Entry::Vacant(v) => {
@ -915,17 +925,17 @@ impl Service {
}); });
debug!("Compressing state at event"); debug!("Compressing state at event");
let state_ids_compressed = Arc::new( let mut state_ids_compressed = HashSet::new();
state_at_incoming_event for (shortstatekey, id) in &state_at_incoming_event {
.iter() state_ids_compressed.insert(
.map(|(shortstatekey, id)| { services()
services() .rooms
.rooms .state_compressor
.state_compressor .compress_state_event(*shortstatekey, id)
.compress_state_event(*shortstatekey, id) .await?,
}) );
.collect::<Result<_>>()?, }
); let state_ids_compressed = Arc::new(state_ids_compressed);
if incoming_pdu.state_key.is_some() { if incoming_pdu.state_key.is_some() {
debug!("Preparing for stateres to derive new room state"); debug!("Preparing for stateres to derive new room state");
@ -933,10 +943,11 @@ impl Service {
// We also add state after incoming event to the fork states // We also add state after incoming event to the fork states
let mut state_after = state_at_incoming_event.clone(); let mut state_after = state_at_incoming_event.clone();
if let Some(state_key) = &incoming_pdu.state_key { if let Some(state_key) = &incoming_pdu.state_key {
let shortstatekey = services().rooms.short.get_or_create_shortstatekey( let shortstatekey = services()
&incoming_pdu.kind.to_string().into(), .rooms
state_key, .short
)?; .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)
.await?;
state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id));
} }
@ -951,7 +962,8 @@ impl Service {
let (sstatehash, new, removed) = services() let (sstatehash, new, removed) = services()
.rooms .rooms
.state_compressor .state_compressor
.save_state(room_id, new_room_state)?; .save_state(room_id, new_room_state)
.await?;
services() services()
.rooms .rooms
@ -1078,35 +1090,32 @@ impl Service {
}; };
let lock = services().globals.stateres_mutex.lock(); let lock = services().globals.stateres_mutex.lock();
let state = match state_res::resolve( let state = state_res::resolve(
room_version_id, room_version_id,
&fork_states, &fork_states,
auth_chain_sets, auth_chain_sets,
fetch_event, fetch_event,
) { ).map_err(|_| Error::bad_database("State resolution failed, either an event could not be found or deserialization"))?;
Ok(new_state) => new_state,
Err(_) => {
return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization"));
}
};
drop(lock); drop(lock);
debug!("State resolution done. Compressing state"); debug!("State resolution done. Compressing state");
let new_room_state = state let mut new_room_state = HashSet::new();
.into_iter() for ((event_type, state_key), event_id) in state {
.map(|((event_type, state_key), event_id)| { let shortstatekey = services()
let shortstatekey = services() .rooms
.rooms .short
.short .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)
.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; .await?;
new_room_state.insert(
services() services()
.rooms .rooms
.state_compressor .state_compressor
.compress_state_event(shortstatekey, &event_id) .compress_state_event(shortstatekey, &event_id)
}) .await?,
.collect::<Result<_>>()?; );
}
Ok(Arc::new(new_room_state)) Ok(Arc::new(new_room_state))
} }

View file

@ -41,7 +41,7 @@ impl Service {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn paginate_relations_with_filter( pub async fn paginate_relations_with_filter(
&self, &self,
sender_user: &UserId, sender_user: &UserId,
room_id: &RoomId, room_id: &RoomId,
@ -77,13 +77,11 @@ impl Service {
match dir { match dir {
Direction::Forward => { Direction::Forward => {
let relations_until = &services().rooms.pdu_metadata.relations_until( let relations_until = &services()
sender_user, .rooms
room_id, .pdu_metadata
target, .relations_until(sender_user, room_id, target, from, depth)
from, .await?;
depth,
)?;
let events_after: Vec<_> = relations_until // TODO: should be relations_after let events_after: Vec<_> = relations_until // TODO: should be relations_after
.iter() .iter()
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
@ -125,13 +123,11 @@ impl Service {
}) })
} }
Direction::Backward => { Direction::Backward => {
let relations_until = &services().rooms.pdu_metadata.relations_until( let relations_until = &services()
sender_user, .rooms
room_id, .pdu_metadata
target, .relations_until(sender_user, room_id, target, from, depth)
from, .await?;
depth,
)?;
let events_before: Vec<_> = relations_until let events_before: Vec<_> = relations_until
.iter() .iter()
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
@ -174,7 +170,7 @@ impl Service {
} }
} }
pub fn relations_until<'a>( pub async fn relations_until<'a>(
&'a self, &'a self,
user_id: &'a UserId, user_id: &'a UserId,
room_id: &'a RoomId, room_id: &'a RoomId,
@ -182,7 +178,11 @@ impl Service {
until: PduCount, until: PduCount,
max_depth: u8, max_depth: u8,
) -> Result<Vec<(PduCount, PduEvent)>> { ) -> Result<Vec<(PduCount, PduEvent)>> {
let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; let room_id = services()
.rooms
.short
.get_or_create_shortroomid(room_id)
.await?;
let target = match services().rooms.timeline.get_pdu_count(target)? { let target = match services().rooms.timeline.get_pdu_count(target)? {
Some(PduCount::Normal(c)) => c, Some(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations // TODO: Support backfilled relations

View file

@ -1,10 +1,12 @@
use std::sync::Arc; use std::sync::Arc;
use crate::Result; use crate::Result;
use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>; async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>;
fn get_shortstatekey( fn get_shortstatekey(
&self, &self,
@ -12,7 +14,7 @@ pub trait Data: Send + Sync {
state_key: &str, state_key: &str,
) -> Result<Option<u64>>; ) -> Result<Option<u64>>;
fn get_or_create_shortstatekey( async fn get_or_create_shortstatekey(
&self, &self,
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
@ -23,9 +25,9 @@ pub trait Data: Send + Sync {
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>;
/// Returns (shortstatehash, already_existed) /// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>;
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>; fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>;
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64>; async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64>;
} }

View file

@ -11,8 +11,8 @@ pub struct Service {
} }
impl Service { impl Service {
pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
self.db.get_or_create_shorteventid(event_id) self.db.get_or_create_shorteventid(event_id).await
} }
pub fn get_shortstatekey( pub fn get_shortstatekey(
@ -23,12 +23,14 @@ impl Service {
self.db.get_shortstatekey(event_type, state_key) self.db.get_shortstatekey(event_type, state_key)
} }
pub fn get_or_create_shortstatekey( pub async fn get_or_create_shortstatekey(
&self, &self,
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<u64> { ) -> Result<u64> {
self.db.get_or_create_shortstatekey(event_type, state_key) self.db
.get_or_create_shortstatekey(event_type, state_key)
.await
} }
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
@ -40,15 +42,15 @@ impl Service {
} }
/// Returns (shortstatehash, already_existed) /// Returns (shortstatehash, already_existed)
pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
self.db.get_or_create_shortstatehash(state_hash) self.db.get_or_create_shortstatehash(state_hash).await
} }
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.db.get_shortroomid(room_id) self.db.get_shortroomid(room_id)
} }
pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
self.db.get_or_create_shortroomid(room_id) self.db.get_or_create_shortroomid(room_id).await
} }
} }

View file

@ -80,14 +80,11 @@ impl Service {
Err(_) => continue, Err(_) => continue,
}; };
services().rooms.state_cache.update_membership( services()
room_id, .rooms
&user_id, .state_cache
membership, .update_membership(room_id, &user_id, membership, &pdu.sender, None, false)
&pdu.sender, .await?;
None,
false,
)?;
} }
TimelineEventType::SpaceChild => { TimelineEventType::SpaceChild => {
services() services()
@ -115,7 +112,7 @@ impl Service {
/// This adds all current state events (not including the incoming event) /// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, state_ids_compressed))] #[tracing::instrument(skip(self, state_ids_compressed))]
pub fn set_event_state( pub async fn set_event_state(
&self, &self,
event_id: &EventId, event_id: &EventId,
room_id: &RoomId, room_id: &RoomId,
@ -124,7 +121,8 @@ impl Service {
let shorteventid = services() let shorteventid = services()
.rooms .rooms
.short .short
.get_or_create_shorteventid(event_id)?; .get_or_create_shorteventid(event_id)
.await?;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?;
@ -138,7 +136,8 @@ impl Service {
let (shortstatehash, already_existed) = services() let (shortstatehash, already_existed) = services()
.rooms .rooms
.short .short
.get_or_create_shortstatehash(&state_hash)?; .get_or_create_shortstatehash(&state_hash)
.await?;
if !already_existed { if !already_existed {
let states_parents = previous_shortstatehash.map_or_else( let states_parents = previous_shortstatehash.map_or_else(
@ -187,11 +186,12 @@ impl Service {
/// This adds all current state events (not including the incoming event) /// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, new_pdu))] #[tracing::instrument(skip(self, new_pdu))]
pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
let shorteventid = services() let shorteventid = services()
.rooms .rooms
.short .short
.get_or_create_shorteventid(&new_pdu.event_id)?; .get_or_create_shorteventid(&new_pdu.event_id)
.await?;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?;
@ -213,12 +213,14 @@ impl Service {
let shortstatekey = services() let shortstatekey = services()
.rooms .rooms
.short .short
.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)
.await?;
let new = services() let new = services()
.rooms .rooms
.state_compressor .state_compressor
.compress_state_event(shortstatekey, &new_pdu.event_id)?; .compress_state_event(shortstatekey, &new_pdu.event_id)
.await?;
let replaces = states_parents let replaces = states_parents
.last() .last()
@ -234,7 +236,7 @@ impl Service {
} }
// TODO: statehash with deterministic inputs // TODO: statehash with deterministic inputs
let shortstatehash = services().globals.next_count()?; let shortstatehash = services().globals.next_count().await?;
let mut statediffnew = HashSet::new(); let mut statediffnew = HashSet::new();
statediffnew.insert(new); statediffnew.insert(new);

View file

@ -327,6 +327,7 @@ impl Service {
.rooms .rooms
.timeline .timeline
.create_hash_and_sign_event(new_event, sender, room_id, state_lock) .create_hash_and_sign_event(new_event, sender, room_id, state_lock)
.await
.is_ok()) .is_ok())
} }

View file

@ -1,22 +1,24 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use crate::{service::appservice::RegistrationInfo, Result}; use crate::{service::appservice::RegistrationInfo, Result};
use async_trait::async_trait;
use ruma::{ use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent}, events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw, serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
}; };
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
fn mark_as_invited( async fn mark_as_invited(
&self, &self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
) -> Result<()>; ) -> Result<()>;
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; async fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; fn update_joined_count(&self, room_id: &RoomId) -> Result<()>;

View file

@ -25,7 +25,7 @@ pub struct Service {
impl Service { impl Service {
/// Update current membership data. /// Update current membership data.
#[tracing::instrument(skip(self, last_state))] #[tracing::instrument(skip(self, last_state))]
pub fn update_membership( pub async fn update_membership(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
user_id: &UserId, user_id: &UserId,
@ -103,6 +103,7 @@ impl Service {
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&tag_event?, &tag_event?,
) )
.await
.ok(); .ok();
}; };
@ -132,13 +133,16 @@ impl Service {
} }
if room_ids_updated { if room_ids_updated {
services().account_data.update( services()
None, .account_data
user_id, .update(
GlobalAccountDataEventType::Direct.to_string().into(), None,
&serde_json::to_value(&direct_event) user_id,
.expect("to json always works"), GlobalAccountDataEventType::Direct.to_string().into(),
)?; &serde_json::to_value(&direct_event)
.expect("to json always works"),
)
.await?;
} }
}; };
} }
@ -176,10 +180,12 @@ impl Service {
return Ok(()); return Ok(());
} }
self.db.mark_as_invited(user_id, room_id, last_state)?; self.db
.mark_as_invited(user_id, room_id, last_state)
.await?;
} }
MembershipState::Leave | MembershipState::Ban => { MembershipState::Leave | MembershipState::Ban => {
self.db.mark_as_left(user_id, room_id)?; self.db.mark_as_left(user_id, room_id).await?;
} }
_ => {} _ => {}
} }

View file

@ -89,7 +89,7 @@ impl Service {
} }
} }
pub fn compress_state_event( pub async fn compress_state_event(
&self, &self,
shortstatekey: u64, shortstatekey: u64,
event_id: &EventId, event_id: &EventId,
@ -99,7 +99,8 @@ impl Service {
&services() &services()
.rooms .rooms
.short .short
.get_or_create_shorteventid(event_id)? .get_or_create_shorteventid(event_id)
.await?
.to_be_bytes(), .to_be_bytes(),
); );
Ok(v.try_into().expect("we checked the size above")) Ok(v.try_into().expect("we checked the size above"))
@ -257,7 +258,7 @@ impl Service {
/// Returns the new shortstatehash, and the state diff from the previous room state /// Returns the new shortstatehash, and the state diff from the previous room state
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub fn save_state( pub async fn save_state(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
@ -278,7 +279,8 @@ impl Service {
let (new_shortstatehash, already_existed) = services() let (new_shortstatehash, already_existed) = services()
.rooms .rooms
.short .short
.get_or_create_shortstatehash(&state_hash)?; .get_or_create_shortstatehash(&state_hash)
.await?;
if Some(new_shortstatehash) == previous_shortstatehash { if Some(new_shortstatehash) == previous_shortstatehash {
return Ok(( return Ok((

View file

@ -267,20 +267,22 @@ impl Service {
); );
let insert_lock = mutex_insert.lock().await; let insert_lock = mutex_insert.lock().await;
let count1 = services().globals.next_count()?; let count1 = services().globals.next_count().await?;
// Mark as read first so the sending client doesn't get a notification even if appending // Mark as read first so the sending client doesn't get a notification even if appending
// fails // fails
services() services()
.rooms .rooms
.edus .edus
.read_receipt .read_receipt
.private_read_set(&pdu.room_id, &pdu.sender, count1)?; .private_read_set(&pdu.room_id, &pdu.sender, count1)
.await?;
services() services()
.rooms .rooms
.user .user
.reset_notification_counts(&pdu.sender, &pdu.room_id)?; .reset_notification_counts(&pdu.sender, &pdu.room_id)
.await?;
let count2 = services().globals.next_count()?; let count2 = services().globals.next_count().await?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec(); let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&count2.to_be_bytes()); pdu_id.extend_from_slice(&count2.to_be_bytes());
@ -373,7 +375,10 @@ impl Service {
} }
for push_key in services().pusher.get_pushkeys(user) { for push_key in services().pusher.get_pushkeys(user) {
services().sending.send_push_pdu(&pdu_id, user, push_key?)?; services()
.sending
.send_push_pdu(&pdu_id, user, push_key?)
.await?;
} }
} }
@ -460,14 +465,18 @@ impl Service {
// Update our membership info, we do this here incase a user is invited // Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth // and immediately leaves we need the DB to record the invite event for auth
services().rooms.state_cache.update_membership( services()
&pdu.room_id, .rooms
&target_user_id, .state_cache
content.membership, .update_membership(
&pdu.sender, &pdu.room_id,
invite_state, &target_user_id,
true, content.membership,
)?; &pdu.sender,
invite_state,
true,
)
.await?;
} }
} }
TimelineEventType::RoomMessage => { TimelineEventType::RoomMessage => {
@ -578,7 +587,8 @@ impl Service {
{ {
services() services()
.sending .sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())
.await?;
continue; continue;
} }
@ -592,10 +602,10 @@ impl Service {
{ {
let appservice_uid = appservice.registration.sender_localpart.as_str(); let appservice_uid = appservice.registration.sender_localpart.as_str();
if state_key_uid == appservice_uid { if state_key_uid == appservice_uid {
services().sending.send_pdu_appservice( services()
appservice.registration.id.clone(), .sending
pdu_id.clone(), .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())
)?; .await?;
continue; continue;
} }
} }
@ -645,14 +655,15 @@ impl Service {
{ {
services() services()
.sending .sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())
.await?;
} }
} }
Ok(pdu_id) Ok(pdu_id)
} }
pub fn create_hash_and_sign_event( pub async fn create_hash_and_sign_event(
&self, &self,
pdu_builder: PduBuilder, pdu_builder: PduBuilder,
sender: &UserId, sender: &UserId,
@ -827,7 +838,8 @@ impl Service {
let _shorteventid = services() let _shorteventid = services()
.rooms .rooms
.short .short
.get_or_create_shorteventid(&pdu.event_id)?; .get_or_create_shorteventid(&pdu.event_id)
.await?;
Ok((pdu, pdu_json)) Ok((pdu, pdu_json))
} }
@ -842,8 +854,9 @@ impl Service {
room_id: &RoomId, room_id: &RoomId,
state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<Arc<EventId>> { ) -> Result<Arc<EventId>> {
let (pdu, pdu_json) = let (pdu, pdu_json) = self
self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)
.await?;
if let Some(admin_room) = services().admin.get_admin_room()? { if let Some(admin_room) = services().admin.get_admin_room()? {
if admin_room == room_id { if admin_room == room_id {
@ -986,7 +999,7 @@ impl Service {
// We append to state before appending the pdu, so we don't have a moment in time with the // We append to state before appending the pdu, so we don't have a moment in time with the
// pdu without it's state. This is okay because append_pdu can't fail. // pdu without it's state. This is okay because append_pdu can't fail.
let statehashid = services().rooms.state.append_to_state(&pdu)?; let statehashid = services().rooms.state.append_to_state(&pdu).await?;
let pdu_id = self let pdu_id = self
.append_pdu( .append_pdu(
@ -1027,7 +1040,10 @@ impl Service {
// Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above
servers.remove(services().globals.server_name()); servers.remove(services().globals.server_name());
services().sending.send_pdu(servers.into_iter(), &pdu_id)?; services()
.sending
.send_pdu(servers.into_iter(), &pdu_id)
.await?;
Ok(pdu.event_id) Ok(pdu.event_id)
} }
@ -1046,11 +1062,11 @@ impl Service {
) -> Result<Option<Vec<u8>>> { ) -> Result<Option<Vec<u8>>> {
// We append to state before appending the pdu, so we don't have a moment in time with the // We append to state before appending the pdu, so we don't have a moment in time with the
// pdu without it's state. This is okay because append_pdu can't fail. // pdu without it's state. This is okay because append_pdu can't fail.
services().rooms.state.set_event_state( services()
&pdu.event_id, .rooms
&pdu.room_id, .state
state_ids_compressed, .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)
)?; .await?;
if soft_fail { if soft_fail {
services() services()
@ -1264,7 +1280,7 @@ impl Service {
); );
let insert_lock = mutex_insert.lock().await; let insert_lock = mutex_insert.lock().await;
let count = services().globals.next_count()?; let count = services().globals.next_count().await?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec(); let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&0_u64.to_be_bytes());
pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes());

View file

@ -1,8 +1,10 @@
use crate::Result; use crate::Result;
use async_trait::async_trait;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; async fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>; fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>;

View file

@ -10,8 +10,12 @@ pub struct Service {
} }
impl Service { impl Service {
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { pub async fn reset_notification_counts(
self.db.reset_notification_counts(user_id, room_id) &self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<()> {
self.db.reset_notification_counts(user_id, room_id).await
} }
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {

View file

@ -1,9 +1,11 @@
use async_trait::async_trait;
use ruma::ServerName; use ruma::ServerName;
use crate::Result; use crate::Result;
use super::{OutgoingKind, SendingEventType}; use super::{OutgoingKind, SendingEventType};
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn active_requests<'a>( fn active_requests<'a>(
@ -16,7 +18,7 @@ pub trait Data: Send + Sync {
fn delete_active_request(&self, key: Vec<u8>) -> Result<()>; fn delete_active_request(&self, key: Vec<u8>) -> Result<()>;
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>;
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>;
fn queue_requests( async fn queue_requests(
&self, &self,
requests: &[(&OutgoingKind, SendingEventType)], requests: &[(&OutgoingKind, SendingEventType)],
) -> Result<Vec<Vec<u8>>>; ) -> Result<Vec<Vec<u8>>>;

View file

@ -370,10 +370,13 @@ impl Service {
} }
#[tracing::instrument(skip(self, pdu_id, user, pushkey))] #[tracing::instrument(skip(self, pdu_id, user, pushkey))]
pub fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { pub async fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> {
let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey);
let event = SendingEventType::Pdu(pdu_id.to_owned()); let event = SendingEventType::Pdu(pdu_id.to_owned());
let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; let keys = self
.db
.queue_requests(&[(&outgoing_kind, event.clone())])
.await?;
self.sender self.sender
.send((outgoing_kind, event, keys.into_iter().next().unwrap())) .send((outgoing_kind, event, keys.into_iter().next().unwrap()))
.unwrap(); .unwrap();
@ -382,7 +385,7 @@ impl Service {
} }
#[tracing::instrument(skip(self, servers, pdu_id))] #[tracing::instrument(skip(self, servers, pdu_id))]
pub fn send_pdu<I: Iterator<Item = OwnedServerName>>( pub async fn send_pdu<I: Iterator<Item = OwnedServerName>>(
&self, &self,
servers: I, servers: I,
pdu_id: &[u8], pdu_id: &[u8],
@ -396,12 +399,15 @@ impl Service {
) )
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let keys = self.db.queue_requests( let keys = self
&requests .db
.iter() .queue_requests(
.map(|(o, e)| (o, e.clone())) &requests
.collect::<Vec<_>>(), .iter()
)?; .map(|(o, e)| (o, e.clone()))
.collect::<Vec<_>>(),
)
.await?;
for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) {
self.sender self.sender
.send((outgoing_kind.to_owned(), event, key)) .send((outgoing_kind.to_owned(), event, key))
@ -412,7 +418,7 @@ impl Service {
} }
#[tracing::instrument(skip(self, server, serialized))] #[tracing::instrument(skip(self, server, serialized))]
pub fn send_reliable_edu( pub async fn send_reliable_edu(
&self, &self,
server: &ServerName, server: &ServerName,
serialized: Vec<u8>, serialized: Vec<u8>,
@ -420,7 +426,10 @@ impl Service {
) -> Result<()> { ) -> Result<()> {
let outgoing_kind = OutgoingKind::Normal(server.to_owned()); let outgoing_kind = OutgoingKind::Normal(server.to_owned());
let event = SendingEventType::Edu(serialized); let event = SendingEventType::Edu(serialized);
let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; let keys = self
.db
.queue_requests(&[(&outgoing_kind, event.clone())])
.await?;
self.sender self.sender
.send((outgoing_kind, event, keys.into_iter().next().unwrap())) .send((outgoing_kind, event, keys.into_iter().next().unwrap()))
.unwrap(); .unwrap();
@ -429,10 +438,13 @@ impl Service {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Result<()> { pub async fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Result<()> {
let outgoing_kind = OutgoingKind::Appservice(appservice_id); let outgoing_kind = OutgoingKind::Appservice(appservice_id);
let event = SendingEventType::Pdu(pdu_id); let event = SendingEventType::Pdu(pdu_id);
let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; let keys = self
.db
.queue_requests(&[(&outgoing_kind, event.clone())])
.await?;
self.sender self.sender
.send((outgoing_kind, event, keys.into_iter().next().unwrap())) .send((outgoing_kind, event, keys.into_iter().next().unwrap()))
.unwrap(); .unwrap();

View file

@ -1,4 +1,5 @@
use crate::Result; use crate::Result;
use async_trait::async_trait;
use ruma::{ use ruma::{
api::client::{device::Device, filter::FilterDefinition}, api::client::{device::Device, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
@ -9,6 +10,7 @@ use ruma::{
}; };
use std::collections::BTreeMap; use std::collections::BTreeMap;
#[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Check if a user has an account on this homeserver. /// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool>; fn exists(&self, user_id: &UserId) -> Result<bool>;
@ -55,7 +57,7 @@ pub trait Data: Send + Sync {
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()>; fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()>;
/// Adds a new device to a user. /// Adds a new device to a user.
fn create_device( async fn create_device(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -64,7 +66,7 @@ pub trait Data: Send + Sync {
) -> Result<()>; ) -> Result<()>;
/// Removes a device from a user. /// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
/// Returns an iterator over all device ids of this user. /// Returns an iterator over all device ids of this user.
fn all_device_ids<'a>( fn all_device_ids<'a>(
@ -75,7 +77,7 @@ pub trait Data: Send + Sync {
/// Replaces the access token of one device. /// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>;
fn add_one_time_key( async fn add_one_time_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -85,7 +87,7 @@ pub trait Data: Send + Sync {
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64>; fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64>;
fn take_one_time_key( async fn take_one_time_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -98,14 +100,14 @@ pub trait Data: Send + Sync {
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>>; ) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>>;
fn add_device_keys( async fn add_device_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
device_keys: &Raw<DeviceKeys>, device_keys: &Raw<DeviceKeys>,
) -> Result<()>; ) -> Result<()>;
fn add_cross_signing_keys( async fn add_cross_signing_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
master_key: &Raw<CrossSigningKey>, master_key: &Raw<CrossSigningKey>,
@ -114,7 +116,7 @@ pub trait Data: Send + Sync {
notify: bool, notify: bool,
) -> Result<()>; ) -> Result<()>;
fn sign_key( async fn sign_key(
&self, &self,
target_id: &UserId, target_id: &UserId,
key_id: &str, key_id: &str,
@ -129,7 +131,7 @@ pub trait Data: Send + Sync {
to: Option<u64>, to: Option<u64>,
) -> Box<dyn Send + Iterator<Item = Result<OwnedUserId>> + 'a>; ) -> Box<dyn Send + Iterator<Item = Result<OwnedUserId>> + 'a>;
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>;
fn get_device_keys( fn get_device_keys(
&self, &self,
@ -167,7 +169,7 @@ pub trait Data: Send + Sync {
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>>; fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>>;
fn add_to_device_event( async fn add_to_device_event(
&self, &self,
sender: &UserId, sender: &UserId,
target_user_id: &UserId, target_user_id: &UserId,
@ -189,7 +191,7 @@ pub trait Data: Send + Sync {
until: u64, until: u64,
) -> Result<()>; ) -> Result<()>;
fn update_device_metadata( async fn update_device_metadata(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,

View file

@ -340,7 +340,7 @@ impl Service {
} }
/// Adds a new device to a user. /// Adds a new device to a user.
pub fn create_device( pub async fn create_device(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -349,18 +349,19 @@ impl Service {
) -> Result<()> { ) -> Result<()> {
self.db self.db
.create_device(user_id, device_id, token, initial_device_display_name) .create_device(user_id, device_id, token, initial_device_display_name)
.await
} }
/// Removes a device from a user. /// Removes a device from a user.
pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
self.db.remove_device(user_id, device_id) self.db.remove_device(user_id, device_id).await
} }
/// Returns an iterator over all device ids of this user. /// Returns an iterator over all device ids of this user.
pub fn all_device_ids<'a>( pub fn all_device_ids<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
) -> impl Iterator<Item = Result<OwnedDeviceId>> + 'a { ) -> impl Send + Iterator<Item = Result<OwnedDeviceId>> + 'a {
self.db.all_device_ids(user_id) self.db.all_device_ids(user_id)
} }
@ -369,7 +370,7 @@ impl Service {
self.db.set_token(user_id, device_id, token) self.db.set_token(user_id, device_id, token)
} }
pub fn add_one_time_key( pub async fn add_one_time_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -378,19 +379,22 @@ impl Service {
) -> Result<()> { ) -> Result<()> {
self.db self.db
.add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value)
.await
} }
pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.db.last_one_time_keys_update(user_id) self.db.last_one_time_keys_update(user_id)
} }
pub fn take_one_time_key( pub async fn take_one_time_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
key_algorithm: &DeviceKeyAlgorithm, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> { ) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
self.db.take_one_time_key(user_id, device_id, key_algorithm) self.db
.take_one_time_key(user_id, device_id, key_algorithm)
.await
} }
pub fn count_one_time_keys( pub fn count_one_time_keys(
@ -401,16 +405,18 @@ impl Service {
self.db.count_one_time_keys(user_id, device_id) self.db.count_one_time_keys(user_id, device_id)
} }
pub fn add_device_keys( pub async fn add_device_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
device_keys: &Raw<DeviceKeys>, device_keys: &Raw<DeviceKeys>,
) -> Result<()> { ) -> Result<()> {
self.db.add_device_keys(user_id, device_id, device_keys) self.db
.add_device_keys(user_id, device_id, device_keys)
.await
} }
pub fn add_cross_signing_keys( pub async fn add_cross_signing_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
master_key: &Raw<CrossSigningKey>, master_key: &Raw<CrossSigningKey>,
@ -418,23 +424,27 @@ impl Service {
user_signing_key: &Option<Raw<CrossSigningKey>>, user_signing_key: &Option<Raw<CrossSigningKey>>,
notify: bool, notify: bool,
) -> Result<()> { ) -> Result<()> {
self.db.add_cross_signing_keys( self.db
user_id, .add_cross_signing_keys(
master_key, user_id,
self_signing_key, master_key,
user_signing_key, self_signing_key,
notify, user_signing_key,
) notify,
)
.await
} }
pub fn sign_key( pub async fn sign_key(
&self, &self,
target_id: &UserId, target_id: &UserId,
key_id: &str, key_id: &str,
signature: (String, String), signature: (String, String),
sender_id: &UserId, sender_id: &UserId,
) -> Result<()> { ) -> Result<()> {
self.db.sign_key(target_id, key_id, signature, sender_id) self.db
.sign_key(target_id, key_id, signature, sender_id)
.await
} }
pub fn keys_changed<'a>( pub fn keys_changed<'a>(
@ -446,8 +456,8 @@ impl Service {
self.db.keys_changed(user_or_room_id, from, to) self.db.keys_changed(user_or_room_id, from, to)
} }
pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { pub async fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
self.db.mark_device_key_update(user_id) self.db.mark_device_key_update(user_id).await
} }
pub fn get_device_keys( pub fn get_device_keys(
@ -501,7 +511,7 @@ impl Service {
self.db.get_user_signing_key(user_id) self.db.get_user_signing_key(user_id)
} }
pub fn add_to_device_event( pub async fn add_to_device_event(
&self, &self,
sender: &UserId, sender: &UserId,
target_user_id: &UserId, target_user_id: &UserId,
@ -509,13 +519,15 @@ impl Service {
event_type: &str, event_type: &str,
content: serde_json::Value, content: serde_json::Value,
) -> Result<()> { ) -> Result<()> {
self.db.add_to_device_event( self.db
sender, .add_to_device_event(
target_user_id, sender,
target_device_id, target_user_id,
event_type, target_device_id,
content, event_type,
) content,
)
.await
} }
pub fn get_to_device_events( pub fn get_to_device_events(
@ -535,13 +547,15 @@ impl Service {
self.db.remove_to_device_events(user_id, device_id, until) self.db.remove_to_device_events(user_id, device_id, until)
} }
pub fn update_device_metadata( pub async fn update_device_metadata(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
device: &Device, device: &Device,
) -> Result<()> { ) -> Result<()> {
self.db.update_device_metadata(user_id, device_id, device) self.db
.update_device_metadata(user_id, device_id, device)
.await
} }
/// Get device metadata. /// Get device metadata.
@ -565,10 +579,10 @@ impl Service {
} }
/// Deactivate account /// Deactivate account
pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
// Remove all associated devices // Remove all associated devices
for device_id in self.all_device_ids(user_id) { for device_id in self.all_device_ids(user_id) {
self.remove_device(user_id, &device_id?)?; self.remove_device(user_id, &device_id?).await?;
} }
// Set the password to "" to indicate a deactivated account. Hashes will never result in an // Set the password to "" to indicate a deactivated account. Hashes will never result in an