mas_storage_pg/user/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the user-related
8//! repositories
9
10use async_trait::async_trait;
11use mas_data_model::{Clock, User};
12use mas_storage::user::{UserFilter, UserRepository};
13use rand::RngCore;
14use sea_query::{Expr, PostgresQueryBuilder, Query};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21    DatabaseError,
22    filter::{Filter, StatementExt},
23    iden::Users,
24    pagination::QueryBuilderExt,
25    tracing::ExecuteExt,
26};
27
28mod email;
29mod password;
30mod recovery;
31mod registration;
32mod registration_token;
33mod session;
34mod terms;
35
36#[cfg(test)]
37mod tests;
38
39pub use self::{
40    email::PgUserEmailRepository, password::PgUserPasswordRepository,
41    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
42    registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
43    terms::PgUserTermsRepository,
44};
45
46/// An implementation of [`UserRepository`] for a PostgreSQL connection
47pub struct PgUserRepository<'c> {
48    conn: &'c mut PgConnection,
49}
50
51impl<'c> PgUserRepository<'c> {
52    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
53    pub fn new(conn: &'c mut PgConnection) -> Self {
54        Self { conn }
55    }
56}
57
58mod priv_ {
59    // The enum_def macro generates a public enum, which we don't want, because it
60    // triggers the missing docs warning
61    #![allow(missing_docs)]
62
63    use chrono::{DateTime, Utc};
64    use sea_query::enum_def;
65    use uuid::Uuid;
66
67    #[derive(Debug, Clone, sqlx::FromRow)]
68    #[enum_def]
69    pub(super) struct UserLookup {
70        pub(super) user_id: Uuid,
71        pub(super) username: String,
72        pub(super) created_at: DateTime<Utc>,
73        pub(super) locked_at: Option<DateTime<Utc>>,
74        pub(super) deactivated_at: Option<DateTime<Utc>>,
75        pub(super) can_request_admin: bool,
76    }
77}
78
79use priv_::{UserLookup, UserLookupIden};
80
81impl From<UserLookup> for User {
82    fn from(value: UserLookup) -> Self {
83        let id = value.user_id.into();
84        Self {
85            id,
86            username: value.username,
87            sub: id.to_string(),
88            created_at: value.created_at,
89            locked_at: value.locked_at,
90            deactivated_at: value.deactivated_at,
91            can_request_admin: value.can_request_admin,
92        }
93    }
94}
95
96impl Filter for UserFilter<'_> {
97    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
98        sea_query::Condition::all()
99            .add_option(self.state().map(|state| {
100                match state {
101                    mas_storage::user::UserState::Deactivated => {
102                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
103                    }
104                    mas_storage::user::UserState::Locked => {
105                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
106                    }
107                    mas_storage::user::UserState::Active => {
108                        Expr::col((Users::Table, Users::LockedAt))
109                            .is_null()
110                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
111                    }
112                }
113            }))
114            .add_option(self.can_request_admin().map(|can_request_admin| {
115                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
116            }))
117    }
118}
119
120#[async_trait]
121impl UserRepository for PgUserRepository<'_> {
122    type Error = DatabaseError;
123
124    #[tracing::instrument(
125        name = "db.user.lookup",
126        skip_all,
127        fields(
128            db.query.text,
129            user.id = %id,
130        ),
131        err,
132    )]
133    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
134        let res = sqlx::query_as!(
135            UserLookup,
136            r#"
137                SELECT user_id
138                     , username
139                     , created_at
140                     , locked_at
141                     , deactivated_at
142                     , can_request_admin
143                FROM users
144                WHERE user_id = $1
145            "#,
146            Uuid::from(id),
147        )
148        .traced()
149        .fetch_optional(&mut *self.conn)
150        .await?;
151
152        let Some(res) = res else { return Ok(None) };
153
154        Ok(Some(res.into()))
155    }
156
157    #[tracing::instrument(
158        name = "db.user.find_by_username",
159        skip_all,
160        fields(
161            db.query.text,
162            user.username = username,
163        ),
164        err,
165    )]
166    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
167        // We may have multiple users with the same username, but with a different
168        // casing. In this case, we want to return the one which matches the exact
169        // casing
170        let res = sqlx::query_as!(
171            UserLookup,
172            r#"
173                SELECT user_id
174                     , username
175                     , created_at
176                     , locked_at
177                     , deactivated_at
178                     , can_request_admin
179                FROM users
180                WHERE LOWER(username) = LOWER($1)
181            "#,
182            username,
183        )
184        .traced()
185        .fetch_all(&mut *self.conn)
186        .await?;
187
188        match &res[..] {
189            // Happy path: there is only one user matching the username…
190            [user] => Ok(Some(user.clone().into())),
191            // …or none.
192            [] => Ok(None),
193            list => {
194                // If there are multiple users with the same username, we want to
195                // return the one which matches the exact casing
196                if let Some(user) = list.iter().find(|user| user.username == username) {
197                    Ok(Some(user.clone().into()))
198                } else {
199                    // If none match exactly, we prefer to return nothing
200                    Ok(None)
201                }
202            }
203        }
204    }
205
206    #[tracing::instrument(
207        name = "db.user.add",
208        skip_all,
209        fields(
210            db.query.text,
211            user.username = username,
212            user.id,
213        ),
214        err,
215    )]
216    async fn add(
217        &mut self,
218        rng: &mut (dyn RngCore + Send),
219        clock: &dyn Clock,
220        username: String,
221    ) -> Result<User, Self::Error> {
222        let created_at = clock.now();
223        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
224        tracing::Span::current().record("user.id", tracing::field::display(id));
225
226        let res = sqlx::query!(
227            r#"
228                INSERT INTO users (user_id, username, created_at)
229                VALUES ($1, $2, $3)
230                ON CONFLICT (username) DO NOTHING
231            "#,
232            Uuid::from(id),
233            username,
234            created_at,
235        )
236        .traced()
237        .execute(&mut *self.conn)
238        .await?;
239
240        // If the user already exists, want to return an error but not poison the
241        // transaction
242        DatabaseError::ensure_affected_rows(&res, 1)?;
243
244        Ok(User {
245            id,
246            username,
247            sub: id.to_string(),
248            created_at,
249            locked_at: None,
250            deactivated_at: None,
251            can_request_admin: false,
252        })
253    }
254
255    #[tracing::instrument(
256        name = "db.user.exists",
257        skip_all,
258        fields(
259            db.query.text,
260            user.username = username,
261        ),
262        err,
263    )]
264    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
265        let exists = sqlx::query_scalar!(
266            r#"
267                SELECT EXISTS(
268                    SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
269                ) AS "exists!"
270            "#,
271            username
272        )
273        .traced()
274        .fetch_one(&mut *self.conn)
275        .await?;
276
277        Ok(exists)
278    }
279
280    #[tracing::instrument(
281        name = "db.user.lock",
282        skip_all,
283        fields(
284            db.query.text,
285            %user.id,
286        ),
287        err,
288    )]
289    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
290        if user.locked_at.is_some() {
291            return Ok(user);
292        }
293
294        let locked_at = clock.now();
295        let res = sqlx::query!(
296            r#"
297                UPDATE users
298                SET locked_at = $1
299                WHERE user_id = $2
300            "#,
301            locked_at,
302            Uuid::from(user.id),
303        )
304        .traced()
305        .execute(&mut *self.conn)
306        .await?;
307
308        DatabaseError::ensure_affected_rows(&res, 1)?;
309
310        user.locked_at = Some(locked_at);
311
312        Ok(user)
313    }
314
315    #[tracing::instrument(
316        name = "db.user.unlock",
317        skip_all,
318        fields(
319            db.query.text,
320            %user.id,
321        ),
322        err,
323    )]
324    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
325        if user.locked_at.is_none() {
326            return Ok(user);
327        }
328
329        let res = sqlx::query!(
330            r#"
331                UPDATE users
332                SET locked_at = NULL
333                WHERE user_id = $1
334            "#,
335            Uuid::from(user.id),
336        )
337        .traced()
338        .execute(&mut *self.conn)
339        .await?;
340
341        DatabaseError::ensure_affected_rows(&res, 1)?;
342
343        user.locked_at = None;
344
345        Ok(user)
346    }
347
348    #[tracing::instrument(
349        name = "db.user.deactivate",
350        skip_all,
351        fields(
352            db.query.text,
353            %user.id,
354        ),
355        err,
356    )]
357    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
358        if user.deactivated_at.is_some() {
359            return Ok(user);
360        }
361
362        let deactivated_at = clock.now();
363        let res = sqlx::query!(
364            r#"
365                UPDATE users
366                SET deactivated_at = $2
367                WHERE user_id = $1
368                  AND deactivated_at IS NULL
369            "#,
370            Uuid::from(user.id),
371            deactivated_at,
372        )
373        .traced()
374        .execute(&mut *self.conn)
375        .await?;
376
377        DatabaseError::ensure_affected_rows(&res, 1)?;
378
379        user.deactivated_at = Some(deactivated_at);
380
381        Ok(user)
382    }
383
384    #[tracing::instrument(
385        name = "db.user.reactivate",
386        skip_all,
387        fields(
388            db.query.text,
389            %user.id,
390        ),
391        err,
392    )]
393    async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
394        if user.deactivated_at.is_none() {
395            return Ok(user);
396        }
397
398        let res = sqlx::query!(
399            r#"
400                UPDATE users
401                SET deactivated_at = NULL
402                WHERE user_id = $1
403            "#,
404            Uuid::from(user.id),
405        )
406        .traced()
407        .execute(&mut *self.conn)
408        .await?;
409
410        DatabaseError::ensure_affected_rows(&res, 1)?;
411
412        user.deactivated_at = None;
413
414        Ok(user)
415    }
416
417    #[tracing::instrument(
418        name = "db.user.set_can_request_admin",
419        skip_all,
420        fields(
421            db.query.text,
422            %user.id,
423            user.can_request_admin = can_request_admin,
424        ),
425        err,
426    )]
427    async fn set_can_request_admin(
428        &mut self,
429        mut user: User,
430        can_request_admin: bool,
431    ) -> Result<User, Self::Error> {
432        let res = sqlx::query!(
433            r#"
434                UPDATE users
435                SET can_request_admin = $2
436                WHERE user_id = $1
437            "#,
438            Uuid::from(user.id),
439            can_request_admin,
440        )
441        .traced()
442        .execute(&mut *self.conn)
443        .await?;
444
445        DatabaseError::ensure_affected_rows(&res, 1)?;
446
447        user.can_request_admin = can_request_admin;
448
449        Ok(user)
450    }
451
452    #[tracing::instrument(
453        name = "db.user.list",
454        skip_all,
455        fields(
456            db.query.text,
457        ),
458        err,
459    )]
460    async fn list(
461        &mut self,
462        filter: UserFilter<'_>,
463        pagination: mas_storage::Pagination,
464    ) -> Result<mas_storage::Page<User>, Self::Error> {
465        let (sql, arguments) = Query::select()
466            .expr_as(
467                Expr::col((Users::Table, Users::UserId)),
468                UserLookupIden::UserId,
469            )
470            .expr_as(
471                Expr::col((Users::Table, Users::Username)),
472                UserLookupIden::Username,
473            )
474            .expr_as(
475                Expr::col((Users::Table, Users::CreatedAt)),
476                UserLookupIden::CreatedAt,
477            )
478            .expr_as(
479                Expr::col((Users::Table, Users::LockedAt)),
480                UserLookupIden::LockedAt,
481            )
482            .expr_as(
483                Expr::col((Users::Table, Users::DeactivatedAt)),
484                UserLookupIden::DeactivatedAt,
485            )
486            .expr_as(
487                Expr::col((Users::Table, Users::CanRequestAdmin)),
488                UserLookupIden::CanRequestAdmin,
489            )
490            .from(Users::Table)
491            .apply_filter(filter)
492            .generate_pagination((Users::Table, Users::UserId), pagination)
493            .build_sqlx(PostgresQueryBuilder);
494
495        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
496            .traced()
497            .fetch_all(&mut *self.conn)
498            .await?;
499
500        let page = pagination.process(edges).map(User::from);
501
502        Ok(page)
503    }
504
505    #[tracing::instrument(
506        name = "db.user.count",
507        skip_all,
508        fields(
509            db.query.text,
510        ),
511        err,
512    )]
513    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
514        let (sql, arguments) = Query::select()
515            .expr(Expr::col((Users::Table, Users::UserId)).count())
516            .from(Users::Table)
517            .apply_filter(filter)
518            .build_sqlx(PostgresQueryBuilder);
519
520        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
521            .traced()
522            .fetch_one(&mut *self.conn)
523            .await?;
524
525        count
526            .try_into()
527            .map_err(DatabaseError::to_invalid_operation)
528    }
529
530    #[tracing::instrument(
531        name = "db.user.acquire_lock_for_sync",
532        skip_all,
533        fields(
534            db.query.text,
535            user.id = %user.id,
536        ),
537        err,
538    )]
539    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
540        // XXX: this lock isn't stictly scoped to users, but as we don't use many
541        // postgres advisory locks, it's fine for now. Later on, we could use row-level
542        // locks to make sure we don't get into trouble
543
544        // Convert the user ID to a u128 and grab the lower 64 bits
545        // As this includes 64bit of the random part of the ULID, it should be random
546        // enough to not collide
547        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
548
549        // Use a PG advisory lock, which will be released when the transaction is
550        // committed or rolled back
551        sqlx::query!(
552            r#"
553                SELECT pg_advisory_xact_lock($1)
554            "#,
555            lock_id,
556        )
557        .traced()
558        .execute(&mut *self.conn)
559        .await?;
560
561        Ok(())
562    }
563}