1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-09-10 18:51:00 +00:00

refactor: use async-aware RwLocks and Mutexes where possible

This commit is contained in:
Matthias Ahouansou 2024-03-05 14:22:54 +00:00
parent 57575b7c6f
commit becaad677f
No known key found for this signature in database
21 changed files with 1171 additions and 1006 deletions

View file

@ -1,25 +1,24 @@
/// An async function that can recursively call itself.
type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
use ruma::{
api::federation::discovery::{get_remote_server_keys, get_server_keys},
CanonicalJsonObject, CanonicalJsonValue, OwnedServerName, OwnedServerSigningKeyId,
RoomVersionId,
};
use std::{
collections::{hash_map, BTreeMap, HashMap, HashSet},
pin::Pin,
sync::{Arc, RwLock, RwLockWriteGuard},
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use tokio::sync::Semaphore;
use async_recursion::async_recursion;
use futures_util::{stream::FuturesUnordered, Future, StreamExt};
use ruma::{
api::{
client::error::ErrorKind,
federation::{
discovery::get_remote_server_keys_batch::{self, v2::QueryCriteria},
discovery::{
get_remote_server_keys,
get_remote_server_keys_batch::{self, v2::QueryCriteria},
get_server_keys,
},
event::{get_event, get_room_state_ids},
membership::create_join_event,
},
@ -31,9 +30,11 @@ use ruma::{
int,
serde::Base64,
state_res::{self, RoomVersion, StateMap},
uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName,
uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
OwnedServerName, OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName,
};
use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore};
use tracing::{debug, error, info, trace, warn};
use crate::{service::*, services, Error, PduEvent, Result};
@ -168,7 +169,7 @@ impl Service {
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.await
.get(&*prev_id)
{
// Exponential backoff
@ -189,7 +190,7 @@ impl Service {
.globals
.bad_event_ratelimiter
.write()
.unwrap()
.await
.entry((*prev_id).to_owned())
{
hash_map::Entry::Vacant(e) => {
@ -213,7 +214,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
if let Err(e) = self
@ -233,7 +234,7 @@ impl Service {
.globals
.bad_event_ratelimiter
.write()
.unwrap()
.await
.entry((*prev_id).to_owned())
{
hash_map::Entry::Vacant(e) => {
@ -249,7 +250,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.remove(&room_id.to_owned());
debug!(
"Handling prev event {} took {}m{}s",
@ -267,7 +268,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
let r = services()
.rooms
@ -285,7 +286,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.remove(&room_id.to_owned());
r
@ -326,11 +327,8 @@ impl Service {
let room_version =
RoomVersion::new(room_version_id).expect("room version is supported");
let mut val = match ruma::signatures::verify_event(
&pub_key_map.read().expect("RwLock is poisoned."),
&value,
room_version_id,
) {
let guard = pub_key_map.read().await;
let mut val = match ruma::signatures::verify_event(&guard, &value, room_version_id) {
Err(e) => {
// Drop
warn!("Dropping bad event {}: {}", event_id, e,);
@ -365,6 +363,8 @@ impl Service {
Ok(ruma::signatures::Verified::All) => value,
};
drop(guard);
// Now that we have checked the signature and hashes we can add the eventID and convert
// to our PduEvent type
val.insert(
@ -692,13 +692,15 @@ impl Service {
{
Ok(res) => {
debug!("Fetching state events at event.");
let collect = res
.pdu_ids
.iter()
.map(|x| Arc::from(&**x))
.collect::<Vec<_>>();
let state_vec = self
.fetch_and_handle_outliers(
origin,
&res.pdu_ids
.iter()
.map(|x| Arc::from(&**x))
.collect::<Vec<_>>(),
&collect,
create_event,
room_id,
room_version_id,
@ -805,7 +807,7 @@ impl Service {
.globals
.roomid_mutex_state
.write()
.unwrap()
.await
.entry(room_id.to_owned())
.or_default(),
);
@ -884,14 +886,18 @@ impl Service {
debug!("Starting soft fail auth check");
if soft_fail {
services().rooms.timeline.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
)?;
services()
.rooms
.timeline
.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
)
.await?;
// Soft fail, we keep the event as an outlier but don't add it to the timeline
warn!("Event was soft failed: {:?}", incoming_pdu);
@ -912,14 +918,18 @@ impl Service {
// We use the `state_at_event` instead of `state_after` so we accurately
// represent the state for this event.
let pdu_id = services().rooms.timeline.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
)?;
let pdu_id = services()
.rooms
.timeline
.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
)
.await?;
debug!("Appended incoming pdu");
@ -1034,7 +1044,8 @@ impl Service {
/// d. TODO: Ask other servers over federation?
#[allow(clippy::type_complexity)]
#[tracing::instrument(skip_all)]
pub(crate) fn fetch_and_handle_outliers<'a>(
#[async_recursion]
pub(crate) async fn fetch_and_handle_outliers<'a>(
&'a self,
origin: &'a ServerName,
events: &'a [Arc<EventId>],
@ -1042,176 +1053,175 @@ impl Service {
room_id: &'a RoomId,
room_version_id: &'a RoomVersionId,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> AsyncRecursiveType<'a, Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>>
{
Box::pin(async move {
let back_off = |id| match services()
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> {
let back_off = |id| async move {
match services()
.globals
.bad_event_ratelimiter
.write()
.unwrap()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
}
};
let mut pdus = vec![];
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
let mut pdus = vec![];
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
continue;
}
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack overflow in
// handle_outlier_pdu.
let mut todo_auth_events = vec![Arc::clone(id)];
let mut events_in_reverse_order = Vec::new();
let mut events_all = HashSet::new();
let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&*next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
if events_all.contains(&next_id) {
continue;
}
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack overflow in
// handle_outlier_pdu.
let mut todo_auth_events = vec![Arc::clone(id)];
let mut events_in_reverse_order = Vec::new();
let mut events_all = HashSet::new();
let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.get(&*next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
if events_all.contains(&next_id) {
continue;
}
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
trace!("Found {} in db", next_id);
continue;
}
info!("Fetching {} over federation.", next_id);
match services()
.sending
.send_federation_request(
origin,
get_event::v1::Request {
event_id: (*next_id).to_owned(),
},
)
.await
{
Ok(res) => {
info!("Got {} over federation", next_id);
let (calculated_event_id, value) =
match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) {
Ok(t) => t,
Err(_) => {
back_off((*next_id).to_owned());
continue;
}
};
if calculated_event_id != *next_id {
warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}",
next_id, calculated_event_id, &res.pdu);
}
if let Some(auth_events) =
value.get("auth_events").and_then(|c| c.as_array())
{
for auth_event in auth_events {
if let Ok(auth_event) =
serde_json::from_value(auth_event.clone().into())
{
let a: Arc<EventId> = auth_event;
todo_auth_events.push(a);
} else {
warn!("Auth event id is not valid");
}
}
} else {
warn!("Auth event list invalid");
}
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
}
Err(_) => {
warn!("Failed to fetch event: {}", next_id);
back_off((*next_id).to_owned());
}
}
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.get(&**next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
trace!("Found {} in db", next_id);
continue;
}
info!("Fetching {} over federation.", next_id);
match services()
.sending
.send_federation_request(
origin,
get_event::v1::Request {
event_id: (*next_id).to_owned(),
},
)
.await
{
Ok(res) => {
info!("Got {} over federation", next_id);
let (calculated_event_id, value) =
match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) {
Ok(t) => t,
Err(_) => {
back_off((*next_id).to_owned()).await;
continue;
}
};
if calculated_event_id != *next_id {
warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}",
next_id, calculated_event_id, &res.pdu);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
match self
.handle_outlier_pdu(
origin,
create_event,
next_id,
room_id,
value.clone(),
true,
pub_key_map,
)
.await
{
Ok((pdu, json)) => {
if next_id == id {
pdus.push((pdu, Some(json)));
if let Some(auth_events) =
value.get("auth_events").and_then(|c| c.as_array())
{
for auth_event in auth_events {
if let Ok(auth_event) =
serde_json::from_value(auth_event.clone().into())
{
let a: Arc<EventId> = auth_event;
todo_auth_events.push(a);
} else {
warn!("Auth event id is not valid");
}
}
} else {
warn!("Auth event list invalid");
}
Err(e) => {
warn!("Authentication of event {} failed: {:?}", next_id, e);
back_off((**next_id).to_owned());
}
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
}
Err(_) => {
warn!("Failed to fetch event: {}", next_id);
back_off((*next_id).to_owned()).await;
}
}
}
pdus
})
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&**next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
match self
.handle_outlier_pdu(
origin,
create_event,
next_id,
room_id,
value.clone(),
true,
pub_key_map,
)
.await
{
Ok((pdu, json)) => {
if next_id == id {
pdus.push((pdu, Some(json)));
}
}
Err(e) => {
warn!("Authentication of event {} failed: {:?}", next_id, e);
back_off((**next_id).to_owned()).await;
}
}
}
}
pdus
}
async fn fetch_unknown_prev_events(
@ -1360,7 +1370,7 @@ impl Service {
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.await
.insert(signature_server.clone(), keys);
}
@ -1369,7 +1379,7 @@ impl Service {
// Gets a list of servers for which we don't have the signing key yet. We go over
// the PDUs and either cache the key or add it to the list that needs to be retrieved.
fn get_server_keys_from_cache(
async fn get_server_keys_from_cache(
&self,
pdu: &RawJsonValue,
servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
@ -1393,7 +1403,7 @@ impl Service {
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.await
.get(event_id)
{
// Exponential backoff
@ -1469,17 +1479,19 @@ impl Service {
> = BTreeMap::new();
{
let mut pkm = pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?;
let mut pkm = pub_key_map.write().await;
// Try to fetch keys, failure is okay
// Servers we couldn't find in the cache will be added to `servers`
for pdu in &event.room_state.state {
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
let _ = self
.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm)
.await;
}
for pdu in &event.room_state.auth_chain {
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
let _ = self
.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm)
.await;
}
drop(pkm);
@ -1503,9 +1515,7 @@ impl Service {
.await
{
trace!("Got signing keys: {:?}", keys);
let mut pkm = pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?;
let mut pkm = pub_key_map.write().await;
for k in keys.server_keys {
let k = match k.deserialize() {
Ok(key) => key,
@ -1564,10 +1574,7 @@ impl Service {
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(origin.to_string(), result);
pub_key_map.write().await.insert(origin.to_string(), result);
}
}
info!("Done handling result");
@ -1632,14 +1639,14 @@ impl Service {
.globals
.servername_ratelimiter
.read()
.unwrap()
.await
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit = match permit {
Some(p) => p,
None => {
let mut write = services().globals.servername_ratelimiter.write().unwrap();
let mut write = services().globals.servername_ratelimiter.write().await;
let s = Arc::clone(
write
.entry(origin.to_owned())
@ -1651,24 +1658,26 @@ impl Service {
}
.await;
let back_off = |id| match services()
.globals
.bad_signature_ratelimiter
.write()
.unwrap()
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
let back_off = |id| async {
match services()
.globals
.bad_signature_ratelimiter
.write()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
if let Some((time, tries)) = services()
.globals
.bad_signature_ratelimiter
.read()
.unwrap()
.await
.get(&signature_ids)
{
// Exponential backoff
@ -1775,7 +1784,7 @@ impl Service {
drop(permit);
back_off(signature_ids);
back_off(signature_ids).await;
warn!("Failed to find public key for server: {}", origin);
Err(Error::BadServerResponse(