diff --git a/Cargo.lock b/Cargo.lock index 4cca6c5a..a6d9d3a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2190,7 +2190,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.12.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "assign", "js_int", @@ -2210,7 +2210,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.12.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "js_int", "ruma-common", @@ -2222,7 +2222,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.20.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "as_variant", "assign", @@ -2245,7 +2245,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.15.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "as_variant", "base64 0.22.1", @@ -2276,7 +2276,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.30.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "as_variant", "indexmap 2.2.6", @@ -2299,7 +2299,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.11.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "bytes", "http 1.1.0", @@ -2317,7 +2317,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.10.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "js_int", "thiserror 2.0.11", @@ -2326,7 +2326,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.15.1" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "cfg-if", "proc-macro-crate", @@ -2341,7 +2341,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.11.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "js_int", "ruma-common", @@ -2353,7 +2353,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.5.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "headers", "http 1.1.0", @@ -2366,7 +2366,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.17.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -2382,7 +2382,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.13.0" -source = "git+https://github.com/ruma/ruma.git#afaf132362fe6195556a872351a70337e97ab755" +source = "git+https://github.com/ruma/ruma.git#9e6099161d4ed295e694fa0d5de2b28a23840a4f" dependencies = [ "js_int", "ruma-common", diff --git a/Cargo.toml b/Cargo.toml index 3e463174..e5fbd41a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ repository = "https://gitlab.com/famedly/conduit" version = "0.10.0-alpha" # See also `rust-toolchain.toml` -rust-version = "1.80.0" +rust-version = "1.81.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/complement/Dockerfile b/complement/Dockerfile index 0bf0cfcd..ce067ec3 100644 --- a/complement/Dockerfile +++ b/complement/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.80.0 +FROM rust:1.81.0 WORKDIR /workdir diff --git a/flake.nix b/flake.nix index 4b50a7ec..df05bf86 100644 --- a/flake.nix +++ b/flake.nix @@ -59,7 +59,7 @@ file = ./rust-toolchain.toml; # See also `rust-toolchain.toml` - sha256 = "sha256-6eN/GKzjVSjEhGO9FhWObkRFaE1Jf+uqMSdQnb8lcB4="; + sha256 = "sha256-VZZnlyP69+Y3crrLHQyJirqlHrTtGTsyiSnZB8jEvVo="; }; }); in diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 995b142b..465ffdee 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -9,7 +9,7 @@ # If you're having trouble making the relevant changes, bug a maintainer. [toolchain] -channel = "1.80.0" +channel = "1.81.0" components = [ # For rust-analyzer "rust-src", diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 06fcc182..6a7e8338 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,15 +1,7 @@ use crate::{services, Error, Result, Ruma}; -use rand::seq::SliceRandom; -use ruma::{ - api::{ - appservice, - client::{ - alias::{create_alias, delete_alias, get_alias}, - error::ErrorKind, - }, - federation, - }, - OwnedRoomAliasId, +use ruma::api::client::{ + alias::{create_alias, delete_alias, get_alias}, + error::ErrorKind, }; /// # `PUT /_matrix/client/r0/directory/room/{roomAlias}` @@ -115,75 +107,9 @@ pub async fn delete_alias_route( pub async fn get_alias_route( body: Ruma, ) -> Result { - get_alias_helper(body.body.room_alias).await -} - -pub(crate) async fn get_alias_helper( - room_alias: OwnedRoomAliasId, -) -> Result { - if room_alias.server_name() != services().globals.server_name() { - let response = services() - .sending - .send_federation_request( - room_alias.server_name(), - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await?; - - let mut servers = response.servers; - servers.shuffle(&mut rand::thread_rng()); - - return Ok(get_alias::v3::Response::new(response.room_id, servers)); - } - - let mut room_id = None; - match services().rooms.alias.resolve_local_alias(&room_alias)? { - Some(r) => room_id = Some(r), - None => { - for appservice in services().appservice.read().await.values() { - if appservice.aliases.is_match(room_alias.as_str()) - && matches!( - services() - .sending - .send_appservice_request( - appservice.registration.clone(), - appservice::query::query_room_alias::v1::Request { - room_alias: room_alias.clone(), - }, - ) - .await, - Ok(Some(_opt_result)) - ) - { - room_id = Some( - services() - .rooms - .alias - .resolve_local_alias(&room_alias)? - .ok_or_else(|| { - Error::bad_config("Appservice lied to us. Room does not exist.") - })?, - ); - break; - } - } - } - }; - - let room_id = match room_id { - Some(room_id) => room_id, - None => { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room with alias not found.", - )) - } - }; - - Ok(get_alias::v3::Response::new( - room_id, - vec![services().globals.server_name().to_owned()], - )) + services() + .rooms + .alias + .get_alias_helper(body.body.room_alias) + .await } diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 520bfa00..3cc50274 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -2,44 +2,38 @@ use ruma::{ api::{ client::{ error::ErrorKind, + knock::knock_room, membership::{ ban_user, forget_room, get_member_events, invite_user, join_room_by_id, join_room_by_id_or_alias, joined_members, joined_rooms, kick_user, leave_room, - unban_user, ThirdPartySigned, + unban_user, }, }, federation::{self, membership::create_invite}, }, - canonical_json::to_canonical_value, events::{ room::{ - join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, + join_rules::JoinRule, member::{MembershipState, RoomMemberEventContent}, }, StateEventType, TimelineEventType, }, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, - OwnedEventId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, UserId, + serde::Raw, + CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedServerName, RoomId, UserId, }; -use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use serde_json::value::to_raw_value; use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, HashSet}, sync::Arc, - time::{Duration, Instant}, }; use tokio::sync::RwLock; -use tracing::{debug, error, info, warn}; +use tracing::{error, info, warn}; use crate::{ - service::{ - globals::SigningKeys, - pdu::{gen_event_id_canonical_json, PduBuilder}, - }, + service::pdu::{gen_event_id_canonical_json, PduBuilder}, services, utils, Error, PduEvent, Result, Ruma, }; -use super::get_alias_helper; - /// # `POST /_matrix/client/r0/rooms/{roomId}/join` /// /// Tries to join the sender user into a room. @@ -49,38 +43,35 @@ use super::get_alias_helper; pub async fn join_room_by_id_route( body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let Ruma:: { + body, sender_user, .. + } = body; - let mut servers = Vec::new(); // There is no body.server_name for /roomId/join - servers.extend( - services() - .rooms - .state_cache - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + let join_room_by_id::v3::Request { + room_id, + reason, + third_party_signed, + } = body; - servers.push( - body.room_id - .server_name() - .expect("Room IDs should always have a server name") - .into(), - ); + let sender_user = sender_user.as_ref().expect("user is authenticated"); - join_room_by_id_helper( - body.sender_user.as_deref(), - &body.room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - ) - .await + let (servers, room_id) = services() + .rooms + .state_cache + .get_room_id_and_via_servers(sender_user, room_id.into(), vec![]) + .await?; + + services() + .rooms + .helpers + .join_room_by_id( + sender_user, + &room_id, + reason.clone(), + &servers, + third_party_signed.as_ref(), + ) + .await } /// # `POST /_matrix/client/r0/join/{roomIdOrAlias}` @@ -95,53 +86,188 @@ pub async fn join_room_by_id_or_alias_route( let sender_user = body.sender_user.as_deref().expect("user is authenticated"); let body = body.body; - let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { - Ok(room_id) => { - let mut servers = body.via.clone(); - servers.extend( - services() - .rooms - .state_cache - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + let (servers, room_id) = services() + .rooms + .state_cache + .get_room_id_and_via_servers(sender_user, body.room_id_or_alias, body.via) + .await?; - servers.push( - room_id - .server_name() - .expect("Room IDs should always have a server name") - .into(), - ); - - (servers, room_id) - } - Err(room_alias) => { - let response = get_alias_helper(room_alias).await?; - - (response.servers, response.room_id) - } - }; - - let join_room_response = join_room_by_id_helper( - Some(sender_user), - &room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - ) - .await?; + let join_room_response = services() + .rooms + .helpers + .join_room_by_id( + sender_user, + &room_id, + body.reason.clone(), + &servers, + body.third_party_signed.as_ref(), + ) + .await?; Ok(join_room_by_id_or_alias::v3::Response { room_id: join_room_response.room_id, }) } +/// # `POST /_matrix/client/v3/knock/{roomIdOrAlias}` +/// +/// Tries to knock on a room. +/// +/// - If the server knowns about this room: creates the knock event and does auth rules locally +/// - If the server does not know about the room: asks other servers over federation +pub async fn knock_room_route( + body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); + let body = body.body; + + let (servers, room_id) = services() + .rooms + .state_cache + .get_room_id_and_via_servers(sender_user, body.room_id_or_alias, body.via) + .await?; + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Ask a remote server if we are not participating in this room + if !services() + .rooms + .state_cache + .server_in_room(services().globals.server_name(), &room_id)? + { + info!("Knocking on {room_id} over federation."); + + let mut make_knock_response_and_server = Err(Error::BadServerResponse( + "No server available to assist in knocking.", + )); + + for remote_server in servers { + if remote_server == services().globals.server_name() { + continue; + } + info!("Asking {remote_server} for make_knock"); + let make_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::knock::create_knock_event_template::v1::Request { + room_id: room_id.to_owned(), + user_id: sender_user.to_owned(), + ver: services().globals.supported_room_versions(), + }, + ) + .await; + + if let Ok(make_knock_response) = make_join_response { + make_knock_response_and_server = Ok((make_knock_response, remote_server.clone())); + + break; + } + } + + let (knock_template, remote_server) = make_knock_response_and_server?; + + info!("make_knock finished"); + + let room_version_id = knock_template.room_version; + + let (event_id, knock_event, _) = services().rooms.helpers.populate_membership_template( + &knock_template.event, + sender_user, + body.reason, + &room_version_id, + MembershipState::Knock, + )?; + + info!("Asking {remote_server} for send_knock"); + let send_kock_response = services() + .sending + .send_federation_request( + &remote_server, + federation::knock::send_knock::v1::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(knock_event.clone()), + }, + ) + .await?; + + info!("send_knock finished"); + + let mut stripped_state = send_kock_response.knock_room_state; + // Not sure how useful this is in reality, but spec examples show `/sync` returning the actual knock membership event + stripped_state.push(Raw::from_json(to_raw_value(&knock_event).expect( + "All keys are Strings, and CanonicalJsonValue Serialization never fails", + ))); + + services().rooms.state_cache.update_membership( + &room_id, + sender_user, + MembershipState::Knock, + sender_user, + Some(stripped_state), + false, + )?; + } else { + info!("We can knock locally"); + + match services() + .rooms + .state_accessor + .get_join_rules(&room_id)? + .map(|content| content.join_rule) + { + Some(JoinRule::Knock) | Some(JoinRule::KnockRestricted(_)) => (), + _ => { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "You are not allowed to knock on this room.", + )) + } + }; + + let event = RoomMemberEventContent { + membership: MembershipState::Knock, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: body.reason.clone(), + join_authorized_via_users_server: None, + }; + + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; + } + + Ok(knock_room::v3::Response::new(room_id)) +} + /// # `POST /_matrix/client/r0/rooms/{roomId}/leave` /// /// Tries to leave the sender user from a room. @@ -518,712 +644,6 @@ pub async fn joined_members_route( Ok(joined_members::v3::Response { joined }) } -async fn join_room_by_id_helper( - sender_user: Option<&UserId>, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, -) -> Result { - let sender_user = sender_user.expect("user is authenticated"); - - if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, room_id) { - return Ok(join_room_by_id::v3::Response { - room_id: room_id.into(), - }); - } - - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Ask a remote server if we are not participating in this room - if !services() - .rooms - .state_cache - .server_in_room(services().globals.server_name(), room_id)? - { - info!("Joining {room_id} over federation."); - - let (make_join_response, remote_server) = - make_join_request(sender_user, room_id, servers).await?; - - info!("make_join finished"); - - let room_version_id = match make_join_response.room_version { - Some(room_version) - if services() - .globals - .supported_room_versions() - .contains(&room_version) => - { - room_version - } - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; - - let join_authorized_via_users_server = join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - - // TODO: Is origin needed? - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason, - join_authorized_via_users_server, - }) - .expect("event is valid, we just created it"), - ); - - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - join_event_stub.remove("event_id"); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); - - // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("Event format validated when event was hashed") - ); - let event_id = <&EventId>::try_from(event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - join_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - // It has enough fields to be called a proper event now - let mut join_event = join_event_stub; - - info!("Asking {remote_server} for send_join"); - let send_join_response = services() - .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, - }, - ) - .await?; - - info!("send_join finished"); - - if let Some(signed_raw) = &send_join_response.room_state.event { - info!("There is a signed event. This room is probably using restricted joins. Adding signature to our event"); - let (signed_event_id, signed_value) = - match gen_event_id_canonical_json(signed_raw, &room_version_id) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; - - if signed_event_id != event_id { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent event with wrong event id", - )); - } - - match signed_value["signatures"] - .as_object() - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent invalid signatures type", - )) - .and_then(|e| { - e.get(remote_server.as_str()).ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Server did not send its signature", - )) - }) { - Ok(signature) => { - join_event - .get_mut("signatures") - .expect("we created a valid pdu") - .as_object_mut() - .expect("we created a valid pdu") - .insert(remote_server.to_string(), signature.clone()); - } - Err(e) => { - warn!( - "Server {remote_server} sent invalid signature in sendjoin signatures for event {signed_value:?}: {e:?}", - ); - } - } - } - - services().rooms.short.get_or_create_shortroomid(room_id)?; - - info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) - .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; - - let mut state = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); - - info!("Fetching join signing keys"); - services() - .rooms - .event_handler - .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) - .await?; - - info!("Going through send_join response room_state"); - for result in send_join_response - .room_state - .state - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { - let (event_id, value) = match result.await { - Ok(t) => t, - Err(_) => continue, - }; - - let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { - warn!("Invalid PDU in send_join response: {} {:?}", e, value); - Error::BadServerResponse("Invalid PDU in send_join response.") - })?; - - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; - if let Some(state_key) = &pdu.state_key { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; - state.insert(shortstatekey, pdu.event_id.clone()); - } - } - - info!("Going through send_join response auth_chain"); - for result in send_join_response - .room_state - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { - let (event_id, value) = match result.await { - Ok(t) => t, - Err(_) => continue, - }; - - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; - } - - info!("Running send_join auth check"); - let authenticated = state_res::event_auth::auth_check( - &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), - &parsed_join_pdu, - None::, // TODO: third party invite - |k, s| { - services() - .rooms - .timeline - .get_pdu( - state.get( - &services() - .rooms - .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, - )?, - ) - .ok()? - }, - ) - .map_err(|e| { - warn!("Auth check failed: {e}"); - Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") - })?; - - if !authenticated { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth check failed", - )); - } - - info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| { - services() - .rooms - .state_compressor - .compress_state_event(k, &id) - }) - .collect::>()?, - ), - )?; - - services() - .rooms - .state - .force_state(room_id, statehash_before_join, new, removed, &state_lock) - .await?; - - info!("Updating joined counts for new room"); - services().rooms.state_cache.update_joined_count(room_id)?; - - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; - - info!("Appending new room join event"); - services() - .rooms - .timeline - .append_pdu( - &parsed_join_pdu, - join_event, - vec![(*parsed_join_pdu.event_id).to_owned()], - &state_lock, - ) - .await?; - - info!("Setting final room state for new room"); - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - services() - .rooms - .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; - } else { - info!("We can join locally"); - - let join_rules_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomJoinRules, - "", - )?; - - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; - - let restriction_rooms = match join_rules_event_content { - Some(RoomJoinRulesEventContent { - join_rule: JoinRule::Restricted(restricted), - }) - | Some(RoomJoinRulesEventContent { - join_rule: JoinRule::KnockRestricted(restricted), - }) => restricted - .allow - .into_iter() - .filter_map(|a| match a { - AllowRule::RoomMembership(r) => Some(r.room_id), - _ => None, - }) - .collect(), - _ => Vec::new(), - }; - - let authorized_user = if restriction_rooms.iter().any(|restriction_room_id| { - services() - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .unwrap_or(false) - }) { - let mut auth_user = None; - for user in services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .collect::>() - { - if user.server_name() == services().globals.server_name() - && services() - .rooms - .state_accessor - .user_can_invite(room_id, &user, sender_user, &state_lock) - .unwrap_or(false) - { - auth_user = Some(user); - break; - } - } - auth_user - } else { - None - }; - - let event = RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason: reason.clone(), - join_authorized_via_users_server: authorized_user, - }; - - // Try normal join first - let error = match services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - timestamp: None, - }, - sender_user, - room_id, - &state_lock, - ) - .await - { - Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), - Err(e) => e, - }; - - if !restriction_rooms.is_empty() - && servers - .iter() - .any(|s| *s != services().globals.server_name()) - { - info!( - "We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements" - ); - let (make_join_response, remote_server) = - make_join_request(sender_user, room_id, servers).await?; - - let room_version_id = match make_join_response.room_version { - Some(room_version_id) - if services() - .globals - .supported_room_versions() - .contains(&room_version_id) => - { - room_version_id - } - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; - let join_authorized_via_users_server = join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - let restricted_join = join_authorized_via_users_server.is_some(); - - // TODO: Is origin needed? - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason, - join_authorized_via_users_server, - }) - .expect("event is valid, we just created it"), - ); - - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - join_event_stub.remove("event_id"); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); - - // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = OwnedEventId::try_from(event_id) - .expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - join_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - // It has enough fields to be called a proper event now - let join_event = join_event_stub; - - let send_join_response = services() - .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, - }, - ) - .await?; - - let pdu = if let Some(signed_raw) = send_join_response.room_state.event { - let (signed_event_id, signed_pdu) = - gen_event_id_canonical_json(&signed_raw, &room_version_id)?; - - if signed_event_id != event_id { - return Err(Error::BadServerResponse( - "Server sent event with wrong event id", - )); - } - - signed_pdu - } else if restricted_join { - return Err(Error::BadServerResponse( - "No signed event was returned, despite just performing a restricted join", - )); - } else { - join_event - }; - - drop(state_lock); - let pub_key_map = RwLock::new(BTreeMap::new()); - services() - .rooms - .event_handler - .handle_incoming_pdu(&remote_server, &event_id, room_id, pdu, true, &pub_key_map) - .await?; - } else { - return Err(error); - } - } - - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) -} - -async fn make_join_request( - sender_user: &UserId, - room_id: &RoomId, - servers: &[OwnedServerName], -) -> Result<( - federation::membership::prepare_join_event::v1::Response, - OwnedServerName, -)> { - let mut make_join_response_and_server = Err(Error::BadServerResponse( - "No server available to assist in joining.", - )); - - for remote_server in servers { - if remote_server == services().globals.server_name() { - continue; - } - info!("Asking {remote_server} for make_join"); - let make_join_response = services() - .sending - .send_federation_request( - remote_server, - federation::membership::prepare_join_event::v1::Request { - room_id: room_id.to_owned(), - user_id: sender_user.to_owned(), - ver: services().globals.supported_room_versions(), - }, - ) - .await; - - make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); - - if make_join_response_and_server.is_ok() { - break; - } - } - - make_join_response_and_server -} - -async fn validate_and_add_event_id( - pdu: &RawJsonValue, - room_version: &RoomVersionId, - pub_key_map: &RwLock>, -) -> Result<(OwnedEventId, CanonicalJsonObject)> { - let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&value, room_version) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid PDU format"))? - )) - .expect("ruma's reference hashes are valid event ids"); - - let back_off = |id| async { - match services() - .globals - .bad_event_ratelimiter - .write() - .await - .entry(id) - { - Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - } - }; - - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .await - .get(&event_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); - } - } - - let origin_server_ts = value.get("origin_server_ts").ok_or_else(|| { - error!("Invalid PDU, no origin_server_ts field"); - Error::BadRequest( - ErrorKind::MissingParam, - "Invalid PDU, no origin_server_ts field", - ) - })?; - - let origin_server_ts: MilliSecondsSinceUnixEpoch = { - let ts = origin_server_ts.as_integer().ok_or_else(|| { - Error::BadRequest( - ErrorKind::InvalidParam, - "origin_server_ts must be an integer", - ) - })?; - - MilliSecondsSinceUnixEpoch(i64::from(ts).try_into().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Time must be after the unix epoch") - })?) - }; - - let unfiltered_keys = (*pub_key_map.read().await).clone(); - - let keys = - services() - .globals - .filter_keys_server_map(unfiltered_keys, origin_server_ts, room_version); - - if let Err(e) = ruma::signatures::verify_event(&keys, &value, room_version) { - warn!("Event {} failed verification {:?} {}", event_id, pdu, e); - back_off(event_id).await; - return Err(Error::BadServerResponse("Event failed verification.")); - } - - value.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - Ok((event_id, value)) -} - pub(crate) async fn invite_helper<'a>( sender_user: &UserId, user_id: &UserId, @@ -1270,7 +690,7 @@ pub(crate) async fn invite_helper<'a>( &state_lock, )?; - let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = services().rooms.state.stripped_state(&pdu.room_id)?; drop(state_lock); diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index ec6c06b0..39342d5b 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -10,7 +10,8 @@ use ruma::{ self, v3::{ Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, - LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, State, Timeline, ToDevice, + KnockState, KnockedRoom, LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, + State, Timeline, ToDevice, }, v4::{SlidingOp, SlidingSyncRoomHero}, DeviceLists, UnreadNotificationsCount, @@ -503,6 +504,50 @@ async fn sync_helper( ); } + let mut knocked_rooms = BTreeMap::new(); + let all_knocked_rooms: Vec<_> = services() + .rooms + .state_cache + .rooms_knocked(&sender_user) + .collect(); + for result in all_knocked_rooms { + let (room_id, knock_state_events) = result?; + + { + // Get and drop the lock to wait for remaining operations to finish + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().await; + drop(insert_lock); + } + + let knock_count = services() + .rooms + .state_cache + .get_knock_count(&room_id, &sender_user)?; + + // knock before last sync + if Some(since) >= knock_count { + continue; + } + + knocked_rooms.insert( + room_id.clone(), + KnockedRoom { + knock_state: KnockState { + events: knock_state_events, + }, + }, + ); + } + for user_id in left_encrypted_users { let dont_share_encrypted_room = services() .rooms @@ -538,7 +583,7 @@ async fn sync_helper( leave: left_rooms, join: joined_rooms, invite: invited_rooms, - knock: BTreeMap::new(), // TODO + knock: knocked_rooms, }, presence: Presence { events: presence_updates diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 4da86f4e..f8768a9a 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -31,7 +31,11 @@ use ruma::{ }, event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, keys::{claim_keys, get_keys}, - membership::{create_invite, create_join_event, prepare_join_event}, + knock::{create_knock_event_template, send_knock}, + membership::{ + create_invite, create_join_event, create_leave_event, prepare_join_event, + prepare_leave_event, + }, openid::get_openid_userinfo, query::{get_profile_information, get_room_information}, space::get_hierarchy, @@ -67,7 +71,7 @@ use std::{ sync::Arc, time::{Duration, Instant, SystemTime}, }; -use tokio::sync::RwLock; +use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, warn}; @@ -637,7 +641,7 @@ async fn request_well_known(destination: &str) -> Option<(String, Instant)> { let response = services() .globals .default_client() - .get(&format!("https://{destination}/.well-known/matrix/server")) + .get(format!("https://{destination}/.well-known/matrix/server")) .send() .await; debug!("Got well known response"); @@ -1494,42 +1498,60 @@ pub async fn get_room_state_ids_route( }) } +/// # `GET /_matrix/federation/v1/make_knock/{roomId}/{userId}` +/// +/// Creates a knock template. +pub async fn create_knock_event_template_route( + body: Ruma, +) -> Result { + let (mutex_state, room_version_id) = + member_shake_preamble(&body.sender_servername, &body.room_id).await?; + let state_lock = mutex_state.lock().await; + + Ok(create_knock_event_template::v1::Response { + room_version: room_version_id, + event: create_membership_template( + &body.user_id, + &body.room_id, + None, + MembershipState::Knock, + state_lock, + )?, + }) +} + +/// # `GET /_matrix/federation/v1/make_leave/{roomId}/{userId}` +/// +/// Creates a leave template. +pub async fn create_leave_event_template_route( + body: Ruma, +) -> Result { + let (mutex_state, room_version_id) = + member_shake_preamble(&body.sender_servername, &body.room_id).await?; + let state_lock = mutex_state.lock().await; + + Ok(prepare_leave_event::v1::Response { + room_version: Some(room_version_id), + event: create_membership_template( + &body.user_id, + &body.room_id, + None, + MembershipState::Leave, + state_lock, + )?, + }) +} + /// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}` /// /// Creates a join template. pub async fn create_join_event_template_route( body: Ruma, ) -> Result { - if !services().rooms.metadata.exists(&body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room is unknown to this server.", - )); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; - - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .await - .entry(body.room_id.to_owned()) - .or_default(), - ); + let (mutex_state, room_version_id) = + member_shake_preamble(&body.sender_servername, &body.room_id).await?; let state_lock = mutex_state.lock().await; - let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; - let join_authorized_via_users_server = if // The following two functions check whether the user can "join" without performing a restricted join !services() @@ -1580,12 +1602,32 @@ pub async fn create_join_event_template_route( )); } + Ok(prepare_join_event::v1::Response { + room_version: Some(room_version_id), + event: create_membership_template( + &body.user_id, + &body.room_id, + join_authorized_via_users_server, + MembershipState::Join, + state_lock, + )?, + }) +} + +/// Creates a template for the given membership state, to return on the `/make_` endpoints +fn create_membership_template( + user_id: &UserId, + room_id: &RoomId, + join_authorized_via_users_server: Option, + membership: MembershipState, + state_lock: tokio::sync::MutexGuard<'_, ()>, +) -> Result, Error> { let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, blurhash: None, displayname: None, is_direct: None, - membership: MembershipState::Join, + membership, third_party_invite: None, reason: None, join_authorized_via_users_server, @@ -1597,12 +1639,12 @@ pub async fn create_join_event_template_route( event_type: TimelineEventType::RoomMember, content, unsigned: None, - state_key: Some(body.user_id.to_string()), + state_key: Some(user_id.to_string()), redacts: None, timestamp: None, }, - &body.user_id, - &body.room_id, + user_id, + room_id, &state_lock, )?; @@ -1610,17 +1652,13 @@ pub async fn create_join_event_template_route( pdu_json.remove("event_id"); - Ok(prepare_join_event::v1::Response { - room_version: Some(room_version_id), - event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), - }) + let raw_event = to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"); + + Ok(raw_event) } -async fn create_join_event( - sender_servername: &ServerName, - room_id: &RoomId, - pdu: &RawJsonValue, -) -> Result { +/// checks whether the given room exists, and checks whether the specified server is allowed to send events according to the ACL +fn room_and_acl_check(room_id: &RoomId, sender_servername: &OwnedServerName) -> Result<(), Error> { if !services().rooms.metadata.exists(room_id)? { return Err(Error::BadRequest( ErrorKind::NotFound, @@ -1632,6 +1670,40 @@ async fn create_join_event( .rooms .event_handler .acl_check(sender_servername, room_id)?; + Ok(()) +} + +/// Takes care of common boilerpalte for room membership handshake endpoints. +/// The returned mutex must be locked by the caller. +async fn member_shake_preamble( + sender_servername: &Option, + room_id: &RoomId, +) -> Result<(Arc>, RoomVersionId), Error> { + let sender_servername = sender_servername.as_ref().expect("server is authenticated"); + room_and_acl_check(room_id, sender_servername)?; + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); + + let room_version_id = services().rooms.state.get_room_version(room_id)?; + + Ok((mutex_state, room_version_id)) +} + +async fn create_join_event( + sender_servername: &Option, + room_id: &RoomId, + pdu: &RawJsonValue, +) -> Result { + let sender_servername = sender_servername.as_ref().expect("server is authenticated"); + room_and_acl_check(room_id, sender_servername)?; // We need to return the state prior to joining, let's keep a reference to that here let shortstatehash = services() @@ -1643,8 +1715,44 @@ async fn create_join_event( "Pdu state not found.", ))?; + let pdu = append_member_pdu(MembershipState::Join, sender_servername, room_id, pdu).await?; + + let state_ids = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await?; + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(room_id, state_ids.values().cloned().collect()) + .await?; + + Ok(create_join_event::v1::RoomState { + auth_chain: auth_chain_ids + .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + state: state_ids + .iter() + .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + event: pdu.map(|pdu| { + to_raw_value(&CanonicalJsonValue::Object(pdu)) + .expect("To raw json should not fail since only change was adding signature") + }), + }) +} + +/// Takes the given membership PDU and attempts to append it to the timeline +async fn append_member_pdu( + membership: MembershipState, + sender_servername: &OwnedServerName, + room_id: &RoomId, + pdu: &RawJsonValue, +) -> Result>, Error> { let pub_key_map = RwLock::new(BTreeMap::new()); - // let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and hashes checks let room_version_id = services().rooms.state.get_room_version(room_id)?; @@ -1719,17 +1827,18 @@ async fn create_join_event( ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid event content"))?; - if event_content.membership != MembershipState::Join { + if event_content.membership != membership { return Err(Error::BadRequest( ErrorKind::BadJson, "Membership of sent event does not match that of the endpoint", )); } - let sign_join_event = event_content - .join_authorized_via_users_server - .map(|user| user.server_name() == services().globals.server_name()) - .unwrap_or_default() + let sign_join_event = membership == MembershipState::Join + && event_content + .join_authorized_via_users_server + .map(|user| user.server_name() == services().globals.server_name()) + .unwrap_or_default() && user_can_perform_restricted_join(&sender, room_id, &room_version_id).unwrap_or_default(); if sign_join_event { @@ -1779,17 +1888,6 @@ async fn create_join_event( ))?; drop(mutex_lock); - let state_ids = services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await?; - let auth_chain_ids = services() - .rooms - .auth_chain - .get_auth_chain(room_id, state_ids.values().cloned().collect()) - .await?; - let servers = services() .rooms .state_cache @@ -1799,26 +1897,7 @@ async fn create_join_event( services().sending.send_pdu(servers, &pdu_id)?; - Ok(create_join_event::v1::RoomState { - auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - state: state_ids - .iter() - .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - // Event field is required if we sign the join event. - event: if sign_join_event { - Some( - to_raw_value(&CanonicalJsonValue::Object(value)) - .expect("To raw json should not fail since only change was adding signature"), - ) - } else { - None - }, - }) + Ok(if sign_join_event { Some(value) } else { None }) } /// # `PUT /_matrix/federation/v1/send_join/{roomId}/{eventId}` @@ -1827,12 +1906,7 @@ async fn create_join_event( pub async fn create_join_event_v1_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(&body.sender_servername, &body.room_id, &body.pdu).await?; Ok(create_join_event::v1::Response { room_state }) } @@ -1843,16 +1917,11 @@ pub async fn create_join_event_v1_route( pub async fn create_join_event_v2_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - let create_join_event::v1::RoomState { auth_chain, state, event, - } = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + } = create_join_event(&body.sender_servername, &body.room_id, &body.pdu).await?; let room_state = create_join_event::v2::RoomState { members_omitted: false, auth_chain, @@ -1864,6 +1933,54 @@ pub async fn create_join_event_v2_route( Ok(create_join_event::v2::Response { room_state }) } +/// # `PUT /_matrix/federation/v2/send_leave/{roomId}/{eventId}` +/// +/// Submits a signed leave event. +pub async fn create_leave_event_route( + body: Ruma, +) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + room_and_acl_check(&body.room_id, sender_servername)?; + + append_member_pdu( + MembershipState::Leave, + sender_servername, + &body.room_id, + &body.pdu, + ) + .await?; + + Ok(create_leave_event::v2::Response {}) +} + +/// # `PUT /_matrix/federation/v1/send_knock/{roomId}/{eventId}` +/// +/// Submits a signed knock event. +pub async fn create_knock_event_route( + body: Ruma, +) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + room_and_acl_check(&body.room_id, sender_servername)?; + + append_member_pdu( + MembershipState::Knock, + sender_servername, + &body.room_id, + &body.pdu, + ) + .await?; + + Ok(send_knock::v1::Response { + knock_room_state: services().rooms.state.stripped_state(&body.room_id)?, + }) +} + /// Checks whether the given user can join the given room via a restricted join. /// This doesn't check the current user's membership. This should be done externally, /// either by using the state cache or attempting to authorize the event. @@ -1943,44 +2060,54 @@ fn user_can_perform_restricted_join( pub async fn create_invite_route( body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let Ruma:: { + body, + sender_servername, + .. + } = body; + + let create_invite::v2::Request { + room_id, + room_version, + event, + invite_room_state, + .. + } = body; + + let sender_servername = sender_servername.expect("server is authenticated"); services() .rooms .event_handler - .acl_check(sender_servername, &body.room_id)?; - + .acl_check(&sender_servername, &room_id)?; if !services() .globals .supported_room_versions() - .contains(&body.room_version) + .contains(&room_version) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { - room_version: body.room_version.clone(), + room_version: room_version.clone(), }, "Server does not support this room version.", )); } - let mut signed_event = utils::to_canonical_object(&body.event) + let mut signed_event = utils::to_canonical_object(&event) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; ruma::signatures::hash_and_sign_event( services().globals.server_name().as_str(), services().globals.keypair(), &mut signed_event, - &body.room_version, + &room_version, ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; // Generate event id let event_id = EventId::parse(format!( "${}", - ruma::signatures::reference_hash(&signed_event, &body.room_version) + ruma::signatures::reference_hash(&signed_event, &room_version) .expect("Event format validated when event was hashed") )) .expect("ruma's reference hashes are valid event ids"); @@ -2015,9 +2142,9 @@ pub async fn create_invite_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?; - let mut invite_state = body.invite_room_state.clone(); + let mut invite_state = invite_room_state.clone(); - let mut event: JsonObject = serde_json::from_str(body.event.get()) + let mut event: JsonObject = serde_json::from_str(event.get()) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; event.insert("event_id".to_owned(), "$dummy".into()); @@ -2033,16 +2160,55 @@ pub async fn create_invite_route( if !services() .rooms .state_cache - .server_in_room(services().globals.server_name(), &body.room_id)? + .server_in_room(services().globals.server_name(), &room_id)? { - services().rooms.state_cache.update_membership( - &body.room_id, - &invited_user, - MembershipState::Invite, - &sender, - Some(invite_state), - true, - )?; + // If the user has already knocked on the room, we take that as the user wanting to join + // the room as soon as their knock is accepted, as recommended by the spec. + // + // https://spec.matrix.org/v1.13/client-server-api/#knocking-on-rooms + if services() + .rooms + .state_cache + .is_knocked(&invited_user, &room_id) + .unwrap_or_default() + { + // We want to try join automatically first, before notifying clients that they were invited. + // We also shouldn't block giving the calling server the response on attempting to join the + // room, since it's not relevant for the caller. + tokio::spawn(async move { + if services().rooms.helpers.join_room_by_id(&invited_user, &room_id, None, &invite_state.iter() .filter_map(|event| event.deserialize().ok()) + .map(|event| event.sender().server_name().to_owned()) + .collect::>() +, None) + .await + .is_err() && + // Checking whether the state has changed since we started this join handshake + services() + .rooms + .state_cache + .is_knocked(&invited_user, &room_id) + .unwrap_or_default() + { + let _ = services().rooms.state_cache.update_membership( + &room_id, + &invited_user, + MembershipState::Invite, + &sender, + Some(invite_state), + true, + ); + } + }); + } else { + services().rooms.state_cache.update_membership( + &room_id, + &invited_user, + MembershipState::Invite, + &sender, + Some(invite_state), + true, + )?; + } } Ok(create_invite::v2::Response { diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs index e7b53d30..317c44d0 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/database/key_value/rooms/mod.rs @@ -16,6 +16,30 @@ mod threads; mod timeline; mod user; +use ruma::{RoomId, UserId}; + use crate::{database::KeyValueDatabase, service}; impl service::rooms::Data for KeyValueDatabase {} + +/// Constructs roomuser_id and userroom_id respectively in byte form +fn get_room_and_user_byte_ids(room_id: &RoomId, user_id: &UserId) -> (Vec, Vec) { + ( + get_roomuser_id_bytes(room_id, user_id), + get_userroom_id_bytes(user_id, room_id), + ) +} + +fn get_roomuser_id_bytes(room_id: &RoomId, user_id: &UserId) -> Vec { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + roomuser_id +} + +fn get_userroom_id_bytes(user_id: &UserId, room_id: &RoomId) -> Vec { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + userroom_id +} diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 49e3842b..689ff7cd 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -7,32 +7,28 @@ use ruma::{ }; use crate::{ - database::KeyValueDatabase, + database::{abstraction::KvTree, KeyValueDatabase}, service::{self, appservice::RegistrationInfo}, services, utils, Error, Result, }; +use super::{get_room_and_user_byte_ids, get_userroom_id_bytes}; + impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); self.roomuseroncejoinedids.insert(&userroom_id, &[]) } fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_joined.insert(&userroom_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?; self.userroomid_invitestate.remove(&userroom_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_knockstate.remove(&userroom_id)?; + self.roomuserid_knockcount.remove(&roomuser_id)?; self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; @@ -45,13 +41,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { room_id: &RoomId, last_state: Option>>, ) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_invitestate.insert( &userroom_id, @@ -64,6 +54,35 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { )?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_knockstate.remove(&userroom_id)?; + self.roomuserid_knockcount.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + fn mark_as_knocked( + &self, + user_id: &UserId, + room_id: &RoomId, + last_state: Option>>, + ) -> Result<()> { + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); + + self.userroomid_knockstate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()) + .expect("state to bytes always works"), + )?; + self.roomuserid_knockcount.insert( + &roomuser_id, + &services().globals.next_count()?.to_be_bytes(), + )?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; @@ -71,13 +90,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { } fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_leftstate.insert( &userroom_id, @@ -91,6 +104,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_invitestate.remove(&userroom_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_knockstate.remove(&userroom_id)?; + self.roomuserid_knockcount.remove(&roomuser_id)?; Ok(()) } @@ -225,13 +240,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { /// Makes a user forget a room. #[tracing::instrument(skip(self))] fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; @@ -413,6 +422,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { }) } + #[tracing::instrument(skip(self))] + fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_knockcount + .get(&key)? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid knockcount in db.") + })?)) + }) + } + #[tracing::instrument(skip(self))] fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { let mut key = room_id.as_bytes().to_vec(); @@ -460,34 +484,17 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { &'a self, user_id: &UserId, ) -> Box>)>> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + scan_userroom_id_memberstate_tree(user_id, &self.userroomid_invitestate) + } - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid.") - })?; - - let state = serde_json::from_slice(&state).map_err(|_| { - Error::bad_database("Invalid state in userroomid_invitestate.") - })?; - - Ok((room_id, state)) - }), - ) + /// Returns an iterator over all rooms a user has knocked on. + #[allow(clippy::type_complexity)] + #[tracing::instrument(skip(self))] + fn rooms_knocked<'a>( + &'a self, + user_id: &UserId, + ) -> Box>)>> + 'a> { + scan_userroom_id_memberstate_tree(user_id, &self.userroomid_knockstate) } #[tracing::instrument(skip(self))] @@ -511,6 +518,27 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { .transpose() } + #[tracing::instrument(skip(self))] + fn knock_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + + self.userroomid_knockstate + .get(&key)? + .map(|state| { + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_knockstate."))?; + + Ok(state) + }) + .transpose() + } + #[tracing::instrument(skip(self))] fn left_state( &self, @@ -539,69 +567,80 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { &'a self, user_id: &UserId, ) -> Box>)>> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid.") - })?; - - let state = serde_json::from_slice(&state).map_err(|_| { - Error::bad_database("Invalid state in userroomid_leftstate.") - })?; - - Ok((room_id, state)) - }), - ) + scan_userroom_id_memberstate_tree(user_id, &self.userroomid_leftstate) } #[tracing::instrument(skip(self))] fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) } #[tracing::instrument(skip(self))] fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) } #[tracing::instrument(skip(self))] fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) } + #[tracing::instrument(skip(self))] + fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let userroom_id = get_userroom_id_bytes(user_id, room_id); + + Ok(self.userroomid_knockstate.get(&userroom_id)?.is_some()) + } + #[tracing::instrument(skip(self))] fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } } + +/// Scans the given userroom_id_`member`state tree for rooms, returning an iterator of room_ids +/// and a vector of raw state events +#[allow(clippy::type_complexity)] +fn scan_userroom_id_memberstate_tree<'a, T>( + user_id: &UserId, + userroom_id_memberstate_tree: &'a Arc, +) -> Box>)>> + 'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new( + userroom_id_memberstate_tree + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database( + "Room ID in userroomid_state is invalid unicode.", + ) + })?, + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_state is invalid.") + })?; + + let state = serde_json::from_slice(&state).map_err(|_| { + Error::bad_database("Invalid state in userroomid_state.") + })?; + + Ok((room_id, state)) + }), + ) +} diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 0331a624..e89e2041 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -9,6 +9,8 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEven use service::rooms::timeline::PduCount; +use super::get_userroom_id_bytes; + impl service::rooms::timeline::Data for KeyValueDatabase { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self @@ -286,15 +288,11 @@ impl service::rooms::timeline::Data for KeyValueDatabase { let mut notifies_batch = Vec::new(); let mut highlights_batch = Vec::new(); for user in notifies { - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(&user, room_id); notifies_batch.push(userroom_id); } for user in highlights { - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(&user, room_id); highlights_batch.push(userroom_id); } diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 4c435720..2ba4240e 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -2,14 +2,11 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use super::{get_room_and_user_byte_ids, get_userroom_id_bytes}; + impl service::rooms::user::Data for KeyValueDatabase { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); + let (roomuser_id, userroom_id) = get_room_and_user_byte_ids(room_id, user_id); self.userroomid_notificationcount .insert(&userroom_id, &0_u64.to_be_bytes())?; @@ -25,9 +22,7 @@ impl service::rooms::user::Data for KeyValueDatabase { } fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); self.userroomid_notificationcount .get(&userroom_id)? @@ -39,9 +34,7 @@ impl service::rooms::user::Data for KeyValueDatabase { } fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + let userroom_id = get_userroom_id_bytes(user_id, room_id); self.userroomid_highlightcount .get(&userroom_id)? diff --git a/src/database/mod.rs b/src/database/mod.rs index 26db8268..44954e48 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -97,8 +97,10 @@ pub struct KeyValueDatabase { pub(super) roomid_joinedcount: Arc, pub(super) roomid_invitedcount: Arc, pub(super) roomuseroncejoinedids: Arc, - pub(super) userroomid_invitestate: Arc, // InviteState = Vec> + pub(super) userroomid_invitestate: Arc, // InviteState = Vec> pub(super) roomuserid_invitecount: Arc, // InviteCount = Count + pub(super) userroomid_knockstate: Arc, // KnockState = Vec> + pub(super) roomuserid_knockcount: Arc, // KnockCount = Count pub(super) userroomid_leftstate: Arc, pub(super) roomuserid_leftcount: Arc, @@ -313,6 +315,8 @@ impl KeyValueDatabase { roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, + userroomid_knockstate: builder.open_tree("userroomid_knockstate")?, + roomuserid_knockcount: builder.open_tree("roomuserid_knockcount")?, userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, diff --git a/src/main.rs b/src/main.rs index 2776c200..5669cc0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -345,6 +345,7 @@ fn routes(config: &Config) -> Router { .ruma_route(client_server::get_alias_route) .ruma_route(client_server::join_room_by_id_route) .ruma_route(client_server::join_room_by_id_or_alias_route) + .ruma_route(client_server::knock_room_route) .ruma_route(client_server::joined_members_route) .ruma_route(client_server::leave_room_route) .ruma_route(client_server::forget_room_route) @@ -458,6 +459,10 @@ fn routes(config: &Config) -> Router { .ruma_route(server_server::create_join_event_template_route) .ruma_route(server_server::create_join_event_v1_route) .ruma_route(server_server::create_join_event_v2_route) + .ruma_route(server_server::create_leave_event_template_route) + .ruma_route(server_server::create_leave_event_route) + .ruma_route(server_server::create_knock_event_template_route) + .ruma_route(server_server::create_knock_event_route) .ruma_route(server_server::create_invite_route) .ruma_route(server_server::get_devices_route) .ruma_route(server_server::get_content_route) diff --git a/src/service/mod.rs b/src/service/mod.rs index 552c71af..c328bf7e 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -73,6 +73,7 @@ impl Services { }, }, event_handler: rooms::event_handler::Service, + helpers: rooms::helpers::Service, lazy_loading: rooms::lazy_loading::Service { db, lazy_load_waiting: Mutex::new(HashMap::new()), diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 95d52ad3..87c3d4c0 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,11 +1,16 @@ mod data; pub use data::Data; +use rand::seq::SliceRandom; use tracing::error; use crate::{services, Error, Result}; use ruma::{ - api::client::error::ErrorKind, + api::{ + appservice, + client::{alias::get_alias, error::ErrorKind}, + federation, + }, events::{ room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, StateEventType, @@ -98,4 +103,72 @@ impl Service { ) -> Box> + 'a> { self.db.local_aliases_for_room(room_id) } + + /// Resolves an alias to a room id, and a set of servers to join or knock via, either locally or over federation + #[tracing::instrument(skip(self))] + pub async fn get_alias_helper( + &self, + room_alias: OwnedRoomAliasId, + ) -> Result { + if room_alias.server_name() != services().globals.server_name() { + let response = services() + .sending + .send_federation_request( + room_alias.server_name(), + federation::query::get_room_information::v1::Request { + room_alias: room_alias.to_owned(), + }, + ) + .await?; + + let mut servers = response.servers; + servers.shuffle(&mut rand::thread_rng()); + + return Ok(get_alias::v3::Response::new(response.room_id, servers)); + } + + let mut room_id = None; + match services().rooms.alias.resolve_local_alias(&room_alias)? { + Some(r) => room_id = Some(r), + None => { + for appservice in services().appservice.read().await.values() { + if appservice.aliases.is_match(room_alias.as_str()) + && matches!( + services() + .sending + .send_appservice_request( + appservice.registration.clone(), + appservice::query::query_room_alias::v1::Request { + room_alias: room_alias.clone(), + }, + ) + .await, + Ok(Some(_opt_result)) + ) + { + room_id = + Some(self.resolve_local_alias(&room_alias)?.ok_or_else(|| { + Error::bad_config("Appservice lied to us. Room does not exist.") + })?); + break; + } + } + } + }; + + let room_id = match room_id { + Some(room_id) => room_id, + None => { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Room with alias not found.", + )) + } + }; + + Ok(get_alias::v3::Response::new( + room_id, + vec![services().globals.server_name().to_owned()], + )) + } } diff --git a/src/service/rooms/helpers/mod.rs b/src/service/rooms/helpers/mod.rs new file mode 100644 index 00000000..e8d5b68a --- /dev/null +++ b/src/service/rooms/helpers/mod.rs @@ -0,0 +1,699 @@ +use std::{ + collections::{hash_map::Entry, BTreeMap, HashMap}, + sync::Arc, + time::{Duration, Instant}, +}; + +use ruma::{ + api::{ + client::{ + error::ErrorKind, + membership::{join_room_by_id, ThirdPartySigned}, + }, + federation, + }, + canonical_json::to_canonical_value, + events::{ + room::{ + join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + }, + TimelineEventType, + }, + state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, + OwnedEventId, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, UserId, +}; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; + +use crate::{ + service::{ + globals::SigningKeys, + pdu::{gen_event_id_canonical_json, PduBuilder}, + }, + services, utils, Error, PduEvent, Result, +}; + +pub struct Service; + +impl Service { + /// Attempts to join a room. + /// If the room cannot be joined locally, it attempts to join over federation, soley using the + /// specified servers + #[tracing::instrument(skip(self, reason, servers, _third_party_signed))] + pub async fn join_room_by_id( + &self, + sender_user: &UserId, + room_id: &RoomId, + reason: Option, + servers: &[OwnedServerName], + _third_party_signed: Option<&ThirdPartySigned>, + ) -> Result { + if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, room_id) { + return Ok(join_room_by_id::v3::Response { + room_id: room_id.into(), + }); + } + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Ask a remote server if we are not participating in this room + if !services() + .rooms + .state_cache + .server_in_room(services().globals.server_name(), room_id)? + { + info!("Joining {room_id} over federation."); + + let (make_join_response, remote_server) = + make_join_request(sender_user, room_id, servers).await?; + + info!("make_join finished"); + + let room_version_id = match make_join_response.room_version { + Some(room_version) + if services() + .globals + .supported_room_versions() + .contains(&room_version) => + { + room_version + } + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + + let (event_id, mut join_event, _) = self.populate_membership_template( + &make_join_response.event, + sender_user, + reason, + &room_version_id, + MembershipState::Join, + )?; + + info!("Asking {remote_server} for send_join"); + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + omit_members: false, + }, + ) + .await?; + + info!("send_join finished"); + + if let Some(signed_raw) = &send_join_response.room_state.event { + info!("There is a signed event. This room is probably using restricted joins. Adding signature to our event"); + let (signed_event_id, signed_value) = + match gen_event_id_canonical_json(signed_raw, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + } + }; + + if signed_event_id != event_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent event with wrong event id", + )); + } + + match signed_value["signatures"] + .as_object() + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent invalid signatures type", + )) + .and_then(|e| { + e.get(remote_server.as_str()).ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server did not send its signature", + )) + }) { + Ok(signature) => { + join_event + .get_mut("signatures") + .expect("we created a valid pdu") + .as_object_mut() + .expect("we created a valid pdu") + .insert(remote_server.to_string(), signature.clone()); + } + Err(e) => { + warn!( + "Server {remote_server} sent invalid signature in sendjoin signatures for event {signed_value:?}: {e:?}", + ); + } + } + } + + services().rooms.short.get_or_create_shortroomid(room_id)?; + + info!("Parsing join event"); + let parsed_join_pdu = PduEvent::from_id_val(&event_id, join_event.clone()) + .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; + + let mut state = HashMap::new(); + let pub_key_map = RwLock::new(BTreeMap::new()); + + info!("Fetching join signing keys"); + services() + .rooms + .event_handler + .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) + .await?; + + info!("Going through send_join response room_state"); + for result in send_join_response + .room_state + .state + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let (event_id, value) = match result.await { + Ok(t) => t, + Err(_) => continue, + }; + + let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { + warn!("Invalid PDU in send_join response: {} {:?}", e, value); + Error::BadServerResponse("Invalid PDU in send_join response.") + })?; + + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; + if let Some(state_key) = &pdu.state_key { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + state.insert(shortstatekey, pdu.event_id.clone()); + } + } + + info!("Going through send_join response auth_chain"); + for result in send_join_response + .room_state + .auth_chain + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let (event_id, value) = match result.await { + Ok(t) => t, + Err(_) => continue, + }; + + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; + } + + info!("Running send_join auth check"); + let authenticated = state_res::event_auth::auth_check( + &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), + &parsed_join_pdu, + None::, // TODO: third party invite + |k, s| { + services() + .rooms + .timeline + .get_pdu( + state.get( + &services() + .rooms + .short + .get_or_create_shortstatekey(&k.to_string().into(), s) + .ok()?, + )?, + ) + .ok()? + }, + ) + .map_err(|e| { + warn!("Auth check failed: {e}"); + Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") + })?; + + if !authenticated { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth check failed", + )); + } + + info!("Saving state from send_join"); + let (statehash_before_join, new, removed) = + services().rooms.state_compressor.save_state( + room_id, + Arc::new( + state + .into_iter() + .map(|(k, id)| { + services() + .rooms + .state_compressor + .compress_state_event(k, &id) + }) + .collect::>()?, + ), + )?; + + services() + .rooms + .state + .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .await?; + + info!("Updating joined counts for new room"); + services().rooms.state_cache.update_joined_count(room_id)?; + + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + + info!("Appending new room join event"); + services() + .rooms + .timeline + .append_pdu( + &parsed_join_pdu, + join_event, + vec![(*parsed_join_pdu.event_id).to_owned()], + &state_lock, + ) + .await?; + + info!("Setting final room state for new room"); + // We set the room state after inserting the pdu, so that we never have a moment in time + // where events in the current room state do not exist + services() + .rooms + .state + .set_room_state(room_id, statehash_after_join, &state_lock)?; + } else { + info!("We can join locally"); + + let join_rules_event_content = + services().rooms.state_accessor.get_join_rules(room_id)?; + + let restriction_rooms = match join_rules_event_content { + Some(RoomJoinRulesEventContent { + join_rule: JoinRule::Restricted(restricted), + }) + | Some(RoomJoinRulesEventContent { + join_rule: JoinRule::KnockRestricted(restricted), + }) => restricted + .allow + .into_iter() + .filter_map(|a| match a { + AllowRule::RoomMembership(r) => Some(r.room_id), + _ => None, + }) + .collect(), + _ => Vec::new(), + }; + + let authorized_user = if restriction_rooms.iter().any(|restriction_room_id| { + services() + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + .unwrap_or(false) + }) { + let mut auth_user = None; + for user in services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(Result::ok) + .collect::>() + { + if user.server_name() == services().globals.server_name() + && services() + .rooms + .state_accessor + .user_can_invite(room_id, &user, sender_user, &state_lock) + .unwrap_or(false) + { + auth_user = Some(user); + break; + } + } + auth_user + } else { + None + }; + + let event = RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: reason.clone(), + join_authorized_via_users_server: authorized_user, + }; + + // Try normal join first + let Err(error) = services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await + else { + return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())); + }; + + if !restriction_rooms.is_empty() + && servers + .iter() + .any(|s| *s != services().globals.server_name()) + { + info!( + "We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements" + ); + let (make_join_response, remote_server) = + make_join_request(sender_user, room_id, servers).await?; + + let room_version_id = match make_join_response.room_version { + Some(room_version_id) + if services() + .globals + .supported_room_versions() + .contains(&room_version_id) => + { + room_version_id + } + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + + let (event_id, join_event, restricted_join) = self.populate_membership_template( + &make_join_response.event, + sender_user, + reason, + &room_version_id, + MembershipState::Join, + )?; + + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + omit_members: false, + }, + ) + .await?; + + let pdu = if let Some(signed_raw) = send_join_response.room_state.event { + let (signed_event_id, signed_pdu) = + gen_event_id_canonical_json(&signed_raw, &room_version_id)?; + + if signed_event_id != event_id { + return Err(Error::BadServerResponse( + "Server sent event with wrong event id", + )); + } + + signed_pdu + } else if restricted_join { + return Err(Error::BadServerResponse( + "No signed event was returned, despite just performing a restricted join", + )); + } else { + join_event + }; + + drop(state_lock); + let pub_key_map = RwLock::new(BTreeMap::new()); + services() + .rooms + .event_handler + .handle_incoming_pdu( + &remote_server, + &event_id, + room_id, + pdu, + true, + &pub_key_map, + ) + .await?; + } else { + return Err(error); + } + } + + Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) + } + + /// Takes a membership template, as returned from the `/federation/*/make_*` endpoints, and + /// populates them to the point as to where they are a full pdu, ready to be appended to the timeline + /// + /// Returns the event id, the pdu, and whether this event is a restricted join + pub fn populate_membership_template( + &self, + member_template: &RawJsonValue, + sender_user: &UserId, + reason: Option, + room_version_id: &RoomVersionId, + membership: MembershipState, + ) -> Result<(OwnedEventId, BTreeMap, bool), Error> { + let mut member_event_stub: CanonicalJsonObject = + serde_json::from_str(member_template.get()).map_err(|_| { + Error::BadServerResponse("Invalid make_knock event json received from server.") + })?; + + let join_authorized_via_users_server = member_event_stub + .get("content") + .map(|s| { + s.as_object()? + .get("join_authorised_via_users_server")? + .as_str() + }) + .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + + let restricted_join = join_authorized_via_users_server.is_some(); + + member_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + + member_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + + member_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + membership, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: reason.clone(), + join_authorized_via_users_server, + }) + .expect("event is valid, we just created it"), + ); + + member_event_stub.remove("event_id"); + + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut member_event_stub, + room_version_id, + ) + .expect("event is valid, we just created it"); + + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&member_event_stub, room_version_id) + .expect("Event format validated when event was hashed") + ); + + let event_id = ::try_from(event_id) + .expect("ruma's reference hashes are valid event ids"); + + member_event_stub.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + Ok((event_id, member_event_stub, restricted_join)) + } +} + +async fn make_join_request( + sender_user: &UserId, + room_id: &RoomId, + servers: &[OwnedServerName], +) -> Result<( + federation::membership::prepare_join_event::v1::Response, + OwnedServerName, +)> { + let mut make_join_response_and_server = Err(Error::BadServerResponse( + "No server available to assist in joining.", + )); + + for remote_server in servers { + if remote_server == services().globals.server_name() { + continue; + } + info!("Asking {remote_server} for make_join"); + let make_join_response = services() + .sending + .send_federation_request( + remote_server, + federation::membership::prepare_join_event::v1::Request { + room_id: room_id.to_owned(), + user_id: sender_user.to_owned(), + ver: services().globals.supported_room_versions(), + }, + ) + .await; + + make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); + + if make_join_response_and_server.is_ok() { + break; + } + } + + make_join_response_and_server +} +async fn validate_and_add_event_id( + pdu: &RawJsonValue, + room_version: &RoomVersionId, + pub_key_map: &RwLock>, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&value, room_version) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid PDU format"))? + )) + .expect("ruma's reference hashes are valid event ids"); + + let back_off = |id| async { + match services() + .globals + .bad_event_ratelimiter + .write() + .await + .entry(id) + { + Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + } + }; + + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(&event_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {}", event_id); + return Err(Error::BadServerResponse("bad event, still backing off")); + } + } + + let origin_server_ts = value.get("origin_server_ts").ok_or_else(|| { + error!("Invalid PDU, no origin_server_ts field"); + Error::BadRequest( + ErrorKind::MissingParam, + "Invalid PDU, no origin_server_ts field", + ) + })?; + + let origin_server_ts: MilliSecondsSinceUnixEpoch = { + let ts = origin_server_ts.as_integer().ok_or_else(|| { + Error::BadRequest( + ErrorKind::InvalidParam, + "origin_server_ts must be an integer", + ) + })?; + + MilliSecondsSinceUnixEpoch(i64::from(ts).try_into().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Time must be after the unix epoch") + })?) + }; + + let unfiltered_keys = (*pub_key_map.read().await).clone(); + + let keys = + services() + .globals + .filter_keys_server_map(unfiltered_keys, origin_server_ts, room_version); + + if let Err(e) = ruma::signatures::verify_event(&keys, &value, room_version) { + warn!("Event {} failed verification {:?} {}", event_id, pdu, e); + back_off(event_id).await; + return Err(Error::BadServerResponse("Event failed verification.")); + } + + value.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + Ok((event_id, value)) +} diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index f0739841..2f0c3347 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -3,6 +3,7 @@ pub mod auth_chain; pub mod directory; pub mod edus; pub mod event_handler; +pub mod helpers; pub mod lazy_loading; pub mod metadata; pub mod outlier; @@ -45,6 +46,7 @@ pub struct Service { pub directory: directory::Service, pub edus: edus::Service, pub event_handler: event_handler::Service, + pub helpers: helpers::Service, pub lazy_loading: lazy_loading::Service, pub metadata: metadata::Service, pub outlier: outlier::Service, diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index f5bd7e9f..93b46072 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -10,6 +10,7 @@ use ruma::{ events::{ room::{create::RoomCreateEventContent, member::MembershipState}, AnyStrippedStateEvent, StateEventType, TimelineEventType, + RECOMMENDED_STRIPPED_STATE_EVENT_TYPES, }, serde::Raw, state_res::{self, StateMap}, @@ -258,58 +259,22 @@ impl Service { } } - #[tracing::instrument(skip(self, invite_event))] - pub fn calculate_invite_state( - &self, - invite_event: &PduEvent, - ) -> Result>> { - let mut state = Vec::new(); - // Add recommended events - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCreate, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomJoinRules, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomAvatar, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomName, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { - state.push(e.to_stripped_state_event()); - } - - state.push(invite_event.to_stripped_state_event()); - Ok(state) + #[tracing::instrument(skip(self, room_id))] + /// Gets all the [recommended stripped state events] from the given room + /// + /// [recommended stripped state events]: https://spec.matrix.org/v1.13/client-server-api/#stripped-state + pub fn stripped_state(&self, room_id: &RoomId) -> Result>> { + RECOMMENDED_STRIPPED_STATE_EVENT_TYPES + .iter() + .filter_map(|state_event_type| { + services() + .rooms + .state_accessor + .room_state_get(room_id, state_event_type, "") + .transpose() + }) + .map(|e| e.map(|e| e.to_stripped_state_event())) + .collect::>>() } /// Set the state hash to a new version, but does not update state_cache. diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index e1bcd3c3..fcf542c3 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -426,8 +426,8 @@ impl Service { }) } - /// Returns the join rule for a given room - pub fn get_join_rule( + /// Returns the space-room join rule for a given room + pub fn get_space_room_join_rule( &self, current_room: &RoomId, ) -> Result<(SpaceRoomJoinRule, Vec), Error> { @@ -450,6 +450,26 @@ impl Service { .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) } + /// Returns the join rules event content of a room, if there are any and we are aware of it locally + #[tracing::instrument(skip(self))] + pub fn get_join_rules( + &self, + room_id: &RoomId, + ) -> Result, Error> { + let join_rules_event = self.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + + join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str::(join_rules_event.content.get()) + .map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose() + } + /// Returns an empty vec if not a restricted room pub fn allowed_room_ids(&self, join_rule: JoinRule) -> Vec { let mut room_ids = vec![]; diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index b511919a..3ee73dfa 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -16,6 +16,12 @@ pub trait Data: Send + Sync { room_id: &RoomId, last_state: Option>>, ) -> Result<()>; + fn mark_as_knocked( + &self, + user_id: &UserId, + room_id: &RoomId, + last_state: Option>>, + ) -> Result<()>; fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; @@ -65,6 +71,8 @@ pub trait Data: Send + Sync { fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; /// Returns an iterator over all rooms this user joined. @@ -80,12 +88,25 @@ pub trait Data: Send + Sync { user_id: &UserId, ) -> Box>)>> + 'a>; + /// Returns an iterator over all rooms a user has knocked on. + #[allow(clippy::type_complexity)] + fn rooms_knocked<'a>( + &'a self, + user_id: &UserId, + ) -> Box>)>> + 'a>; + fn invite_state( &self, user_id: &UserId, room_id: &RoomId, ) -> Result>>>; + fn knock_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>>>; + fn left_state( &self, user_id: &UserId, @@ -105,5 +126,7 @@ pub trait Data: Send + Sync { fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 1604a14a..5cfd9738 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -12,7 +12,7 @@ use ruma::{ RoomAccountDataEventType, StateEventType, }, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; use tracing::warn; @@ -40,6 +40,13 @@ impl Service { // TODO: displayname, avatar url } + // We don't need to store stripped state on behalf of remote users, since these events are only used on `/sync` + let last_state = if user_id.server_name() == services().globals.server_name() { + last_state + } else { + None + }; + match &membership { MembershipState::Join => { // Check if the user never joined this room @@ -178,6 +185,9 @@ impl Service { self.db.mark_as_invited(user_id, room_id, last_state)?; } + MembershipState::Knock => { + self.db.mark_as_knocked(user_id, room_id, last_state)?; + } MembershipState::Leave | MembershipState::Ban => { self.db.mark_as_left(user_id, room_id)?; } @@ -283,6 +293,11 @@ impl Service { self.db.get_invite_count(room_id, user_id) } + #[tracing::instrument(skip(self))] + pub fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + self.db.get_knock_count(room_id, user_id) + } + #[tracing::instrument(skip(self))] pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { self.db.get_left_count(room_id, user_id) @@ -306,6 +321,15 @@ impl Service { self.db.rooms_invited(user_id) } + /// Returns an iterator over all rooms a user has knocked on. + #[tracing::instrument(skip(self))] + pub fn rooms_knocked<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator>)>> + 'a { + self.db.rooms_knocked(user_id) + } + #[tracing::instrument(skip(self))] pub fn invite_state( &self, @@ -315,6 +339,15 @@ impl Service { self.db.invite_state(user_id, room_id) } + #[tracing::instrument(skip(self))] + pub fn knock_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>>> { + self.db.knock_state(user_id, room_id) + } + #[tracing::instrument(skip(self))] pub fn left_state( &self, @@ -348,8 +381,60 @@ impl Service { self.db.is_invited(user_id, room_id) } + #[tracing::instrument(skip(self))] + pub fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.is_knocked(user_id, room_id) + } + #[tracing::instrument(skip(self))] pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } + + /// Function to assist performing a membership event that may require help from a remote server + /// + /// If a room id is provided, the servers returned will consist of: + /// - the `via` argument, provided by the client + /// - servers of the senders of the stripped state events we are given + /// - the server in the room id + /// + /// Otherwise, the servers returned will come from the response when resolving the alias. + #[tracing::instrument(skip(self))] + pub async fn get_room_id_and_via_servers( + &self, + sender_user: &UserId, + room_id_or_alias: OwnedRoomOrAliasId, + via: Vec, + ) -> Result<(Vec, OwnedRoomId), Error> { + let (servers, room_id) = match OwnedRoomId::try_from(room_id_or_alias) { + Ok(room_id) => { + let mut servers = via; + servers.extend( + self.invite_state(sender_user, &room_id) + .transpose() + .or_else(|| self.knock_state(sender_user, &room_id).transpose()) + .transpose()? + .unwrap_or_default() + .iter() + .filter_map(|event| event.deserialize().ok()) + .map(|event| event.sender().server_name().to_owned()), + ); + + servers.push( + room_id + .server_name() + .expect("Room IDs should always have a server name") + .into(), + ); + + (servers, room_id) + } + Err(room_alias) => { + let response = services().rooms.alias.get_alias_helper(room_alias).await?; + + (response.servers, response.room_id) + } + }; + Ok((servers, room_id)) + } } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 80690663..7615aed1 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -450,22 +450,29 @@ impl Service { let content = serde_json::from_str::(pdu.content.get()) .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - let invite_state = match content.membership { - MembershipState::Invite => { - let state = services().rooms.state.calculate_invite_state(pdu)?; + let stripped_state = match content.membership { + MembershipState::Invite | MembershipState::Knock => { + let mut state = services().rooms.state.stripped_state(&pdu.room_id)?; + // So that clients can get info about who invitied them (not relevant for knocking), the reason, when, etc. + state.push(pdu.to_stripped_state_event()); Some(state) } _ => None, }; - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth + // Here we don't attempt to join if the previous membership was knock and the + // new one is join, like we do for `/federation/*/invite`, as not only are there + // implementation difficulties due to callers not implementing `Send`, but + // invites we recieve which aren't over `/invite` must have been due to a + // database reset or switching server implementations, which means we probably + // shouldn't be joining automatically anyways, since it may surprise users to + // suddenly join rooms which clients didn't even show as being knocked on before. services().rooms.state_cache.update_membership( &pdu.room_id, &target_user_id, content.membership, &pdu.sender, - invite_state, + stripped_state, true, )?; }