1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-07-22 17:18:35 +00:00

Merge branch 'next' into 'shekohex/fix-media-file-name'

# Conflicts:
#   Cargo.toml
#   src/database/mod.rs
#   src/service/globals/mod.rs
This commit is contained in:
Shady Khalifa 2024-04-18 10:14:28 +00:00
commit a691b263dc
130 changed files with 8761 additions and 5036 deletions

View file

@ -1,13 +1,20 @@
mod data;
pub use data::Data;
use ruma::{
OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId,
serde::Base64, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName,
OwnedServerSigningKeyId, OwnedUserId,
};
use sha2::Digest;
use crate::api::server_server::FedDest;
use crate::{services, Config, Error, Result};
use futures_util::FutureExt;
use hyper::{
client::connect::dns::{GaiResolver, Name},
service::Service as HyperService,
};
use reqwest::dns::{Addrs, Resolve, Resolving};
use ruma::{
api::{
client::sync::sync_events,
@ -17,20 +24,24 @@ use ruma::{
};
use std::{
collections::{BTreeMap, HashMap},
error::Error as StdError,
fs,
future::Future,
future::{self, Future},
iter,
net::{IpAddr, SocketAddr},
path::PathBuf,
sync::{
atomic::{self, AtomicBool},
Arc, Mutex, RwLock,
Arc, Mutex, RwLock as StdRwLock,
},
time::{Duration, Instant},
};
use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore};
use tracing::{error, info};
use trust_dns_resolver::TokioAsyncResolver;
use base64::{engine::general_purpose, Engine as _};
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
@ -43,7 +54,7 @@ pub struct Service {
pub db: &'static dyn Data,
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
pub tls_name_override: Arc<StdRwLock<TlsNameMap>>,
pub config: Config,
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
dns_resolver: TokioAsyncResolver,
@ -54,11 +65,12 @@ pub struct Service {
pub unstable_room_versions: Vec<RoomVersionId>,
pub bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
pub bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
pub bad_query_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
pub servername_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>,
pub sync_receivers: RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>,
pub roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>,
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>, // this lock will be held longer
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, // this lock will be held longer
pub roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub stateres_mutex: Arc<Mutex<()>>,
pub rotate: RotationHandler,
@ -96,6 +108,45 @@ impl Default for RotationHandler {
}
}
pub struct Resolver {
inner: GaiResolver,
overrides: Arc<StdRwLock<TlsNameMap>>,
}
impl Resolver {
pub fn new(overrides: Arc<StdRwLock<TlsNameMap>>) -> Self {
Resolver {
inner: GaiResolver::new(),
overrides,
}
}
}
impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving {
self.overrides
.read()
.unwrap()
.get(name.as_str())
.and_then(|(override_name, port)| {
override_name.first().map(|first_name| {
let x: Box<dyn Iterator<Item = SocketAddr> + Send> =
Box::new(iter::once(SocketAddr::new(*first_name, *port)));
let x: Resolving = Box::pin(future::ready(Ok(x)));
x
})
})
.unwrap_or_else(|| {
let this = &mut self.inner.clone();
Box::pin(HyperService::<Name>::call(this, name).map(|result| {
result
.map(|addrs| -> Addrs { Box::new(addrs) })
.map_err(|err| -> Box<dyn StdError + Send + Sync> { Box::new(err) })
}))
})
}
}
impl Service {
pub fn load(db: &'static dyn Data, config: Config) -> Result<Self> {
let keypair = db.load_keypair();
@ -109,7 +160,7 @@ impl Service {
}
};
let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new()));
let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new()));
let jwt_decoding_key = config
.jwt_secret
@ -117,14 +168,8 @@ impl Service {
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
let default_client = reqwest_client_builder(&config)?.build()?;
let name_override = Arc::clone(&tls_name_override);
let federation_client = reqwest_client_builder(&config)?
.resolve_fn(move |domain| {
let read_guard = name_override.read().unwrap();
let (override_name, port) = read_guard.get(&domain)?;
let first_name = override_name.get(0)?;
Some(SocketAddr::new(*first_name, *port))
})
.dns_resolver(Arc::new(Resolver::new(tls_name_override.clone())))
.build()?;
// Supported and stable room versions
@ -158,6 +203,7 @@ impl Service {
unstable_room_versions,
bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
bad_query_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
roomid_mutex_state: RwLock::new(HashMap::new()),
roomid_mutex_insert: RwLock::new(HashMap::new()),
@ -209,6 +255,16 @@ impl Service {
self.db.current_count()
}
#[tracing::instrument(skip(self))]
pub fn last_check_for_updates_id(&self) -> Result<u64> {
self.db.last_check_for_updates_id()
}
#[tracing::instrument(skip(self))]
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.db.update_check_for_updates_id(id)
}
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
self.db.watch(user_id, device_id).await
}
@ -217,10 +273,6 @@ impl Service {
self.db.cleanup()
}
pub fn memory_usage(&self) -> Result<String> {
self.db.memory_usage()
}
pub fn server_name(&self) -> &ServerName {
self.config.server_name.as_ref()
}
@ -261,6 +313,10 @@ impl Service {
self.config.enable_lightning_bolt
}
pub fn allow_check_for_updates(&self) -> bool {
self.config.allow_check_for_updates
}
pub fn trusted_servers(&self) -> &[OwnedServerName] {
&self.config.trusted_servers
}
@ -323,7 +379,19 @@ impl Service {
&self,
origin: &ServerName,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
self.db.signing_keys_for(origin)
let mut keys = self.db.signing_keys_for(origin)?;
if origin == self.server_name() {
keys.insert(
format!("ed25519:{}", services().globals.keypair().version())
.try_into()
.expect("found invalid server signing keys in DB"),
VerifyKey {
key: Base64::new(self.keypair.public_key().to_vec()),
},
);
}
Ok(keys)
}
pub fn database_version(&self) -> Result<u64> {
@ -345,12 +413,9 @@ impl Service {
let mut r = PathBuf::new();
r.push(self.config.database_path.clone());
r.push("media");
r.push(base64::encode_config(
// Using the hash of the key as the filename
// This is to prevent the total length of the path from exceeding the maximum length
sha2::Sha256::digest(key),
base64::URL_SAFE_NO_PAD,
));
// Using the hash of the key as the filename
// This is to prevent the total length of the path from exceeding the maximum length
r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key));
r
}
@ -363,10 +428,14 @@ impl Service {
let mut r = PathBuf::new();
r.push(self.config.database_path.clone());
r.push("media");
r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
r.push(general_purpose::URL_SAFE_NO_PAD.encode(key));
r
}
pub fn well_known_client(&self) -> &Option<String> {
&self.config.well_known_client
}
pub fn shutdown(&self) {
self.shutdown.store(true, atomic::Ordering::Relaxed);
// On shutdown