From ca76e92abd35bce283a79844094a0961eb8e7575 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Thu, 31 Oct 2024 16:23:54 +0000 Subject: [PATCH 1/5] refactor federation membership handshake endpoints, reducing duplication --- src/api/server_server.rs | 203 +++++++++++++++++++++------------------ 1 file changed, 112 insertions(+), 91 deletions(-) diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 4da86f4e..40f107b8 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -67,7 +67,7 @@ use std::{ sync::Arc, time::{Duration, Instant, SystemTime}, }; -use tokio::sync::RwLock; +use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, warn}; @@ -1500,36 +1500,10 @@ pub async fn get_room_state_ids_route( pub async fn create_join_event_template_route( body: Ruma, ) -> Result { - if !services().rooms.metadata.exists(&body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room is unknown to this server.", - )); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; - - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.to_owned()) - .or_default(), - ); + let (mutex_state, room_version_id) = + member_shake_preamble(&body.sender_servername, &body.room_id).await?; let state_lock = mutex_state.lock().await; - let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; - let join_authorized_via_users_server = if // The following two functions check whether the user can "join" without performing a restricted join !services() @@ -1580,12 +1554,32 @@ pub async fn create_join_event_template_route( )); } + Ok(prepare_join_event::v1::Response { + room_version: Some(room_version_id), + event: create_membership_template( + &body.user_id, + &body.room_id, + join_authorized_via_users_server, + MembershipState::Join, + state_lock, + )?, + }) +} + +/// Creates a template for the given membership state, to return on the `/make_` endpoints +fn create_membership_template( + user_id: &UserId, + room_id: &RoomId, + join_authorized_via_users_server: Option, + membership: MembershipState, + state_lock: tokio::sync::MutexGuard<'_, ()>, +) -> Result, Error> { let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, blurhash: None, displayname: None, is_direct: None, - membership: MembershipState::Join, + membership, third_party_invite: None, reason: None, join_authorized_via_users_server, @@ -1597,12 +1591,12 @@ pub async fn create_join_event_template_route( event_type: TimelineEventType::RoomMember, content, unsigned: None, - state_key: Some(body.user_id.to_string()), + state_key: Some(user_id.to_string()), redacts: None, timestamp: None, }, - &body.user_id, - &body.room_id, + user_id, + room_id, &state_lock, )?; @@ -1610,17 +1604,13 @@ pub async fn create_join_event_template_route( pdu_json.remove("event_id"); - Ok(prepare_join_event::v1::Response { - room_version: Some(room_version_id), - event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), - }) + let raw_event = to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"); + + Ok(raw_event) } -async fn create_join_event( - sender_servername: &ServerName, - room_id: &RoomId, - pdu: &RawJsonValue, -) -> Result { +/// checks whether the given room exists, and checks whether the specified server is allowed to send events according to the ACL +fn room_and_acl_check(room_id: &RoomId, sender_servername: &OwnedServerName) -> Result<(), Error> { if !services().rooms.metadata.exists(room_id)? { return Err(Error::BadRequest( ErrorKind::NotFound, @@ -1632,6 +1622,40 @@ async fn create_join_event( .rooms .event_handler .acl_check(sender_servername, room_id)?; + Ok(()) +} + +/// Takes care of common boilerpalte for room membership handshake endpoints. +/// The returned mutex must be locked by the caller. +async fn member_shake_preamble( + sender_servername: &Option, + room_id: &RoomId, +) -> Result<(Arc>, RoomVersionId), Error> { + let sender_servername = sender_servername.as_ref().expect("server is authenticated"); + room_and_acl_check(room_id, sender_servername)?; + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); + + let room_version_id = services().rooms.state.get_room_version(room_id)?; + + Ok((mutex_state, room_version_id)) +} + +async fn create_join_event( + sender_servername: &Option, + room_id: &RoomId, + pdu: &RawJsonValue, +) -> Result { + let sender_servername = sender_servername.as_ref().expect("server is authenticated"); + room_and_acl_check(room_id, sender_servername)?; // We need to return the state prior to joining, let's keep a reference to that here let shortstatehash = services() @@ -1643,8 +1667,44 @@ async fn create_join_event( "Pdu state not found.", ))?; + let pdu = append_member_pdu(MembershipState::Join, sender_servername, room_id, pdu).await?; + + let state_ids = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await?; + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(room_id, state_ids.values().cloned().collect()) + .await?; + + Ok(create_join_event::v1::RoomState { + auth_chain: auth_chain_ids + .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + state: state_ids + .iter() + .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + event: pdu.map(|pdu| { + to_raw_value(&CanonicalJsonValue::Object(pdu)) + .expect("To raw json should not fail since only change was adding signature") + }), + }) +} + +/// Takes the given membership PDU and attempts to append it to the timeline +async fn append_member_pdu( + membership: MembershipState, + sender_servername: &OwnedServerName, + room_id: &RoomId, + pdu: &RawJsonValue, +) -> Result>, Error> { let pub_key_map = RwLock::new(BTreeMap::new()); - // let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and hashes checks let room_version_id = services().rooms.state.get_room_version(room_id)?; @@ -1719,17 +1779,18 @@ async fn create_join_event( ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid event content"))?; - if event_content.membership != MembershipState::Join { + if event_content.membership != membership { return Err(Error::BadRequest( ErrorKind::BadJson, "Membership of sent event does not match that of the endpoint", )); } - let sign_join_event = event_content - .join_authorized_via_users_server - .map(|user| user.server_name() == services().globals.server_name()) - .unwrap_or_default() + let sign_join_event = membership == MembershipState::Join + && event_content + .join_authorized_via_users_server + .map(|user| user.server_name() == services().globals.server_name()) + .unwrap_or_default() && user_can_perform_restricted_join(&sender, room_id, &room_version_id).unwrap_or_default(); if sign_join_event { @@ -1779,17 +1840,6 @@ async fn create_join_event( ))?; drop(mutex_lock); - let state_ids = services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await?; - let auth_chain_ids = services() - .rooms - .auth_chain - .get_auth_chain(room_id, state_ids.values().cloned().collect()) - .await?; - let servers = services() .rooms .state_cache @@ -1799,26 +1849,7 @@ async fn create_join_event( services().sending.send_pdu(servers, &pdu_id)?; - Ok(create_join_event::v1::RoomState { - auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - state: state_ids - .iter() - .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - // Event field is required if we sign the join event. - event: if sign_join_event { - Some( - to_raw_value(&CanonicalJsonValue::Object(value)) - .expect("To raw json should not fail since only change was adding signature"), - ) - } else { - None - }, - }) + Ok(if sign_join_event { Some(value) } else { None }) } /// # `PUT /_matrix/federation/v1/send_join/{roomId}/{eventId}` @@ -1827,12 +1858,7 @@ async fn create_join_event( pub async fn create_join_event_v1_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(&body.sender_servername, &body.room_id, &body.pdu).await?; Ok(create_join_event::v1::Response { room_state }) } @@ -1843,16 +1869,11 @@ pub async fn create_join_event_v1_route( pub async fn create_join_event_v2_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - let create_join_event::v1::RoomState { auth_chain, state, event, - } = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + } = create_join_event(&body.sender_servername, &body.room_id, &body.pdu).await?; let room_state = create_join_event::v2::RoomState { members_omitted: false, auth_chain, From d0c1b920aed85d9a5e2b44791ef2a9366a0b646a Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Fri, 28 Feb 2025 14:09:33 +0000 Subject: [PATCH 2/5] feat(federation): implement /make_leave and /send_leave --- src/api/server_server.rs | 50 +++++++++++++++++++++++++++++++++++++++- src/main.rs | 2 ++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 40f107b8..60d37bde 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -31,7 +31,10 @@ use ruma::{ }, event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, keys::{claim_keys, get_keys}, - membership::{create_invite, create_join_event, prepare_join_event}, + membership::{ + create_invite, create_join_event, create_leave_event, prepare_join_event, + prepare_leave_event, + }, openid::get_openid_userinfo, query::{get_profile_information, get_room_information}, space::get_hierarchy, @@ -1494,6 +1497,28 @@ pub async fn get_room_state_ids_route( }) } +/// # `GET /_matrix/federation/v1/make_leave/{roomId}/{userId}` +/// +/// Creates a leave template. +pub async fn create_leave_event_template_route( + body: Ruma, +) -> Result { + let (mutex_state, room_version_id) = + member_shake_preamble(&body.sender_servername, &body.room_id).await?; + let state_lock = mutex_state.lock().await; + + Ok(prepare_leave_event::v1::Response { + room_version: Some(room_version_id), + event: create_membership_template( + &body.user_id, + &body.room_id, + None, + MembershipState::Leave, + state_lock, + )?, + }) +} + /// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}` /// /// Creates a join template. @@ -1885,6 +1910,29 @@ pub async fn create_join_event_v2_route( Ok(create_join_event::v2::Response { room_state }) } +/// # `PUT /_matrix/federation/v2/send_leave/{roomId}/{eventId}` +/// +/// Submits a signed leave event. +pub async fn create_leave_event_route( + body: Ruma, +) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + room_and_acl_check(&body.room_id, sender_servername)?; + + append_member_pdu( + MembershipState::Leave, + sender_servername, + &body.room_id, + &body.pdu, + ) + .await?; + + Ok(create_leave_event::v2::Response {}) +} + /// Checks whether the given user can join the given room via a restricted join. /// This doesn't check the current user's membership. This should be done externally, /// either by using the state cache or attempting to authorize the event. diff --git a/src/main.rs b/src/main.rs index 2776c200..d37cfb21 100644 --- a/src/main.rs +++ b/src/main.rs @@ -458,6 +458,8 @@ fn routes(config: &Config) -> Router { .ruma_route(server_server::create_join_event_template_route) .ruma_route(server_server::create_join_event_v1_route) .ruma_route(server_server::create_join_event_v2_route) + .ruma_route(server_server::create_leave_event_template_route) + .ruma_route(server_server::create_leave_event_route) .ruma_route(server_server::create_invite_route) .ruma_route(server_server::get_devices_route) .ruma_route(server_server::get_content_route) From 8acacdebc8f3fcc39b96c60c95d5f09325884c83 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Sat, 1 Mar 2025 02:44:24 +0000 Subject: [PATCH 3/5] chore: bump ruma & rust --- Cargo.lock | 24 ++++++++++++------------ Cargo.toml | 2 +- complement/Dockerfile | 2 +- flake.nix | 2 +- rust-toolchain.toml | 2 +- src/api/server_server.rs | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4cca6c5a..a6d9d3a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2190,7 +2190,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.12.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "assign", "js_int", @@ -2210,7 +2210,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.12.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "js_int", "ruma-common", @@ -2222,7 +2222,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.20.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "as_variant", "assign", @@ -2245,7 +2245,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.15.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "as_variant", "base64 0.22.1", @@ -2276,7 +2276,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.30.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "as_variant", "indexmap 2.2.6", @@ -2299,7 +2299,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.11.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "bytes", "http 1.1.0", @@ -2317,7 +2317,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.10.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "js_int", "thiserror 2.0.11", @@ -2326,7 +2326,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.15.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "cfg-if", "proc-macro-crate", @@ -2341,7 +2341,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.11.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "js_int", "ruma-common", @@ -2353,7 +2353,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.5.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "headers", "http 1.1.0", @@ -2366,7 +2366,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.17.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -2382,7 +2382,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.13.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "js_int", "ruma-common", diff --git a/Cargo.toml b/Cargo.toml index 3e463174..e5fbd41a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ repository = "https://gitlab.com/famedly/conduit" version = "0.10.0-alpha" # See also `rust-toolchain.toml` -rust-version = "1.80.0" +rust-version = "1.81.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/complement/Dockerfile b/complement/Dockerfile index 0bf0cfcd..ce067ec3 100644 --- a/complement/Dockerfile +++ b/complement/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.80.0 +FROM rust:1.81.0 WORKDIR /workdir diff --git a/flake.nix b/flake.nix index 4b50a7ec..df05bf86 100644 --- a/flake.nix +++ b/flake.nix @@ -59,7 +59,7 @@ file = ./rust-toolchain.toml; # See also `rust-toolchain.toml` - sha256 = "sha256-6eN/GKzjVSjEhGO9FhWObkRFaE1Jf+uqMSdQnb8lcB4="; + sha256 = "sha256-VZZnlyP69+Y3crrLHQyJirqlHrTtGTsyiSnZB8jEvVo="; }; }); in diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 995b142b..465ffdee 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -9,7 +9,7 @@ # If you're having trouble making the relevant changes, bug a maintainer. [toolchain] -channel = "1.80.0" +channel = "1.81.0" components = [ # For rust-analyzer "rust-src", diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 60d37bde..8b8e6d30 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -640,7 +640,7 @@ async fn request_well_known(destination: &str) -> Option<(String, Instant)> { let response = services() .globals .default_client() - .get(&format!("https://{destination}/.well-known/matrix/server")) + .get(format!("https://{destination}/.well-known/matrix/server")) .send() .await; debug!("Got well known response"); From f4d90e99891a16ebde9df8cf1675f464c98c1cf6 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Sat, 1 Mar 2025 19:17:32 +0000 Subject: [PATCH 4/5] refactor: move duplicate code and some other small optimizations --- src/api/client_server/membership.rs | 339 +++++++++----------- src/database/key_value/rooms/mod.rs | 24 ++ src/database/key_value/rooms/state_cache.rs | 161 ++++------ src/database/key_value/rooms/timeline.rs | 10 +- src/database/key_value/rooms/user.rs | 17 +- src/database/mod.rs | 2 +- src/service/rooms/state/mod.rs | 69 +--- src/service/rooms/state_cache/mod.rs | 7 + src/service/rooms/timeline/mod.rs | 8 +- 9 files changed, 283 insertions(+), 354 deletions(-) diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 520bfa00..714a0b7c 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -18,8 +18,10 @@ use ruma::{ }, StateEventType, TimelineEventType, }, + serde::Raw, state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, - OwnedEventId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, UserId, + OwnedEventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomId, + RoomVersionId, UserId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use std::{ @@ -74,7 +76,7 @@ pub async fn join_room_by_id_route( ); join_room_by_id_helper( - body.sender_user.as_deref(), + body.sender_user.as_deref().expect("user is authenticated"), &body.room_id, body.reason.clone(), &servers, @@ -95,9 +97,122 @@ pub async fn join_room_by_id_or_alias_route( let sender_user = body.sender_user.as_deref().expect("user is authenticated"); let body = body.body; - let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { + let (servers, room_id) = + get_room_id_and_via_servers(sender_user, body.room_id_or_alias, body.via).await?; + + let join_room_response = join_room_by_id_helper( + sender_user, + &room_id, + body.reason.clone(), + &servers, + body.third_party_signed.as_ref(), + ) + .await?; + + Ok(join_room_by_id_or_alias::v3::Response { + room_id: join_room_response.room_id, + }) +} + +/// Takes a membership template, as returned from the `/federation/*/make_*` endpoints, and +/// populates them to the point as to where they are a full pdu, ready to be appended to the timeline +/// +/// Returns the event id, the pdu, and whether this event is a restricted join +fn populate_membership_template( + member_template: &RawJsonValue, + sender_user: &UserId, + reason: Option, + room_version_id: &RoomVersionId, + membership: MembershipState, +) -> Result<(OwnedEventId, BTreeMap, bool), Error> { + let mut member_event_stub: CanonicalJsonObject = serde_json::from_str(member_template.get()) + .map_err(|_| { + Error::BadServerResponse("Invalid make_knock event json received from server.") + })?; + + let join_authorized_via_users_server = member_event_stub + .get("content") + .map(|s| { + s.as_object()? + .get("join_authorised_via_users_server")? + .as_str() + }) + .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + + let restricted_join = join_authorized_via_users_server.is_some(); + + member_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + + member_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + + member_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + membership, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: reason.clone(), + join_authorized_via_users_server, + }) + .expect("event is valid, we just created it"), + ); + + member_event_stub.remove("event_id"); + + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut member_event_stub, + room_version_id, + ) + .expect("event is valid, we just created it"); + + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&member_event_stub, room_version_id) + .expect("Event format validated when event was hashed") + ); + + let event_id = + ::try_from(event_id).expect("ruma's reference hashes are valid event ids"); + + member_event_stub.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + Ok((event_id, member_event_stub, restricted_join)) +} + +/// Function to assist performing a membership event that may require help from a remote server +/// +/// If a room id is provided, the servers returned will consist of: +/// - the `via` argument, provided by the client +/// - servers of the senders of the stripped state events we are given +/// - the server in the room id +/// +/// Otherwise, the servers returned will come from the response when resolving the alias. +async fn get_room_id_and_via_servers( + sender_user: &UserId, + room_id_or_alias: OwnedRoomOrAliasId, + via: Vec, +) -> Result<(Vec, OwnedRoomId), Error> { + let (servers, room_id) = match OwnedRoomId::try_from(room_id_or_alias) { Ok(room_id) => { - let mut servers = body.via.clone(); + let mut servers = via.clone(); servers.extend( services() .rooms @@ -127,19 +242,7 @@ pub async fn join_room_by_id_or_alias_route( (response.servers, response.room_id) } }; - - let join_room_response = join_room_by_id_helper( - Some(sender_user), - &room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - ) - .await?; - - Ok(join_room_by_id_or_alias::v3::Response { - room_id: join_room_response.room_id, - }) + Ok((servers, room_id)) } /// # `POST /_matrix/client/r0/rooms/{roomId}/leave` @@ -519,14 +622,12 @@ pub async fn joined_members_route( } async fn join_room_by_id_helper( - sender_user: Option<&UserId>, + sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, ) -> Result { - let sender_user = sender_user.expect("user is authenticated"); - if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, room_id) { return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), @@ -569,77 +670,13 @@ async fn join_room_by_id_helper( _ => return Err(Error::BadServerResponse("Room version is not supported")), }; - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; - - let join_authorized_via_users_server = join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - - // TODO: Is origin needed? - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason, - join_authorized_via_users_server, - }) - .expect("event is valid, we just created it"), - ); - - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - join_event_stub.remove("event_id"); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut join_event_stub, + let (event_id, mut join_event, _) = populate_membership_template( + &make_join_response.event, + sender_user, + reason, &room_version_id, - ) - .expect("event is valid, we just created it"); - - // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("Event format validated when event was hashed") - ); - let event_id = <&EventId>::try_from(event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - join_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - // It has enough fields to be called a proper event now - let mut join_event = join_event_stub; + MembershipState::Join, + )?; info!("Asking {remote_server} for send_join"); let send_join_response = services() @@ -709,7 +746,7 @@ async fn join_room_by_id_helper( services().rooms.short.get_or_create_shortroomid(room_id)?; info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) + let parsed_join_pdu = PduEvent::from_id_val(&event_id, join_event.clone()) .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; let mut state = HashMap::new(); @@ -854,21 +891,7 @@ async fn join_room_by_id_helper( } else { info!("We can join locally"); - let join_rules_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomJoinRules, - "", - )?; - - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; + let join_rules_event_content = get_join_rules(room_id)?; let restriction_rooms = match join_rules_event_content { Some(RoomJoinRulesEventContent { @@ -930,7 +953,7 @@ async fn join_room_by_id_helper( }; // Try normal join first - let error = match services() + let Err(error) = services() .rooms .timeline .build_and_append_pdu( @@ -947,9 +970,8 @@ async fn join_room_by_id_helper( &state_lock, ) .await - { - Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), - Err(e) => e, + else { + return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())); }; if !restriction_rooms.is_empty() @@ -974,77 +996,14 @@ async fn join_room_by_id_helper( } _ => return Err(Error::BadServerResponse("Room version is not supported")), }; - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; - let join_authorized_via_users_server = join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - let restricted_join = join_authorized_via_users_server.is_some(); - // TODO: Is origin needed? - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason, - join_authorized_via_users_server, - }) - .expect("event is valid, we just created it"), - ); - - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - join_event_stub.remove("event_id"); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut join_event_stub, + let (event_id, join_event, restricted_join) = populate_membership_template( + &make_join_response.event, + sender_user, + reason, &room_version_id, - ) - .expect("event is valid, we just created it"); - - // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = OwnedEventId::try_from(event_id) - .expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - join_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - // It has enough fields to be called a proper event now - let join_event = join_event_stub; + MembershipState::Join, + )?; let send_join_response = services() .sending @@ -1093,6 +1052,26 @@ async fn join_room_by_id_helper( Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } +/// Returns the join rules event content of a room, if there are any and we are aware of it locally +fn get_join_rules(room_id: &RoomId) -> Result, Error> { + let join_rules_event = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomJoinRules, + "", + )?; + + join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str::(join_rules_event.content.get()) + .map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose() +} + async fn make_join_request( sender_user: &UserId, room_id: &RoomId, @@ -1270,7 +1249,7 @@ pub(crate) async fn invite_helper<'a>( &state_lock, )?; - let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = services().rooms.state.stripped_state(&pdu.room_id)?; drop(state_lock); diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs index e7b53d30..317c44d0 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/database/key_value/rooms/mod.rs @@ -16,6 +16,30 @@ mod threads; mod timeline; mod user; +use ruma::{RoomId, UserId}; + use crate::{database::KeyValueDatabase, service}; impl service::rooms::Data for KeyValueDatabase {} + +/// Constructs roomuser_id and userroom_id respectively in byte form +fn get_room_and_user_byte_ids(room_id: &RoomId, user_id: &UserId) -> (Vec, Vec) { + ( + get_roomuser_id_bytes(room_id, user_id), + get_userroom_id_bytes(user_id, room_id), + ) +} + +fn get_roomuser_id_bytes(room_id: &RoomId, user_id: &UserId) -> Vec { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + roomuser_id +} + +fn get_userroom_id_bytes(user_id: &UserId, room_id: &RoomId) -> Vec { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + userroom_id +} diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 49e3842b..126f4acc 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -7,27 +7,21 @@ use ruma::{ }; use crate::{ - database::KeyValueDatabase, + database::{abstraction::KvTree, KeyValueDatabase}, service::{self, appservice::RegistrationInfo}, services, utils, Error, Result, }; +use super::{get_room_and_user_byte_ids, get_userroom_id_bytes}; + impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); self.roomuseroncejoinedids.insert(&userroom_id, &[]) } fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_joined.insert(&userroom_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?; @@ -45,13 +39,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { room_id: &RoomId, last_state: Option>>, ) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_invitestate.insert( &userroom_id, @@ -70,14 +58,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Ok(()) } - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_leftstate.insert( &userroom_id, @@ -225,13 +208,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { /// Makes a user forget a room. #[tracing::instrument(skip(self))] fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; @@ -460,34 +437,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { &'a self, user_id: &UserId, ) -> Box>)>> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid.") - })?; - - let state = serde_json::from_slice(&state).map_err(|_| { - Error::bad_database("Invalid state in userroomid_invitestate.") - })?; - - Ok((room_id, state)) - }), - ) + scan_userroom_id_memberstate_tree(user_id, &self.userroomid_invitestate) } #[tracing::instrument(skip(self))] @@ -539,69 +489,80 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { &'a self, user_id: &UserId, ) -> Box>)>> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid.") - })?; - - let state = serde_json::from_slice(&state).map_err(|_| { - Error::bad_database("Invalid state in userroomid_leftstate.") - })?; - - Ok((room_id, state)) - }), - ) + scan_userroom_id_memberstate_tree(user_id, &self.userroomid_leftstate) } #[tracing::instrument(skip(self))] fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) } #[tracing::instrument(skip(self))] fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) } #[tracing::instrument(skip(self))] fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) } + #[tracing::instrument(skip(self))] + fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let userroom_id = get_userroom_id_bytes(user_id, room_id); + + Ok(self.userroomid_knockstate.get(&userroom_id)?.is_some()) + } + #[tracing::instrument(skip(self))] fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } } + +/// Scans the given userroom_id_`member`state tree for rooms, returning an iterator of room_ids +/// and a vector of raw state events +#[allow(clippy::type_complexity)] +fn scan_userroom_id_memberstate_tree<'a, T>( + user_id: &UserId, + userroom_id_memberstate_tree: &'a Arc, +) -> Box>)>> + 'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new( + userroom_id_memberstate_tree + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database( + "Room ID in userroomid_state is invalid unicode.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_state is invalid.") + })?; + + let state = serde_json::from_slice(&state).map_err(|_| { + Error::bad_database("Invalid state in userroomid_state.") + })?; + + Ok((room_id, state)) + }), + ) +} diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 0331a624..e89e2041 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -9,6 +9,8 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEven use service::rooms::timeline::PduCount; +use super::get_userroom_id_bytes; + impl service::rooms::timeline::Data for KeyValueDatabase { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self @@ -286,15 +288,11 @@ impl service::rooms::timeline::Data for KeyValueDatabase { let mut notifies_batch = Vec::new(); let mut highlights_batch = Vec::new(); for user in notifies { - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(&user, room_id); notifies_batch.push(userroom_id); } for user in highlights { - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(&user, room_id); highlights_batch.push(userroom_id); } diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 4c435720..2ba4240e 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -2,14 +2,11 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use super::{get_room_and_user_byte_ids, get_userroom_id_bytes}; + impl service::rooms::user::Data for KeyValueDatabase { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_notificationcount .insert(&userroom_id, &0_u64.to_be_bytes())?; @@ -25,9 +22,7 @@ impl service::rooms::user::Data for KeyValueDatabase { } fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); self.userroomid_notificationcount .get(&userroom_id)? @@ -39,9 +34,7 @@ impl service::rooms::user::Data for KeyValueDatabase { } fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); self.userroomid_highlightcount .get(&userroom_id)? diff --git a/src/database/mod.rs b/src/database/mod.rs index 26db8268..e2bfc2c9 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -97,7 +97,7 @@ pub struct KeyValueDatabase { pub(super) roomid_joinedcount: Arc, pub(super) roomid_invitedcount: Arc, pub(super) roomuseroncejoinedids: Arc, - pub(super) userroomid_invitestate: Arc, // InviteState = Vec> + pub(super) userroomid_invitestate: Arc, // InviteState = Vec> pub(super) roomuserid_invitecount: Arc, // InviteCount = Count pub(super) userroomid_leftstate: Arc, pub(super) roomuserid_leftcount: Arc, diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index f5bd7e9f..93b46072 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -10,6 +10,7 @@ use ruma::{ events::{ room::{create::RoomCreateEventContent, member::MembershipState}, AnyStrippedStateEvent, StateEventType, TimelineEventType, + RECOMMENDED_STRIPPED_STATE_EVENT_TYPES, }, serde::Raw, state_res::{self, StateMap}, @@ -258,58 +259,22 @@ impl Service { } } - #[tracing::instrument(skip(self, invite_event))] - pub fn calculate_invite_state( - &self, - invite_event: &PduEvent, - ) -> Result>> { - let mut state = Vec::new(); - // Add recommended events - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCreate, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomJoinRules, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomAvatar, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomName, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { - state.push(e.to_stripped_state_event()); - } - - state.push(invite_event.to_stripped_state_event()); - Ok(state) + #[tracing::instrument(skip(self, room_id))] + /// Gets all the [recommended stripped state events] from the given room + /// + /// [recommended stripped state events]: https://spec.matrix.org/v1.13/client-server-api/#stripped-state + pub fn stripped_state(&self, room_id: &RoomId) -> Result>> { + RECOMMENDED_STRIPPED_STATE_EVENT_TYPES + .iter() + .filter_map(|state_event_type| { + services() + .rooms + .state_accessor + .room_state_get(room_id, state_event_type, "") + .transpose() + }) + .map(|e| e.map(|e| e.to_stripped_state_event())) + .collect::>>() } /// Set the state hash to a new version, but does not update state_cache. diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 1604a14a..d8fa73b8 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -40,6 +40,13 @@ impl Service { // TODO: displayname, avatar url } + // We don't need to store stripped state on behalf of remote users, since these events are only used on `/sync` + let last_state = if user_id.server_name() == services().globals.server_name() { + last_state + } else { + None + }; + match &membership { MembershipState::Join => { // Check if the user never joined this room diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 80690663..78dccd0a 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -450,9 +450,11 @@ impl Service { let content = serde_json::from_str::(pdu.content.get()) .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - let invite_state = match content.membership { + let stripped_state = match content.membership { MembershipState::Invite => { - let state = services().rooms.state.calculate_invite_state(pdu)?; + let mut state = services().rooms.state.stripped_state(&pdu.room_id)?; + // So that clients can get info about who invitied them, the reason, when, etc. + state.push(pdu.to_stripped_state_event()); Some(state) } _ => None, @@ -465,7 +467,7 @@ impl Service { &target_user_id, content.membership, &pdu.sender, - invite_state, + stripped_state, true, )?; } From 21af83ea72adb7d4bd179236d20b6fd0a34b6859 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Sat, 1 Mar 2025 19:17:37 +0000 Subject: [PATCH 5/5] feat: knocking You may notice that we do no database migration for populating the state cache for knocking. This is because that in all the places where we use the state cache, it doesn't make a difference: - For local users, the clients wouldn't have been able to knock on rooms, as the `/knock` endpoint wasn't implemented yet, and I am not aware of any client which tries to knock over `/state`, as it would fail if the server is not currently in the room - It is not used for remote users --- src/api/client_server/alias.rs | 90 +- src/api/client_server/membership.rs | 947 ++++---------------- src/api/client_server/sync.rs | 49 +- src/api/server_server.rs | 141 ++- src/database/key_value/rooms/state_cache.rs | 78 ++ src/database/mod.rs | 4 + src/main.rs | 3 + src/service/mod.rs | 1 + src/service/rooms/alias/mod.rs | 75 +- src/service/rooms/helpers/mod.rs | 699 +++++++++++++++ src/service/rooms/mod.rs | 2 + src/service/rooms/state_accessor/mod.rs | 24 +- src/service/rooms/state_cache/data.rs | 23 + src/service/rooms/state_cache/mod.rs | 80 +- src/service/rooms/timeline/mod.rs | 13 +- 15 files changed, 1362 insertions(+), 867 deletions(-) create mode 100644 src/service/rooms/helpers/mod.rs diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 06fcc182..6a7e8338 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,15 +1,7 @@ use crate::{services, Error, Result, Ruma}; -use rand::seq::SliceRandom; -use ruma::{ - api::{ - appservice, - client::{ - alias::{create_alias, delete_alias, get_alias}, - error::ErrorKind, - }, - federation, - }, - OwnedRoomAliasId, +use ruma::api::client::{ + alias::{create_alias, delete_alias, get_alias}, + error::ErrorKind, }; /// # `PUT /_matrix/client/r0/directory/room/{roomAlias}` @@ -115,75 +107,9 @@ pub async fn delete_alias_route( pub async fn get_alias_route( body: Ruma, ) -> Result { - get_alias_helper(body.body.room_alias).await -} - -pub(crate) async fn get_alias_helper( - room_alias: OwnedRoomAliasId, -) -> Result { - if room_alias.server_name() != services().globals.server_name() { - let response = services() - .sending - .send_federation_request( - room_alias.server_name(), - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await?; - - let mut servers = response.servers; - servers.shuffle(&mut rand::thread_rng()); - - return Ok(get_alias::v3::Response::new(response.room_id, servers)); - } - - let mut room_id = None; - match services().rooms.alias.resolve_local_alias(&room_alias)? { - Some(r) => room_id = Some(r), - None => { - for appservice in services().appservice.read().await.values() { - if appservice.aliases.is_match(room_alias.as_str()) - && matches!( - services() - .sending - .send_appservice_request( - appservice.registration.clone(), - appservice::query::query_room_alias::v1::Request { - room_alias: room_alias.clone(), - }, - ) - .await, - Ok(Some(_opt_result)) - ) - { - room_id = Some( - services() - .rooms - .alias - .resolve_local_alias(&room_alias)? - .ok_or_else(|| { - Error::bad_config("Appservice lied to us. Room does not exist.") - })?, - ); - break; - } - } - } - }; - - let room_id = match room_id { - Some(room_id) => room_id, - None => { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room with alias not found.", - )) - } - }; - - Ok(get_alias::v3::Response::new( - room_id, - vec![services().globals.server_name().to_owned()], - )) + services() + .rooms + .alias + .get_alias_helper(body.body.room_alias) + .await } diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 714a0b7c..3cc50274 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -2,46 +2,38 @@ use ruma::{ api::{ client::{ error::ErrorKind, + knock::knock_room, membership::{ ban_user, forget_room, get_member_events, invite_user, join_room_by_id, join_room_by_id_or_alias, joined_members, joined_rooms, kick_user, leave_room, - unban_user, ThirdPartySigned, + unban_user, }, }, federation::{self, membership::create_invite}, }, - canonical_json::to_canonical_value, events::{ room::{ - join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, + join_rules::JoinRule, member::{MembershipState, RoomMemberEventContent}, }, StateEventType, TimelineEventType, }, serde::Raw, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, - OwnedEventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomId, - RoomVersionId, UserId, + CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedServerName, RoomId, UserId, }; -use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use serde_json::value::to_raw_value; use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, HashSet}, sync::Arc, - time::{Duration, Instant}, }; use tokio::sync::RwLock; -use tracing::{debug, error, info, warn}; +use tracing::{error, info, warn}; use crate::{ - service::{ - globals::SigningKeys, - pdu::{gen_event_id_canonical_json, PduBuilder}, - }, + service::pdu::{gen_event_id_canonical_json, PduBuilder}, services, utils, Error, PduEvent, Result, Ruma, }; -use super::get_alias_helper; - /// # `POST /_matrix/client/r0/rooms/{roomId}/join` /// /// Tries to join the sender user into a room. @@ -51,38 +43,35 @@ use super::get_alias_helper; pub async fn join_room_by_id_route( body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let Ruma:: { + body, sender_user, .. + } = body; - let mut servers = Vec::new(); // There is no body.server_name for /roomId/join - servers.extend( - services() - .rooms - .state_cache - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + let join_room_by_id::v3::Request { + room_id, + reason, + third_party_signed, + } = body; - servers.push( - body.room_id - .server_name() - .expect("Room IDs should always have a server name") - .into(), - ); + let sender_user = sender_user.as_ref().expect("user is authenticated"); - join_room_by_id_helper( - body.sender_user.as_deref().expect("user is authenticated"), - &body.room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - ) - .await + let (servers, room_id) = services() + .rooms + .state_cache + .get_room_id_and_via_servers(sender_user, room_id.into(), vec![]) + .await?; + + services() + .rooms + .helpers + .join_room_by_id( + sender_user, + &room_id, + reason.clone(), + &servers, + third_party_signed.as_ref(), + ) + .await } /// # `POST /_matrix/client/r0/join/{roomIdOrAlias}` @@ -97,152 +86,186 @@ pub async fn join_room_by_id_or_alias_route( let sender_user = body.sender_user.as_deref().expect("user is authenticated"); let body = body.body; - let (servers, room_id) = - get_room_id_and_via_servers(sender_user, body.room_id_or_alias, body.via).await?; + let (servers, room_id) = services() + .rooms + .state_cache + .get_room_id_and_via_servers(sender_user, body.room_id_or_alias, body.via) + .await?; - let join_room_response = join_room_by_id_helper( - sender_user, - &room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - ) - .await?; + let join_room_response = services() + .rooms + .helpers + .join_room_by_id( + sender_user, + &room_id, + body.reason.clone(), + &servers, + body.third_party_signed.as_ref(), + ) + .await?; Ok(join_room_by_id_or_alias::v3::Response { room_id: join_room_response.room_id, }) } -/// Takes a membership template, as returned from the `/federation/*/make_*` endpoints, and -/// populates them to the point as to where they are a full pdu, ready to be appended to the timeline +/// # `POST /_matrix/client/v3/knock/{roomIdOrAlias}` /// -/// Returns the event id, the pdu, and whether this event is a restricted join -fn populate_membership_template( - member_template: &RawJsonValue, - sender_user: &UserId, - reason: Option, - room_version_id: &RoomVersionId, - membership: MembershipState, -) -> Result<(OwnedEventId, BTreeMap, bool), Error> { - let mut member_event_stub: CanonicalJsonObject = serde_json::from_str(member_template.get()) - .map_err(|_| { - Error::BadServerResponse("Invalid make_knock event json received from server.") - })?; +/// Tries to knock on a room. +/// +/// - If the server knowns about this room: creates the knock event and does auth rules locally +/// - If the server does not know about the room: asks other servers over federation +pub async fn knock_room_route( + body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); + let body = body.body; - let join_authorized_via_users_server = member_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + let (servers, room_id) = services() + .rooms + .state_cache + .get_room_id_and_via_servers(sender_user, body.room_id_or_alias, body.via) + .await?; - let restricted_join = join_authorized_via_users_server.is_some(); - - member_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), ); + let state_lock = mutex_state.lock().await; - member_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); + // Ask a remote server if we are not participating in this room + if !services() + .rooms + .state_cache + .server_in_room(services().globals.server_name(), &room_id)? + { + info!("Knocking on {room_id} over federation."); - member_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership, + let mut make_knock_response_and_server = Err(Error::BadServerResponse( + "No server available to assist in knocking.", + )); + + for remote_server in servers { + if remote_server == services().globals.server_name() { + continue; + } + info!("Asking {remote_server} for make_knock"); + let make_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::knock::create_knock_event_template::v1::Request { + room_id: room_id.to_owned(), + user_id: sender_user.to_owned(), + ver: services().globals.supported_room_versions(), + }, + ) + .await; + + if let Ok(make_knock_response) = make_join_response { + make_knock_response_and_server = Ok((make_knock_response, remote_server.clone())); + + break; + } + } + + let (knock_template, remote_server) = make_knock_response_and_server?; + + info!("make_knock finished"); + + let room_version_id = knock_template.room_version; + + let (event_id, knock_event, _) = services().rooms.helpers.populate_membership_template( + &knock_template.event, + sender_user, + body.reason, + &room_version_id, + MembershipState::Knock, + )?; + + info!("Asking {remote_server} for send_knock"); + let send_kock_response = services() + .sending + .send_federation_request( + &remote_server, + federation::knock::send_knock::v1::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(knock_event.clone()), + }, + ) + .await?; + + info!("send_knock finished"); + + let mut stripped_state = send_kock_response.knock_room_state; + // Not sure how useful this is in reality, but spec examples show `/sync` returning the actual knock membership event + stripped_state.push(Raw::from_json(to_raw_value(&knock_event).expect( + "All keys are Strings, and CanonicalJsonValue Serialization never fails", + ))); + + services().rooms.state_cache.update_membership( + &room_id, + sender_user, + MembershipState::Knock, + sender_user, + Some(stripped_state), + false, + )?; + } else { + info!("We can knock locally"); + + match services() + .rooms + .state_accessor + .get_join_rules(&room_id)? + .map(|content| content.join_rule) + { + Some(JoinRule::Knock) | Some(JoinRule::KnockRestricted(_)) => (), + _ => { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "You are not allowed to knock on this room.", + )) + } + }; + + let event = RoomMemberEventContent { + membership: MembershipState::Knock, displayname: services().users.displayname(sender_user)?, avatar_url: services().users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, blurhash: services().users.blurhash(sender_user)?, - reason: reason.clone(), - join_authorized_via_users_server, - }) - .expect("event is valid, we just created it"), - ); + reason: body.reason.clone(), + join_authorized_via_users_server: None, + }; - member_event_stub.remove("event_id"); + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; + } - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut member_event_stub, - room_version_id, - ) - .expect("event is valid, we just created it"); - - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&member_event_stub, room_version_id) - .expect("Event format validated when event was hashed") - ); - - let event_id = - ::try_from(event_id).expect("ruma's reference hashes are valid event ids"); - - member_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - Ok((event_id, member_event_stub, restricted_join)) -} - -/// Function to assist performing a membership event that may require help from a remote server -/// -/// If a room id is provided, the servers returned will consist of: -/// - the `via` argument, provided by the client -/// - servers of the senders of the stripped state events we are given -/// - the server in the room id -/// -/// Otherwise, the servers returned will come from the response when resolving the alias. -async fn get_room_id_and_via_servers( - sender_user: &UserId, - room_id_or_alias: OwnedRoomOrAliasId, - via: Vec, -) -> Result<(Vec, OwnedRoomId), Error> { - let (servers, room_id) = match OwnedRoomId::try_from(room_id_or_alias) { - Ok(room_id) => { - let mut servers = via.clone(); - servers.extend( - services() - .rooms - .state_cache - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); - - servers.push( - room_id - .server_name() - .expect("Room IDs should always have a server name") - .into(), - ); - - (servers, room_id) - } - Err(room_alias) => { - let response = get_alias_helper(room_alias).await?; - - (response.servers, response.room_id) - } - }; - Ok((servers, room_id)) + Ok(knock_room::v3::Response::new(room_id)) } /// # `POST /_matrix/client/r0/rooms/{roomId}/leave` @@ -621,588 +644,6 @@ pub async fn joined_members_route( Ok(joined_members::v3::Response { joined }) } -async fn join_room_by_id_helper( - sender_user: &UserId, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, -) -> Result { - if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, room_id) { - return Ok(join_room_by_id::v3::Response { - room_id: room_id.into(), - }); - } - - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Ask a remote server if we are not participating in this room - if !services() - .rooms - .state_cache - .server_in_room(services().globals.server_name(), room_id)? - { - info!("Joining {room_id} over federation."); - - let (make_join_response, remote_server) = - make_join_request(sender_user, room_id, servers).await?; - - info!("make_join finished"); - - let room_version_id = match make_join_response.room_version { - Some(room_version) - if services() - .globals - .supported_room_versions() - .contains(&room_version) => - { - room_version - } - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - - let (event_id, mut join_event, _) = populate_membership_template( - &make_join_response.event, - sender_user, - reason, - &room_version_id, - MembershipState::Join, - )?; - - info!("Asking {remote_server} for send_join"); - let send_join_response = services() - .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, - }, - ) - .await?; - - info!("send_join finished"); - - if let Some(signed_raw) = &send_join_response.room_state.event { - info!("There is a signed event. This room is probably using restricted joins. Adding signature to our event"); - let (signed_event_id, signed_value) = - match gen_event_id_canonical_json(signed_raw, &room_version_id) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; - - if signed_event_id != event_id { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent event with wrong event id", - )); - } - - match signed_value["signatures"] - .as_object() - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent invalid signatures type", - )) - .and_then(|e| { - e.get(remote_server.as_str()).ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Server did not send its signature", - )) - }) { - Ok(signature) => { - join_event - .get_mut("signatures") - .expect("we created a valid pdu") - .as_object_mut() - .expect("we created a valid pdu") - .insert(remote_server.to_string(), signature.clone()); - } - Err(e) => { - warn!( - "Server {remote_server} sent invalid signature in sendjoin signatures for event {signed_value:?}: {e:?}", - ); - } - } - } - - services().rooms.short.get_or_create_shortroomid(room_id)?; - - info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(&event_id, join_event.clone()) - .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; - - let mut state = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); - - info!("Fetching join signing keys"); - services() - .rooms - .event_handler - .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) - .await?; - - info!("Going through send_join response room_state"); - for result in send_join_response - .room_state - .state - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { - let (event_id, value) = match result.await { - Ok(t) => t, - Err(_) => continue, - }; - - let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { - warn!("Invalid PDU in send_join response: {} {:?}", e, value); - Error::BadServerResponse("Invalid PDU in send_join response.") - })?; - - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; - if let Some(state_key) = &pdu.state_key { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; - state.insert(shortstatekey, pdu.event_id.clone()); - } - } - - info!("Going through send_join response auth_chain"); - for result in send_join_response - .room_state - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { - let (event_id, value) = match result.await { - Ok(t) => t, - Err(_) => continue, - }; - - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; - } - - info!("Running send_join auth check"); - let authenticated = state_res::event_auth::auth_check( - &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), - &parsed_join_pdu, - None::, // TODO: third party invite - |k, s| { - services() - .rooms - .timeline - .get_pdu( - state.get( - &services() - .rooms - .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, - )?, - ) - .ok()? - }, - ) - .map_err(|e| { - warn!("Auth check failed: {e}"); - Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") - })?; - - if !authenticated { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth check failed", - )); - } - - info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| { - services() - .rooms - .state_compressor - .compress_state_event(k, &id) - }) - .collect::>()?, - ), - )?; - - services() - .rooms - .state - .force_state(room_id, statehash_before_join, new, removed, &state_lock) - .await?; - - info!("Updating joined counts for new room"); - services().rooms.state_cache.update_joined_count(room_id)?; - - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; - - info!("Appending new room join event"); - services() - .rooms - .timeline - .append_pdu( - &parsed_join_pdu, - join_event, - vec![(*parsed_join_pdu.event_id).to_owned()], - &state_lock, - ) - .await?; - - info!("Setting final room state for new room"); - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - services() - .rooms - .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; - } else { - info!("We can join locally"); - - let join_rules_event_content = get_join_rules(room_id)?; - - let restriction_rooms = match join_rules_event_content { - Some(RoomJoinRulesEventContent { - join_rule: JoinRule::Restricted(restricted), - }) - | Some(RoomJoinRulesEventContent { - join_rule: JoinRule::KnockRestricted(restricted), - }) => restricted - .allow - .into_iter() - .filter_map(|a| match a { - AllowRule::RoomMembership(r) => Some(r.room_id), - _ => None, - }) - .collect(), - _ => Vec::new(), - }; - - let authorized_user = if restriction_rooms.iter().any(|restriction_room_id| { - services() - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .unwrap_or(false) - }) { - let mut auth_user = None; - for user in services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .collect::>() - { - if user.server_name() == services().globals.server_name() - && services() - .rooms - .state_accessor - .user_can_invite(room_id, &user, sender_user, &state_lock) - .unwrap_or(false) - { - auth_user = Some(user); - break; - } - } - auth_user - } else { - None - }; - - let event = RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason: reason.clone(), - join_authorized_via_users_server: authorized_user, - }; - - // Try normal join first - let Err(error) = services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - timestamp: None, - }, - sender_user, - room_id, - &state_lock, - ) - .await - else { - return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())); - }; - - if !restriction_rooms.is_empty() - && servers - .iter() - .any(|s| *s != services().globals.server_name()) - { - info!( - "We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements" - ); - let (make_join_response, remote_server) = - make_join_request(sender_user, room_id, servers).await?; - - let room_version_id = match make_join_response.room_version { - Some(room_version_id) - if services() - .globals - .supported_room_versions() - .contains(&room_version_id) => - { - room_version_id - } - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - - let (event_id, join_event, restricted_join) = populate_membership_template( - &make_join_response.event, - sender_user, - reason, - &room_version_id, - MembershipState::Join, - )?; - - let send_join_response = services() - .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, - }, - ) - .await?; - - let pdu = if let Some(signed_raw) = send_join_response.room_state.event { - let (signed_event_id, signed_pdu) = - gen_event_id_canonical_json(&signed_raw, &room_version_id)?; - - if signed_event_id != event_id { - return Err(Error::BadServerResponse( - "Server sent event with wrong event id", - )); - } - - signed_pdu - } else if restricted_join { - return Err(Error::BadServerResponse( - "No signed event was returned, despite just performing a restricted join", - )); - } else { - join_event - }; - - drop(state_lock); - let pub_key_map = RwLock::new(BTreeMap::new()); - services() - .rooms - .event_handler - .handle_incoming_pdu(&remote_server, &event_id, room_id, pdu, true, &pub_key_map) - .await?; - } else { - return Err(error); - } - } - - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) -} - -/// Returns the join rules event content of a room, if there are any and we are aware of it locally -fn get_join_rules(room_id: &RoomId) -> Result, Error> { - let join_rules_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomJoinRules, - "", - )?; - - join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str::(join_rules_event.content.get()) - .map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose() -} - -async fn make_join_request( - sender_user: &UserId, - room_id: &RoomId, - servers: &[OwnedServerName], -) -> Result<( - federation::membership::prepare_join_event::v1::Response, - OwnedServerName, -)> { - let mut make_join_response_and_server = Err(Error::BadServerResponse( - "No server available to assist in joining.", - )); - - for remote_server in servers { - if remote_server == services().globals.server_name() { - continue; - } - info!("Asking {remote_server} for make_join"); - let make_join_response = services() - .sending - .send_federation_request( - remote_server, - federation::membership::prepare_join_event::v1::Request { - room_id: room_id.to_owned(), - user_id: sender_user.to_owned(), - ver: services().globals.supported_room_versions(), - }, - ) - .await; - - make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); - - if make_join_response_and_server.is_ok() { - break; - } - } - - make_join_response_and_server -} - -async fn validate_and_add_event_id( - pdu: &RawJsonValue, - room_version: &RoomVersionId, - pub_key_map: &RwLock>, -) -> Result<(OwnedEventId, CanonicalJsonObject)> { - let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&value, room_version) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid PDU format"))? - )) - .expect("ruma's reference hashes are valid event ids"); - - let back_off = |id| async { - match services() - .globals - .bad_event_ratelimiter - .write() - .await - .entry(id) - { - Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - } - }; - - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .await - .get(&event_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); - } - } - - let origin_server_ts = value.get("origin_server_ts").ok_or_else(|| { - error!("Invalid PDU, no origin_server_ts field"); - Error::BadRequest( - ErrorKind::MissingParam, - "Invalid PDU, no origin_server_ts field", - ) - })?; - - let origin_server_ts: MilliSecondsSinceUnixEpoch = { - let ts = origin_server_ts.as_integer().ok_or_else(|| { - Error::BadRequest( - ErrorKind::InvalidParam, - "origin_server_ts must be an integer", - ) - })?; - - MilliSecondsSinceUnixEpoch(i64::from(ts).try_into().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Time must be after the unix epoch") - })?) - }; - - let unfiltered_keys = (*pub_key_map.read().await).clone(); - - let keys = - services() - .globals - .filter_keys_server_map(unfiltered_keys, origin_server_ts, room_version); - - if let Err(e) = ruma::signatures::verify_event(&keys, &value, room_version) { - warn!("Event {} failed verification {:?} {}", event_id, pdu, e); - back_off(event_id).await; - return Err(Error::BadServerResponse("Event failed verification.")); - } - - value.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - Ok((event_id, value)) -} - pub(crate) async fn invite_helper<'a>( sender_user: &UserId, user_id: &UserId, diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index ec6c06b0..39342d5b 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -10,7 +10,8 @@ use ruma::{ self, v3::{ Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, - LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, State, Timeline, ToDevice, + KnockState, KnockedRoom, LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, + State, Timeline, ToDevice, }, v4::{SlidingOp, SlidingSyncRoomHero}, DeviceLists, UnreadNotificationsCount, @@ -503,6 +504,50 @@ async fn sync_helper( ); } + let mut knocked_rooms = BTreeMap::new(); + let all_knocked_rooms: Vec<_> = services() + .rooms + .state_cache + .rooms_knocked(&sender_user) + .collect(); + for result in all_knocked_rooms { + let (room_id, knock_state_events) = result?; + + { + // Get and drop the lock to wait for remaining operations to finish + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().await; + drop(insert_lock); + } + + let knock_count = services() + .rooms + .state_cache + .get_knock_count(&room_id, &sender_user)?; + + // knock before last sync + if Some(since) >= knock_count { + continue; + } + + knocked_rooms.insert( + room_id.clone(), + KnockedRoom { + knock_state: KnockState { + events: knock_state_events, + }, + }, + ); + } + for user_id in left_encrypted_users { let dont_share_encrypted_room = services() .rooms @@ -538,7 +583,7 @@ async fn sync_helper( leave: left_rooms, join: joined_rooms, invite: invited_rooms, - knock: BTreeMap::new(), // TODO + knock: knocked_rooms, }, presence: Presence { events: presence_updates diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 8b8e6d30..f8768a9a 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -31,6 +31,7 @@ use ruma::{ }, event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, keys::{claim_keys, get_keys}, + knock::{create_knock_event_template, send_knock}, membership::{ create_invite, create_join_event, create_leave_event, prepare_join_event, prepare_leave_event, @@ -1497,6 +1498,28 @@ pub async fn get_room_state_ids_route( }) } +/// # `GET /_matrix/federation/v1/make_knock/{roomId}/{userId}` +/// +/// Creates a knock template. +pub async fn create_knock_event_template_route( + body: Ruma, +) -> Result { + let (mutex_state, room_version_id) = + member_shake_preamble(&body.sender_servername, &body.room_id).await?; + let state_lock = mutex_state.lock().await; + + Ok(create_knock_event_template::v1::Response { + room_version: room_version_id, + event: create_membership_template( + &body.user_id, + &body.room_id, + None, + MembershipState::Knock, + state_lock, + )?, + }) +} + /// # `GET /_matrix/federation/v1/make_leave/{roomId}/{userId}` /// /// Creates a leave template. @@ -1933,6 +1956,31 @@ pub async fn create_leave_event_route( Ok(create_leave_event::v2::Response {}) } +/// # `PUT /_matrix/federation/v1/send_knock/{roomId}/{eventId}` +/// +/// Submits a signed knock event. +pub async fn create_knock_event_route( + body: Ruma, +) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + room_and_acl_check(&body.room_id, sender_servername)?; + + append_member_pdu( + MembershipState::Knock, + sender_servername, + &body.room_id, + &body.pdu, + ) + .await?; + + Ok(send_knock::v1::Response { + knock_room_state: services().rooms.state.stripped_state(&body.room_id)?, + }) +} + /// Checks whether the given user can join the given room via a restricted join. /// This doesn't check the current user's membership. This should be done externally, /// either by using the state cache or attempting to authorize the event. @@ -2012,44 +2060,54 @@ fn user_can_perform_restricted_join( pub async fn create_invite_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let Ruma:: { + body, + sender_servername, + .. + } = body; + + let create_invite::v2::Request { + room_id, + room_version, + event, + invite_room_state, + .. + } = body; + + let sender_servername = sender_servername.expect("server is authenticated"); services() .rooms .event_handler - .acl_check(sender_servername, &body.room_id)?; - + .acl_check(&sender_servername, &room_id)?; if !services() .globals .supported_room_versions() - .contains(&body.room_version) + .contains(&room_version) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { - room_version: body.room_version.clone(), + room_version: room_version.clone(), }, "Server does not support this room version.", )); } - let mut signed_event = utils::to_canonical_object(&body.event) + let mut signed_event = utils::to_canonical_object(&event) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; ruma::signatures::hash_and_sign_event( services().globals.server_name().as_str(), services().globals.keypair(), &mut signed_event, - &body.room_version, + &room_version, ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; // Generate event id let event_id = EventId::parse(format!( "${}", - ruma::signatures::reference_hash(&signed_event, &body.room_version) + ruma::signatures::reference_hash(&signed_event, &room_version) .expect("Event format validated when event was hashed") )) .expect("ruma's reference hashes are valid event ids"); @@ -2084,9 +2142,9 @@ pub async fn create_invite_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?; - let mut invite_state = body.invite_room_state.clone(); + let mut invite_state = invite_room_state.clone(); - let mut event: JsonObject = serde_json::from_str(body.event.get()) + let mut event: JsonObject = serde_json::from_str(event.get()) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; event.insert("event_id".to_owned(), "$dummy".into()); @@ -2102,16 +2160,55 @@ pub async fn create_invite_route( if !services() .rooms .state_cache - .server_in_room(services().globals.server_name(), &body.room_id)? + .server_in_room(services().globals.server_name(), &room_id)? { - services().rooms.state_cache.update_membership( - &body.room_id, - &invited_user, - MembershipState::Invite, - &sender, - Some(invite_state), - true, - )?; + // If the user has already knocked on the room, we take that as the user wanting to join + // the room as soon as their knock is accepted, as recommended by the spec. + // + // https://spec.matrix.org/v1.13/client-server-api/#knocking-on-rooms + if services() + .rooms + .state_cache + .is_knocked(&invited_user, &room_id) + .unwrap_or_default() + { + // We want to try join automatically first, before notifying clients that they were invited. + // We also shouldn't block giving the calling server the response on attempting to join the + // room, since it's not relevant for the caller. + tokio::spawn(async move { + if services().rooms.helpers.join_room_by_id(&invited_user, &room_id, None, &invite_state.iter() .filter_map(|event| event.deserialize().ok()) + .map(|event| event.sender().server_name().to_owned()) + .collect::>() +, None) + .await + .is_err() && + // Checking whether the state has changed since we started this join handshake + services() + .rooms + .state_cache + .is_knocked(&invited_user, &room_id) + .unwrap_or_default() + { + let _ = services().rooms.state_cache.update_membership( + &room_id, + &invited_user, + MembershipState::Invite, + &sender, + Some(invite_state), + true, + ); + } + }); + } else { + services().rooms.state_cache.update_membership( + &room_id, + &invited_user, + MembershipState::Invite, + &sender, + Some(invite_state), + true, + )?; + } } Ok(create_invite::v2::Response { diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 126f4acc..689ff7cd 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -27,6 +27,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.roomuserid_joined.insert(&roomuser_id, &[])?; self.userroomid_invitestate.remove(&userroom_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_knockstate.remove(&userroom_id)?; + self.roomuserid_knockcount.remove(&roomuser_id)?; self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; @@ -52,12 +54,40 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { )?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_knockstate.remove(&userroom_id)?; + self.roomuserid_knockcount.remove(&roomuser_id)?; self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; Ok(()) } + fn mark_as_knocked( + &self, + user_id: &UserId, + room_id: &RoomId, + last_state: Option>>, + ) -> Result<()> { + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); + + self.userroomid_knockstate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()) + .expect("state to bytes always works"), + )?; + self.roomuserid_knockcount.insert( + &roomuser_id, + &services().globals.next_count()?.to_be_bytes(), + )?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); @@ -74,6 +104,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_invitestate.remove(&userroom_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_knockstate.remove(&userroom_id)?; + self.roomuserid_knockcount.remove(&roomuser_id)?; Ok(()) } @@ -390,6 +422,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { }) } + #[tracing::instrument(skip(self))] + fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_knockcount + .get(&key)? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid knockcount in db.") + })?)) + }) + } + #[tracing::instrument(skip(self))] fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { let mut key = room_id.as_bytes().to_vec(); @@ -440,6 +487,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { scan_userroom_id_memberstate_tree(user_id, &self.userroomid_invitestate) } + /// Returns an iterator over all rooms a user has knocked on. + #[allow(clippy::type_complexity)] + #[tracing::instrument(skip(self))] + fn rooms_knocked<'a>( + &'a self, + user_id: &UserId, + ) -> Box>)>> + 'a> { + scan_userroom_id_memberstate_tree(user_id, &self.userroomid_knockstate) + } + #[tracing::instrument(skip(self))] fn invite_state( &self, @@ -461,6 +518,27 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { .transpose() } + #[tracing::instrument(skip(self))] + fn knock_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + + self.userroomid_knockstate + .get(&key)? + .map(|state| { + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_knockstate."))?; + + Ok(state) + }) + .transpose() + } + #[tracing::instrument(skip(self))] fn left_state( &self, diff --git a/src/database/mod.rs b/src/database/mod.rs index e2bfc2c9..44954e48 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -99,6 +99,8 @@ pub struct KeyValueDatabase { pub(super) roomuseroncejoinedids: Arc, pub(super) userroomid_invitestate: Arc, // InviteState = Vec> pub(super) roomuserid_invitecount: Arc, // InviteCount = Count + pub(super) userroomid_knockstate: Arc, // KnockState = Vec> + pub(super) roomuserid_knockcount: Arc, // KnockCount = Count pub(super) userroomid_leftstate: Arc, pub(super) roomuserid_leftcount: Arc, @@ -313,6 +315,8 @@ impl KeyValueDatabase { roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, + userroomid_knockstate: builder.open_tree("userroomid_knockstate")?, + roomuserid_knockcount: builder.open_tree("roomuserid_knockcount")?, userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, diff --git a/src/main.rs b/src/main.rs index d37cfb21..5669cc0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -345,6 +345,7 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::get_alias_route) .ruma_route(client_server::join_room_by_id_route) .ruma_route(client_server::join_room_by_id_or_alias_route) + .ruma_route(client_server::knock_room_route) .ruma_route(client_server::joined_members_route) .ruma_route(client_server::leave_room_route) .ruma_route(client_server::forget_room_route) @@ -460,6 +461,8 @@ fn routes(config: &Config) -> Router { .ruma_route(server_server::create_join_event_v2_route) .ruma_route(server_server::create_leave_event_template_route) .ruma_route(server_server::create_leave_event_route) + .ruma_route(server_server::create_knock_event_template_route) + .ruma_route(server_server::create_knock_event_route) .ruma_route(server_server::create_invite_route) .ruma_route(server_server::get_devices_route) .ruma_route(server_server::get_content_route) diff --git a/src/service/mod.rs b/src/service/mod.rs index 552c71af..c328bf7e 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -73,6 +73,7 @@ impl Services { }, }, event_handler: rooms::event_handler::Service, + helpers: rooms::helpers::Service, lazy_loading: rooms::lazy_loading::Service { db, lazy_load_waiting: Mutex::new(HashMap::new()), diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 95d52ad3..87c3d4c0 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,11 +1,16 @@ mod data; pub use data::Data; +use rand::seq::SliceRandom; use tracing::error; use crate::{services, Error, Result}; use ruma::{ - api::client::error::ErrorKind, + api::{ + appservice, + client::{alias::get_alias, error::ErrorKind}, + federation, + }, events::{ room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, StateEventType, @@ -98,4 +103,72 @@ impl Service { ) -> Box> + 'a> { self.db.local_aliases_for_room(room_id) } + + /// Resolves an alias to a room id, and a set of servers to join or knock via, either locally or over federation + #[tracing::instrument(skip(self))] + pub async fn get_alias_helper( + &self, + room_alias: OwnedRoomAliasId, + ) -> Result { + if room_alias.server_name() != services().globals.server_name() { + let response = services() + .sending + .send_federation_request( + room_alias.server_name(), + federation::query::get_room_information::v1::Request { + room_alias: room_alias.to_owned(), + }, + ) + .await?; + + let mut servers = response.servers; + servers.shuffle(&mut rand::thread_rng()); + + return Ok(get_alias::v3::Response::new(response.room_id, servers)); + } + + let mut room_id = None; + match services().rooms.alias.resolve_local_alias(&room_alias)? { + Some(r) => room_id = Some(r), + None => { + for appservice in services().appservice.read().await.values() { + if appservice.aliases.is_match(room_alias.as_str()) + && matches!( + services() + .sending + .send_appservice_request( + appservice.registration.clone(), + appservice::query::query_room_alias::v1::Request { + room_alias: room_alias.clone(), + }, + ) + .await, + Ok(Some(_opt_result)) + ) + { + room_id = + Some(self.resolve_local_alias(&room_alias)?.ok_or_else(|| { + Error::bad_config("Appservice lied to us. Room does not exist.") + })?); + break; + } + } + } + }; + + let room_id = match room_id { + Some(room_id) => room_id, + None => { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Room with alias not found.", + )) + } + }; + + Ok(get_alias::v3::Response::new( + room_id, + vec![services().globals.server_name().to_owned()], + )) + } } diff --git a/src/service/rooms/helpers/mod.rs b/src/service/rooms/helpers/mod.rs new file mode 100644 index 00000000..e8d5b68a --- /dev/null +++ b/src/service/rooms/helpers/mod.rs @@ -0,0 +1,699 @@ +use std::{ + collections::{hash_map::Entry, BTreeMap, HashMap}, + sync::Arc, + time::{Duration, Instant}, +}; + +use ruma::{ + api::{ + client::{ + error::ErrorKind, + membership::{join_room_by_id, ThirdPartySigned}, + }, + federation, + }, + canonical_json::to_canonical_value, + events::{ + room::{ + join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + }, + TimelineEventType, + }, + state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, + OwnedEventId, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, UserId, +}; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; + +use crate::{ + service::{ + globals::SigningKeys, + pdu::{gen_event_id_canonical_json, PduBuilder}, + }, + services, utils, Error, PduEvent, Result, +}; + +pub struct Service; + +impl Service { + /// Attempts to join a room. + /// If the room cannot be joined locally, it attempts to join over federation, soley using the + /// specified servers + #[tracing::instrument(skip(self, reason, servers, _third_party_signed))] + pub async fn join_room_by_id( + &self, + sender_user: &UserId, + room_id: &RoomId, + reason: Option, + servers: &[OwnedServerName], + _third_party_signed: Option<&ThirdPartySigned>, + ) -> Result { + if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, room_id) { + return Ok(join_room_by_id::v3::Response { + room_id: room_id.into(), + }); + } + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Ask a remote server if we are not participating in this room + if !services() + .rooms + .state_cache + .server_in_room(services().globals.server_name(), room_id)? + { + info!("Joining {room_id} over federation."); + + let (make_join_response, remote_server) = + make_join_request(sender_user, room_id, servers).await?; + + info!("make_join finished"); + + let room_version_id = match make_join_response.room_version { + Some(room_version) + if services() + .globals + .supported_room_versions() + .contains(&room_version) => + { + room_version + } + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + + let (event_id, mut join_event, _) = self.populate_membership_template( + &make_join_response.event, + sender_user, + reason, + &room_version_id, + MembershipState::Join, + )?; + + info!("Asking {remote_server} for send_join"); + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + omit_members: false, + }, + ) + .await?; + + info!("send_join finished"); + + if let Some(signed_raw) = &send_join_response.room_state.event { + info!("There is a signed event. This room is probably using restricted joins. Adding signature to our event"); + let (signed_event_id, signed_value) = + match gen_event_id_canonical_json(signed_raw, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + } + }; + + if signed_event_id != event_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent event with wrong event id", + )); + } + + match signed_value["signatures"] + .as_object() + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent invalid signatures type", + )) + .and_then(|e| { + e.get(remote_server.as_str()).ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server did not send its signature", + )) + }) { + Ok(signature) => { + join_event + .get_mut("signatures") + .expect("we created a valid pdu") + .as_object_mut() + .expect("we created a valid pdu") + .insert(remote_server.to_string(), signature.clone()); + } + Err(e) => { + warn!( + "Server {remote_server} sent invalid signature in sendjoin signatures for event {signed_value:?}: {e:?}", + ); + } + } + } + + services().rooms.short.get_or_create_shortroomid(room_id)?; + + info!("Parsing join event"); + let parsed_join_pdu = PduEvent::from_id_val(&event_id, join_event.clone()) + .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; + + let mut state = HashMap::new(); + let pub_key_map = RwLock::new(BTreeMap::new()); + + info!("Fetching join signing keys"); + services() + .rooms + .event_handler + .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) + .await?; + + info!("Going through send_join response room_state"); + for result in send_join_response + .room_state + .state + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let (event_id, value) = match result.await { + Ok(t) => t, + Err(_) => continue, + }; + + let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { + warn!("Invalid PDU in send_join response: {} {:?}", e, value); + Error::BadServerResponse("Invalid PDU in send_join response.") + })?; + + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; + if let Some(state_key) = &pdu.state_key { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + state.insert(shortstatekey, pdu.event_id.clone()); + } + } + + info!("Going through send_join response auth_chain"); + for result in send_join_response + .room_state + .auth_chain + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let (event_id, value) = match result.await { + Ok(t) => t, + Err(_) => continue, + }; + + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; + } + + info!("Running send_join auth check"); + let authenticated = state_res::event_auth::auth_check( + &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), + &parsed_join_pdu, + None::, // TODO: third party invite + |k, s| { + services() + .rooms + .timeline + .get_pdu( + state.get( + &services() + .rooms + .short + .get_or_create_shortstatekey(&k.to_string().into(), s) + .ok()?, + )?, + ) + .ok()? + }, + ) + .map_err(|e| { + warn!("Auth check failed: {e}"); + Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") + })?; + + if !authenticated { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth check failed", + )); + } + + info!("Saving state from send_join"); + let (statehash_before_join, new, removed) = + services().rooms.state_compressor.save_state( + room_id, + Arc::new( + state + .into_iter() + .map(|(k, id)| { + services() + .rooms + .state_compressor + .compress_state_event(k, &id) + }) + .collect::>()?, + ), + )?; + + services() + .rooms + .state + .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .await?; + + info!("Updating joined counts for new room"); + services().rooms.state_cache.update_joined_count(room_id)?; + + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + + info!("Appending new room join event"); + services() + .rooms + .timeline + .append_pdu( + &parsed_join_pdu, + join_event, + vec![(*parsed_join_pdu.event_id).to_owned()], + &state_lock, + ) + .await?; + + info!("Setting final room state for new room"); + // We set the room state after inserting the pdu, so that we never have a moment in time + // where events in the current room state do not exist + services() + .rooms + .state + .set_room_state(room_id, statehash_after_join, &state_lock)?; + } else { + info!("We can join locally"); + + let join_rules_event_content = + services().rooms.state_accessor.get_join_rules(room_id)?; + + let restriction_rooms = match join_rules_event_content { + Some(RoomJoinRulesEventContent { + join_rule: JoinRule::Restricted(restricted), + }) + | Some(RoomJoinRulesEventContent { + join_rule: JoinRule::KnockRestricted(restricted), + }) => restricted + .allow + .into_iter() + .filter_map(|a| match a { + AllowRule::RoomMembership(r) => Some(r.room_id), + _ => None, + }) + .collect(), + _ => Vec::new(), + }; + + let authorized_user = if restriction_rooms.iter().any(|restriction_room_id| { + services() + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + .unwrap_or(false) + }) { + let mut auth_user = None; + for user in services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(Result::ok) + .collect::>() + { + if user.server_name() == services().globals.server_name() + && services() + .rooms + .state_accessor + .user_can_invite(room_id, &user, sender_user, &state_lock) + .unwrap_or(false) + { + auth_user = Some(user); + break; + } + } + auth_user + } else { + None + }; + + let event = RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: reason.clone(), + join_authorized_via_users_server: authorized_user, + }; + + // Try normal join first + let Err(error) = services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await + else { + return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())); + }; + + if !restriction_rooms.is_empty() + && servers + .iter() + .any(|s| *s != services().globals.server_name()) + { + info!( + "We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements" + ); + let (make_join_response, remote_server) = + make_join_request(sender_user, room_id, servers).await?; + + let room_version_id = match make_join_response.room_version { + Some(room_version_id) + if services() + .globals + .supported_room_versions() + .contains(&room_version_id) => + { + room_version_id + } + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + + let (event_id, join_event, restricted_join) = self.populate_membership_template( + &make_join_response.event, + sender_user, + reason, + &room_version_id, + MembershipState::Join, + )?; + + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + omit_members: false, + }, + ) + .await?; + + let pdu = if let Some(signed_raw) = send_join_response.room_state.event { + let (signed_event_id, signed_pdu) = + gen_event_id_canonical_json(&signed_raw, &room_version_id)?; + + if signed_event_id != event_id { + return Err(Error::BadServerResponse( + "Server sent event with wrong event id", + )); + } + + signed_pdu + } else if restricted_join { + return Err(Error::BadServerResponse( + "No signed event was returned, despite just performing a restricted join", + )); + } else { + join_event + }; + + drop(state_lock); + let pub_key_map = RwLock::new(BTreeMap::new()); + services() + .rooms + .event_handler + .handle_incoming_pdu( + &remote_server, + &event_id, + room_id, + pdu, + true, + &pub_key_map, + ) + .await?; + } else { + return Err(error); + } + } + + Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) + } + + /// Takes a membership template, as returned from the `/federation/*/make_*` endpoints, and + /// populates them to the point as to where they are a full pdu, ready to be appended to the timeline + /// + /// Returns the event id, the pdu, and whether this event is a restricted join + pub fn populate_membership_template( + &self, + member_template: &RawJsonValue, + sender_user: &UserId, + reason: Option, + room_version_id: &RoomVersionId, + membership: MembershipState, + ) -> Result<(OwnedEventId, BTreeMap, bool), Error> { + let mut member_event_stub: CanonicalJsonObject = + serde_json::from_str(member_template.get()).map_err(|_| { + Error::BadServerResponse("Invalid make_knock event json received from server.") + })?; + + let join_authorized_via_users_server = member_event_stub + .get("content") + .map(|s| { + s.as_object()? + .get("join_authorised_via_users_server")? + .as_str() + }) + .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + + let restricted_join = join_authorized_via_users_server.is_some(); + + member_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + + member_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + + member_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + membership, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: reason.clone(), + join_authorized_via_users_server, + }) + .expect("event is valid, we just created it"), + ); + + member_event_stub.remove("event_id"); + + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut member_event_stub, + room_version_id, + ) + .expect("event is valid, we just created it"); + + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&member_event_stub, room_version_id) + .expect("Event format validated when event was hashed") + ); + + let event_id = ::try_from(event_id) + .expect("ruma's reference hashes are valid event ids"); + + member_event_stub.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + Ok((event_id, member_event_stub, restricted_join)) + } +} + +async fn make_join_request( + sender_user: &UserId, + room_id: &RoomId, + servers: &[OwnedServerName], +) -> Result<( + federation::membership::prepare_join_event::v1::Response, + OwnedServerName, +)> { + let mut make_join_response_and_server = Err(Error::BadServerResponse( + "No server available to assist in joining.", + )); + + for remote_server in servers { + if remote_server == services().globals.server_name() { + continue; + } + info!("Asking {remote_server} for make_join"); + let make_join_response = services() + .sending + .send_federation_request( + remote_server, + federation::membership::prepare_join_event::v1::Request { + room_id: room_id.to_owned(), + user_id: sender_user.to_owned(), + ver: services().globals.supported_room_versions(), + }, + ) + .await; + + make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); + + if make_join_response_and_server.is_ok() { + break; + } + } + + make_join_response_and_server +} +async fn validate_and_add_event_id( + pdu: &RawJsonValue, + room_version: &RoomVersionId, + pub_key_map: &RwLock>, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&value, room_version) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid PDU format"))? + )) + .expect("ruma's reference hashes are valid event ids"); + + let back_off = |id| async { + match services() + .globals + .bad_event_ratelimiter + .write() + .await + .entry(id) + { + Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + } + }; + + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(&event_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {}", event_id); + return Err(Error::BadServerResponse("bad event, still backing off")); + } + } + + let origin_server_ts = value.get("origin_server_ts").ok_or_else(|| { + error!("Invalid PDU, no origin_server_ts field"); + Error::BadRequest( + ErrorKind::MissingParam, + "Invalid PDU, no origin_server_ts field", + ) + })?; + + let origin_server_ts: MilliSecondsSinceUnixEpoch = { + let ts = origin_server_ts.as_integer().ok_or_else(|| { + Error::BadRequest( + ErrorKind::InvalidParam, + "origin_server_ts must be an integer", + ) + })?; + + MilliSecondsSinceUnixEpoch(i64::from(ts).try_into().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Time must be after the unix epoch") + })?) + }; + + let unfiltered_keys = (*pub_key_map.read().await).clone(); + + let keys = + services() + .globals + .filter_keys_server_map(unfiltered_keys, origin_server_ts, room_version); + + if let Err(e) = ruma::signatures::verify_event(&keys, &value, room_version) { + warn!("Event {} failed verification {:?} {}", event_id, pdu, e); + back_off(event_id).await; + return Err(Error::BadServerResponse("Event failed verification.")); + } + + value.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + Ok((event_id, value)) +} diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index f0739841..2f0c3347 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -3,6 +3,7 @@ pub mod auth_chain; pub mod directory; pub mod edus; pub mod event_handler; +pub mod helpers; pub mod lazy_loading; pub mod metadata; pub mod outlier; @@ -45,6 +46,7 @@ pub struct Service { pub directory: directory::Service, pub edus: edus::Service, pub event_handler: event_handler::Service, + pub helpers: helpers::Service, pub lazy_loading: lazy_loading::Service, pub metadata: metadata::Service, pub outlier: outlier::Service, diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index e1bcd3c3..fcf542c3 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -426,8 +426,8 @@ impl Service { }) } - /// Returns the join rule for a given room - pub fn get_join_rule( + /// Returns the space-room join rule for a given room + pub fn get_space_room_join_rule( &self, current_room: &RoomId, ) -> Result<(SpaceRoomJoinRule, Vec), Error> { @@ -450,6 +450,26 @@ impl Service { .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) } + /// Returns the join rules event content of a room, if there are any and we are aware of it locally + #[tracing::instrument(skip(self))] + pub fn get_join_rules( + &self, + room_id: &RoomId, + ) -> Result, Error> { + let join_rules_event = self.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + + join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str::(join_rules_event.content.get()) + .map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose() + } + /// Returns an empty vec if not a restricted room pub fn allowed_room_ids(&self, join_rule: JoinRule) -> Vec { let mut room_ids = vec![]; diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index b511919a..3ee73dfa 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -16,6 +16,12 @@ pub trait Data: Send + Sync { room_id: &RoomId, last_state: Option>>, ) -> Result<()>; + fn mark_as_knocked( + &self, + user_id: &UserId, + room_id: &RoomId, + last_state: Option>>, + ) -> Result<()>; fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; @@ -65,6 +71,8 @@ pub trait Data: Send + Sync { fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; /// Returns an iterator over all rooms this user joined. @@ -80,12 +88,25 @@ pub trait Data: Send + Sync { user_id: &UserId, ) -> Box>)>> + 'a>; + /// Returns an iterator over all rooms a user has knocked on. + #[allow(clippy::type_complexity)] + fn rooms_knocked<'a>( + &'a self, + user_id: &UserId, + ) -> Box>)>> + 'a>; + fn invite_state( &self, user_id: &UserId, room_id: &RoomId, ) -> Result>>>; + fn knock_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>>>; + fn left_state( &self, user_id: &UserId, @@ -105,5 +126,7 @@ pub trait Data: Send + Sync { fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index d8fa73b8..5cfd9738 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -12,7 +12,7 @@ use ruma::{ RoomAccountDataEventType, StateEventType, }, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; use tracing::warn; @@ -185,6 +185,9 @@ impl Service { self.db.mark_as_invited(user_id, room_id, last_state)?; } + MembershipState::Knock => { + self.db.mark_as_knocked(user_id, room_id, last_state)?; + } MembershipState::Leave | MembershipState::Ban => { self.db.mark_as_left(user_id, room_id)?; } @@ -290,6 +293,11 @@ impl Service { self.db.get_invite_count(room_id, user_id) } + #[tracing::instrument(skip(self))] + pub fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + self.db.get_knock_count(room_id, user_id) + } + #[tracing::instrument(skip(self))] pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { self.db.get_left_count(room_id, user_id) @@ -313,6 +321,15 @@ impl Service { self.db.rooms_invited(user_id) } + /// Returns an iterator over all rooms a user has knocked on. + #[tracing::instrument(skip(self))] + pub fn rooms_knocked<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator>)>> + 'a { + self.db.rooms_knocked(user_id) + } + #[tracing::instrument(skip(self))] pub fn invite_state( &self, @@ -322,6 +339,15 @@ impl Service { self.db.invite_state(user_id, room_id) } + #[tracing::instrument(skip(self))] + pub fn knock_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>>> { + self.db.knock_state(user_id, room_id) + } + #[tracing::instrument(skip(self))] pub fn left_state( &self, @@ -355,8 +381,60 @@ impl Service { self.db.is_invited(user_id, room_id) } + #[tracing::instrument(skip(self))] + pub fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.is_knocked(user_id, room_id) + } + #[tracing::instrument(skip(self))] pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } + + /// Function to assist performing a membership event that may require help from a remote server + /// + /// If a room id is provided, the servers returned will consist of: + /// - the `via` argument, provided by the client + /// - servers of the senders of the stripped state events we are given + /// - the server in the room id + /// + /// Otherwise, the servers returned will come from the response when resolving the alias. + #[tracing::instrument(skip(self))] + pub async fn get_room_id_and_via_servers( + &self, + sender_user: &UserId, + room_id_or_alias: OwnedRoomOrAliasId, + via: Vec, + ) -> Result<(Vec, OwnedRoomId), Error> { + let (servers, room_id) = match OwnedRoomId::try_from(room_id_or_alias) { + Ok(room_id) => { + let mut servers = via; + servers.extend( + self.invite_state(sender_user, &room_id) + .transpose() + .or_else(|| self.knock_state(sender_user, &room_id).transpose()) + .transpose()? + .unwrap_or_default() + .iter() + .filter_map(|event| event.deserialize().ok()) + .map(|event| event.sender().server_name().to_owned()), + ); + + servers.push( + room_id + .server_name() + .expect("Room IDs should always have a server name") + .into(), + ); + + (servers, room_id) + } + Err(room_alias) => { + let response = services().rooms.alias.get_alias_helper(room_alias).await?; + + (response.servers, response.room_id) + } + }; + Ok((servers, room_id)) + } } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 78dccd0a..7615aed1 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -451,17 +451,22 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; let stripped_state = match content.membership { - MembershipState::Invite => { + MembershipState::Invite | MembershipState::Knock => { let mut state = services().rooms.state.stripped_state(&pdu.room_id)?; - // So that clients can get info about who invitied them, the reason, when, etc. + // So that clients can get info about who invitied them (not relevant for knocking), the reason, when, etc. state.push(pdu.to_stripped_state_event()); Some(state) } _ => None, }; - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth + // Here we don't attempt to join if the previous membership was knock and the + // new one is join, like we do for `/federation/*/invite`, as not only are there + // implementation difficulties due to callers not implementing `Send`, but + // invites we recieve which aren't over `/invite` must have been due to a + // database reset or switching server implementations, which means we probably + // shouldn't be joining automatically anyways, since it may surprise users to + // suddenly join rooms which clients didn't even show as being knocked on before. services().rooms.state_cache.update_membership( &pdu.room_id, &target_user_id,