1
0
Fork 0
mirror of https://forgejo.ellis.link/continuwuation/continuwuity.git synced 2025-09-03 16:50:56 +00:00

refactor: address code review feedback for auth and pagination improvements

- Extract duplicated thread/message pagination functions to shared utils module
- Refactor pagination token parsing to use Option combinators instead of defaults
- Split access token generation from assignment for clearer error handling
- Add appservice token collision detection at startup and registration
- Allow appservice re-registration with same token (for config updates)
- Simplify thread relation chunk building using iterator chaining
- Fix saturating_inc edge case in relation queries with explicit filtering
- Add concise comments explaining non-obvious behaviour choices
This commit is contained in:
Tom Foster 2025-08-11 06:24:29 +01:00
parent 9286838d23
commit 583cb924f1
9 changed files with 149 additions and 151 deletions

View file

@ -1,9 +1,9 @@
use axum::extract::State; use axum::extract::State;
use conduwuit::{ use conduwuit::{
Err, Result, at, err, Err, Result, at,
matrix::{ matrix::{
event::{Event, Matches}, event::{Event, Matches},
pdu::{PduCount, ShortEventId}, pdu::PduCount,
}, },
ref_at, ref_at,
utils::{ utils::{
@ -35,6 +35,7 @@ use ruma::{
}; };
use tracing::warn; use tracing::warn;
use super::utils::{count_to_token, parse_pagination_token as parse_token};
use crate::Ruma; use crate::Ruma;
/// list of safe and common non-state events to ignore if the user is ignored /// list of safe and common non-state events to ignore if the user is ignored
@ -61,39 +62,6 @@ const IGNORED_MESSAGE_TYPES: &[TimelineEventType] = &[
const LIMIT_MAX: usize = 100; const LIMIT_MAX: usize = 100;
const LIMIT_DEFAULT: usize = 10; const LIMIT_DEFAULT: usize = 10;
/// Parse a pagination token, trying ShortEventId first, then falling back to
/// PduCount
async fn parse_pagination_token(
_services: &Services,
_room_id: &RoomId,
token: Option<&str>,
default: PduCount,
) -> Result<PduCount> {
let Some(token) = token else {
return Ok(default);
};
// Try parsing as ShortEventId first
if let Ok(shorteventid) = token.parse::<ShortEventId>() {
// ShortEventId maps directly to a PduCount in our database
Ok(PduCount::Normal(shorteventid))
} else if let Ok(count) = token.parse::<u64>() {
// Fallback to PduCount for backwards compatibility
Ok(PduCount::Normal(count))
} else if let Ok(count) = token.parse::<i64>() {
// Also handle negative counts for backfilled events
Ok(PduCount::from_signed(count))
} else {
Err(err!(Request(InvalidParam("Invalid pagination token"))))
}
}
/// Convert a PduCount to a token string (using the underlying ShortEventId)
fn count_to_token(count: PduCount) -> String {
// The PduCount's unsigned value IS the ShortEventId
count.into_unsigned().to_string()
}
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
/// ///
/// Allows paginating through room history. /// Allows paginating through room history.
@ -114,18 +82,17 @@ pub(crate) async fn get_message_events_route(
return Err!(Request(Forbidden("Room does not exist to this server"))); return Err!(Request(Forbidden("Room does not exist to this server")));
} }
let from: PduCount = let from: PduCount = body
parse_pagination_token(&services, room_id, body.from.as_deref(), match body.dir { .from
.as_deref()
.map(parse_token)
.transpose()?
.unwrap_or_else(|| match body.dir {
| Direction::Forward => PduCount::min(), | Direction::Forward => PduCount::min(),
| Direction::Backward => PduCount::max(), | Direction::Backward => PduCount::max(),
}) });
.await?;
let to: Option<PduCount> = if let Some(to_str) = body.to.as_deref() { let to: Option<PduCount> = body.to.as_deref().map(parse_token).transpose()?;
Some(parse_pagination_token(&services, room_id, Some(to_str), PduCount::min()).await?)
} else {
None
};
let limit: usize = body let limit: usize = body
.limit .limit

View file

@ -36,6 +36,7 @@ pub(super) mod typing;
pub(super) mod unstable; pub(super) mod unstable;
pub(super) mod unversioned; pub(super) mod unversioned;
pub(super) mod user_directory; pub(super) mod user_directory;
pub(super) mod utils;
pub(super) mod voip; pub(super) mod voip;
pub(super) mod well_known; pub(super) mod well_known;

View file

@ -1,11 +1,7 @@
use axum::extract::State; use axum::extract::State;
use conduwuit::{ use conduwuit::{
Result, at, err, Result, at,
matrix::{ matrix::{Event, event::RelationTypeEqual, pdu::PduCount},
Event,
event::RelationTypeEqual,
pdu::{PduCount, ShortEventId},
},
utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt},
}; };
use conduwuit_service::Services; use conduwuit_service::Services;
@ -22,42 +18,9 @@ use ruma::{
events::{TimelineEventType, relation::RelationType}, events::{TimelineEventType, relation::RelationType},
}; };
use super::utils::{count_to_token, parse_pagination_token as parse_token};
use crate::Ruma; use crate::Ruma;
/// Parse a pagination token, trying ShortEventId first, then falling back to
/// PduCount
async fn parse_pagination_token(
_services: &Services,
_room_id: &RoomId,
token: Option<&str>,
default: PduCount,
) -> Result<PduCount> {
let Some(token) = token else {
return Ok(default);
};
// Try parsing as ShortEventId first
if let Ok(shorteventid) = token.parse::<ShortEventId>() {
// ShortEventId maps directly to a PduCount in our database
// The shorteventid IS the count value, just need to wrap it
Ok(PduCount::Normal(shorteventid))
} else if let Ok(count) = token.parse::<u64>() {
// Fallback to PduCount for backwards compatibility
Ok(PduCount::Normal(count))
} else if let Ok(count) = token.parse::<i64>() {
// Also handle negative counts for backfilled events
Ok(PduCount::from_signed(count))
} else {
Err(err!(Request(InvalidParam("Invalid pagination token"))))
}
}
/// Convert a PduCount to a token string (using the underlying ShortEventId)
fn count_to_token(count: PduCount) -> String {
// The PduCount's unsigned value IS the ShortEventId
count.into_unsigned().to_string()
}
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}`
pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
State(services): State<crate::State>, State(services): State<crate::State>,
@ -147,17 +110,15 @@ async fn paginate_relations_with_filter(
recurse: bool, recurse: bool,
dir: Direction, dir: Direction,
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let start: PduCount = parse_pagination_token(services, room_id, from, match dir { let start: PduCount = from
| Direction::Forward => PduCount::min(), .map(parse_token)
| Direction::Backward => PduCount::max(), .transpose()?
}) .unwrap_or_else(|| match dir {
.await?; | Direction::Forward => PduCount::min(),
| Direction::Backward => PduCount::max(),
});
let to: Option<PduCount> = if let Some(to_str) = to { let to: Option<PduCount> = to.map(parse_token).transpose()?;
Some(parse_pagination_token(services, room_id, Some(to_str), PduCount::min()).await?)
} else {
None
};
// Use limit or else 30, with maximum 100 // Use limit or else 30, with maximum 100
let limit: usize = limit let limit: usize = limit
@ -238,18 +199,11 @@ async fn paginate_relations_with_filter(
}; };
// Build the response chunk with thread root if needed // Build the response chunk with thread root if needed
let chunk: Vec<_> = if let Some(root) = root_event { let chunk: Vec<_> = root_event
// Add root event at the beginning for backward pagination .into_iter()
std::iter::once(root.into_format()) .map(Event::into_format)
.chain(events.into_iter().map(at!(1)).map(Event::into_format)) .chain(events.into_iter().map(at!(1)).map(Event::into_format))
.collect() .collect();
} else {
events
.into_iter()
.map(at!(1))
.map(Event::into_format)
.collect()
};
Ok(get_relating_events::v1::Response { Ok(get_relating_events::v1::Response {
next_batch, next_batch,

View file

@ -198,8 +198,8 @@ pub(crate) async fn login_route(
.clone() .clone()
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
// Generate a new token for the device // Generate a new token for the device (ensuring no collisions)
let token = utils::random_string(TOKEN_LENGTH); let token = services.users.generate_unique_token().await;
// Determine if device_id was provided and exists in the db for this user // Determine if device_id was provided and exists in the db for this user
let device_exists = if body.device_id.is_some() { let device_exists = if body.device_id.is_some() {

28
src/api/client/utils.rs Normal file
View file

@ -0,0 +1,28 @@
use conduwuit::{
Result, err,
matrix::pdu::{PduCount, ShortEventId},
};
/// Parse a pagination token, trying ShortEventId first, then falling back to
/// PduCount
pub(crate) fn parse_pagination_token(token: &str) -> Result<PduCount> {
// Try parsing as ShortEventId first
if let Ok(shorteventid) = token.parse::<ShortEventId>() {
// ShortEventId maps directly to a PduCount in our database
Ok(PduCount::Normal(shorteventid))
} else if let Ok(count) = token.parse::<u64>() {
// Fallback to PduCount for backwards compatibility
Ok(PduCount::Normal(count))
} else if let Ok(count) = token.parse::<i64>() {
// Also handle negative counts for backfilled events
Ok(PduCount::from_signed(count))
} else {
Err(err!(Request(InvalidParam("Invalid pagination token"))))
}
}
/// Convert a PduCount to a token string (using the underlying ShortEventId)
pub(crate) fn count_to_token(count: PduCount) -> String {
// The PduCount's unsigned value IS the ShortEventId
count.into_unsigned().to_string()
}

View file

@ -355,6 +355,7 @@ async fn find_token(services: &Services, token: Option<&str>) -> Result<Token> {
.map_ok(Token::Appservice); .map_ok(Token::Appservice);
pin_mut!(user_token, appservice_token); pin_mut!(user_token, appservice_token);
// Returns Ok if either token type succeeds, Err only if both fail
match select_ok([Left(user_token), Right(appservice_token)]).await { match select_ok([Left(user_token), Right(appservice_token)]).await {
| Err(e) if !e.is_not_found() => Err(e), | Err(e) if !e.is_not_found() => Err(e),
| Ok((token, _)) => Ok(token), | Ok((token, _)) => Ok(token),

View file

@ -4,7 +4,7 @@ mod registration_info;
use std::{collections::BTreeMap, iter::IntoIterator, sync::Arc}; use std::{collections::BTreeMap, iter::IntoIterator, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{Result, err, utils::stream::IterStream}; use conduwuit::{Err, Result, err, utils::stream::IterStream};
use database::Map; use database::Map;
use futures::{Future, FutureExt, Stream, TryStreamExt}; use futures::{Future, FutureExt, Stream, TryStreamExt};
use ruma::{RoomAliasId, RoomId, UserId, api::appservice::Registration}; use ruma::{RoomAliasId, RoomId, UserId, api::appservice::Registration};
@ -48,36 +48,50 @@ impl crate::Service for Service {
} }
async fn worker(self: Arc<Self>) -> Result { async fn worker(self: Arc<Self>) -> Result {
self.iter_db_ids() // First, collect all appservices to check for token conflicts
.try_for_each(async |appservice| { let appservices: Vec<(String, Registration)> = self.iter_db_ids().try_collect().await?;
let (id, registration) = appservice;
// During startup, resolve any token collisions in favour of appservices // Check for appservice-to-appservice token conflicts
// by logging out conflicting user devices for i in 0..appservices.len() {
if let Ok((user_id, device_id)) = self for j in i.saturating_add(1)..appservices.len() {
.services if appservices[i].1.as_token == appservices[j].1.as_token {
.users return Err!(Database(error!(
.find_from_token(&registration.as_token) "Token collision detected: Appservices '{}' and '{}' have the same token",
.await appservices[i].0, appservices[j].0
{ )));
conduwuit::warn!(
"Token collision detected during startup: Appservice '{}' token was \
also used by user '{}' device '{}'. Logging out the user device to \
resolve conflict.",
id,
user_id.localpart(),
device_id
);
self.services
.users
.remove_device(&user_id, &device_id)
.await;
} }
}
}
self.start_appservice(id, registration).await // Process each appservice
}) for (id, registration) in appservices {
.await // During startup, resolve any token collisions in favour of appservices
// by logging out conflicting user devices
if let Ok((user_id, device_id)) = self
.services
.users
.find_from_token(&registration.as_token)
.await
{
conduwuit::warn!(
"Token collision detected during startup: Appservice '{}' token was also \
used by user '{}' device '{}'. Logging out the user device to resolve \
conflict.",
id,
user_id.localpart(),
device_id
);
self.services
.users
.remove_device(&user_id, &device_id)
.await;
}
self.start_appservice(id, registration).await?;
}
Ok(())
} }
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
@ -125,6 +139,18 @@ impl Service {
) -> Result { ) -> Result {
//TODO: Check for collisions between exclusive appservice namespaces //TODO: Check for collisions between exclusive appservice namespaces
// Check for token collision with other appservices (allow re-registration of
// same appservice)
if let Ok(existing) = self.find_from_token(&registration.as_token).await {
if existing.registration.id != registration.id {
return Err(err!(Request(InvalidParam(
"Cannot register appservice: Token is already used by appservice '{}'. \
Please generate a different token.",
existing.registration.id
))));
}
}
// Prevent token collision with existing user tokens // Prevent token collision with existing user tokens
if self if self
.services .services
@ -182,6 +208,7 @@ impl Service {
.map(|info| info.registration) .map(|info| info.registration)
} }
/// Returns Result to match users::find_from_token for select_ok usage
pub async fn find_from_token(&self, token: &str) -> Result<RegistrationInfo> { pub async fn find_from_token(&self, token: &str) -> Result<RegistrationInfo> {
self.read() self.read()
.await .await

View file

@ -61,6 +61,8 @@ impl Data {
from: PduCount, from: PduCount,
dir: Direction, dir: Direction,
) -> impl Stream<Item = (PduCount, impl Event)> + Send + '_ { ) -> impl Stream<Item = (PduCount, impl Event)> + Send + '_ {
// Query from exact position then filter excludes it (saturating_inc could skip
// events at min/max boundaries)
let from_unsigned = from.into_unsigned(); let from_unsigned = from.into_unsigned();
let mut current = ArrayVec::<u8, 16>::new(); let mut current = ArrayVec::<u8, 16>::new();
current.extend(target.to_be_bytes()); current.extend(target.to_be_bytes());

View file

@ -393,6 +393,31 @@ impl Service {
self.db.userdeviceid_token.qry(&key).await.deserialized() self.db.userdeviceid_token.qry(&key).await.deserialized()
} }
/// Generate a unique access token that doesn't collide with existing tokens
pub async fn generate_unique_token(&self) -> String {
loop {
let token = utils::random_string(32);
// Check for collision with appservice tokens
if self
.services
.appservice
.find_from_token(&token)
.await
.is_ok()
{
continue;
}
// Check for collision with user tokens
if self.db.token_userdeviceid.get(&token).await.is_ok() {
continue;
}
return token;
}
}
/// Replaces the access token of one device. /// Replaces the access token of one device.
pub async fn set_token( pub async fn set_token(
&self, &self,
@ -409,25 +434,18 @@ impl Service {
))); )));
} }
// Prevent token collisions with appservice tokens // Check for token collision with appservices
let final_token = if self if self
.services .services
.appservice .appservice
.find_from_token(token) .find_from_token(token)
.await .await
.is_ok() .is_ok()
{ {
let new_token = utils::random_string(32); return Err!(Request(InvalidParam(
conduwuit::debug_warn!( "Token conflicts with an existing appservice token"
"Token collision prevented: Generated new token for user '{}' device '{}' \ )));
(original token conflicted with an appservice)", }
user_id.localpart(),
device_id
);
new_token
} else {
token.to_owned()
};
// Remove old token // Remove old token
if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await {
@ -436,8 +454,8 @@ impl Service {
} }
// Assign token to user device combination // Assign token to user device combination
self.db.userdeviceid_token.put_raw(key, &final_token); self.db.userdeviceid_token.put_raw(key, token);
self.db.token_userdeviceid.raw_put(&final_token, key); self.db.token_userdeviceid.raw_put(token, key);
Ok(()) Ok(())
} }