1
0
Fork 0
mirror of https://forgejo.ellis.link/continuwuation/continuwuity.git synced 2025-07-27 18:28:31 +00:00

feat(policy-server): Don't fail-closed & refactor references

This commit is contained in:
nexy7574 2025-07-21 21:07:14 +01:00
parent 9465c5df1f
commit dfda27fadc
No known key found for this signature in database
GPG key ID: 0FA334385D0B689F
3 changed files with 49 additions and 46 deletions

View file

@ -1,5 +1,4 @@
mod acl_check; mod acl_check;
mod call_policyserv;
mod fetch_and_handle_outliers; mod fetch_and_handle_outliers;
mod fetch_prev; mod fetch_prev;
mod fetch_state; mod fetch_state;
@ -7,6 +6,7 @@ mod handle_incoming_pdu;
mod handle_outlier_pdu; mod handle_outlier_pdu;
mod handle_prev_pdu; mod handle_prev_pdu;
mod parse_incoming_pdu; mod parse_incoming_pdu;
mod policy_server;
mod resolve_state; mod resolve_state;
mod state_at_incoming; mod state_at_incoming;
mod upgrade_outlier_pdu; mod upgrade_outlier_pdu;

View file

@ -15,14 +15,14 @@ use ruma::{
/// Returns Ok if the policy server allows the event /// Returns Ok if the policy server allows the event
#[implement(super::Service)] #[implement(super::Service)]
#[tracing::instrument(skip_all, level = "debug")] #[tracing::instrument(skip_all, level = "debug")]
pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { pub async fn ask_policy_server(&self, pdu: &PduEvent, room_id: &RoomId) -> Result<bool> {
if *pdu.event_type() == StateEventType::RoomPolicy.into() { if *pdu.event_type() == StateEventType::RoomPolicy.into() {
debug!( debug!(
room_id = %room_id, room_id = %room_id,
event_type = ?pdu.event_type(), event_type = ?pdu.event_type(),
"Skipping spam check for policy server meta-event" "Skipping spam check for policy server meta-event"
); );
return Ok(()); return Ok(true);
} }
let Ok(policyserver) = self let Ok(policyserver) = self
.services .services
@ -31,19 +31,19 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result
.await .await
.map(|c: RoomPolicyEventContent| c) .map(|c: RoomPolicyEventContent| c)
else { else {
return Ok(()); return Ok(true);
}; };
let via = match policyserver.via { let via = match policyserver.via {
| Some(ref via) => ServerName::parse(via)?, | Some(ref via) => ServerName::parse(via)?,
| None => { | None => {
debug!("No policy server configured for room {room_id}"); debug!("No policy server configured for room {room_id}");
return Ok(()); return Ok(true);
}, },
}; };
if via.is_empty() { if via.is_empty() {
debug!("Policy server is empty for room {room_id}, skipping spam check"); debug!("Policy server is empty for room {room_id}, skipping spam check");
return Ok(()); return Ok(true);
} }
if !self.services.state_cache.server_in_room(via, room_id).await { if !self.services.state_cache.server_in_room(via, room_id).await {
debug!( debug!(
@ -51,7 +51,7 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result
via = %via, via = %via,
"Policy server is not in the room, skipping spam check" "Policy server is not in the room, skipping spam check"
); );
return Ok(()); return Ok(true);
} }
let outgoing = self let outgoing = self
.services .services
@ -85,7 +85,7 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result
); );
// Network or policy server errors are treated as non-fatal: event is allowed by // Network or policy server errors are treated as non-fatal: event is allowed by
// default. // default.
return Ok(()); return Err(e);
}, },
| Err(_) => { | Err(_) => {
warn!( warn!(
@ -94,7 +94,7 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result
room_id = %room_id, room_id = %room_id,
"Policy server request timed out after 10 seconds" "Policy server request timed out after 10 seconds"
); );
return Ok(()); return Err!("Request to policy server timed out");
}, },
}; };
if response.recommendation == "spam" { if response.recommendation == "spam" {
@ -107,5 +107,5 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result
return Err!(Request(Forbidden("Event was marked as spam by policy server"))); return Err!(Request(Forbidden("Event was marked as spam by policy server")));
} }
Ok(()) Ok(true)
} }

View file

@ -245,25 +245,27 @@ where
.await?; .await?;
} }
if !soft_fail {
// Don't call the below checks on events that have already soft-failed, there's
// no reason to re-calculate that.
// 14-pre. If the event is not a state event, ask the policy server about it // 14-pre. If the event is not a state event, ask the policy server about it
if incoming_pdu.state_key.is_none() { if incoming_pdu.state_key.is_none() {
debug!( debug!(event_id = %incoming_pdu.event_id, "Checking policy server for event");
event_id = %incoming_pdu.event_id,"Checking policy server for event"); match self.ask_policy_server(&incoming_pdu, room_id).await {
let policy = self.policyserv_check(&incoming_pdu, room_id); | Ok(false) => {
if let Err(e) = policy.await {
warn!( warn!(
event_id = %incoming_pdu.event_id, event_id = %incoming_pdu.event_id,
error = ?e, "Event has been marked as spam by policy server"
"Policy server check failed for event"
); );
if !soft_fail {
soft_fail = true; soft_fail = true;
} },
} | _ => {
debug!( debug!(
event_id = %incoming_pdu.event_id, event_id = %incoming_pdu.event_id,
"Policy server check passed for event" "Event has passed policy server check or the policy server was unavailable."
); );
},
};
} }
// Additionally, if this is a redaction for a soft-failed event, we soft-fail it // Additionally, if this is a redaction for a soft-failed event, we soft-fail it
@ -286,6 +288,7 @@ where
soft_fail = true; soft_fail = true;
} }
} }
}
// 14. Check if the event passes auth based on the "current state" of the room, // 14. Check if the event passes auth based on the "current state" of the room,
// if not soft fail it // if not soft fail it