mas_storage_pg/compat/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing PostgreSQL implementation of repositories for the
8//! compatibility layer
9
10mod access_token;
11mod refresh_token;
12mod session;
13mod sso_login;
14
15pub use self::{
16    access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
17    session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
18};
19
20#[cfg(test)]
21mod tests {
22    use chrono::Duration;
23    use mas_data_model::{Clock, Device, clock::MockClock};
24    use mas_storage::{
25        Pagination, RepositoryAccess,
26        compat::{
27            CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
28            CompatSessionRepository, CompatSsoLoginFilter,
29        },
30        user::UserRepository,
31    };
32    use rand::SeedableRng;
33    use rand_chacha::ChaChaRng;
34    use sqlx::PgPool;
35    use ulid::Ulid;
36
37    use crate::PgRepository;
38
39    #[sqlx::test(migrator = "crate::MIGRATOR")]
40    async fn test_session_repository(pool: PgPool) {
41        let mut rng = ChaChaRng::seed_from_u64(42);
42        let clock = MockClock::default();
43        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
44
45        // Create a user
46        let user = repo
47            .user()
48            .add(&mut rng, &clock, "john".to_owned())
49            .await
50            .unwrap();
51
52        let all = CompatSessionFilter::new().for_user(&user);
53        let active = all.active_only();
54        let finished = all.finished_only();
55        let pagination = Pagination::first(10);
56
57        assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
58        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
59        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
60
61        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
62        assert!(full_list.edges.is_empty());
63        let active_list = repo
64            .compat_session()
65            .list(active, pagination)
66            .await
67            .unwrap();
68        assert!(active_list.edges.is_empty());
69        let finished_list = repo
70            .compat_session()
71            .list(finished, pagination)
72            .await
73            .unwrap();
74        assert!(finished_list.edges.is_empty());
75
76        // Start a compat session for that user
77        let device = Device::generate(&mut rng);
78        let device_str = device.as_str().to_owned();
79        let session = repo
80            .compat_session()
81            .add(&mut rng, &clock, &user, device.clone(), None, false, None)
82            .await
83            .unwrap();
84        assert_eq!(session.user_id, user.id);
85        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
86        assert!(session.is_valid());
87        assert!(!session.is_finished());
88
89        assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
90        assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
91        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
92
93        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
94        assert_eq!(full_list.edges.len(), 1);
95        assert_eq!(full_list.edges[0].0.id, session.id);
96        let active_list = repo
97            .compat_session()
98            .list(active, pagination)
99            .await
100            .unwrap();
101        assert_eq!(active_list.edges.len(), 1);
102        assert_eq!(active_list.edges[0].0.id, session.id);
103        let finished_list = repo
104            .compat_session()
105            .list(finished, pagination)
106            .await
107            .unwrap();
108        assert!(finished_list.edges.is_empty());
109
110        // Lookup the session and check it didn't change
111        let session_lookup = repo
112            .compat_session()
113            .lookup(session.id)
114            .await
115            .unwrap()
116            .expect("compat session not found");
117        assert_eq!(session_lookup.id, session.id);
118        assert_eq!(session_lookup.user_id, user.id);
119        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
120        assert!(session_lookup.is_valid());
121        assert!(!session_lookup.is_finished());
122
123        // Record a user-agent for the session
124        assert!(session_lookup.user_agent.is_none());
125        let session = repo
126            .compat_session()
127            .record_user_agent(session_lookup, "Mozilla/5.0".to_owned())
128            .await
129            .unwrap();
130        assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
131
132        // Reload the session and check again
133        let session_lookup = repo
134            .compat_session()
135            .lookup(session.id)
136            .await
137            .unwrap()
138            .expect("compat session not found");
139        assert_eq!(session_lookup.user_agent.as_deref(), Some("Mozilla/5.0"));
140
141        // Look up the session by device
142        let list = repo
143            .compat_session()
144            .list(
145                CompatSessionFilter::new()
146                    .for_user(&user)
147                    .for_device(&device),
148                pagination,
149            )
150            .await
151            .unwrap();
152        assert_eq!(list.edges.len(), 1);
153        let session_lookup = &list.edges[0].0;
154        assert_eq!(session_lookup.id, session.id);
155        assert_eq!(session_lookup.user_id, user.id);
156        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
157        assert!(session_lookup.is_valid());
158        assert!(!session_lookup.is_finished());
159
160        // Finish the session
161        let session = repo.compat_session().finish(&clock, session).await.unwrap();
162        assert!(!session.is_valid());
163        assert!(session.is_finished());
164
165        assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
166        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
167        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
168
169        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
170        assert_eq!(full_list.edges.len(), 1);
171        assert_eq!(full_list.edges[0].0.id, session.id);
172        let active_list = repo
173            .compat_session()
174            .list(active, pagination)
175            .await
176            .unwrap();
177        assert!(active_list.edges.is_empty());
178        let finished_list = repo
179            .compat_session()
180            .list(finished, pagination)
181            .await
182            .unwrap();
183        assert_eq!(finished_list.edges.len(), 1);
184        assert_eq!(finished_list.edges[0].0.id, session.id);
185
186        // Reload the session and check again
187        let session_lookup = repo
188            .compat_session()
189            .lookup(session.id)
190            .await
191            .unwrap()
192            .expect("compat session not found");
193        assert!(!session_lookup.is_valid());
194        assert!(session_lookup.is_finished());
195
196        // Now add another session, with an SSO login this time
197        let unknown_session = session;
198        // Start a new SSO login
199        let login = repo
200            .compat_sso_login()
201            .add(
202                &mut rng,
203                &clock,
204                "login-token".to_owned(),
205                "https://example.com/callback".parse().unwrap(),
206            )
207            .await
208            .unwrap();
209        assert!(login.is_pending());
210
211        // Start a browser session for the user
212        let browser_session = repo
213            .browser_session()
214            .add(&mut rng, &clock, &user, None)
215            .await
216            .unwrap();
217
218        // Start a compat session for that user
219        let device = Device::generate(&mut rng);
220        let sso_login_session = repo
221            .compat_session()
222            .add(
223                &mut rng,
224                &clock,
225                &user,
226                device,
227                Some(&browser_session),
228                false,
229                None,
230            )
231            .await
232            .unwrap();
233
234        // Associate the login with the session
235        let login = repo
236            .compat_sso_login()
237            .fulfill(&clock, login, &browser_session)
238            .await
239            .unwrap();
240        assert!(login.is_fulfilled());
241        let login = repo
242            .compat_sso_login()
243            .exchange(&clock, login, &sso_login_session)
244            .await
245            .unwrap();
246        assert!(login.is_exchanged());
247
248        // Now query the session list with both the unknown and SSO login session type
249        // filter
250        let all = CompatSessionFilter::new().for_user(&user);
251        let sso_login = all.sso_login_only();
252        let unknown = all.unknown_only();
253        assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
254        assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
255        assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
256
257        let list = repo
258            .compat_session()
259            .list(sso_login, pagination)
260            .await
261            .unwrap();
262        assert_eq!(list.edges.len(), 1);
263        assert_eq!(list.edges[0].0.id, sso_login_session.id);
264        let list = repo
265            .compat_session()
266            .list(unknown, pagination)
267            .await
268            .unwrap();
269        assert_eq!(list.edges.len(), 1);
270        assert_eq!(list.edges[0].0.id, unknown_session.id);
271
272        // Check that combining the two filters works
273        // At this point, there is one active SSO login session and one finished unknown
274        // session
275        assert_eq!(
276            repo.compat_session()
277                .count(all.sso_login_only().active_only())
278                .await
279                .unwrap(),
280            1
281        );
282        assert_eq!(
283            repo.compat_session()
284                .count(all.sso_login_only().finished_only())
285                .await
286                .unwrap(),
287            0
288        );
289        assert_eq!(
290            repo.compat_session()
291                .count(all.unknown_only().active_only())
292                .await
293                .unwrap(),
294            0
295        );
296        assert_eq!(
297            repo.compat_session()
298                .count(all.unknown_only().finished_only())
299                .await
300                .unwrap(),
301            1
302        );
303
304        // Check that we can batch finish sessions
305        let affected = repo
306            .compat_session()
307            .finish_bulk(&clock, all.sso_login_only().active_only())
308            .await
309            .unwrap();
310        assert_eq!(affected, 1);
311        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2);
312        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
313    }
314
315    #[sqlx::test(migrator = "crate::MIGRATOR")]
316    async fn test_access_token_repository(pool: PgPool) {
317        const FIRST_TOKEN: &str = "first_access_token";
318        const SECOND_TOKEN: &str = "second_access_token";
319        let mut rng = ChaChaRng::seed_from_u64(42);
320        let clock = MockClock::default();
321        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
322
323        // Create a user
324        let user = repo
325            .user()
326            .add(&mut rng, &clock, "john".to_owned())
327            .await
328            .unwrap();
329
330        // Start a compat session for that user
331        let device = Device::generate(&mut rng);
332        let session = repo
333            .compat_session()
334            .add(&mut rng, &clock, &user, device, None, false, None)
335            .await
336            .unwrap();
337
338        // Add an access token to that session
339        let token = repo
340            .compat_access_token()
341            .add(
342                &mut rng,
343                &clock,
344                &session,
345                FIRST_TOKEN.to_owned(),
346                Some(Duration::try_minutes(1).unwrap()),
347            )
348            .await
349            .unwrap();
350        assert_eq!(token.session_id, session.id);
351        assert_eq!(token.token, FIRST_TOKEN);
352
353        // Commit the txn and grab a new transaction, to test a conflict
354        repo.save().await.unwrap();
355
356        {
357            let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
358            // Adding the same token a second time should conflict
359            assert!(
360                repo.compat_access_token()
361                    .add(
362                        &mut rng,
363                        &clock,
364                        &session,
365                        FIRST_TOKEN.to_owned(),
366                        Some(Duration::try_minutes(1).unwrap()),
367                    )
368                    .await
369                    .is_err()
370            );
371            repo.cancel().await.unwrap();
372        }
373
374        // Grab a new repo
375        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
376
377        // Looking up via ID works
378        let token_lookup = repo
379            .compat_access_token()
380            .lookup(token.id)
381            .await
382            .unwrap()
383            .expect("compat access token not found");
384        assert_eq!(token.id, token_lookup.id);
385        assert_eq!(token_lookup.session_id, session.id);
386
387        // Looking up via the token value works
388        let token_lookup = repo
389            .compat_access_token()
390            .find_by_token(FIRST_TOKEN)
391            .await
392            .unwrap()
393            .expect("compat access token not found");
394        assert_eq!(token.id, token_lookup.id);
395        assert_eq!(token_lookup.session_id, session.id);
396
397        // Token is currently valid
398        assert!(token.is_valid(clock.now()));
399
400        clock.advance(Duration::try_minutes(1).unwrap());
401        // Token should have expired
402        assert!(!token.is_valid(clock.now()));
403
404        // Add a second access token, this time without expiration
405        let token = repo
406            .compat_access_token()
407            .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
408            .await
409            .unwrap();
410        assert_eq!(token.session_id, session.id);
411        assert_eq!(token.token, SECOND_TOKEN);
412
413        // Token is currently valid
414        assert!(token.is_valid(clock.now()));
415
416        // Make it expire
417        repo.compat_access_token()
418            .expire(&clock, token)
419            .await
420            .unwrap();
421
422        // Reload it
423        let token = repo
424            .compat_access_token()
425            .find_by_token(SECOND_TOKEN)
426            .await
427            .unwrap()
428            .expect("compat access token not found");
429
430        // Token is not valid anymore
431        assert!(!token.is_valid(clock.now()));
432
433        repo.save().await.unwrap();
434    }
435
436    #[sqlx::test(migrator = "crate::MIGRATOR")]
437    async fn test_refresh_token_repository(pool: PgPool) {
438        const ACCESS_TOKEN: &str = "access_token";
439        const REFRESH_TOKEN: &str = "refresh_token";
440        let mut rng = ChaChaRng::seed_from_u64(42);
441        let clock = MockClock::default();
442        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
443
444        // Create a user
445        let user = repo
446            .user()
447            .add(&mut rng, &clock, "john".to_owned())
448            .await
449            .unwrap();
450
451        // Start a compat session for that user
452        let device = Device::generate(&mut rng);
453        let session = repo
454            .compat_session()
455            .add(&mut rng, &clock, &user, device, None, false, None)
456            .await
457            .unwrap();
458
459        // Add an access token to that session
460        let access_token = repo
461            .compat_access_token()
462            .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
463            .await
464            .unwrap();
465
466        let refresh_token = repo
467            .compat_refresh_token()
468            .add(
469                &mut rng,
470                &clock,
471                &session,
472                &access_token,
473                REFRESH_TOKEN.to_owned(),
474            )
475            .await
476            .unwrap();
477        assert_eq!(refresh_token.session_id, session.id);
478        assert_eq!(refresh_token.access_token_id, access_token.id);
479        assert_eq!(refresh_token.token, REFRESH_TOKEN);
480        assert!(refresh_token.is_valid());
481        assert!(!refresh_token.is_consumed());
482
483        // Look it up by ID and check everything matches
484        let refresh_token_lookup = repo
485            .compat_refresh_token()
486            .lookup(refresh_token.id)
487            .await
488            .unwrap()
489            .expect("refresh token not found");
490        assert_eq!(refresh_token_lookup.id, refresh_token.id);
491        assert_eq!(refresh_token_lookup.session_id, session.id);
492        assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
493        assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
494        assert!(refresh_token_lookup.is_valid());
495        assert!(!refresh_token_lookup.is_consumed());
496
497        // Look it up by token and check everything matches
498        let refresh_token_lookup = repo
499            .compat_refresh_token()
500            .find_by_token(REFRESH_TOKEN)
501            .await
502            .unwrap()
503            .expect("refresh token not found");
504        assert_eq!(refresh_token_lookup.id, refresh_token.id);
505        assert_eq!(refresh_token_lookup.session_id, session.id);
506        assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
507        assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
508        assert!(refresh_token_lookup.is_valid());
509        assert!(!refresh_token_lookup.is_consumed());
510
511        // Consume it
512        let refresh_token = repo
513            .compat_refresh_token()
514            .consume(&clock, refresh_token)
515            .await
516            .unwrap();
517        assert!(!refresh_token.is_valid());
518        assert!(refresh_token.is_consumed());
519
520        // Reload it and check again
521        let refresh_token_lookup = repo
522            .compat_refresh_token()
523            .find_by_token(REFRESH_TOKEN)
524            .await
525            .unwrap()
526            .expect("refresh token not found");
527        assert!(!refresh_token_lookup.is_valid());
528        assert!(refresh_token_lookup.is_consumed());
529
530        // Consuming it again should not work
531        assert!(
532            repo.compat_refresh_token()
533                .consume(&clock, refresh_token)
534                .await
535                .is_err()
536        );
537
538        repo.save().await.unwrap();
539    }
540
541    #[sqlx::test(migrator = "crate::MIGRATOR")]
542    async fn test_compat_sso_login_repository(pool: PgPool) {
543        let mut rng = ChaChaRng::seed_from_u64(42);
544        let clock = MockClock::default();
545        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
546
547        // Create a user
548        let user = repo
549            .user()
550            .add(&mut rng, &clock, "john".to_owned())
551            .await
552            .unwrap();
553
554        // Lookup an unknown SSO login
555        let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
556        assert_eq!(login, None);
557
558        let all = CompatSsoLoginFilter::new();
559        let for_user = all.for_user(&user);
560        let pending = all.pending_only();
561        let fulfilled = all.fulfilled_only();
562        let exchanged = all.exchanged_only();
563
564        // Check the initial counts
565        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
566        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
567        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
568        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
569        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
570
571        // Lookup an unknown login token
572        let login = repo
573            .compat_sso_login()
574            .find_by_token("login-token")
575            .await
576            .unwrap();
577        assert_eq!(login, None);
578
579        // Start a new SSO login
580        let login = repo
581            .compat_sso_login()
582            .add(
583                &mut rng,
584                &clock,
585                "login-token".to_owned(),
586                "https://example.com/callback".parse().unwrap(),
587            )
588            .await
589            .unwrap();
590        assert!(login.is_pending());
591
592        // Check the counts
593        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
594        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
595        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
596        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
597        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
598
599        // Lookup the login by ID
600        let login_lookup = repo
601            .compat_sso_login()
602            .lookup(login.id)
603            .await
604            .unwrap()
605            .expect("login not found");
606        assert_eq!(login_lookup, login);
607
608        // Find the login by token
609        let login_lookup = repo
610            .compat_sso_login()
611            .find_by_token("login-token")
612            .await
613            .unwrap()
614            .expect("login not found");
615        assert_eq!(login_lookup, login);
616
617        // Start a compat session for that user
618        let device = Device::generate(&mut rng);
619        let compat_session = repo
620            .compat_session()
621            .add(&mut rng, &clock, &user, device, None, false, None)
622            .await
623            .unwrap();
624
625        // Exchanging before fulfilling should not work
626        // Note: It should also not poison the SQL transaction
627        let res = repo
628            .compat_sso_login()
629            .exchange(&clock, login.clone(), &compat_session)
630            .await;
631        assert!(res.is_err());
632
633        // Start a browser session for that user
634        let browser_session = repo
635            .browser_session()
636            .add(&mut rng, &clock, &user, None)
637            .await
638            .unwrap();
639
640        // Associate the login with the session
641        let login = repo
642            .compat_sso_login()
643            .fulfill(&clock, login, &browser_session)
644            .await
645            .unwrap();
646        assert!(login.is_fulfilled());
647
648        // Check the counts
649        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
650        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
651        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
652        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
653        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
654
655        // Fulfilling again should not work
656        // Note: It should also not poison the SQL transaction
657        let res = repo
658            .compat_sso_login()
659            .fulfill(&clock, login.clone(), &browser_session)
660            .await;
661        assert!(res.is_err());
662
663        // Exchange that login
664        let login = repo
665            .compat_sso_login()
666            .exchange(&clock, login, &compat_session)
667            .await
668            .unwrap();
669        assert!(login.is_exchanged());
670
671        // Check the counts
672        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
673        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
674        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
675        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
676        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
677
678        // Exchange again should not work
679        // Note: It should also not poison the SQL transaction
680        let res = repo
681            .compat_sso_login()
682            .exchange(&clock, login.clone(), &compat_session)
683            .await;
684        assert!(res.is_err());
685
686        // Fulfilling after exchanging should not work
687        // Note: It should also not poison the SQL transaction
688        let res = repo
689            .compat_sso_login()
690            .fulfill(&clock, login.clone(), &browser_session)
691            .await;
692        assert!(res.is_err());
693
694        let pagination = Pagination::first(10);
695
696        // List all logins
697        let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
698        assert!(!logins.has_next_page);
699        assert_eq!(logins.edges, vec![login.clone()]);
700
701        // List the logins for the user
702        let logins = repo
703            .compat_sso_login()
704            .list(for_user, pagination)
705            .await
706            .unwrap();
707        assert!(!logins.has_next_page);
708        assert_eq!(logins.edges, vec![login.clone()]);
709
710        // List only the pending logins for the user
711        let logins = repo
712            .compat_sso_login()
713            .list(for_user.pending_only(), pagination)
714            .await
715            .unwrap();
716        assert!(!logins.has_next_page);
717        assert!(logins.edges.is_empty());
718
719        // List only the fulfilled logins for the user
720        let logins = repo
721            .compat_sso_login()
722            .list(for_user.fulfilled_only(), pagination)
723            .await
724            .unwrap();
725        assert!(!logins.has_next_page);
726        assert!(logins.edges.is_empty());
727
728        // List only the exchanged logins for the user
729        let logins = repo
730            .compat_sso_login()
731            .list(for_user.exchanged_only(), pagination)
732            .await
733            .unwrap();
734        assert!(!logins.has_next_page);
735        assert_eq!(logins.edges, &[login]);
736    }
737}