diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 3ce37fa7..7c288b6e 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{ api::client_server::{self, membership::join_room_by_id_helper}, @@ -15,7 +17,7 @@ use ruma::{ uiaa::{AuthFlow, AuthType, UiaaInfo}, }, events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, - push, UserId, + push, RoomId, UserId, }; use tracing::{info, warn}; @@ -291,29 +293,27 @@ pub async fn register_route(body: Ruma) -> Result = default_rooms + .iter() + .map(Deref::deref) + .filter_map(RoomId::server_name) + .map(Into::into) + .collect(); - let _user_id = user_id.clone(); - let servers = [services().globals.server_name().to_owned()]; - - tokio::spawn(async move { - for room_id in default_rooms { - let _ = join_room_by_id_helper( - Some(&_user_id), - room_id, - Some("All men are equal before fish.".to_owned()), - &servers, - None, - ) - .await - .inspect_err(|e| { - tracing::warn!("Failed to join default room: {e}"); - }); + for room_id in default_rooms { + if let Err(e) = join_room_by_id_helper( + Some(&user_id), + room_id, + Some("All men are equal before fish.".to_owned()), + &servers, + None, + ) + .await + { + warn!("Failed to join default room: {}", e); } - }); + } } Ok(register::v3::Response { diff --git a/src/database/mod.rs b/src/database/mod.rs index ca2c7608..b7cda860 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -970,11 +970,11 @@ impl KeyValueDatabase { db.presenceid_presence.clear()?; // Can only return an error during the first call - services().globals.default_rooms().map_err(|e| { - tracing::error!("Invalid room ID or alias in join-by-default rooms: {}", e); + // services().globals.default_rooms().await.map_err(|e| { + // tracing::error!("Invalid room ID or alias in join-by-default rooms: {}", e); - Error::bad_config("Invalid room ID or alias in join-by-default rooms.") - })?; + // Error::bad_config("Invalid room ID or alias in join-by-default rooms.") + // })?; services().admin.start_handler(); diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 94a2480d..87215be6 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -10,7 +10,7 @@ use tokio::sync::OnceCell; use crate::api::server_server::DestinationResponse; use crate::{services, Config, Error, Result}; -use futures_util::FutureExt; +use futures_util::{FutureExt, TryFutureExt}; use hickory_resolver::TokioAsyncResolver; use hyper_util::client::legacy::connect::dns::{GaiResolver, Name as HyperName}; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; @@ -509,41 +509,55 @@ impl Service { self.config.well_known_client() } - pub fn default_rooms(&self) -> Result<&BTreeSet> { - if let Some(set) = self.default_rooms.get() { - return Ok(set); + pub async fn default_rooms(&self) -> Result<&BTreeSet> { + if let Some(default_rooms) = self.default_rooms.get() { + return Ok(default_rooms); } - let server_name = self.config.server_name.as_str(); let mut default_rooms = BTreeSet::new(); - for s in self.config.default_rooms.iter() { - let next = s - .chars() - .next() - .ok_or_else(|| Error::bad_config("Invalid ID for join-by-default rooms."))?; - + for mut alias_or_id in self.config.default_rooms.iter().cloned() { // anything that does not start '!' should be considered an alias - let room_id = match next { - '!' => OwnedRoomId::from_str(&format!( - "{}:{server_name}", - s.split(':').next().unwrap_or(&s) - )) - .map_err(|_| Error::bad_config("Invalid ID for join-by-default rooms."))?, - _ => { - let Some(room_id) = OwnedRoomAliasId::from_str(&format!( - "#{}:{server_name}", - s.split(':').next().unwrap_or(&s).trim_start_matches('#') - )) - .map_err(|_| Error::bad_config("Invalid alias for join-by-default rooms.")) - .and_then(|alias| services().rooms.alias.resolve_local_alias(&alias))? - else { - warn!("Could not resolve Room ID for join-by-default rooms locally."); + // empty strings are ignored + let room_id = if Some('!') == alias_or_id.chars().next() { + if alias_or_id.split_once(':').is_none() { + alias_or_id = format!("{}:{}", alias_or_id, self.config.server_name); + } - continue; - }; + OwnedRoomId::from_str(&alias_or_id).map_err(|e| { + warn!( + "Invalid room ID ({}) for join-by-default rooms: {}", + alias_or_id, e + ); - room_id + Error::bad_config("Invalid room ID for join-by-default rooms.") + })? + } else { + if alias_or_id.split_once(':').is_none() { + alias_or_id = format!("#{}:{}", alias_or_id, self.config.server_name); + } + + let room_alias = OwnedRoomAliasId::from_str(&alias_or_id).map_err(|e| { + warn!( + "Invalid room alias ({}) for join-by-default rooms: {}", + alias_or_id, e + ); + + Error::bad_config("Invalid room alias for join-by-default rooms.") + })?; + + if room_alias.server_name() == &self.config.server_name { + services() + .rooms + .alias + .resolve_local_alias(&room_alias)? + .ok_or_else(|| { + Error::bad_config("Unknown alias for join-by-default rooms.") + })? + } else { + get_alias_helper(room_alias.clone()) + .await + .map(|res| res.room_id)? } }; @@ -552,9 +566,9 @@ impl Service { self.default_rooms .set(default_rooms) - .expect("default_rooms should be set once"); + .expect("default_rooms should not be set already"); - self.default_rooms() + Ok(self.default_rooms.get().unwrap()) } pub fn shutdown(&self) {