diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 07078328..3482d53e 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,12 +1,14 @@ +use std::time::Duration; + use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{services, utils, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, - session::{get_login_types, login, logout, logout_all}, + session::{get_login_types, login, logout, logout_all, refresh_token}, uiaa::UserIdentifier, }, - UserId, + OwnedDeviceId, UserId, }; use serde::Deserialize; use tracing::{info, warn}; @@ -179,7 +181,16 @@ pub async fn login_route(body: Ruma) -> Result (None, None), + _ => services() + .users + .create_refresh_token(&access_token) + .map(Some) + .map(Option::unzip)?, + }; // Determine if device_id was provided and exists in the db for this user let device_exists = body.device_id.as_ref().map_or(false, |device_id| { @@ -190,12 +201,14 @@ pub async fn login_route(body: Ruma) -> Result) -> Result, +) -> Result { + let expires_at = services() + .users + .get_refresh_token_ttl(&body.refresh_token)? + .ok_or_else(|| { + Error::BadRequest( + ErrorKind::UnknownToken { soft_logout: false }, + "Unknown refresh token.", + ) + })?; + if expires_at < utils::millis_since_unix_epoch() { + return Err(Error::BadRequest( + ErrorKind::UnknownToken { soft_logout: false }, + "Expired refresh token.", + )); + } + + let (user_id, device_id) = { + let access_token = services() + .users + .refresh_to_access_token(&body.refresh_token)? + .expect(""); + services().users.find_from_token(&access_token)?.expect("") + }; + + let access_token = utils::random_string(TOKEN_LENGTH); + let (refresh_token, expires_at) = services().users.create_refresh_token(&access_token)?; + + let device_id: OwnedDeviceId = device_id.into(); + services() + .users + .set_token(&user_id, &device_id, &access_token)?; + + Ok(refresh_token::v3::Response { + access_token, + refresh_token: Some(refresh_token), + expires_in_ms: Some(Duration::from_millis( + expires_at + .checked_sub(utils::millis_since_unix_epoch()) + .expect(""), + )), + }) +} diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 2c5da21b..f617600e 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -23,7 +23,7 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{service::appservice::RegistrationInfo, services, utils, Error, Result}; enum Token { Appservice(Box), @@ -87,6 +87,17 @@ where if let Some(reg_info) = services().appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info.clone())) } else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { + if services() + .users + .get_access_token_ttl(token)? + .is_some_and(|expires_at| expires_at < utils::millis_since_unix_epoch()) + { + return Err(Error::BadRequest( + ErrorKind::UnknownToken { soft_logout: true }, + "Expired access token.", + )); + } + Token::User((user_id, OwnedDeviceId::from(device_id))) } else { Token::Invalid diff --git a/src/config/mod.rs b/src/config/mod.rs index 378ab929..b29428ae 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -47,6 +47,8 @@ pub struct Config { #[serde(default = "false_fn")] pub allow_registration: bool, pub registration_token: Option, + #[serde(default, flatten)] + pub refresh_token: RefreshTokenConfig, #[serde(default = "default_openid_token_ttl")] pub openid_token_ttl: u64, #[serde(default = "true_fn")] @@ -101,6 +103,14 @@ pub struct WellKnownConfig { pub server: Option, } +#[derive(Clone, Debug, Deserialize, Default)] +pub struct RefreshTokenConfig { + #[serde(default = "default_refresh_token_ttl")] + pub ttl: u64, + #[serde(default = "default_access_token_ttl")] + pub access_token_ttl: u64, +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { @@ -304,6 +314,14 @@ fn default_turn_ttl() -> u64 { 60 * 60 * 24 } +fn default_refresh_token_ttl() -> u64 { + 60 * 60 +} + +fn default_access_token_ttl() -> u64 { + 60 * 5 +} + fn default_openid_token_ttl() -> u64 { 60 * 60 } diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 63321a40..d539df9b 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -945,6 +945,60 @@ impl service::users::Data for KeyValueDatabase { } } + fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)> { + let crate::config::RefreshTokenConfig { + ttl, + access_token_ttl, + } = services().globals.config.refresh_token; + + let refresh_token = utils::random_string(TOKEN_LENGTH); + + let mut value = refresh_token.as_bytes().to_vec(); + value.extend_from_slice( + &utils::millis_since_unix_epoch() + .checked_add(access_token_ttl * 1000) + .expect("time is valid") + .to_be_bytes(), + ); + + self.accesstoken_refreshtokenttl + .insert(access_token.as_bytes(), &value)?; + + let mut value = access_token.as_bytes().to_vec(); + value.extend_from_slice( + &utils::millis_since_unix_epoch() + .checked_add(ttl * 1000) + .expect("time is valid") + .to_be_bytes(), + ); + + self.refreshtoken_accesstokenttl + .insert(refresh_token.as_bytes(), &value)?; + + Ok((refresh_token, access_token_ttl)) + } + + fn refresh_to_access_token(&self, refresh_token: &str) -> Result> { + Ok(self + .refreshtoken_accesstokenttl + .get(refresh_token.as_bytes())? + .map(|v| utils::string_from_bytes(&v[..TOKEN_LENGTH]).expect(""))) + } + + fn get_access_token_ttl(&self, access_token: &str) -> Result> { + Ok(self + .accesstoken_refreshtokenttl + .get(access_token.as_bytes())? + .map(|v| u64::from_be_bytes(v[TOKEN_LENGTH..].try_into().expect("")))) + } + + fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result> { + Ok(self + .refreshtoken_accesstokenttl + .get(refresh_token.as_bytes())? + .map(|v| u64::from_be_bytes(v[TOKEN_LENGTH..].try_into().expect("")))) + } + // Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations) fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)> { let token = utils::random_string(TOKEN_LENGTH); diff --git a/src/database/mod.rs b/src/database/mod.rs index 5171d4bb..5021617d 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -58,6 +58,8 @@ pub struct KeyValueDatabase { pub(super) userid_selfsigningkeyid: Arc, pub(super) userid_usersigningkeyid: Arc, pub(super) openidtoken_expiresatuserid: Arc, // expiresatuserid = expiresat + userid + pub(super) accesstoken_refreshtokenttl: Arc, + pub(super) refreshtoken_accesstokenttl: Arc, pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId @@ -294,6 +296,8 @@ impl KeyValueDatabase { userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, openidtoken_expiresatuserid: builder.open_tree("openidtoken_expiresatuserid")?, + accesstoken_refreshtokenttl: builder.open_tree("accesstoken_refreshtokenttl")?, + refreshtoken_accesstokenttl: builder.open_tree("refreshtoken_accesstokenttl")?, userfilterid_filter: builder.open_tree("userfilterid_filter")?, todeviceid_events: builder.open_tree("todeviceid_events")?, diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 4566c36d..a38d7b27 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -212,6 +212,14 @@ pub trait Data: Send + Sync { fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result>; + fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)>; + + fn refresh_to_access_token(&self, refresh_token: &str) -> Result>; + + fn get_access_token_ttl(&self, access_token: &str) -> Result>; + + fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result>; + // Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations) fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)>; diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index a5694a10..10aa74cd 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -593,6 +593,22 @@ impl Service { self.db.get_filter(user_id, filter_id) } + pub fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)> { + self.db.create_refresh_token(access_token) + } + + pub fn refresh_to_access_token(&self, refresh_token: &str) -> Result> { + self.db.refresh_to_access_token(refresh_token) + } + + pub fn get_access_token_ttl(&self, access_token: &str) -> Result> { + self.db.get_access_token_ttl(access_token) + } + + pub fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result> { + self.db.get_refresh_token_ttl(refresh_token) + } + // Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations) pub fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)> { self.db.create_openid_token(user_id)