diff --git a/src/core/log/capture/layer.rs b/src/core/log/capture/layer.rs index b3235d91..e3fe66df 100644 --- a/src/core/log/capture/layer.rs +++ b/src/core/log/capture/layer.rs @@ -40,7 +40,6 @@ where self.state .active .read() - .expect("shared lock") .iter() .filter(|capture| filter(self, capture, event, &ctx)) .for_each(|capture| handle(self, capture, event, &ctx)); diff --git a/src/core/log/capture/state.rs b/src/core/log/capture/state.rs index dad6c8d8..92a1608f 100644 --- a/src/core/log/capture/state.rs +++ b/src/core/log/capture/state.rs @@ -1,10 +1,11 @@ -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use super::Capture; +use crate::SyncRwLock; /// Capture layer state. pub struct State { - pub(super) active: RwLock>>, + pub(super) active: SyncRwLock>>, } impl Default for State { @@ -13,17 +14,14 @@ impl Default for State { impl State { #[must_use] - pub fn new() -> Self { Self { active: RwLock::new(Vec::new()) } } + pub fn new() -> Self { Self { active: SyncRwLock::new(Vec::new()) } } pub(super) fn add(&self, capture: &Arc) { - self.active - .write() - .expect("locked for writing") - .push(capture.clone()); + self.active.write().push(capture.clone()); } pub(super) fn del(&self, capture: &Arc) { - let mut vec = self.active.write().expect("locked for writing"); + let mut vec = self.active.write(); if let Some(pos) = vec.iter().position(|v| Arc::ptr_eq(v, capture)) { vec.swap_remove(pos); } diff --git a/src/database/watchers.rs b/src/database/watchers.rs index efb939d7..0e911c82 100644 --- a/src/database/watchers.rs +++ b/src/database/watchers.rs @@ -2,12 +2,12 @@ use std::{ collections::{HashMap, hash_map}, future::Future, pin::Pin, - sync::RwLock, }; +use conduwuit::SyncRwLock; use tokio::sync::watch; -type Watcher = RwLock, (watch::Sender<()>, watch::Receiver<()>)>>; +type Watcher = SyncRwLock, (watch::Sender<()>, watch::Receiver<()>)>>; #[derive(Default)] pub(crate) struct Watchers { @@ -19,7 +19,7 @@ impl Watchers { &'a self, prefix: &[u8], ) -> Pin + Send + 'a>> { - let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { + let mut rx = match self.watchers.write().entry(prefix.to_vec()) { | hash_map::Entry::Occupied(o) => o.get().1.clone(), | hash_map::Entry::Vacant(v) => { let (tx, rx) = watch::channel(()); @@ -35,7 +35,7 @@ impl Watchers { } pub(crate) fn wake(&self, key: &[u8]) { - let watchers = self.watchers.read().unwrap(); + let watchers = self.watchers.read(); let mut triggered = Vec::new(); for length in 0..=key.len() { if watchers.contains_key(&key[..length]) { @@ -46,7 +46,7 @@ impl Watchers { drop(watchers); if !triggered.is_empty() { - let mut watchers = self.watchers.write().unwrap(); + let mut watchers = self.watchers.write(); for prefix in triggered { if let Some(tx) = watchers.remove(prefix) { tx.0.send(()).expect("channel should still be open"); diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 21c09252..07f1de5c 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,11 +1,11 @@ -use std::sync::{Arc, RwLock}; +use std::sync::Arc; -use conduwuit::{Result, utils}; +use conduwuit::{Result, SyncRwLock, utils}; use database::{Database, Deserialized, Map}; pub struct Data { global: Arc, - counter: RwLock, + counter: SyncRwLock, pub(super) db: Arc, } @@ -16,25 +16,21 @@ impl Data { let db = &args.db; Self { global: db["global"].clone(), - counter: RwLock::new( - Self::stored_count(&db["global"]).expect("initialized global counter"), - ), + counter: SyncRwLock::new(Self::stored_count(&db["global"]).unwrap_or_default()), db: args.db.clone(), } } pub fn next_count(&self) -> Result { let _cork = self.db.cork(); - let mut lock = self.counter.write().expect("locked"); + let mut lock = self.counter.write(); let counter: &mut u64 = &mut lock; debug_assert!( - *counter == Self::stored_count(&self.global).expect("database failure"), + *counter == Self::stored_count(&self.global).unwrap_or_default(), "counter mismatch" ); - *counter = counter - .checked_add(1) - .expect("counter must not overflow u64"); + *counter = counter.checked_add(1).unwrap_or(*counter); self.global.insert(COUNTER, counter.to_be_bytes()); @@ -43,10 +39,10 @@ impl Data { #[inline] pub fn current_count(&self) -> u64 { - let lock = self.counter.read().expect("locked"); + let lock = self.counter.read(); let counter: &u64 = &lock; debug_assert!( - *counter == Self::stored_count(&self.global).expect("database failure"), + *counter == Self::stored_count(&self.global).unwrap_or_default(), "counter mismatch" ); diff --git a/src/service/manager.rs b/src/service/manager.rs index 3cdf5945..7a2e50d5 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -58,7 +58,6 @@ impl Manager { let services: Vec> = self .service .read() - .expect("locked for reading") .values() .map(|val| val.0.upgrade()) .map(|arc| arc.expect("services available for manager startup")) diff --git a/src/service/service.rs b/src/service/service.rs index 574efd8f..3bc61aeb 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -3,11 +3,13 @@ use std::{ collections::BTreeMap, fmt::Write, ops::Deref, - sync::{Arc, OnceLock, RwLock, Weak}, + sync::{Arc, OnceLock, Weak}, }; use async_trait::async_trait; -use conduwuit::{Err, Result, Server, err, error::inspect_log, utils::string::SplitInfallible}; +use conduwuit::{ + Err, Result, Server, SyncRwLock, err, error::inspect_log, utils::string::SplitInfallible, +}; use database::Database; /// Abstract interface for a Service @@ -62,7 +64,7 @@ pub(crate) struct Dep { name: &'static str, } -pub(crate) type Map = RwLock; +pub(crate) type Map = SyncRwLock; pub(crate) type MapType = BTreeMap; pub(crate) type MapVal = (Weak, Weak); pub(crate) type MapKey = String; @@ -143,15 +145,12 @@ pub(crate) fn get(map: &Map, name: &str) -> Option> where T: Any + Send + Sync + Sized, { - map.read() - .expect("locked for reading") - .get(name) - .map(|(_, s)| { - s.upgrade().map(|s| { - s.downcast::() - .expect("Service must be correctly downcast.") - }) - })? + map.read().get(name).map(|(_, s)| { + s.upgrade().map(|s| { + s.downcast::() + .expect("Service must be correctly downcast.") + }) + })? } /// Reference a Service by name. Returns Err if the Service does not exist or @@ -160,21 +159,18 @@ pub(crate) fn try_get(map: &Map, name: &str) -> Result> where T: Any + Send + Sync + Sized, { - map.read() - .expect("locked for reading") - .get(name) - .map_or_else( - || Err!("Service {name:?} does not exist or has not been built yet."), - |(_, s)| { - s.upgrade().map_or_else( - || Err!("Service {name:?} no longer exists."), - |s| { - s.downcast::() - .map_err(|_| err!("Service {name:?} must be correctly downcast.")) - }, - ) - }, - ) + map.read().get(name).map_or_else( + || Err!("Service {name:?} does not exist or has not been built yet."), + |(_, s)| { + s.upgrade().map_or_else( + || Err!("Service {name:?} no longer exists."), + |s| { + s.downcast::() + .map_err(|_| err!("Service {name:?} must be correctly downcast.")) + }, + ) + }, + ) } /// Utility for service implementations; see Service::name() in the trait. diff --git a/src/service/services.rs b/src/service/services.rs index daece245..642f61c7 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,10 +1,8 @@ -use std::{ - any::Any, - collections::BTreeMap, - sync::{Arc, RwLock}, -}; +use std::{any::Any, collections::BTreeMap, sync::Arc}; -use conduwuit::{Result, Server, debug, debug_info, info, trace, utils::stream::IterStream}; +use conduwuit::{ + Result, Server, SyncRwLock, debug, debug_info, info, trace, utils::stream::IterStream, +}; use database::Database; use futures::{Stream, StreamExt, TryStreamExt}; use tokio::sync::Mutex; @@ -52,7 +50,7 @@ impl Services { #[allow(clippy::cognitive_complexity)] pub async fn build(server: Arc) -> Result> { let db = Database::open(&server).await?; - let service: Arc = Arc::new(RwLock::new(BTreeMap::new())); + let service: Arc = Arc::new(SyncRwLock::new(BTreeMap::new())); macro_rules! build { ($tyname:ty) => {{ let built = <$tyname>::build(Args { @@ -193,7 +191,7 @@ impl Services { fn interrupt(&self) { debug!("Interrupting services..."); - for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() { + for (name, (service, ..)) in self.service.read().iter() { if let Some(service) = service.upgrade() { trace!("Interrupting {name}"); service.interrupt(); @@ -205,7 +203,6 @@ impl Services { fn services(&self) -> impl Stream> + Send { self.service .read() - .expect("locked for reading") .values() .filter_map(|val| val.0.upgrade()) .collect::>() @@ -233,10 +230,9 @@ impl Services { #[allow(clippy::needless_pass_by_value)] fn add_service(map: &Arc, s: Arc, a: Arc) { let name = s.name(); - let len = map.read().expect("locked for reading").len(); + let len = map.read().len(); trace!("built service #{len}: {name:?}"); map.write() - .expect("locked for writing") .insert(name.to_owned(), (Arc::downgrade(&s), Arc::downgrade(&a))); } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 7735c87f..acd3dd86 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,10 +1,10 @@ use std::{ collections::{BTreeMap, HashSet}, - sync::{Arc, RwLock}, + sync::Arc, }; use conduwuit::{ - Err, Error, Result, err, error, implement, utils, + Err, Error, Result, SyncRwLock, err, error, implement, utils, utils::{hash, string::EMPTY}, }; use database::{Deserialized, Json, Map}; @@ -19,7 +19,7 @@ use ruma::{ use crate::{Dep, config, globals, users}; pub struct Service { - userdevicesessionid_uiaarequest: RwLock, + userdevicesessionid_uiaarequest: SyncRwLock, db: Data, services: Services, } @@ -42,7 +42,7 @@ pub const SESSION_ID_LENGTH: usize = 32; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), + userdevicesessionid_uiaarequest: SyncRwLock::new(RequestMap::new()), db: Data { userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), }, @@ -268,7 +268,6 @@ fn set_uiaa_request( let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); self.userdevicesessionid_uiaarequest .write() - .expect("locked for writing") .insert(key, request.to_owned()); } @@ -287,7 +286,6 @@ pub fn get_uiaa_request( self.userdevicesessionid_uiaarequest .read() - .expect("locked for reading") .get(&key) .cloned() }