diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index bd47cb42..9e30cbe6 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -255,12 +255,16 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" let ServerSigningKeys { verify_keys, old_verify_keys, + valid_until_ts, .. } = new_keys; prev_keys.verify_keys.extend(verify_keys); prev_keys.old_verify_keys.extend(old_verify_keys); - prev_keys.valid_until_ts = new_keys.valid_until_ts; + + if valid_until_ts > prev_keys.valid_until_ts { + prev_keys.valid_until_ts = valid_until_ts; + } self.server_signingkeys.insert( origin.as_bytes(), diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index f61f59a2..710485a9 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1511,6 +1511,27 @@ impl Service { } } + let origin_server_ts = value.get("origin_server_ts").ok_or_else(|| { + error!("Invalid PDU, no origin_server_ts field"); + Error::BadRequest( + ErrorKind::MissingParam, + "Invalid PDU, no origin_server_ts field", + ) + })?; + + let origin_server_ts: MilliSecondsSinceUnixEpoch = { + let ts = origin_server_ts.as_integer().ok_or_else(|| { + Error::BadRequest( + ErrorKind::InvalidParam, + "origin_server_ts must be an integer", + ) + })?; + + MilliSecondsSinceUnixEpoch(i64::from(ts).try_into().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Time must be after the unix epoch") + })?) + }; + let signatures = value .get("signatures") .ok_or(Error::BadServerResponse( @@ -1530,15 +1551,16 @@ impl Service { let contains_all_ids = |keys: &SigningKeys| { signature_ids.iter().all(|id| { - keys.verify_keys - .keys() - .map(ToString::to_string) - .any(|key_id| id == &key_id) - || keys - .old_verify_keys + (keys.valid_until_ts > origin_server_ts + && keys + .verify_keys .keys() .map(ToString::to_string) - .any(|key_id| id == &key_id) + .any(|key_id| id == &key_id)) + || keys + .old_verify_keys + .iter() + .any(|(key_id, key)| key_id == id && key.expired_ts > origin_server_ts) }) }; @@ -1559,6 +1581,8 @@ impl Service { } pub_key_map.insert(origin.to_string(), result); + } else { + servers.insert(origin.to_owned(), BTreeMap::new()); } }