1
0
Fork 0
mirror of https://forgejo.ellis.link/continuwuation/continuwuity.git synced 2025-07-28 10:48:30 +00:00

refactor: Replace std RwLock with parking_lot

This commit is contained in:
Jade Ellis 2025-07-19 21:03:17 +01:00
parent 30a8c06fd9
commit a1d616e3e3
No known key found for this signature in database
GPG key ID: 8705A2A3EBF77BD2
8 changed files with 54 additions and 72 deletions

View file

@ -40,7 +40,6 @@ where
self.state self.state
.active .active
.read() .read()
.expect("shared lock")
.iter() .iter()
.filter(|capture| filter(self, capture, event, &ctx)) .filter(|capture| filter(self, capture, event, &ctx))
.for_each(|capture| handle(self, capture, event, &ctx)); .for_each(|capture| handle(self, capture, event, &ctx));

View file

@ -1,10 +1,11 @@
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use super::Capture; use super::Capture;
use crate::SyncRwLock;
/// Capture layer state. /// Capture layer state.
pub struct State { pub struct State {
pub(super) active: RwLock<Vec<Arc<Capture>>>, pub(super) active: SyncRwLock<Vec<Arc<Capture>>>,
} }
impl Default for State { impl Default for State {
@ -13,17 +14,14 @@ impl Default for State {
impl State { impl State {
#[must_use] #[must_use]
pub fn new() -> Self { Self { active: RwLock::new(Vec::new()) } } pub fn new() -> Self { Self { active: SyncRwLock::new(Vec::new()) } }
pub(super) fn add(&self, capture: &Arc<Capture>) { pub(super) fn add(&self, capture: &Arc<Capture>) {
self.active self.active.write().push(capture.clone());
.write()
.expect("locked for writing")
.push(capture.clone());
} }
pub(super) fn del(&self, capture: &Arc<Capture>) { pub(super) fn del(&self, capture: &Arc<Capture>) {
let mut vec = self.active.write().expect("locked for writing"); let mut vec = self.active.write();
if let Some(pos) = vec.iter().position(|v| Arc::ptr_eq(v, capture)) { if let Some(pos) = vec.iter().position(|v| Arc::ptr_eq(v, capture)) {
vec.swap_remove(pos); vec.swap_remove(pos);
} }

View file

@ -2,12 +2,12 @@ use std::{
collections::{HashMap, hash_map}, collections::{HashMap, hash_map},
future::Future, future::Future,
pin::Pin, pin::Pin,
sync::RwLock,
}; };
use conduwuit::SyncRwLock;
use tokio::sync::watch; use tokio::sync::watch;
type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>; type Watcher = SyncRwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>;
#[derive(Default)] #[derive(Default)]
pub(crate) struct Watchers { pub(crate) struct Watchers {
@ -19,7 +19,7 @@ impl Watchers {
&'a self, &'a self,
prefix: &[u8], prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { let mut rx = match self.watchers.write().entry(prefix.to_vec()) {
| hash_map::Entry::Occupied(o) => o.get().1.clone(), | hash_map::Entry::Occupied(o) => o.get().1.clone(),
| hash_map::Entry::Vacant(v) => { | hash_map::Entry::Vacant(v) => {
let (tx, rx) = watch::channel(()); let (tx, rx) = watch::channel(());
@ -35,7 +35,7 @@ impl Watchers {
} }
pub(crate) fn wake(&self, key: &[u8]) { pub(crate) fn wake(&self, key: &[u8]) {
let watchers = self.watchers.read().unwrap(); let watchers = self.watchers.read();
let mut triggered = Vec::new(); let mut triggered = Vec::new();
for length in 0..=key.len() { for length in 0..=key.len() {
if watchers.contains_key(&key[..length]) { if watchers.contains_key(&key[..length]) {
@ -46,7 +46,7 @@ impl Watchers {
drop(watchers); drop(watchers);
if !triggered.is_empty() { if !triggered.is_empty() {
let mut watchers = self.watchers.write().unwrap(); let mut watchers = self.watchers.write();
for prefix in triggered { for prefix in triggered {
if let Some(tx) = watchers.remove(prefix) { if let Some(tx) = watchers.remove(prefix) {
tx.0.send(()).expect("channel should still be open"); tx.0.send(()).expect("channel should still be open");

View file

@ -1,11 +1,11 @@
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use conduwuit::{Result, utils}; use conduwuit::{Result, SyncRwLock, utils};
use database::{Database, Deserialized, Map}; use database::{Database, Deserialized, Map};
pub struct Data { pub struct Data {
global: Arc<Map>, global: Arc<Map>,
counter: RwLock<u64>, counter: SyncRwLock<u64>,
pub(super) db: Arc<Database>, pub(super) db: Arc<Database>,
} }
@ -16,25 +16,21 @@ impl Data {
let db = &args.db; let db = &args.db;
Self { Self {
global: db["global"].clone(), global: db["global"].clone(),
counter: RwLock::new( counter: SyncRwLock::new(Self::stored_count(&db["global"]).unwrap_or_default()),
Self::stored_count(&db["global"]).expect("initialized global counter"),
),
db: args.db.clone(), db: args.db.clone(),
} }
} }
pub fn next_count(&self) -> Result<u64> { pub fn next_count(&self) -> Result<u64> {
let _cork = self.db.cork(); let _cork = self.db.cork();
let mut lock = self.counter.write().expect("locked"); let mut lock = self.counter.write();
let counter: &mut u64 = &mut lock; let counter: &mut u64 = &mut lock;
debug_assert!( debug_assert!(
*counter == Self::stored_count(&self.global).expect("database failure"), *counter == Self::stored_count(&self.global).unwrap_or_default(),
"counter mismatch" "counter mismatch"
); );
*counter = counter *counter = counter.checked_add(1).unwrap_or(*counter);
.checked_add(1)
.expect("counter must not overflow u64");
self.global.insert(COUNTER, counter.to_be_bytes()); self.global.insert(COUNTER, counter.to_be_bytes());
@ -43,10 +39,10 @@ impl Data {
#[inline] #[inline]
pub fn current_count(&self) -> u64 { pub fn current_count(&self) -> u64 {
let lock = self.counter.read().expect("locked"); let lock = self.counter.read();
let counter: &u64 = &lock; let counter: &u64 = &lock;
debug_assert!( debug_assert!(
*counter == Self::stored_count(&self.global).expect("database failure"), *counter == Self::stored_count(&self.global).unwrap_or_default(),
"counter mismatch" "counter mismatch"
); );

View file

@ -58,7 +58,6 @@ impl Manager {
let services: Vec<Arc<dyn Service>> = self let services: Vec<Arc<dyn Service>> = self
.service .service
.read() .read()
.expect("locked for reading")
.values() .values()
.map(|val| val.0.upgrade()) .map(|val| val.0.upgrade())
.map(|arc| arc.expect("services available for manager startup")) .map(|arc| arc.expect("services available for manager startup"))

View file

@ -3,11 +3,13 @@ use std::{
collections::BTreeMap, collections::BTreeMap,
fmt::Write, fmt::Write,
ops::Deref, ops::Deref,
sync::{Arc, OnceLock, RwLock, Weak}, sync::{Arc, OnceLock, Weak},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{Err, Result, Server, err, error::inspect_log, utils::string::SplitInfallible}; use conduwuit::{
Err, Result, Server, SyncRwLock, err, error::inspect_log, utils::string::SplitInfallible,
};
use database::Database; use database::Database;
/// Abstract interface for a Service /// Abstract interface for a Service
@ -62,7 +64,7 @@ pub(crate) struct Dep<T: Service + Send + Sync> {
name: &'static str, name: &'static str,
} }
pub(crate) type Map = RwLock<MapType>; pub(crate) type Map = SyncRwLock<MapType>;
pub(crate) type MapType = BTreeMap<MapKey, MapVal>; pub(crate) type MapType = BTreeMap<MapKey, MapVal>;
pub(crate) type MapVal = (Weak<dyn Service>, Weak<dyn Any + Send + Sync>); pub(crate) type MapVal = (Weak<dyn Service>, Weak<dyn Any + Send + Sync>);
pub(crate) type MapKey = String; pub(crate) type MapKey = String;
@ -143,10 +145,7 @@ pub(crate) fn get<T>(map: &Map, name: &str) -> Option<Arc<T>>
where where
T: Any + Send + Sync + Sized, T: Any + Send + Sync + Sized,
{ {
map.read() map.read().get(name).map(|(_, s)| {
.expect("locked for reading")
.get(name)
.map(|(_, s)| {
s.upgrade().map(|s| { s.upgrade().map(|s| {
s.downcast::<T>() s.downcast::<T>()
.expect("Service must be correctly downcast.") .expect("Service must be correctly downcast.")
@ -160,10 +159,7 @@ pub(crate) fn try_get<T>(map: &Map, name: &str) -> Result<Arc<T>>
where where
T: Any + Send + Sync + Sized, T: Any + Send + Sync + Sized,
{ {
map.read() map.read().get(name).map_or_else(
.expect("locked for reading")
.get(name)
.map_or_else(
|| Err!("Service {name:?} does not exist or has not been built yet."), || Err!("Service {name:?} does not exist or has not been built yet."),
|(_, s)| { |(_, s)| {
s.upgrade().map_or_else( s.upgrade().map_or_else(

View file

@ -1,10 +1,8 @@
use std::{ use std::{any::Any, collections::BTreeMap, sync::Arc};
any::Any,
collections::BTreeMap,
sync::{Arc, RwLock},
};
use conduwuit::{Result, Server, debug, debug_info, info, trace, utils::stream::IterStream}; use conduwuit::{
Result, Server, SyncRwLock, debug, debug_info, info, trace, utils::stream::IterStream,
};
use database::Database; use database::Database;
use futures::{Stream, StreamExt, TryStreamExt}; use futures::{Stream, StreamExt, TryStreamExt};
use tokio::sync::Mutex; use tokio::sync::Mutex;
@ -52,7 +50,7 @@ impl Services {
#[allow(clippy::cognitive_complexity)] #[allow(clippy::cognitive_complexity)]
pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> { pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
let db = Database::open(&server).await?; let db = Database::open(&server).await?;
let service: Arc<Map> = Arc::new(RwLock::new(BTreeMap::new())); let service: Arc<Map> = Arc::new(SyncRwLock::new(BTreeMap::new()));
macro_rules! build { macro_rules! build {
($tyname:ty) => {{ ($tyname:ty) => {{
let built = <$tyname>::build(Args { let built = <$tyname>::build(Args {
@ -193,7 +191,7 @@ impl Services {
fn interrupt(&self) { fn interrupt(&self) {
debug!("Interrupting services..."); debug!("Interrupting services...");
for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() { for (name, (service, ..)) in self.service.read().iter() {
if let Some(service) = service.upgrade() { if let Some(service) = service.upgrade() {
trace!("Interrupting {name}"); trace!("Interrupting {name}");
service.interrupt(); service.interrupt();
@ -205,7 +203,6 @@ impl Services {
fn services(&self) -> impl Stream<Item = Arc<dyn Service>> + Send { fn services(&self) -> impl Stream<Item = Arc<dyn Service>> + Send {
self.service self.service
.read() .read()
.expect("locked for reading")
.values() .values()
.filter_map(|val| val.0.upgrade()) .filter_map(|val| val.0.upgrade())
.collect::<Vec<_>>() .collect::<Vec<_>>()
@ -233,10 +230,9 @@ impl Services {
#[allow(clippy::needless_pass_by_value)] #[allow(clippy::needless_pass_by_value)]
fn add_service(map: &Arc<Map>, s: Arc<dyn Service>, a: Arc<dyn Any + Send + Sync>) { fn add_service(map: &Arc<Map>, s: Arc<dyn Service>, a: Arc<dyn Any + Send + Sync>) {
let name = s.name(); let name = s.name();
let len = map.read().expect("locked for reading").len(); let len = map.read().len();
trace!("built service #{len}: {name:?}"); trace!("built service #{len}: {name:?}");
map.write() map.write()
.expect("locked for writing")
.insert(name.to_owned(), (Arc::downgrade(&s), Arc::downgrade(&a))); .insert(name.to_owned(), (Arc::downgrade(&s), Arc::downgrade(&a)));
} }

View file

@ -1,10 +1,10 @@
use std::{ use std::{
collections::{BTreeMap, HashSet}, collections::{BTreeMap, HashSet},
sync::{Arc, RwLock}, sync::Arc,
}; };
use conduwuit::{ use conduwuit::{
Err, Error, Result, err, error, implement, utils, Err, Error, Result, SyncRwLock, err, error, implement, utils,
utils::{hash, string::EMPTY}, utils::{hash, string::EMPTY},
}; };
use database::{Deserialized, Json, Map}; use database::{Deserialized, Json, Map};
@ -19,7 +19,7 @@ use ruma::{
use crate::{Dep, config, globals, users}; use crate::{Dep, config, globals, users};
pub struct Service { pub struct Service {
userdevicesessionid_uiaarequest: RwLock<RequestMap>, userdevicesessionid_uiaarequest: SyncRwLock<RequestMap>,
db: Data, db: Data,
services: Services, services: Services,
} }
@ -42,7 +42,7 @@ pub const SESSION_ID_LENGTH: usize = 32;
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), userdevicesessionid_uiaarequest: SyncRwLock::new(RequestMap::new()),
db: Data { db: Data {
userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(),
}, },
@ -268,7 +268,6 @@ fn set_uiaa_request(
let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
self.userdevicesessionid_uiaarequest self.userdevicesessionid_uiaarequest
.write() .write()
.expect("locked for writing")
.insert(key, request.to_owned()); .insert(key, request.to_owned());
} }
@ -287,7 +286,6 @@ pub fn get_uiaa_request(
self.userdevicesessionid_uiaarequest self.userdevicesessionid_uiaarequest
.read() .read()
.expect("locked for reading")
.get(&key) .get(&key)
.cloned() .cloned()
} }