diff --git a/src/api/client/message.rs b/src/api/client/message.rs index f8818ebb..95a135e1 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,14 +1,14 @@ use axum::extract::State; use conduwuit::{ - Err, Result, at, + Err, Result, at, err, matrix::{ event::{Event, Matches}, - pdu::PduCount, + pdu::{PduCount, ShortEventId}, }, ref_at, utils::{ IterStream, ReadyExt, - result::{FlatOk, LogErr}, + result::LogErr, stream::{BroadbandExt, TryIgnore, WidebandExt}, }, }; @@ -61,6 +61,39 @@ const IGNORED_MESSAGE_TYPES: &[TimelineEventType] = &[ const LIMIT_MAX: usize = 100; const LIMIT_DEFAULT: usize = 10; +/// Parse a pagination token, trying ShortEventId first, then falling back to +/// PduCount +async fn parse_pagination_token( + _services: &Services, + _room_id: &RoomId, + token: Option<&str>, + default: PduCount, +) -> Result { + let Some(token) = token else { + return Ok(default); + }; + + // Try parsing as ShortEventId first + if let Ok(shorteventid) = token.parse::() { + // ShortEventId maps directly to a PduCount in our database + Ok(PduCount::Normal(shorteventid)) + } else if let Ok(count) = token.parse::() { + // Fallback to PduCount for backwards compatibility + Ok(PduCount::Normal(count)) + } else if let Ok(count) = token.parse::() { + // Also handle negative counts for backfilled events + Ok(PduCount::from_signed(count)) + } else { + Err(err!(Request(InvalidParam("Invalid pagination token")))) + } +} + +/// Convert a PduCount to a token string (using the underlying ShortEventId) +fn count_to_token(count: PduCount) -> String { + // The PduCount's unsigned value IS the ShortEventId + count.into_unsigned().to_string() +} + /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// /// Allows paginating through room history. @@ -81,17 +114,18 @@ pub(crate) async fn get_message_events_route( return Err!(Request(Forbidden("Room does not exist to this server"))); } - let from: PduCount = body - .from - .as_deref() - .map(str::parse) - .transpose()? - .unwrap_or_else(|| match body.dir { + let from: PduCount = + parse_pagination_token(&services, room_id, body.from.as_deref(), match body.dir { | Direction::Forward => PduCount::min(), | Direction::Backward => PduCount::max(), - }); + }) + .await?; - let to: Option = body.to.as_deref().map(str::parse).flat_ok(); + let to: Option = if let Some(to_str) = body.to.as_deref() { + Some(parse_pagination_token(&services, room_id, Some(to_str), PduCount::min()).await?) + } else { + None + }; let limit: usize = body .limit @@ -180,8 +214,8 @@ pub(crate) async fn get_message_events_route( .collect(); Ok(get_message_events::v3::Response { - start: from.to_string(), - end: next_token.as_ref().map(ToString::to_string), + start: count_to_token(from), + end: next_token.map(count_to_token), chunk, state, }) diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 1aa34ada..48bcde20 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,7 +1,11 @@ use axum::extract::State; use conduwuit::{ - Result, at, - matrix::{Event, event::RelationTypeEqual, pdu::PduCount}, + Result, at, err, + matrix::{ + Event, + event::RelationTypeEqual, + pdu::{PduCount, ShortEventId}, + }, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, }; use conduwuit_service::Services; @@ -20,6 +24,40 @@ use ruma::{ use crate::Ruma; +/// Parse a pagination token, trying ShortEventId first, then falling back to +/// PduCount +async fn parse_pagination_token( + _services: &Services, + _room_id: &RoomId, + token: Option<&str>, + default: PduCount, +) -> Result { + let Some(token) = token else { + return Ok(default); + }; + + // Try parsing as ShortEventId first + if let Ok(shorteventid) = token.parse::() { + // ShortEventId maps directly to a PduCount in our database + // The shorteventid IS the count value, just need to wrap it + Ok(PduCount::Normal(shorteventid)) + } else if let Ok(count) = token.parse::() { + // Fallback to PduCount for backwards compatibility + Ok(PduCount::Normal(count)) + } else if let Ok(count) = token.parse::() { + // Also handle negative counts for backfilled events + Ok(PduCount::from_signed(count)) + } else { + Err(err!(Request(InvalidParam("Invalid pagination token")))) + } +} + +/// Convert a PduCount to a token string (using the underlying ShortEventId) +fn count_to_token(count: PduCount) -> String { + // The PduCount's unsigned value IS the ShortEventId + count.into_unsigned().to_string() +} + /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, @@ -109,15 +147,17 @@ async fn paginate_relations_with_filter( recurse: bool, dir: Direction, ) -> Result { - let start: PduCount = from - .map(str::parse) - .transpose()? - .unwrap_or_else(|| match dir { - | Direction::Forward => PduCount::min(), - | Direction::Backward => PduCount::max(), - }); + let start: PduCount = parse_pagination_token(services, room_id, from, match dir { + | Direction::Forward => PduCount::min(), + | Direction::Backward => PduCount::max(), + }) + .await?; - let to: Option = to.map(str::parse).flat_ok(); + let to: Option = if let Some(to_str) = to { + Some(parse_pagination_token(services, room_id, Some(to_str), PduCount::min()).await?) + } else { + None + }; // Use limit or else 30, with maximum 100 let limit: usize = limit @@ -129,6 +169,11 @@ async fn paginate_relations_with_filter( // Spec (v1.10) recommends depth of at least 3 let depth: u8 = if recurse { 3 } else { 1 }; + // Check if this is a thread request + let is_thread = filter_rel_type + .as_ref() + .is_some_and(|rel| *rel == RelationType::Thread); + let events: Vec<_> = services .rooms .pdu_metadata @@ -152,23 +197,65 @@ async fn paginate_relations_with_filter( .collect() .await; - let next_batch = match dir { - | Direction::Forward => events.last(), - | Direction::Backward => events.first(), + // For threads, check if we should include the root event + let mut root_event = None; + if is_thread && dir == Direction::Backward { + // Check if we've reached the beginning of the thread + // (fewer events than requested means we've exhausted the thread) + if events.len() < limit { + // Try to get the thread root event + if let Ok(root_pdu) = services.rooms.timeline.get_pdu(target).await { + // Check visibility + if services + .rooms + .state_accessor + .user_can_see_event(sender_user, room_id, target) + .await + { + // Store the root event to add to the response + root_event = Some(root_pdu); + } + } + } } - .map(at!(0)) - .as_ref() - .map(ToString::to_string); + + // Determine if there are more events to fetch + let has_more = if root_event.is_some() { + false // We've included the root, no more events + } else { + // Check if we got a full page of results (might be more) + events.len() >= limit + }; + + let next_batch = if has_more { + match dir { + | Direction::Forward => events.last(), + | Direction::Backward => events.first(), + } + .map(|(count, _)| count_to_token(*count)) + } else { + None + }; + + // Build the response chunk with thread root if needed + let chunk: Vec<_> = if let Some(root) = root_event { + // Add root event at the beginning for backward pagination + std::iter::once(root.into_format()) + .chain(events.into_iter().map(at!(1)).map(Event::into_format)) + .collect() + } else { + events + .into_iter() + .map(at!(1)) + .map(Event::into_format) + .collect() + }; Ok(get_relating_events::v1::Response { next_batch, prev_batch: from.map(Into::into), recursion_depth: recurse.then_some(depth.into()), - chunk: events - .into_iter() - .map(at!(1)) - .map(Event::into_format) - .collect(), + chunk, }) } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index c1376cb0..f9cc80a0 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -61,9 +61,10 @@ impl Data { from: PduCount, dir: Direction, ) -> impl Stream + Send + '_ { + let from_unsigned = from.into_unsigned(); let mut current = ArrayVec::::new(); current.extend(target.to_be_bytes()); - current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes()); + current.extend(from_unsigned.to_be_bytes()); let current = current.as_slice(); match dir { | Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(), @@ -73,6 +74,17 @@ impl Data { .ready_take_while(move |key| key.starts_with(&target.to_be_bytes())) .map(|to_from| u64_from_u8(&to_from[8..16])) .map(PduCount::from_unsigned) + .ready_filter(move |count| { + if from == PduCount::min() || from == PduCount::max() { + true + } else { + let count_unsigned = count.into_unsigned(); + match dir { + | Direction::Forward => count_unsigned > from_unsigned, + | Direction::Backward => count_unsigned < from_unsigned, + } + } + }) .wide_filter_map(move |shorteventid| async move { let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();