use std::{ fmt::Debug, hash::Hash, sync::{Arc, TryLockError::WouldBlock}, }; use tokio::sync::OwnedMutexGuard as Omg; use crate::{Result, err}; /// Map of Mutexes pub struct MutexMap { map: Map, } pub struct Guard { map: Map, val: Omg, } type Map = Arc>; type MapMutex = std::sync::Mutex>; type HashMap = std::collections::HashMap>; type Value = Arc>; impl MutexMap where Key: Clone + Eq + Hash + Send, Val: Default + Send, { #[must_use] pub fn new() -> Self { Self { map: Map::new(MapMutex::new(HashMap::new())), } } #[tracing::instrument(level = "trace", skip(self))] pub async fn lock<'a, K>(&'a self, k: &'a K) -> Guard where K: Debug + Send + ?Sized + Sync, Key: TryFrom<&'a K>, >::Error: Debug, { let val = self .map .lock() .expect("locked") .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); Guard:: { map: Arc::clone(&self.map), val: val.lock_owned().await, } } #[tracing::instrument(level = "trace", skip(self))] pub fn try_lock<'a, K>(&self, k: &'a K) -> Result> where K: Debug + Send + ?Sized + Sync, Key: TryFrom<&'a K>, >::Error: Debug, { let val = self .map .lock() .expect("locked") .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); Ok(Guard:: { map: Arc::clone(&self.map), val: val.try_lock_owned().map_err(|_| err!("would yield"))?, }) } #[tracing::instrument(level = "trace", skip(self))] pub fn try_try_lock<'a, K>(&self, k: &'a K) -> Result> where K: Debug + Send + ?Sized + Sync, Key: TryFrom<&'a K>, >::Error: Debug, { let val = self .map .try_lock() .map_err(|e| match e { | WouldBlock => err!("would block"), | _ => panic!("{e:?}"), })? .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); Ok(Guard:: { map: Arc::clone(&self.map), val: val.try_lock_owned().map_err(|_| err!("would yield"))?, }) } #[must_use] pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } #[must_use] pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } #[must_use] pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } } impl Default for MutexMap where Key: Clone + Eq + Hash + Send, Val: Default + Send, { fn default() -> Self { Self::new() } } impl Drop for Guard { #[tracing::instrument(name = "unlock", level = "trace", skip_all)] fn drop(&mut self) { if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { self.map.lock().expect("locked").retain(|_, val| { !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2 }); } } }