1use async_trait::async_trait;
11use mas_data_model::{Clock, User};
12use mas_storage::user::{UserFilter, UserRepository};
13use rand::RngCore;
14use sea_query::{Expr, PostgresQueryBuilder, Query};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21 DatabaseError,
22 filter::{Filter, StatementExt},
23 iden::Users,
24 pagination::QueryBuilderExt,
25 tracing::ExecuteExt,
26};
27
28mod email;
29mod password;
30mod recovery;
31mod registration;
32mod registration_token;
33mod session;
34mod terms;
35
36#[cfg(test)]
37mod tests;
38
39pub use self::{
40 email::PgUserEmailRepository, password::PgUserPasswordRepository,
41 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
42 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
43 terms::PgUserTermsRepository,
44};
45
46pub struct PgUserRepository<'c> {
48 conn: &'c mut PgConnection,
49}
50
51impl<'c> PgUserRepository<'c> {
52 pub fn new(conn: &'c mut PgConnection) -> Self {
54 Self { conn }
55 }
56}
57
58mod priv_ {
59 #![allow(missing_docs)]
62
63 use chrono::{DateTime, Utc};
64 use sea_query::enum_def;
65 use uuid::Uuid;
66
67 #[derive(Debug, Clone, sqlx::FromRow)]
68 #[enum_def]
69 pub(super) struct UserLookup {
70 pub(super) user_id: Uuid,
71 pub(super) username: String,
72 pub(super) created_at: DateTime<Utc>,
73 pub(super) locked_at: Option<DateTime<Utc>>,
74 pub(super) deactivated_at: Option<DateTime<Utc>>,
75 pub(super) can_request_admin: bool,
76 }
77}
78
79use priv_::{UserLookup, UserLookupIden};
80
81impl From<UserLookup> for User {
82 fn from(value: UserLookup) -> Self {
83 let id = value.user_id.into();
84 Self {
85 id,
86 username: value.username,
87 sub: id.to_string(),
88 created_at: value.created_at,
89 locked_at: value.locked_at,
90 deactivated_at: value.deactivated_at,
91 can_request_admin: value.can_request_admin,
92 }
93 }
94}
95
96impl Filter for UserFilter<'_> {
97 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
98 sea_query::Condition::all()
99 .add_option(self.state().map(|state| {
100 match state {
101 mas_storage::user::UserState::Deactivated => {
102 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
103 }
104 mas_storage::user::UserState::Locked => {
105 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
106 }
107 mas_storage::user::UserState::Active => {
108 Expr::col((Users::Table, Users::LockedAt))
109 .is_null()
110 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
111 }
112 }
113 }))
114 .add_option(self.can_request_admin().map(|can_request_admin| {
115 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
116 }))
117 }
118}
119
120#[async_trait]
121impl UserRepository for PgUserRepository<'_> {
122 type Error = DatabaseError;
123
124 #[tracing::instrument(
125 name = "db.user.lookup",
126 skip_all,
127 fields(
128 db.query.text,
129 user.id = %id,
130 ),
131 err,
132 )]
133 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
134 let res = sqlx::query_as!(
135 UserLookup,
136 r#"
137 SELECT user_id
138 , username
139 , created_at
140 , locked_at
141 , deactivated_at
142 , can_request_admin
143 FROM users
144 WHERE user_id = $1
145 "#,
146 Uuid::from(id),
147 )
148 .traced()
149 .fetch_optional(&mut *self.conn)
150 .await?;
151
152 let Some(res) = res else { return Ok(None) };
153
154 Ok(Some(res.into()))
155 }
156
157 #[tracing::instrument(
158 name = "db.user.find_by_username",
159 skip_all,
160 fields(
161 db.query.text,
162 user.username = username,
163 ),
164 err,
165 )]
166 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
167 let res = sqlx::query_as!(
171 UserLookup,
172 r#"
173 SELECT user_id
174 , username
175 , created_at
176 , locked_at
177 , deactivated_at
178 , can_request_admin
179 FROM users
180 WHERE LOWER(username) = LOWER($1)
181 "#,
182 username,
183 )
184 .traced()
185 .fetch_all(&mut *self.conn)
186 .await?;
187
188 match &res[..] {
189 [user] => Ok(Some(user.clone().into())),
191 [] => Ok(None),
193 list => {
194 if let Some(user) = list.iter().find(|user| user.username == username) {
197 Ok(Some(user.clone().into()))
198 } else {
199 Ok(None)
201 }
202 }
203 }
204 }
205
206 #[tracing::instrument(
207 name = "db.user.add",
208 skip_all,
209 fields(
210 db.query.text,
211 user.username = username,
212 user.id,
213 ),
214 err,
215 )]
216 async fn add(
217 &mut self,
218 rng: &mut (dyn RngCore + Send),
219 clock: &dyn Clock,
220 username: String,
221 ) -> Result<User, Self::Error> {
222 let created_at = clock.now();
223 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
224 tracing::Span::current().record("user.id", tracing::field::display(id));
225
226 let res = sqlx::query!(
227 r#"
228 INSERT INTO users (user_id, username, created_at)
229 VALUES ($1, $2, $3)
230 ON CONFLICT (username) DO NOTHING
231 "#,
232 Uuid::from(id),
233 username,
234 created_at,
235 )
236 .traced()
237 .execute(&mut *self.conn)
238 .await?;
239
240 DatabaseError::ensure_affected_rows(&res, 1)?;
243
244 Ok(User {
245 id,
246 username,
247 sub: id.to_string(),
248 created_at,
249 locked_at: None,
250 deactivated_at: None,
251 can_request_admin: false,
252 })
253 }
254
255 #[tracing::instrument(
256 name = "db.user.exists",
257 skip_all,
258 fields(
259 db.query.text,
260 user.username = username,
261 ),
262 err,
263 )]
264 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
265 let exists = sqlx::query_scalar!(
266 r#"
267 SELECT EXISTS(
268 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
269 ) AS "exists!"
270 "#,
271 username
272 )
273 .traced()
274 .fetch_one(&mut *self.conn)
275 .await?;
276
277 Ok(exists)
278 }
279
280 #[tracing::instrument(
281 name = "db.user.lock",
282 skip_all,
283 fields(
284 db.query.text,
285 %user.id,
286 ),
287 err,
288 )]
289 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
290 if user.locked_at.is_some() {
291 return Ok(user);
292 }
293
294 let locked_at = clock.now();
295 let res = sqlx::query!(
296 r#"
297 UPDATE users
298 SET locked_at = $1
299 WHERE user_id = $2
300 "#,
301 locked_at,
302 Uuid::from(user.id),
303 )
304 .traced()
305 .execute(&mut *self.conn)
306 .await?;
307
308 DatabaseError::ensure_affected_rows(&res, 1)?;
309
310 user.locked_at = Some(locked_at);
311
312 Ok(user)
313 }
314
315 #[tracing::instrument(
316 name = "db.user.unlock",
317 skip_all,
318 fields(
319 db.query.text,
320 %user.id,
321 ),
322 err,
323 )]
324 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
325 if user.locked_at.is_none() {
326 return Ok(user);
327 }
328
329 let res = sqlx::query!(
330 r#"
331 UPDATE users
332 SET locked_at = NULL
333 WHERE user_id = $1
334 "#,
335 Uuid::from(user.id),
336 )
337 .traced()
338 .execute(&mut *self.conn)
339 .await?;
340
341 DatabaseError::ensure_affected_rows(&res, 1)?;
342
343 user.locked_at = None;
344
345 Ok(user)
346 }
347
348 #[tracing::instrument(
349 name = "db.user.deactivate",
350 skip_all,
351 fields(
352 db.query.text,
353 %user.id,
354 ),
355 err,
356 )]
357 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
358 if user.deactivated_at.is_some() {
359 return Ok(user);
360 }
361
362 let deactivated_at = clock.now();
363 let res = sqlx::query!(
364 r#"
365 UPDATE users
366 SET deactivated_at = $2
367 WHERE user_id = $1
368 AND deactivated_at IS NULL
369 "#,
370 Uuid::from(user.id),
371 deactivated_at,
372 )
373 .traced()
374 .execute(&mut *self.conn)
375 .await?;
376
377 DatabaseError::ensure_affected_rows(&res, 1)?;
378
379 user.deactivated_at = Some(deactivated_at);
380
381 Ok(user)
382 }
383
384 #[tracing::instrument(
385 name = "db.user.reactivate",
386 skip_all,
387 fields(
388 db.query.text,
389 %user.id,
390 ),
391 err,
392 )]
393 async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
394 if user.deactivated_at.is_none() {
395 return Ok(user);
396 }
397
398 let res = sqlx::query!(
399 r#"
400 UPDATE users
401 SET deactivated_at = NULL
402 WHERE user_id = $1
403 "#,
404 Uuid::from(user.id),
405 )
406 .traced()
407 .execute(&mut *self.conn)
408 .await?;
409
410 DatabaseError::ensure_affected_rows(&res, 1)?;
411
412 user.deactivated_at = None;
413
414 Ok(user)
415 }
416
417 #[tracing::instrument(
418 name = "db.user.set_can_request_admin",
419 skip_all,
420 fields(
421 db.query.text,
422 %user.id,
423 user.can_request_admin = can_request_admin,
424 ),
425 err,
426 )]
427 async fn set_can_request_admin(
428 &mut self,
429 mut user: User,
430 can_request_admin: bool,
431 ) -> Result<User, Self::Error> {
432 let res = sqlx::query!(
433 r#"
434 UPDATE users
435 SET can_request_admin = $2
436 WHERE user_id = $1
437 "#,
438 Uuid::from(user.id),
439 can_request_admin,
440 )
441 .traced()
442 .execute(&mut *self.conn)
443 .await?;
444
445 DatabaseError::ensure_affected_rows(&res, 1)?;
446
447 user.can_request_admin = can_request_admin;
448
449 Ok(user)
450 }
451
452 #[tracing::instrument(
453 name = "db.user.list",
454 skip_all,
455 fields(
456 db.query.text,
457 ),
458 err,
459 )]
460 async fn list(
461 &mut self,
462 filter: UserFilter<'_>,
463 pagination: mas_storage::Pagination,
464 ) -> Result<mas_storage::Page<User>, Self::Error> {
465 let (sql, arguments) = Query::select()
466 .expr_as(
467 Expr::col((Users::Table, Users::UserId)),
468 UserLookupIden::UserId,
469 )
470 .expr_as(
471 Expr::col((Users::Table, Users::Username)),
472 UserLookupIden::Username,
473 )
474 .expr_as(
475 Expr::col((Users::Table, Users::CreatedAt)),
476 UserLookupIden::CreatedAt,
477 )
478 .expr_as(
479 Expr::col((Users::Table, Users::LockedAt)),
480 UserLookupIden::LockedAt,
481 )
482 .expr_as(
483 Expr::col((Users::Table, Users::DeactivatedAt)),
484 UserLookupIden::DeactivatedAt,
485 )
486 .expr_as(
487 Expr::col((Users::Table, Users::CanRequestAdmin)),
488 UserLookupIden::CanRequestAdmin,
489 )
490 .from(Users::Table)
491 .apply_filter(filter)
492 .generate_pagination((Users::Table, Users::UserId), pagination)
493 .build_sqlx(PostgresQueryBuilder);
494
495 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
496 .traced()
497 .fetch_all(&mut *self.conn)
498 .await?;
499
500 let page = pagination.process(edges).map(User::from);
501
502 Ok(page)
503 }
504
505 #[tracing::instrument(
506 name = "db.user.count",
507 skip_all,
508 fields(
509 db.query.text,
510 ),
511 err,
512 )]
513 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
514 let (sql, arguments) = Query::select()
515 .expr(Expr::col((Users::Table, Users::UserId)).count())
516 .from(Users::Table)
517 .apply_filter(filter)
518 .build_sqlx(PostgresQueryBuilder);
519
520 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
521 .traced()
522 .fetch_one(&mut *self.conn)
523 .await?;
524
525 count
526 .try_into()
527 .map_err(DatabaseError::to_invalid_operation)
528 }
529
530 #[tracing::instrument(
531 name = "db.user.acquire_lock_for_sync",
532 skip_all,
533 fields(
534 db.query.text,
535 user.id = %user.id,
536 ),
537 err,
538 )]
539 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
540 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
548
549 sqlx::query!(
552 r#"
553 SELECT pg_advisory_xact_lock($1)
554 "#,
555 lock_id,
556 )
557 .traced()
558 .execute(&mut *self.conn)
559 .await?;
560
561 Ok(())
562 }
563}