diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 95a135e1..4d489c2f 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,9 +1,9 @@ use axum::extract::State; use conduwuit::{ - Err, Result, at, err, + Err, Result, at, matrix::{ event::{Event, Matches}, - pdu::{PduCount, ShortEventId}, + pdu::PduCount, }, ref_at, utils::{ @@ -35,6 +35,7 @@ use ruma::{ }; use tracing::warn; +use super::utils::{count_to_token, parse_pagination_token as parse_token}; use crate::Ruma; /// list of safe and common non-state events to ignore if the user is ignored @@ -61,39 +62,6 @@ const IGNORED_MESSAGE_TYPES: &[TimelineEventType] = &[ const LIMIT_MAX: usize = 100; const LIMIT_DEFAULT: usize = 10; -/// Parse a pagination token, trying ShortEventId first, then falling back to -/// PduCount -async fn parse_pagination_token( - _services: &Services, - _room_id: &RoomId, - token: Option<&str>, - default: PduCount, -) -> Result { - let Some(token) = token else { - return Ok(default); - }; - - // Try parsing as ShortEventId first - if let Ok(shorteventid) = token.parse::() { - // ShortEventId maps directly to a PduCount in our database - Ok(PduCount::Normal(shorteventid)) - } else if let Ok(count) = token.parse::() { - // Fallback to PduCount for backwards compatibility - Ok(PduCount::Normal(count)) - } else if let Ok(count) = token.parse::() { - // Also handle negative counts for backfilled events - Ok(PduCount::from_signed(count)) - } else { - Err(err!(Request(InvalidParam("Invalid pagination token")))) - } -} - -/// Convert a PduCount to a token string (using the underlying ShortEventId) -fn count_to_token(count: PduCount) -> String { - // The PduCount's unsigned value IS the ShortEventId - count.into_unsigned().to_string() -} - /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// /// Allows paginating through room history. @@ -114,18 +82,17 @@ pub(crate) async fn get_message_events_route( return Err!(Request(Forbidden("Room does not exist to this server"))); } - let from: PduCount = - parse_pagination_token(&services, room_id, body.from.as_deref(), match body.dir { + let from: PduCount = body + .from + .as_deref() + .map(parse_token) + .transpose()? + .unwrap_or_else(|| match body.dir { | Direction::Forward => PduCount::min(), | Direction::Backward => PduCount::max(), - }) - .await?; + }); - let to: Option = if let Some(to_str) = body.to.as_deref() { - Some(parse_pagination_token(&services, room_id, Some(to_str), PduCount::min()).await?) - } else { - None - }; + let to: Option = body.to.as_deref().map(parse_token).transpose()?; let limit: usize = body .limit diff --git a/src/api/client/mod.rs b/src/api/client/mod.rs index be54e65f..e4be20b7 100644 --- a/src/api/client/mod.rs +++ b/src/api/client/mod.rs @@ -36,6 +36,7 @@ pub(super) mod typing; pub(super) mod unstable; pub(super) mod unversioned; pub(super) mod user_directory; +pub(super) mod utils; pub(super) mod voip; pub(super) mod well_known; diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 48bcde20..f6d8fe9e 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,11 +1,7 @@ use axum::extract::State; use conduwuit::{ - Result, at, err, - matrix::{ - Event, - event::RelationTypeEqual, - pdu::{PduCount, ShortEventId}, - }, + Result, at, + matrix::{Event, event::RelationTypeEqual, pdu::PduCount}, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, }; use conduwuit_service::Services; @@ -22,42 +18,9 @@ use ruma::{ events::{TimelineEventType, relation::RelationType}, }; +use super::utils::{count_to_token, parse_pagination_token as parse_token}; use crate::Ruma; -/// Parse a pagination token, trying ShortEventId first, then falling back to -/// PduCount -async fn parse_pagination_token( - _services: &Services, - _room_id: &RoomId, - token: Option<&str>, - default: PduCount, -) -> Result { - let Some(token) = token else { - return Ok(default); - }; - - // Try parsing as ShortEventId first - if let Ok(shorteventid) = token.parse::() { - // ShortEventId maps directly to a PduCount in our database - // The shorteventid IS the count value, just need to wrap it - Ok(PduCount::Normal(shorteventid)) - } else if let Ok(count) = token.parse::() { - // Fallback to PduCount for backwards compatibility - Ok(PduCount::Normal(count)) - } else if let Ok(count) = token.parse::() { - // Also handle negative counts for backfilled events - Ok(PduCount::from_signed(count)) - } else { - Err(err!(Request(InvalidParam("Invalid pagination token")))) - } -} - -/// Convert a PduCount to a token string (using the underlying ShortEventId) -fn count_to_token(count: PduCount) -> String { - // The PduCount's unsigned value IS the ShortEventId - count.into_unsigned().to_string() -} - /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, @@ -147,17 +110,15 @@ async fn paginate_relations_with_filter( recurse: bool, dir: Direction, ) -> Result { - let start: PduCount = parse_pagination_token(services, room_id, from, match dir { - | Direction::Forward => PduCount::min(), - | Direction::Backward => PduCount::max(), - }) - .await?; + let start: PduCount = from + .map(parse_token) + .transpose()? + .unwrap_or_else(|| match dir { + | Direction::Forward => PduCount::min(), + | Direction::Backward => PduCount::max(), + }); - let to: Option = if let Some(to_str) = to { - Some(parse_pagination_token(services, room_id, Some(to_str), PduCount::min()).await?) - } else { - None - }; + let to: Option = to.map(parse_token).transpose()?; // Use limit or else 30, with maximum 100 let limit: usize = limit @@ -238,18 +199,11 @@ async fn paginate_relations_with_filter( }; // Build the response chunk with thread root if needed - let chunk: Vec<_> = if let Some(root) = root_event { - // Add root event at the beginning for backward pagination - std::iter::once(root.into_format()) - .chain(events.into_iter().map(at!(1)).map(Event::into_format)) - .collect() - } else { - events - .into_iter() - .map(at!(1)) - .map(Event::into_format) - .collect() - }; + let chunk: Vec<_> = root_event + .into_iter() + .map(Event::into_format) + .chain(events.into_iter().map(at!(1)).map(Event::into_format)) + .collect(); Ok(get_relating_events::v1::Response { next_batch, diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 992073c6..fe07e41d 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -198,8 +198,8 @@ pub(crate) async fn login_route( .clone() .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); - // Generate a new token for the device - let token = utils::random_string(TOKEN_LENGTH); + // Generate a new token for the device (ensuring no collisions) + let token = services.users.generate_unique_token().await; // Determine if device_id was provided and exists in the db for this user let device_exists = if body.device_id.is_some() { diff --git a/src/api/client/utils.rs b/src/api/client/utils.rs new file mode 100644 index 00000000..cc941b95 --- /dev/null +++ b/src/api/client/utils.rs @@ -0,0 +1,28 @@ +use conduwuit::{ + Result, err, + matrix::pdu::{PduCount, ShortEventId}, +}; + +/// Parse a pagination token, trying ShortEventId first, then falling back to +/// PduCount +pub(crate) fn parse_pagination_token(token: &str) -> Result { + // Try parsing as ShortEventId first + if let Ok(shorteventid) = token.parse::() { + // ShortEventId maps directly to a PduCount in our database + Ok(PduCount::Normal(shorteventid)) + } else if let Ok(count) = token.parse::() { + // Fallback to PduCount for backwards compatibility + Ok(PduCount::Normal(count)) + } else if let Ok(count) = token.parse::() { + // Also handle negative counts for backfilled events + Ok(PduCount::from_signed(count)) + } else { + Err(err!(Request(InvalidParam("Invalid pagination token")))) + } +} + +/// Convert a PduCount to a token string (using the underlying ShortEventId) +pub(crate) fn count_to_token(count: PduCount) -> String { + // The PduCount's unsigned value IS the ShortEventId + count.into_unsigned().to_string() +} diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 5088e699..44afc3ef 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -355,6 +355,7 @@ async fn find_token(services: &Services, token: Option<&str>) -> Result { .map_ok(Token::Appservice); pin_mut!(user_token, appservice_token); + // Returns Ok if either token type succeeds, Err only if both fail match select_ok([Left(user_token), Right(appservice_token)]).await { | Err(e) if !e.is_not_found() => Err(e), | Ok((token, _)) => Ok(token), diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index ad9c4a3f..ebd798f6 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -4,7 +4,7 @@ mod registration_info; use std::{collections::BTreeMap, iter::IntoIterator, sync::Arc}; use async_trait::async_trait; -use conduwuit::{Result, err, utils::stream::IterStream}; +use conduwuit::{Err, Result, err, utils::stream::IterStream}; use database::Map; use futures::{Future, FutureExt, Stream, TryStreamExt}; use ruma::{RoomAliasId, RoomId, UserId, api::appservice::Registration}; @@ -48,36 +48,50 @@ impl crate::Service for Service { } async fn worker(self: Arc) -> Result { - self.iter_db_ids() - .try_for_each(async |appservice| { - let (id, registration) = appservice; + // First, collect all appservices to check for token conflicts + let appservices: Vec<(String, Registration)> = self.iter_db_ids().try_collect().await?; - // During startup, resolve any token collisions in favour of appservices - // by logging out conflicting user devices - if let Ok((user_id, device_id)) = self - .services - .users - .find_from_token(®istration.as_token) - .await - { - conduwuit::warn!( - "Token collision detected during startup: Appservice '{}' token was \ - also used by user '{}' device '{}'. Logging out the user device to \ - resolve conflict.", - id, - user_id.localpart(), - device_id - ); - - self.services - .users - .remove_device(&user_id, &device_id) - .await; + // Check for appservice-to-appservice token conflicts + for i in 0..appservices.len() { + for j in i.saturating_add(1)..appservices.len() { + if appservices[i].1.as_token == appservices[j].1.as_token { + return Err!(Database(error!( + "Token collision detected: Appservices '{}' and '{}' have the same token", + appservices[i].0, appservices[j].0 + ))); } + } + } - self.start_appservice(id, registration).await - }) - .await + // Process each appservice + for (id, registration) in appservices { + // During startup, resolve any token collisions in favour of appservices + // by logging out conflicting user devices + if let Ok((user_id, device_id)) = self + .services + .users + .find_from_token(®istration.as_token) + .await + { + conduwuit::warn!( + "Token collision detected during startup: Appservice '{}' token was also \ + used by user '{}' device '{}'. Logging out the user device to resolve \ + conflict.", + id, + user_id.localpart(), + device_id + ); + + self.services + .users + .remove_device(&user_id, &device_id) + .await; + } + + self.start_appservice(id, registration).await?; + } + + Ok(()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -125,6 +139,18 @@ impl Service { ) -> Result { //TODO: Check for collisions between exclusive appservice namespaces + // Check for token collision with other appservices (allow re-registration of + // same appservice) + if let Ok(existing) = self.find_from_token(®istration.as_token).await { + if existing.registration.id != registration.id { + return Err(err!(Request(InvalidParam( + "Cannot register appservice: Token is already used by appservice '{}'. \ + Please generate a different token.", + existing.registration.id + )))); + } + } + // Prevent token collision with existing user tokens if self .services @@ -182,6 +208,7 @@ impl Service { .map(|info| info.registration) } + /// Returns Result to match users::find_from_token for select_ok usage pub async fn find_from_token(&self, token: &str) -> Result { self.read() .await diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index f9cc80a0..a746b4cc 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -61,6 +61,8 @@ impl Data { from: PduCount, dir: Direction, ) -> impl Stream + Send + '_ { + // Query from exact position then filter excludes it (saturating_inc could skip + // events at min/max boundaries) let from_unsigned = from.into_unsigned(); let mut current = ArrayVec::::new(); current.extend(target.to_be_bytes()); diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 0aacc0e1..eb54660e 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -393,6 +393,31 @@ impl Service { self.db.userdeviceid_token.qry(&key).await.deserialized() } + /// Generate a unique access token that doesn't collide with existing tokens + pub async fn generate_unique_token(&self) -> String { + loop { + let token = utils::random_string(32); + + // Check for collision with appservice tokens + if self + .services + .appservice + .find_from_token(&token) + .await + .is_ok() + { + continue; + } + + // Check for collision with user tokens + if self.db.token_userdeviceid.get(&token).await.is_ok() { + continue; + } + + return token; + } + } + /// Replaces the access token of one device. pub async fn set_token( &self, @@ -409,25 +434,18 @@ impl Service { ))); } - // Prevent token collisions with appservice tokens - let final_token = if self + // Check for token collision with appservices + if self .services .appservice .find_from_token(token) .await .is_ok() { - let new_token = utils::random_string(32); - conduwuit::debug_warn!( - "Token collision prevented: Generated new token for user '{}' device '{}' \ - (original token conflicted with an appservice)", - user_id.localpart(), - device_id - ); - new_token - } else { - token.to_owned() - }; + return Err!(Request(InvalidParam( + "Token conflicts with an existing appservice token" + ))); + } // Remove old token if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { @@ -436,8 +454,8 @@ impl Service { } // Assign token to user device combination - self.db.userdeviceid_token.put_raw(key, &final_token); - self.db.token_userdeviceid.raw_put(&final_token, key); + self.db.userdeviceid_token.put_raw(key, token); + self.db.token_userdeviceid.raw_put(token, key); Ok(()) }