1use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 Clock, UserEmailAuthentication, UserRegistration, UserRegistrationPassword,
12 UserRegistrationToken,
13};
14use mas_storage::user::UserRegistrationRepository;
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
22
23pub struct PgUserRegistrationRepository<'c> {
26 conn: &'c mut PgConnection,
27}
28
29impl<'c> PgUserRegistrationRepository<'c> {
30 pub fn new(conn: &'c mut PgConnection) -> Self {
33 Self { conn }
34 }
35}
36
37struct UserRegistrationLookup {
38 user_registration_id: Uuid,
39 ip_address: Option<IpAddr>,
40 user_agent: Option<String>,
41 post_auth_action: Option<serde_json::Value>,
42 username: String,
43 display_name: Option<String>,
44 terms_url: Option<String>,
45 email_authentication_id: Option<Uuid>,
46 user_registration_token_id: Option<Uuid>,
47 hashed_password: Option<String>,
48 hashed_password_version: Option<i32>,
49 created_at: DateTime<Utc>,
50 completed_at: Option<DateTime<Utc>>,
51}
52
53impl TryFrom<UserRegistrationLookup> for UserRegistration {
54 type Error = DatabaseInconsistencyError;
55
56 fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
57 let id = Ulid::from(value.user_registration_id);
58
59 let password = match (value.hashed_password, value.hashed_password_version) {
60 (Some(hashed_password), Some(version)) => {
61 let version = version.try_into().map_err(|e| {
62 DatabaseInconsistencyError::on("user_registrations")
63 .column("hashed_password_version")
64 .row(id)
65 .source(e)
66 })?;
67
68 Some(UserRegistrationPassword {
69 hashed_password,
70 version,
71 })
72 }
73 (None, None) => None,
74 _ => {
75 return Err(DatabaseInconsistencyError::on("user_registrations")
76 .column("hashed_password")
77 .row(id));
78 }
79 };
80
81 let terms_url = value
82 .terms_url
83 .map(|u| u.parse())
84 .transpose()
85 .map_err(|e| {
86 DatabaseInconsistencyError::on("user_registrations")
87 .column("terms_url")
88 .row(id)
89 .source(e)
90 })?;
91
92 Ok(UserRegistration {
93 id,
94 ip_address: value.ip_address,
95 user_agent: value.user_agent,
96 post_auth_action: value.post_auth_action,
97 username: value.username,
98 display_name: value.display_name,
99 terms_url,
100 email_authentication_id: value.email_authentication_id.map(Ulid::from),
101 user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
102 password,
103 created_at: value.created_at,
104 completed_at: value.completed_at,
105 })
106 }
107}
108
109#[async_trait]
110impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
111 type Error = DatabaseError;
112
113 #[tracing::instrument(
114 name = "db.user_registration.lookup",
115 skip_all,
116 fields(
117 db.query.text,
118 user_registration.id = %id,
119 ),
120 err,
121 )]
122 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
123 let res = sqlx::query_as!(
124 UserRegistrationLookup,
125 r#"
126 SELECT user_registration_id
127 , ip_address as "ip_address: IpAddr"
128 , user_agent
129 , post_auth_action
130 , username
131 , display_name
132 , terms_url
133 , email_authentication_id
134 , user_registration_token_id
135 , hashed_password
136 , hashed_password_version
137 , created_at
138 , completed_at
139 FROM user_registrations
140 WHERE user_registration_id = $1
141 "#,
142 Uuid::from(id),
143 )
144 .traced()
145 .fetch_optional(&mut *self.conn)
146 .await?;
147
148 let Some(res) = res else { return Ok(None) };
149
150 Ok(Some(res.try_into()?))
151 }
152
153 #[tracing::instrument(
154 name = "db.user_registration.add",
155 skip_all,
156 fields(
157 db.query.text,
158 user_registration.id,
159 ),
160 err,
161 )]
162 async fn add(
163 &mut self,
164 rng: &mut (dyn RngCore + Send),
165 clock: &dyn Clock,
166 username: String,
167 ip_address: Option<IpAddr>,
168 user_agent: Option<String>,
169 post_auth_action: Option<serde_json::Value>,
170 ) -> Result<UserRegistration, Self::Error> {
171 let created_at = clock.now();
172 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
173 tracing::Span::current().record("user_registration.id", tracing::field::display(id));
174
175 sqlx::query!(
176 r#"
177 INSERT INTO user_registrations
178 ( user_registration_id
179 , ip_address
180 , user_agent
181 , post_auth_action
182 , username
183 , created_at
184 )
185 VALUES ($1, $2, $3, $4, $5, $6)
186 "#,
187 Uuid::from(id),
188 ip_address as Option<IpAddr>,
189 user_agent.as_deref(),
190 post_auth_action,
191 username,
192 created_at,
193 )
194 .traced()
195 .execute(&mut *self.conn)
196 .await?;
197
198 Ok(UserRegistration {
199 id,
200 ip_address,
201 user_agent,
202 post_auth_action,
203 created_at,
204 completed_at: None,
205 username,
206 display_name: None,
207 terms_url: None,
208 email_authentication_id: None,
209 user_registration_token_id: None,
210 password: None,
211 })
212 }
213
214 #[tracing::instrument(
215 name = "db.user_registration.set_display_name",
216 skip_all,
217 fields(
218 db.query.text,
219 user_registration.id = %user_registration.id,
220 user_registration.display_name = display_name,
221 ),
222 err,
223 )]
224 async fn set_display_name(
225 &mut self,
226 mut user_registration: UserRegistration,
227 display_name: String,
228 ) -> Result<UserRegistration, Self::Error> {
229 let res = sqlx::query!(
230 r#"
231 UPDATE user_registrations
232 SET display_name = $2
233 WHERE user_registration_id = $1 AND completed_at IS NULL
234 "#,
235 Uuid::from(user_registration.id),
236 display_name,
237 )
238 .traced()
239 .execute(&mut *self.conn)
240 .await?;
241
242 DatabaseError::ensure_affected_rows(&res, 1)?;
243
244 user_registration.display_name = Some(display_name);
245
246 Ok(user_registration)
247 }
248
249 #[tracing::instrument(
250 name = "db.user_registration.set_terms_url",
251 skip_all,
252 fields(
253 db.query.text,
254 user_registration.id = %user_registration.id,
255 user_registration.terms_url = %terms_url,
256 ),
257 err,
258 )]
259 async fn set_terms_url(
260 &mut self,
261 mut user_registration: UserRegistration,
262 terms_url: Url,
263 ) -> Result<UserRegistration, Self::Error> {
264 let res = sqlx::query!(
265 r#"
266 UPDATE user_registrations
267 SET terms_url = $2
268 WHERE user_registration_id = $1 AND completed_at IS NULL
269 "#,
270 Uuid::from(user_registration.id),
271 terms_url.as_str(),
272 )
273 .traced()
274 .execute(&mut *self.conn)
275 .await?;
276
277 DatabaseError::ensure_affected_rows(&res, 1)?;
278
279 user_registration.terms_url = Some(terms_url);
280
281 Ok(user_registration)
282 }
283
284 #[tracing::instrument(
285 name = "db.user_registration.set_email_authentication",
286 skip_all,
287 fields(
288 db.query.text,
289 %user_registration.id,
290 %user_email_authentication.id,
291 %user_email_authentication.email,
292 ),
293 err,
294 )]
295 async fn set_email_authentication(
296 &mut self,
297 mut user_registration: UserRegistration,
298 user_email_authentication: &UserEmailAuthentication,
299 ) -> Result<UserRegistration, Self::Error> {
300 let res = sqlx::query!(
301 r#"
302 UPDATE user_registrations
303 SET email_authentication_id = $2
304 WHERE user_registration_id = $1 AND completed_at IS NULL
305 "#,
306 Uuid::from(user_registration.id),
307 Uuid::from(user_email_authentication.id),
308 )
309 .traced()
310 .execute(&mut *self.conn)
311 .await?;
312
313 DatabaseError::ensure_affected_rows(&res, 1)?;
314
315 user_registration.email_authentication_id = Some(user_email_authentication.id);
316
317 Ok(user_registration)
318 }
319
320 #[tracing::instrument(
321 name = "db.user_registration.set_password",
322 skip_all,
323 fields(
324 db.query.text,
325 user_registration.id = %user_registration.id,
326 user_registration.hashed_password = hashed_password,
327 user_registration.hashed_password_version = version,
328 ),
329 err,
330 )]
331 async fn set_password(
332 &mut self,
333 mut user_registration: UserRegistration,
334 hashed_password: String,
335 version: u16,
336 ) -> Result<UserRegistration, Self::Error> {
337 let res = sqlx::query!(
338 r#"
339 UPDATE user_registrations
340 SET hashed_password = $2, hashed_password_version = $3
341 WHERE user_registration_id = $1 AND completed_at IS NULL
342 "#,
343 Uuid::from(user_registration.id),
344 hashed_password,
345 i32::from(version),
346 )
347 .traced()
348 .execute(&mut *self.conn)
349 .await?;
350
351 DatabaseError::ensure_affected_rows(&res, 1)?;
352
353 user_registration.password = Some(UserRegistrationPassword {
354 hashed_password,
355 version,
356 });
357
358 Ok(user_registration)
359 }
360
361 #[tracing::instrument(
362 name = "db.user_registration.set_registration_token",
363 skip_all,
364 fields(
365 db.query.text,
366 %user_registration.id,
367 %user_registration_token.id,
368 ),
369 err,
370 )]
371 async fn set_registration_token(
372 &mut self,
373 mut user_registration: UserRegistration,
374 user_registration_token: &UserRegistrationToken,
375 ) -> Result<UserRegistration, Self::Error> {
376 let res = sqlx::query!(
377 r#"
378 UPDATE user_registrations
379 SET user_registration_token_id = $2
380 WHERE user_registration_id = $1 AND completed_at IS NULL
381 "#,
382 Uuid::from(user_registration.id),
383 Uuid::from(user_registration_token.id),
384 )
385 .traced()
386 .execute(&mut *self.conn)
387 .await?;
388
389 DatabaseError::ensure_affected_rows(&res, 1)?;
390
391 user_registration.user_registration_token_id = Some(user_registration_token.id);
392
393 Ok(user_registration)
394 }
395
396 #[tracing::instrument(
397 name = "db.user_registration.complete",
398 skip_all,
399 fields(
400 db.query.text,
401 user_registration.id = %user_registration.id,
402 ),
403 err,
404 )]
405 async fn complete(
406 &mut self,
407 clock: &dyn Clock,
408 mut user_registration: UserRegistration,
409 ) -> Result<UserRegistration, Self::Error> {
410 let completed_at = clock.now();
411 let res = sqlx::query!(
412 r#"
413 UPDATE user_registrations
414 SET completed_at = $2
415 WHERE user_registration_id = $1 AND completed_at IS NULL
416 "#,
417 Uuid::from(user_registration.id),
418 completed_at,
419 )
420 .traced()
421 .execute(&mut *self.conn)
422 .await?;
423
424 DatabaseError::ensure_affected_rows(&res, 1)?;
425
426 user_registration.completed_at = Some(completed_at);
427
428 Ok(user_registration)
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use std::net::{IpAddr, Ipv4Addr};
435
436 use mas_data_model::{Clock, UserRegistrationPassword, clock::MockClock};
437 use rand::SeedableRng;
438 use rand_chacha::ChaChaRng;
439 use sqlx::PgPool;
440
441 use crate::PgRepository;
442
443 #[sqlx::test(migrator = "crate::MIGRATOR")]
444 async fn test_create_lookup_complete(pool: PgPool) {
445 let mut rng = ChaChaRng::seed_from_u64(42);
446 let clock = MockClock::default();
447
448 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
449
450 let registration = repo
451 .user_registration()
452 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
453 .await
454 .unwrap();
455
456 assert_eq!(registration.created_at, clock.now());
457 assert_eq!(registration.completed_at, None);
458 assert_eq!(registration.username, "alice");
459 assert_eq!(registration.display_name, None);
460 assert_eq!(registration.terms_url, None);
461 assert_eq!(registration.email_authentication_id, None);
462 assert_eq!(registration.password, None);
463 assert_eq!(registration.user_agent, None);
464 assert_eq!(registration.ip_address, None);
465 assert_eq!(registration.post_auth_action, None);
466
467 let lookup = repo
468 .user_registration()
469 .lookup(registration.id)
470 .await
471 .unwrap()
472 .unwrap();
473
474 assert_eq!(lookup.id, registration.id);
475 assert_eq!(lookup.created_at, registration.created_at);
476 assert_eq!(lookup.completed_at, registration.completed_at);
477 assert_eq!(lookup.username, registration.username);
478 assert_eq!(lookup.display_name, registration.display_name);
479 assert_eq!(lookup.terms_url, registration.terms_url);
480 assert_eq!(
481 lookup.email_authentication_id,
482 registration.email_authentication_id
483 );
484 assert_eq!(lookup.password, registration.password);
485 assert_eq!(lookup.user_agent, registration.user_agent);
486 assert_eq!(lookup.ip_address, registration.ip_address);
487 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
488
489 let registration = repo
491 .user_registration()
492 .complete(&clock, registration)
493 .await
494 .unwrap();
495 assert_eq!(registration.completed_at, Some(clock.now()));
496
497 let lookup = repo
499 .user_registration()
500 .lookup(registration.id)
501 .await
502 .unwrap()
503 .unwrap();
504 assert_eq!(lookup.completed_at, registration.completed_at);
505
506 let res = repo
508 .user_registration()
509 .complete(&clock, registration)
510 .await;
511 assert!(res.is_err());
512 }
513
514 #[sqlx::test(migrator = "crate::MIGRATOR")]
515 async fn test_create_useragent_ipaddress(pool: PgPool) {
516 let mut rng = ChaChaRng::seed_from_u64(42);
517 let clock = MockClock::default();
518
519 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
520
521 let registration = repo
522 .user_registration()
523 .add(
524 &mut rng,
525 &clock,
526 "alice".to_owned(),
527 Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
528 Some("Mozilla/5.0".to_owned()),
529 Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
530 )
531 .await
532 .unwrap();
533
534 assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
535 assert_eq!(
536 registration.ip_address,
537 Some(IpAddr::V4(Ipv4Addr::LOCALHOST))
538 );
539 assert_eq!(
540 registration.post_auth_action,
541 Some(
542 serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
543 )
544 );
545
546 let lookup = repo
547 .user_registration()
548 .lookup(registration.id)
549 .await
550 .unwrap()
551 .unwrap();
552
553 assert_eq!(lookup.user_agent, registration.user_agent);
554 assert_eq!(lookup.ip_address, registration.ip_address);
555 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
556 }
557
558 #[sqlx::test(migrator = "crate::MIGRATOR")]
559 async fn test_set_display_name(pool: PgPool) {
560 let mut rng = ChaChaRng::seed_from_u64(42);
561 let clock = MockClock::default();
562
563 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
564
565 let registration = repo
566 .user_registration()
567 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
568 .await
569 .unwrap();
570
571 assert_eq!(registration.display_name, None);
572
573 let registration = repo
574 .user_registration()
575 .set_display_name(registration, "Alice".to_owned())
576 .await
577 .unwrap();
578
579 assert_eq!(registration.display_name, Some("Alice".to_owned()));
580
581 let lookup = repo
582 .user_registration()
583 .lookup(registration.id)
584 .await
585 .unwrap()
586 .unwrap();
587
588 assert_eq!(lookup.display_name, registration.display_name);
589
590 let registration = repo
592 .user_registration()
593 .set_display_name(registration, "Bob".to_owned())
594 .await
595 .unwrap();
596
597 assert_eq!(registration.display_name, Some("Bob".to_owned()));
598
599 let lookup = repo
600 .user_registration()
601 .lookup(registration.id)
602 .await
603 .unwrap()
604 .unwrap();
605
606 assert_eq!(lookup.display_name, registration.display_name);
607
608 let registration = repo
610 .user_registration()
611 .complete(&clock, registration)
612 .await
613 .unwrap();
614
615 let res = repo
616 .user_registration()
617 .set_display_name(registration, "Charlie".to_owned())
618 .await;
619 assert!(res.is_err());
620 }
621
622 #[sqlx::test(migrator = "crate::MIGRATOR")]
623 async fn test_set_terms_url(pool: PgPool) {
624 let mut rng = ChaChaRng::seed_from_u64(42);
625 let clock = MockClock::default();
626
627 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
628
629 let registration = repo
630 .user_registration()
631 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
632 .await
633 .unwrap();
634
635 assert_eq!(registration.terms_url, None);
636
637 let registration = repo
638 .user_registration()
639 .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
640 .await
641 .unwrap();
642
643 assert_eq!(
644 registration.terms_url,
645 Some("https://example.com/terms".parse().unwrap())
646 );
647
648 let lookup = repo
649 .user_registration()
650 .lookup(registration.id)
651 .await
652 .unwrap()
653 .unwrap();
654
655 assert_eq!(lookup.terms_url, registration.terms_url);
656
657 let registration = repo
659 .user_registration()
660 .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
661 .await
662 .unwrap();
663
664 assert_eq!(
665 registration.terms_url,
666 Some("https://example.com/terms2".parse().unwrap())
667 );
668
669 let lookup = repo
670 .user_registration()
671 .lookup(registration.id)
672 .await
673 .unwrap()
674 .unwrap();
675
676 assert_eq!(lookup.terms_url, registration.terms_url);
677
678 let registration = repo
680 .user_registration()
681 .complete(&clock, registration)
682 .await
683 .unwrap();
684
685 let res = repo
686 .user_registration()
687 .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
688 .await;
689 assert!(res.is_err());
690 }
691
692 #[sqlx::test(migrator = "crate::MIGRATOR")]
693 async fn test_set_email_authentication(pool: PgPool) {
694 let mut rng = ChaChaRng::seed_from_u64(42);
695 let clock = MockClock::default();
696
697 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
698
699 let registration = repo
700 .user_registration()
701 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
702 .await
703 .unwrap();
704
705 assert_eq!(registration.email_authentication_id, None);
706
707 let authentication = repo
708 .user_email()
709 .add_authentication_for_registration(
710 &mut rng,
711 &clock,
712 "alice@example.com".to_owned(),
713 ®istration,
714 )
715 .await
716 .unwrap();
717
718 let registration = repo
719 .user_registration()
720 .set_email_authentication(registration, &authentication)
721 .await
722 .unwrap();
723
724 assert_eq!(
725 registration.email_authentication_id,
726 Some(authentication.id)
727 );
728
729 let lookup = repo
730 .user_registration()
731 .lookup(registration.id)
732 .await
733 .unwrap()
734 .unwrap();
735
736 assert_eq!(
737 lookup.email_authentication_id,
738 registration.email_authentication_id
739 );
740
741 let registration = repo
743 .user_registration()
744 .set_email_authentication(registration, &authentication)
745 .await
746 .unwrap();
747
748 assert_eq!(
749 registration.email_authentication_id,
750 Some(authentication.id)
751 );
752
753 let lookup = repo
754 .user_registration()
755 .lookup(registration.id)
756 .await
757 .unwrap()
758 .unwrap();
759
760 assert_eq!(
761 lookup.email_authentication_id,
762 registration.email_authentication_id
763 );
764
765 let registration = repo
767 .user_registration()
768 .complete(&clock, registration)
769 .await
770 .unwrap();
771
772 let res = repo
773 .user_registration()
774 .set_email_authentication(registration, &authentication)
775 .await;
776 assert!(res.is_err());
777 }
778
779 #[sqlx::test(migrator = "crate::MIGRATOR")]
780 async fn test_set_password(pool: PgPool) {
781 let mut rng = ChaChaRng::seed_from_u64(42);
782 let clock = MockClock::default();
783
784 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
785
786 let registration = repo
787 .user_registration()
788 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
789 .await
790 .unwrap();
791
792 assert_eq!(registration.password, None);
793
794 let registration = repo
795 .user_registration()
796 .set_password(registration, "fakehashedpassword".to_owned(), 1)
797 .await
798 .unwrap();
799
800 assert_eq!(
801 registration.password,
802 Some(UserRegistrationPassword {
803 hashed_password: "fakehashedpassword".to_owned(),
804 version: 1,
805 })
806 );
807
808 let lookup = repo
809 .user_registration()
810 .lookup(registration.id)
811 .await
812 .unwrap()
813 .unwrap();
814
815 assert_eq!(lookup.password, registration.password);
816
817 let registration = repo
819 .user_registration()
820 .set_password(registration, "fakehashedpassword2".to_owned(), 2)
821 .await
822 .unwrap();
823
824 assert_eq!(
825 registration.password,
826 Some(UserRegistrationPassword {
827 hashed_password: "fakehashedpassword2".to_owned(),
828 version: 2,
829 })
830 );
831
832 let lookup = repo
833 .user_registration()
834 .lookup(registration.id)
835 .await
836 .unwrap()
837 .unwrap();
838
839 assert_eq!(lookup.password, registration.password);
840
841 let registration = repo
843 .user_registration()
844 .complete(&clock, registration)
845 .await
846 .unwrap();
847
848 let res = repo
849 .user_registration()
850 .set_password(registration, "fakehashedpassword3".to_owned(), 3)
851 .await;
852 assert!(res.is_err());
853 }
854}