mas_storage_pg/user/
registration.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    Clock, UserEmailAuthentication, UserRegistration, UserRegistrationPassword,
12    UserRegistrationToken,
13};
14use mas_storage::user::UserRegistrationRepository;
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
22
23/// An implementation of [`UserRegistrationRepository`] for a PostgreSQL
24/// connection
25pub struct PgUserRegistrationRepository<'c> {
26    conn: &'c mut PgConnection,
27}
28
29impl<'c> PgUserRegistrationRepository<'c> {
30    /// Create a new [`PgUserRegistrationRepository`] from an active PostgreSQL
31    /// connection
32    pub fn new(conn: &'c mut PgConnection) -> Self {
33        Self { conn }
34    }
35}
36
37struct UserRegistrationLookup {
38    user_registration_id: Uuid,
39    ip_address: Option<IpAddr>,
40    user_agent: Option<String>,
41    post_auth_action: Option<serde_json::Value>,
42    username: String,
43    display_name: Option<String>,
44    terms_url: Option<String>,
45    email_authentication_id: Option<Uuid>,
46    user_registration_token_id: Option<Uuid>,
47    hashed_password: Option<String>,
48    hashed_password_version: Option<i32>,
49    created_at: DateTime<Utc>,
50    completed_at: Option<DateTime<Utc>>,
51}
52
53impl TryFrom<UserRegistrationLookup> for UserRegistration {
54    type Error = DatabaseInconsistencyError;
55
56    fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
57        let id = Ulid::from(value.user_registration_id);
58
59        let password = match (value.hashed_password, value.hashed_password_version) {
60            (Some(hashed_password), Some(version)) => {
61                let version = version.try_into().map_err(|e| {
62                    DatabaseInconsistencyError::on("user_registrations")
63                        .column("hashed_password_version")
64                        .row(id)
65                        .source(e)
66                })?;
67
68                Some(UserRegistrationPassword {
69                    hashed_password,
70                    version,
71                })
72            }
73            (None, None) => None,
74            _ => {
75                return Err(DatabaseInconsistencyError::on("user_registrations")
76                    .column("hashed_password")
77                    .row(id));
78            }
79        };
80
81        let terms_url = value
82            .terms_url
83            .map(|u| u.parse())
84            .transpose()
85            .map_err(|e| {
86                DatabaseInconsistencyError::on("user_registrations")
87                    .column("terms_url")
88                    .row(id)
89                    .source(e)
90            })?;
91
92        Ok(UserRegistration {
93            id,
94            ip_address: value.ip_address,
95            user_agent: value.user_agent,
96            post_auth_action: value.post_auth_action,
97            username: value.username,
98            display_name: value.display_name,
99            terms_url,
100            email_authentication_id: value.email_authentication_id.map(Ulid::from),
101            user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
102            password,
103            created_at: value.created_at,
104            completed_at: value.completed_at,
105        })
106    }
107}
108
109#[async_trait]
110impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
111    type Error = DatabaseError;
112
113    #[tracing::instrument(
114        name = "db.user_registration.lookup",
115        skip_all,
116        fields(
117            db.query.text,
118            user_registration.id = %id,
119        ),
120        err,
121    )]
122    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
123        let res = sqlx::query_as!(
124            UserRegistrationLookup,
125            r#"
126                SELECT user_registration_id
127                     , ip_address as "ip_address: IpAddr"
128                     , user_agent
129                     , post_auth_action
130                     , username
131                     , display_name
132                     , terms_url
133                     , email_authentication_id
134                     , user_registration_token_id
135                     , hashed_password
136                     , hashed_password_version
137                     , created_at
138                     , completed_at
139                FROM user_registrations
140                WHERE user_registration_id = $1
141            "#,
142            Uuid::from(id),
143        )
144        .traced()
145        .fetch_optional(&mut *self.conn)
146        .await?;
147
148        let Some(res) = res else { return Ok(None) };
149
150        Ok(Some(res.try_into()?))
151    }
152
153    #[tracing::instrument(
154        name = "db.user_registration.add",
155        skip_all,
156        fields(
157            db.query.text,
158            user_registration.id,
159        ),
160        err,
161    )]
162    async fn add(
163        &mut self,
164        rng: &mut (dyn RngCore + Send),
165        clock: &dyn Clock,
166        username: String,
167        ip_address: Option<IpAddr>,
168        user_agent: Option<String>,
169        post_auth_action: Option<serde_json::Value>,
170    ) -> Result<UserRegistration, Self::Error> {
171        let created_at = clock.now();
172        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
173        tracing::Span::current().record("user_registration.id", tracing::field::display(id));
174
175        sqlx::query!(
176            r#"
177                INSERT INTO user_registrations
178                  ( user_registration_id
179                  , ip_address
180                  , user_agent
181                  , post_auth_action
182                  , username
183                  , created_at
184                  )
185                VALUES ($1, $2, $3, $4, $5, $6)
186            "#,
187            Uuid::from(id),
188            ip_address as Option<IpAddr>,
189            user_agent.as_deref(),
190            post_auth_action,
191            username,
192            created_at,
193        )
194        .traced()
195        .execute(&mut *self.conn)
196        .await?;
197
198        Ok(UserRegistration {
199            id,
200            ip_address,
201            user_agent,
202            post_auth_action,
203            created_at,
204            completed_at: None,
205            username,
206            display_name: None,
207            terms_url: None,
208            email_authentication_id: None,
209            user_registration_token_id: None,
210            password: None,
211        })
212    }
213
214    #[tracing::instrument(
215        name = "db.user_registration.set_display_name",
216        skip_all,
217        fields(
218            db.query.text,
219            user_registration.id = %user_registration.id,
220            user_registration.display_name = display_name,
221        ),
222        err,
223    )]
224    async fn set_display_name(
225        &mut self,
226        mut user_registration: UserRegistration,
227        display_name: String,
228    ) -> Result<UserRegistration, Self::Error> {
229        let res = sqlx::query!(
230            r#"
231                UPDATE user_registrations
232                SET display_name = $2
233                WHERE user_registration_id = $1 AND completed_at IS NULL
234            "#,
235            Uuid::from(user_registration.id),
236            display_name,
237        )
238        .traced()
239        .execute(&mut *self.conn)
240        .await?;
241
242        DatabaseError::ensure_affected_rows(&res, 1)?;
243
244        user_registration.display_name = Some(display_name);
245
246        Ok(user_registration)
247    }
248
249    #[tracing::instrument(
250        name = "db.user_registration.set_terms_url",
251        skip_all,
252        fields(
253            db.query.text,
254            user_registration.id = %user_registration.id,
255            user_registration.terms_url = %terms_url,
256        ),
257        err,
258    )]
259    async fn set_terms_url(
260        &mut self,
261        mut user_registration: UserRegistration,
262        terms_url: Url,
263    ) -> Result<UserRegistration, Self::Error> {
264        let res = sqlx::query!(
265            r#"
266                UPDATE user_registrations
267                SET terms_url = $2
268                WHERE user_registration_id = $1 AND completed_at IS NULL
269            "#,
270            Uuid::from(user_registration.id),
271            terms_url.as_str(),
272        )
273        .traced()
274        .execute(&mut *self.conn)
275        .await?;
276
277        DatabaseError::ensure_affected_rows(&res, 1)?;
278
279        user_registration.terms_url = Some(terms_url);
280
281        Ok(user_registration)
282    }
283
284    #[tracing::instrument(
285        name = "db.user_registration.set_email_authentication",
286        skip_all,
287        fields(
288            db.query.text,
289            %user_registration.id,
290            %user_email_authentication.id,
291            %user_email_authentication.email,
292        ),
293        err,
294    )]
295    async fn set_email_authentication(
296        &mut self,
297        mut user_registration: UserRegistration,
298        user_email_authentication: &UserEmailAuthentication,
299    ) -> Result<UserRegistration, Self::Error> {
300        let res = sqlx::query!(
301            r#"
302                UPDATE user_registrations
303                SET email_authentication_id = $2
304                WHERE user_registration_id = $1 AND completed_at IS NULL
305            "#,
306            Uuid::from(user_registration.id),
307            Uuid::from(user_email_authentication.id),
308        )
309        .traced()
310        .execute(&mut *self.conn)
311        .await?;
312
313        DatabaseError::ensure_affected_rows(&res, 1)?;
314
315        user_registration.email_authentication_id = Some(user_email_authentication.id);
316
317        Ok(user_registration)
318    }
319
320    #[tracing::instrument(
321        name = "db.user_registration.set_password",
322        skip_all,
323        fields(
324            db.query.text,
325            user_registration.id = %user_registration.id,
326            user_registration.hashed_password = hashed_password,
327            user_registration.hashed_password_version = version,
328        ),
329        err,
330    )]
331    async fn set_password(
332        &mut self,
333        mut user_registration: UserRegistration,
334        hashed_password: String,
335        version: u16,
336    ) -> Result<UserRegistration, Self::Error> {
337        let res = sqlx::query!(
338            r#"
339                UPDATE user_registrations
340                SET hashed_password = $2, hashed_password_version = $3
341                WHERE user_registration_id = $1 AND completed_at IS NULL
342            "#,
343            Uuid::from(user_registration.id),
344            hashed_password,
345            i32::from(version),
346        )
347        .traced()
348        .execute(&mut *self.conn)
349        .await?;
350
351        DatabaseError::ensure_affected_rows(&res, 1)?;
352
353        user_registration.password = Some(UserRegistrationPassword {
354            hashed_password,
355            version,
356        });
357
358        Ok(user_registration)
359    }
360
361    #[tracing::instrument(
362        name = "db.user_registration.set_registration_token",
363        skip_all,
364        fields(
365            db.query.text,
366            %user_registration.id,
367            %user_registration_token.id,
368        ),
369        err,
370    )]
371    async fn set_registration_token(
372        &mut self,
373        mut user_registration: UserRegistration,
374        user_registration_token: &UserRegistrationToken,
375    ) -> Result<UserRegistration, Self::Error> {
376        let res = sqlx::query!(
377            r#"
378                UPDATE user_registrations
379                SET user_registration_token_id = $2
380                WHERE user_registration_id = $1 AND completed_at IS NULL
381            "#,
382            Uuid::from(user_registration.id),
383            Uuid::from(user_registration_token.id),
384        )
385        .traced()
386        .execute(&mut *self.conn)
387        .await?;
388
389        DatabaseError::ensure_affected_rows(&res, 1)?;
390
391        user_registration.user_registration_token_id = Some(user_registration_token.id);
392
393        Ok(user_registration)
394    }
395
396    #[tracing::instrument(
397        name = "db.user_registration.complete",
398        skip_all,
399        fields(
400            db.query.text,
401            user_registration.id = %user_registration.id,
402        ),
403        err,
404    )]
405    async fn complete(
406        &mut self,
407        clock: &dyn Clock,
408        mut user_registration: UserRegistration,
409    ) -> Result<UserRegistration, Self::Error> {
410        let completed_at = clock.now();
411        let res = sqlx::query!(
412            r#"
413                UPDATE user_registrations
414                SET completed_at = $2
415                WHERE user_registration_id = $1 AND completed_at IS NULL
416            "#,
417            Uuid::from(user_registration.id),
418            completed_at,
419        )
420        .traced()
421        .execute(&mut *self.conn)
422        .await?;
423
424        DatabaseError::ensure_affected_rows(&res, 1)?;
425
426        user_registration.completed_at = Some(completed_at);
427
428        Ok(user_registration)
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use std::net::{IpAddr, Ipv4Addr};
435
436    use mas_data_model::{Clock, UserRegistrationPassword, clock::MockClock};
437    use rand::SeedableRng;
438    use rand_chacha::ChaChaRng;
439    use sqlx::PgPool;
440
441    use crate::PgRepository;
442
443    #[sqlx::test(migrator = "crate::MIGRATOR")]
444    async fn test_create_lookup_complete(pool: PgPool) {
445        let mut rng = ChaChaRng::seed_from_u64(42);
446        let clock = MockClock::default();
447
448        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
449
450        let registration = repo
451            .user_registration()
452            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
453            .await
454            .unwrap();
455
456        assert_eq!(registration.created_at, clock.now());
457        assert_eq!(registration.completed_at, None);
458        assert_eq!(registration.username, "alice");
459        assert_eq!(registration.display_name, None);
460        assert_eq!(registration.terms_url, None);
461        assert_eq!(registration.email_authentication_id, None);
462        assert_eq!(registration.password, None);
463        assert_eq!(registration.user_agent, None);
464        assert_eq!(registration.ip_address, None);
465        assert_eq!(registration.post_auth_action, None);
466
467        let lookup = repo
468            .user_registration()
469            .lookup(registration.id)
470            .await
471            .unwrap()
472            .unwrap();
473
474        assert_eq!(lookup.id, registration.id);
475        assert_eq!(lookup.created_at, registration.created_at);
476        assert_eq!(lookup.completed_at, registration.completed_at);
477        assert_eq!(lookup.username, registration.username);
478        assert_eq!(lookup.display_name, registration.display_name);
479        assert_eq!(lookup.terms_url, registration.terms_url);
480        assert_eq!(
481            lookup.email_authentication_id,
482            registration.email_authentication_id
483        );
484        assert_eq!(lookup.password, registration.password);
485        assert_eq!(lookup.user_agent, registration.user_agent);
486        assert_eq!(lookup.ip_address, registration.ip_address);
487        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
488
489        // Mark the registration as completed
490        let registration = repo
491            .user_registration()
492            .complete(&clock, registration)
493            .await
494            .unwrap();
495        assert_eq!(registration.completed_at, Some(clock.now()));
496
497        // Lookup the registration again
498        let lookup = repo
499            .user_registration()
500            .lookup(registration.id)
501            .await
502            .unwrap()
503            .unwrap();
504        assert_eq!(lookup.completed_at, registration.completed_at);
505
506        // Do it again, it should fail
507        let res = repo
508            .user_registration()
509            .complete(&clock, registration)
510            .await;
511        assert!(res.is_err());
512    }
513
514    #[sqlx::test(migrator = "crate::MIGRATOR")]
515    async fn test_create_useragent_ipaddress(pool: PgPool) {
516        let mut rng = ChaChaRng::seed_from_u64(42);
517        let clock = MockClock::default();
518
519        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
520
521        let registration = repo
522            .user_registration()
523            .add(
524                &mut rng,
525                &clock,
526                "alice".to_owned(),
527                Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
528                Some("Mozilla/5.0".to_owned()),
529                Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
530            )
531            .await
532            .unwrap();
533
534        assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
535        assert_eq!(
536            registration.ip_address,
537            Some(IpAddr::V4(Ipv4Addr::LOCALHOST))
538        );
539        assert_eq!(
540            registration.post_auth_action,
541            Some(
542                serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
543            )
544        );
545
546        let lookup = repo
547            .user_registration()
548            .lookup(registration.id)
549            .await
550            .unwrap()
551            .unwrap();
552
553        assert_eq!(lookup.user_agent, registration.user_agent);
554        assert_eq!(lookup.ip_address, registration.ip_address);
555        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
556    }
557
558    #[sqlx::test(migrator = "crate::MIGRATOR")]
559    async fn test_set_display_name(pool: PgPool) {
560        let mut rng = ChaChaRng::seed_from_u64(42);
561        let clock = MockClock::default();
562
563        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
564
565        let registration = repo
566            .user_registration()
567            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
568            .await
569            .unwrap();
570
571        assert_eq!(registration.display_name, None);
572
573        let registration = repo
574            .user_registration()
575            .set_display_name(registration, "Alice".to_owned())
576            .await
577            .unwrap();
578
579        assert_eq!(registration.display_name, Some("Alice".to_owned()));
580
581        let lookup = repo
582            .user_registration()
583            .lookup(registration.id)
584            .await
585            .unwrap()
586            .unwrap();
587
588        assert_eq!(lookup.display_name, registration.display_name);
589
590        // Setting it again should work
591        let registration = repo
592            .user_registration()
593            .set_display_name(registration, "Bob".to_owned())
594            .await
595            .unwrap();
596
597        assert_eq!(registration.display_name, Some("Bob".to_owned()));
598
599        let lookup = repo
600            .user_registration()
601            .lookup(registration.id)
602            .await
603            .unwrap()
604            .unwrap();
605
606        assert_eq!(lookup.display_name, registration.display_name);
607
608        // Can't set it once completed
609        let registration = repo
610            .user_registration()
611            .complete(&clock, registration)
612            .await
613            .unwrap();
614
615        let res = repo
616            .user_registration()
617            .set_display_name(registration, "Charlie".to_owned())
618            .await;
619        assert!(res.is_err());
620    }
621
622    #[sqlx::test(migrator = "crate::MIGRATOR")]
623    async fn test_set_terms_url(pool: PgPool) {
624        let mut rng = ChaChaRng::seed_from_u64(42);
625        let clock = MockClock::default();
626
627        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
628
629        let registration = repo
630            .user_registration()
631            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
632            .await
633            .unwrap();
634
635        assert_eq!(registration.terms_url, None);
636
637        let registration = repo
638            .user_registration()
639            .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
640            .await
641            .unwrap();
642
643        assert_eq!(
644            registration.terms_url,
645            Some("https://example.com/terms".parse().unwrap())
646        );
647
648        let lookup = repo
649            .user_registration()
650            .lookup(registration.id)
651            .await
652            .unwrap()
653            .unwrap();
654
655        assert_eq!(lookup.terms_url, registration.terms_url);
656
657        // Setting it again should work
658        let registration = repo
659            .user_registration()
660            .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
661            .await
662            .unwrap();
663
664        assert_eq!(
665            registration.terms_url,
666            Some("https://example.com/terms2".parse().unwrap())
667        );
668
669        let lookup = repo
670            .user_registration()
671            .lookup(registration.id)
672            .await
673            .unwrap()
674            .unwrap();
675
676        assert_eq!(lookup.terms_url, registration.terms_url);
677
678        // Can't set it once completed
679        let registration = repo
680            .user_registration()
681            .complete(&clock, registration)
682            .await
683            .unwrap();
684
685        let res = repo
686            .user_registration()
687            .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
688            .await;
689        assert!(res.is_err());
690    }
691
692    #[sqlx::test(migrator = "crate::MIGRATOR")]
693    async fn test_set_email_authentication(pool: PgPool) {
694        let mut rng = ChaChaRng::seed_from_u64(42);
695        let clock = MockClock::default();
696
697        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
698
699        let registration = repo
700            .user_registration()
701            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
702            .await
703            .unwrap();
704
705        assert_eq!(registration.email_authentication_id, None);
706
707        let authentication = repo
708            .user_email()
709            .add_authentication_for_registration(
710                &mut rng,
711                &clock,
712                "alice@example.com".to_owned(),
713                &registration,
714            )
715            .await
716            .unwrap();
717
718        let registration = repo
719            .user_registration()
720            .set_email_authentication(registration, &authentication)
721            .await
722            .unwrap();
723
724        assert_eq!(
725            registration.email_authentication_id,
726            Some(authentication.id)
727        );
728
729        let lookup = repo
730            .user_registration()
731            .lookup(registration.id)
732            .await
733            .unwrap()
734            .unwrap();
735
736        assert_eq!(
737            lookup.email_authentication_id,
738            registration.email_authentication_id
739        );
740
741        // Setting it again should work
742        let registration = repo
743            .user_registration()
744            .set_email_authentication(registration, &authentication)
745            .await
746            .unwrap();
747
748        assert_eq!(
749            registration.email_authentication_id,
750            Some(authentication.id)
751        );
752
753        let lookup = repo
754            .user_registration()
755            .lookup(registration.id)
756            .await
757            .unwrap()
758            .unwrap();
759
760        assert_eq!(
761            lookup.email_authentication_id,
762            registration.email_authentication_id
763        );
764
765        // Can't set it once completed
766        let registration = repo
767            .user_registration()
768            .complete(&clock, registration)
769            .await
770            .unwrap();
771
772        let res = repo
773            .user_registration()
774            .set_email_authentication(registration, &authentication)
775            .await;
776        assert!(res.is_err());
777    }
778
779    #[sqlx::test(migrator = "crate::MIGRATOR")]
780    async fn test_set_password(pool: PgPool) {
781        let mut rng = ChaChaRng::seed_from_u64(42);
782        let clock = MockClock::default();
783
784        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
785
786        let registration = repo
787            .user_registration()
788            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
789            .await
790            .unwrap();
791
792        assert_eq!(registration.password, None);
793
794        let registration = repo
795            .user_registration()
796            .set_password(registration, "fakehashedpassword".to_owned(), 1)
797            .await
798            .unwrap();
799
800        assert_eq!(
801            registration.password,
802            Some(UserRegistrationPassword {
803                hashed_password: "fakehashedpassword".to_owned(),
804                version: 1,
805            })
806        );
807
808        let lookup = repo
809            .user_registration()
810            .lookup(registration.id)
811            .await
812            .unwrap()
813            .unwrap();
814
815        assert_eq!(lookup.password, registration.password);
816
817        // Setting it again should work
818        let registration = repo
819            .user_registration()
820            .set_password(registration, "fakehashedpassword2".to_owned(), 2)
821            .await
822            .unwrap();
823
824        assert_eq!(
825            registration.password,
826            Some(UserRegistrationPassword {
827                hashed_password: "fakehashedpassword2".to_owned(),
828                version: 2,
829            })
830        );
831
832        let lookup = repo
833            .user_registration()
834            .lookup(registration.id)
835            .await
836            .unwrap()
837            .unwrap();
838
839        assert_eq!(lookup.password, registration.password);
840
841        // Can't set it once completed
842        let registration = repo
843            .user_registration()
844            .complete(&clock, registration)
845            .await
846            .unwrap();
847
848        let res = repo
849            .user_registration()
850            .set_password(registration, "fakehashedpassword3".to_owned(), 3)
851            .await;
852        assert!(res.is_err());
853    }
854}