1mod access_token;
11mod refresh_token;
12mod session;
13mod sso_login;
14
15pub use self::{
16 access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
17 session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
18};
19
20#[cfg(test)]
21mod tests {
22 use chrono::Duration;
23 use mas_data_model::{Clock, Device, clock::MockClock};
24 use mas_storage::{
25 Pagination, RepositoryAccess,
26 compat::{
27 CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
28 CompatSessionRepository, CompatSsoLoginFilter,
29 },
30 user::UserRepository,
31 };
32 use rand::SeedableRng;
33 use rand_chacha::ChaChaRng;
34 use sqlx::PgPool;
35 use ulid::Ulid;
36
37 use crate::PgRepository;
38
39 #[sqlx::test(migrator = "crate::MIGRATOR")]
40 async fn test_session_repository(pool: PgPool) {
41 let mut rng = ChaChaRng::seed_from_u64(42);
42 let clock = MockClock::default();
43 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
44
45 let user = repo
47 .user()
48 .add(&mut rng, &clock, "john".to_owned())
49 .await
50 .unwrap();
51
52 let all = CompatSessionFilter::new().for_user(&user);
53 let active = all.active_only();
54 let finished = all.finished_only();
55 let pagination = Pagination::first(10);
56
57 assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
58 assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
59 assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
60
61 let full_list = repo.compat_session().list(all, pagination).await.unwrap();
62 assert!(full_list.edges.is_empty());
63 let active_list = repo
64 .compat_session()
65 .list(active, pagination)
66 .await
67 .unwrap();
68 assert!(active_list.edges.is_empty());
69 let finished_list = repo
70 .compat_session()
71 .list(finished, pagination)
72 .await
73 .unwrap();
74 assert!(finished_list.edges.is_empty());
75
76 let device = Device::generate(&mut rng);
78 let device_str = device.as_str().to_owned();
79 let session = repo
80 .compat_session()
81 .add(&mut rng, &clock, &user, device.clone(), None, false, None)
82 .await
83 .unwrap();
84 assert_eq!(session.user_id, user.id);
85 assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
86 assert!(session.is_valid());
87 assert!(!session.is_finished());
88
89 assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
90 assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
91 assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
92
93 let full_list = repo.compat_session().list(all, pagination).await.unwrap();
94 assert_eq!(full_list.edges.len(), 1);
95 assert_eq!(full_list.edges[0].0.id, session.id);
96 let active_list = repo
97 .compat_session()
98 .list(active, pagination)
99 .await
100 .unwrap();
101 assert_eq!(active_list.edges.len(), 1);
102 assert_eq!(active_list.edges[0].0.id, session.id);
103 let finished_list = repo
104 .compat_session()
105 .list(finished, pagination)
106 .await
107 .unwrap();
108 assert!(finished_list.edges.is_empty());
109
110 let session_lookup = repo
112 .compat_session()
113 .lookup(session.id)
114 .await
115 .unwrap()
116 .expect("compat session not found");
117 assert_eq!(session_lookup.id, session.id);
118 assert_eq!(session_lookup.user_id, user.id);
119 assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
120 assert!(session_lookup.is_valid());
121 assert!(!session_lookup.is_finished());
122
123 assert!(session_lookup.user_agent.is_none());
125 let session = repo
126 .compat_session()
127 .record_user_agent(session_lookup, "Mozilla/5.0".to_owned())
128 .await
129 .unwrap();
130 assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
131
132 let session_lookup = repo
134 .compat_session()
135 .lookup(session.id)
136 .await
137 .unwrap()
138 .expect("compat session not found");
139 assert_eq!(session_lookup.user_agent.as_deref(), Some("Mozilla/5.0"));
140
141 let list = repo
143 .compat_session()
144 .list(
145 CompatSessionFilter::new()
146 .for_user(&user)
147 .for_device(&device),
148 pagination,
149 )
150 .await
151 .unwrap();
152 assert_eq!(list.edges.len(), 1);
153 let session_lookup = &list.edges[0].0;
154 assert_eq!(session_lookup.id, session.id);
155 assert_eq!(session_lookup.user_id, user.id);
156 assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
157 assert!(session_lookup.is_valid());
158 assert!(!session_lookup.is_finished());
159
160 let session = repo.compat_session().finish(&clock, session).await.unwrap();
162 assert!(!session.is_valid());
163 assert!(session.is_finished());
164
165 assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
166 assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
167 assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
168
169 let full_list = repo.compat_session().list(all, pagination).await.unwrap();
170 assert_eq!(full_list.edges.len(), 1);
171 assert_eq!(full_list.edges[0].0.id, session.id);
172 let active_list = repo
173 .compat_session()
174 .list(active, pagination)
175 .await
176 .unwrap();
177 assert!(active_list.edges.is_empty());
178 let finished_list = repo
179 .compat_session()
180 .list(finished, pagination)
181 .await
182 .unwrap();
183 assert_eq!(finished_list.edges.len(), 1);
184 assert_eq!(finished_list.edges[0].0.id, session.id);
185
186 let session_lookup = repo
188 .compat_session()
189 .lookup(session.id)
190 .await
191 .unwrap()
192 .expect("compat session not found");
193 assert!(!session_lookup.is_valid());
194 assert!(session_lookup.is_finished());
195
196 let unknown_session = session;
198 let login = repo
200 .compat_sso_login()
201 .add(
202 &mut rng,
203 &clock,
204 "login-token".to_owned(),
205 "https://example.com/callback".parse().unwrap(),
206 )
207 .await
208 .unwrap();
209 assert!(login.is_pending());
210
211 let browser_session = repo
213 .browser_session()
214 .add(&mut rng, &clock, &user, None)
215 .await
216 .unwrap();
217
218 let device = Device::generate(&mut rng);
220 let sso_login_session = repo
221 .compat_session()
222 .add(
223 &mut rng,
224 &clock,
225 &user,
226 device,
227 Some(&browser_session),
228 false,
229 None,
230 )
231 .await
232 .unwrap();
233
234 let login = repo
236 .compat_sso_login()
237 .fulfill(&clock, login, &browser_session)
238 .await
239 .unwrap();
240 assert!(login.is_fulfilled());
241 let login = repo
242 .compat_sso_login()
243 .exchange(&clock, login, &sso_login_session)
244 .await
245 .unwrap();
246 assert!(login.is_exchanged());
247
248 let all = CompatSessionFilter::new().for_user(&user);
251 let sso_login = all.sso_login_only();
252 let unknown = all.unknown_only();
253 assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
254 assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
255 assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
256
257 let list = repo
258 .compat_session()
259 .list(sso_login, pagination)
260 .await
261 .unwrap();
262 assert_eq!(list.edges.len(), 1);
263 assert_eq!(list.edges[0].0.id, sso_login_session.id);
264 let list = repo
265 .compat_session()
266 .list(unknown, pagination)
267 .await
268 .unwrap();
269 assert_eq!(list.edges.len(), 1);
270 assert_eq!(list.edges[0].0.id, unknown_session.id);
271
272 assert_eq!(
276 repo.compat_session()
277 .count(all.sso_login_only().active_only())
278 .await
279 .unwrap(),
280 1
281 );
282 assert_eq!(
283 repo.compat_session()
284 .count(all.sso_login_only().finished_only())
285 .await
286 .unwrap(),
287 0
288 );
289 assert_eq!(
290 repo.compat_session()
291 .count(all.unknown_only().active_only())
292 .await
293 .unwrap(),
294 0
295 );
296 assert_eq!(
297 repo.compat_session()
298 .count(all.unknown_only().finished_only())
299 .await
300 .unwrap(),
301 1
302 );
303
304 let affected = repo
306 .compat_session()
307 .finish_bulk(&clock, all.sso_login_only().active_only())
308 .await
309 .unwrap();
310 assert_eq!(affected, 1);
311 assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2);
312 assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
313 }
314
315 #[sqlx::test(migrator = "crate::MIGRATOR")]
316 async fn test_access_token_repository(pool: PgPool) {
317 const FIRST_TOKEN: &str = "first_access_token";
318 const SECOND_TOKEN: &str = "second_access_token";
319 let mut rng = ChaChaRng::seed_from_u64(42);
320 let clock = MockClock::default();
321 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
322
323 let user = repo
325 .user()
326 .add(&mut rng, &clock, "john".to_owned())
327 .await
328 .unwrap();
329
330 let device = Device::generate(&mut rng);
332 let session = repo
333 .compat_session()
334 .add(&mut rng, &clock, &user, device, None, false, None)
335 .await
336 .unwrap();
337
338 let token = repo
340 .compat_access_token()
341 .add(
342 &mut rng,
343 &clock,
344 &session,
345 FIRST_TOKEN.to_owned(),
346 Some(Duration::try_minutes(1).unwrap()),
347 )
348 .await
349 .unwrap();
350 assert_eq!(token.session_id, session.id);
351 assert_eq!(token.token, FIRST_TOKEN);
352
353 repo.save().await.unwrap();
355
356 {
357 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
358 assert!(
360 repo.compat_access_token()
361 .add(
362 &mut rng,
363 &clock,
364 &session,
365 FIRST_TOKEN.to_owned(),
366 Some(Duration::try_minutes(1).unwrap()),
367 )
368 .await
369 .is_err()
370 );
371 repo.cancel().await.unwrap();
372 }
373
374 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
376
377 let token_lookup = repo
379 .compat_access_token()
380 .lookup(token.id)
381 .await
382 .unwrap()
383 .expect("compat access token not found");
384 assert_eq!(token.id, token_lookup.id);
385 assert_eq!(token_lookup.session_id, session.id);
386
387 let token_lookup = repo
389 .compat_access_token()
390 .find_by_token(FIRST_TOKEN)
391 .await
392 .unwrap()
393 .expect("compat access token not found");
394 assert_eq!(token.id, token_lookup.id);
395 assert_eq!(token_lookup.session_id, session.id);
396
397 assert!(token.is_valid(clock.now()));
399
400 clock.advance(Duration::try_minutes(1).unwrap());
401 assert!(!token.is_valid(clock.now()));
403
404 let token = repo
406 .compat_access_token()
407 .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
408 .await
409 .unwrap();
410 assert_eq!(token.session_id, session.id);
411 assert_eq!(token.token, SECOND_TOKEN);
412
413 assert!(token.is_valid(clock.now()));
415
416 repo.compat_access_token()
418 .expire(&clock, token)
419 .await
420 .unwrap();
421
422 let token = repo
424 .compat_access_token()
425 .find_by_token(SECOND_TOKEN)
426 .await
427 .unwrap()
428 .expect("compat access token not found");
429
430 assert!(!token.is_valid(clock.now()));
432
433 repo.save().await.unwrap();
434 }
435
436 #[sqlx::test(migrator = "crate::MIGRATOR")]
437 async fn test_refresh_token_repository(pool: PgPool) {
438 const ACCESS_TOKEN: &str = "access_token";
439 const REFRESH_TOKEN: &str = "refresh_token";
440 let mut rng = ChaChaRng::seed_from_u64(42);
441 let clock = MockClock::default();
442 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
443
444 let user = repo
446 .user()
447 .add(&mut rng, &clock, "john".to_owned())
448 .await
449 .unwrap();
450
451 let device = Device::generate(&mut rng);
453 let session = repo
454 .compat_session()
455 .add(&mut rng, &clock, &user, device, None, false, None)
456 .await
457 .unwrap();
458
459 let access_token = repo
461 .compat_access_token()
462 .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
463 .await
464 .unwrap();
465
466 let refresh_token = repo
467 .compat_refresh_token()
468 .add(
469 &mut rng,
470 &clock,
471 &session,
472 &access_token,
473 REFRESH_TOKEN.to_owned(),
474 )
475 .await
476 .unwrap();
477 assert_eq!(refresh_token.session_id, session.id);
478 assert_eq!(refresh_token.access_token_id, access_token.id);
479 assert_eq!(refresh_token.token, REFRESH_TOKEN);
480 assert!(refresh_token.is_valid());
481 assert!(!refresh_token.is_consumed());
482
483 let refresh_token_lookup = repo
485 .compat_refresh_token()
486 .lookup(refresh_token.id)
487 .await
488 .unwrap()
489 .expect("refresh token not found");
490 assert_eq!(refresh_token_lookup.id, refresh_token.id);
491 assert_eq!(refresh_token_lookup.session_id, session.id);
492 assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
493 assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
494 assert!(refresh_token_lookup.is_valid());
495 assert!(!refresh_token_lookup.is_consumed());
496
497 let refresh_token_lookup = repo
499 .compat_refresh_token()
500 .find_by_token(REFRESH_TOKEN)
501 .await
502 .unwrap()
503 .expect("refresh token not found");
504 assert_eq!(refresh_token_lookup.id, refresh_token.id);
505 assert_eq!(refresh_token_lookup.session_id, session.id);
506 assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
507 assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
508 assert!(refresh_token_lookup.is_valid());
509 assert!(!refresh_token_lookup.is_consumed());
510
511 let refresh_token = repo
513 .compat_refresh_token()
514 .consume(&clock, refresh_token)
515 .await
516 .unwrap();
517 assert!(!refresh_token.is_valid());
518 assert!(refresh_token.is_consumed());
519
520 let refresh_token_lookup = repo
522 .compat_refresh_token()
523 .find_by_token(REFRESH_TOKEN)
524 .await
525 .unwrap()
526 .expect("refresh token not found");
527 assert!(!refresh_token_lookup.is_valid());
528 assert!(refresh_token_lookup.is_consumed());
529
530 assert!(
532 repo.compat_refresh_token()
533 .consume(&clock, refresh_token)
534 .await
535 .is_err()
536 );
537
538 repo.save().await.unwrap();
539 }
540
541 #[sqlx::test(migrator = "crate::MIGRATOR")]
542 async fn test_compat_sso_login_repository(pool: PgPool) {
543 let mut rng = ChaChaRng::seed_from_u64(42);
544 let clock = MockClock::default();
545 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
546
547 let user = repo
549 .user()
550 .add(&mut rng, &clock, "john".to_owned())
551 .await
552 .unwrap();
553
554 let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
556 assert_eq!(login, None);
557
558 let all = CompatSsoLoginFilter::new();
559 let for_user = all.for_user(&user);
560 let pending = all.pending_only();
561 let fulfilled = all.fulfilled_only();
562 let exchanged = all.exchanged_only();
563
564 assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
566 assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
567 assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
568 assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
569 assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
570
571 let login = repo
573 .compat_sso_login()
574 .find_by_token("login-token")
575 .await
576 .unwrap();
577 assert_eq!(login, None);
578
579 let login = repo
581 .compat_sso_login()
582 .add(
583 &mut rng,
584 &clock,
585 "login-token".to_owned(),
586 "https://example.com/callback".parse().unwrap(),
587 )
588 .await
589 .unwrap();
590 assert!(login.is_pending());
591
592 assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
594 assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
595 assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
596 assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
597 assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
598
599 let login_lookup = repo
601 .compat_sso_login()
602 .lookup(login.id)
603 .await
604 .unwrap()
605 .expect("login not found");
606 assert_eq!(login_lookup, login);
607
608 let login_lookup = repo
610 .compat_sso_login()
611 .find_by_token("login-token")
612 .await
613 .unwrap()
614 .expect("login not found");
615 assert_eq!(login_lookup, login);
616
617 let device = Device::generate(&mut rng);
619 let compat_session = repo
620 .compat_session()
621 .add(&mut rng, &clock, &user, device, None, false, None)
622 .await
623 .unwrap();
624
625 let res = repo
628 .compat_sso_login()
629 .exchange(&clock, login.clone(), &compat_session)
630 .await;
631 assert!(res.is_err());
632
633 let browser_session = repo
635 .browser_session()
636 .add(&mut rng, &clock, &user, None)
637 .await
638 .unwrap();
639
640 let login = repo
642 .compat_sso_login()
643 .fulfill(&clock, login, &browser_session)
644 .await
645 .unwrap();
646 assert!(login.is_fulfilled());
647
648 assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
650 assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
651 assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
652 assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
653 assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
654
655 let res = repo
658 .compat_sso_login()
659 .fulfill(&clock, login.clone(), &browser_session)
660 .await;
661 assert!(res.is_err());
662
663 let login = repo
665 .compat_sso_login()
666 .exchange(&clock, login, &compat_session)
667 .await
668 .unwrap();
669 assert!(login.is_exchanged());
670
671 assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
673 assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
674 assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
675 assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
676 assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
677
678 let res = repo
681 .compat_sso_login()
682 .exchange(&clock, login.clone(), &compat_session)
683 .await;
684 assert!(res.is_err());
685
686 let res = repo
689 .compat_sso_login()
690 .fulfill(&clock, login.clone(), &browser_session)
691 .await;
692 assert!(res.is_err());
693
694 let pagination = Pagination::first(10);
695
696 let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
698 assert!(!logins.has_next_page);
699 assert_eq!(logins.edges, vec![login.clone()]);
700
701 let logins = repo
703 .compat_sso_login()
704 .list(for_user, pagination)
705 .await
706 .unwrap();
707 assert!(!logins.has_next_page);
708 assert_eq!(logins.edges, vec![login.clone()]);
709
710 let logins = repo
712 .compat_sso_login()
713 .list(for_user.pending_only(), pagination)
714 .await
715 .unwrap();
716 assert!(!logins.has_next_page);
717 assert!(logins.edges.is_empty());
718
719 let logins = repo
721 .compat_sso_login()
722 .list(for_user.fulfilled_only(), pagination)
723 .await
724 .unwrap();
725 assert!(!logins.has_next_page);
726 assert!(logins.edges.is_empty());
727
728 let logins = repo
730 .compat_sso_login()
731 .list(for_user.exchanged_only(), pagination)
732 .await
733 .unwrap();
734 assert!(!logins.has_next_page);
735 assert_eq!(logins.edges, &[login]);
736 }
737}