mas_storage_pg/upstream_oauth2/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the repositories
8//! related to the upstream OAuth 2.0 providers
9
10mod link;
11mod provider;
12mod session;
13
14pub use self::{
15    link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16    session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21    use chrono::Duration;
22    use mas_data_model::{
23        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
24        UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
25    };
26    use mas_iana::jose::JsonWebSignatureAlg;
27    use mas_storage::{
28        Pagination, RepositoryAccess,
29        upstream_oauth2::{
30            UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31            UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32            UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
33        },
34        user::UserRepository,
35    };
36    use oauth2_types::scope::{OPENID, Scope};
37    use rand::SeedableRng;
38    use sqlx::PgPool;
39
40    use crate::PgRepository;
41
42    #[sqlx::test(migrator = "crate::MIGRATOR")]
43    async fn test_repository(pool: PgPool) {
44        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45        let clock = MockClock::default();
46        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48        // The provider list should be empty at the start
49        let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50        assert!(all_providers.is_empty());
51
52        // Let's add a provider
53        let provider = repo
54            .upstream_oauth_provider()
55            .add(
56                &mut rng,
57                &clock,
58                UpstreamOAuthProviderParams {
59                    issuer: Some("https://example.com/".to_owned()),
60                    human_name: None,
61                    brand_name: None,
62                    scope: Scope::from_iter([OPENID]),
63                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65                    fetch_userinfo: false,
66                    userinfo_signed_response_alg: None,
67                    token_endpoint_signing_alg: None,
68                    client_id: "client-id".to_owned(),
69                    encrypted_client_secret: None,
70                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71                    token_endpoint_override: None,
72                    authorization_endpoint_override: None,
73                    userinfo_endpoint_override: None,
74                    jwks_uri_override: None,
75                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77                    response_mode: None,
78                    additional_authorization_parameters: Vec::new(),
79                    forward_login_hint: false,
80                    ui_order: 0,
81                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
82                },
83            )
84            .await
85            .unwrap();
86
87        // Look it up in the database
88        let provider = repo
89            .upstream_oauth_provider()
90            .lookup(provider.id)
91            .await
92            .unwrap()
93            .expect("provider to be found in the database");
94        assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
95        assert_eq!(provider.client_id, "client-id");
96
97        // It should be in the list of all providers
98        let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
99        assert_eq!(providers.len(), 1);
100        assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
101        assert_eq!(providers[0].client_id, "client-id");
102
103        // Start a session
104        let session = repo
105            .upstream_oauth_session()
106            .add(
107                &mut rng,
108                &clock,
109                &provider,
110                "some-state".to_owned(),
111                None,
112                Some("some-nonce".to_owned()),
113            )
114            .await
115            .unwrap();
116
117        // Look it up in the database
118        let session = repo
119            .upstream_oauth_session()
120            .lookup(session.id)
121            .await
122            .unwrap()
123            .expect("session to be found in the database");
124        assert_eq!(session.provider_id, provider.id);
125        assert_eq!(session.link_id(), None);
126        assert!(session.is_pending());
127        assert!(!session.is_completed());
128        assert!(!session.is_consumed());
129
130        // Create a link
131        let link = repo
132            .upstream_oauth_link()
133            .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
134            .await
135            .unwrap();
136
137        // We can look it up by its ID
138        repo.upstream_oauth_link()
139            .lookup(link.id)
140            .await
141            .unwrap()
142            .expect("link to be found in database");
143
144        // or by its subject
145        let link = repo
146            .upstream_oauth_link()
147            .find_by_subject(&provider, "a-subject")
148            .await
149            .unwrap()
150            .expect("link to be found in database");
151        assert_eq!(link.subject, "a-subject");
152        assert_eq!(link.provider_id, provider.id);
153
154        let session = repo
155            .upstream_oauth_session()
156            .complete_with_link(&clock, session, &link, None, None, None, None)
157            .await
158            .unwrap();
159        // Reload the session
160        let session = repo
161            .upstream_oauth_session()
162            .lookup(session.id)
163            .await
164            .unwrap()
165            .expect("session to be found in the database");
166        assert!(session.is_completed());
167        assert!(!session.is_consumed());
168        assert_eq!(session.link_id(), Some(link.id));
169
170        let session = repo
171            .upstream_oauth_session()
172            .consume(&clock, session)
173            .await
174            .unwrap();
175        // Reload the session
176        let session = repo
177            .upstream_oauth_session()
178            .lookup(session.id)
179            .await
180            .unwrap()
181            .expect("session to be found in the database");
182        assert!(session.is_consumed());
183
184        let user = repo
185            .user()
186            .add(&mut rng, &clock, "john".to_owned())
187            .await
188            .unwrap();
189        repo.upstream_oauth_link()
190            .associate_to_user(&link, &user)
191            .await
192            .unwrap();
193
194        // XXX: we should also try other combinations of the filter
195        let filter = UpstreamOAuthLinkFilter::new()
196            .for_user(&user)
197            .for_provider(&provider)
198            .for_subject("a-subject")
199            .enabled_providers_only();
200
201        let links = repo
202            .upstream_oauth_link()
203            .list(filter, Pagination::first(10))
204            .await
205            .unwrap();
206        assert!(!links.has_previous_page);
207        assert!(!links.has_next_page);
208        assert_eq!(links.edges.len(), 1);
209        assert_eq!(links.edges[0].id, link.id);
210        assert_eq!(links.edges[0].user_id, Some(user.id));
211
212        assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
213
214        // There should be exactly one enabled provider
215        assert_eq!(
216            repo.upstream_oauth_provider()
217                .count(UpstreamOAuthProviderFilter::new())
218                .await
219                .unwrap(),
220            1
221        );
222        assert_eq!(
223            repo.upstream_oauth_provider()
224                .count(UpstreamOAuthProviderFilter::new().enabled_only())
225                .await
226                .unwrap(),
227            1
228        );
229        assert_eq!(
230            repo.upstream_oauth_provider()
231                .count(UpstreamOAuthProviderFilter::new().disabled_only())
232                .await
233                .unwrap(),
234            0
235        );
236
237        // Disable the provider
238        repo.upstream_oauth_provider()
239            .disable(&clock, provider.clone())
240            .await
241            .unwrap();
242
243        // There should be exactly one disabled provider
244        assert_eq!(
245            repo.upstream_oauth_provider()
246                .count(UpstreamOAuthProviderFilter::new())
247                .await
248                .unwrap(),
249            1
250        );
251        assert_eq!(
252            repo.upstream_oauth_provider()
253                .count(UpstreamOAuthProviderFilter::new().enabled_only())
254                .await
255                .unwrap(),
256            0
257        );
258        assert_eq!(
259            repo.upstream_oauth_provider()
260                .count(UpstreamOAuthProviderFilter::new().disabled_only())
261                .await
262                .unwrap(),
263            1
264        );
265
266        // Test listing and counting sessions
267        let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
268
269        // Count the sessions for the provider
270        let session_count = repo
271            .upstream_oauth_session()
272            .count(session_filter)
273            .await
274            .unwrap();
275        assert_eq!(session_count, 1);
276
277        // List the sessions for the provider
278        let session_page = repo
279            .upstream_oauth_session()
280            .list(session_filter, Pagination::first(10))
281            .await
282            .unwrap();
283
284        assert_eq!(session_page.edges.len(), 1);
285        assert_eq!(session_page.edges[0].id, session.id);
286        assert!(!session_page.has_next_page);
287        assert!(!session_page.has_previous_page);
288
289        // Try deleting the provider
290        repo.upstream_oauth_provider()
291            .delete(provider)
292            .await
293            .unwrap();
294        assert_eq!(
295            repo.upstream_oauth_provider()
296                .count(UpstreamOAuthProviderFilter::new())
297                .await
298                .unwrap(),
299            0
300        );
301    }
302
303    /// Test that the pagination works as expected in the upstream OAuth
304    /// provider repository
305    #[sqlx::test(migrator = "crate::MIGRATOR")]
306    async fn test_provider_repository_pagination(pool: PgPool) {
307        let scope = Scope::from_iter([OPENID]);
308
309        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
310        let clock = MockClock::default();
311        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
312
313        let filter = UpstreamOAuthProviderFilter::new();
314
315        // Count the number of providers before we start
316        assert_eq!(
317            repo.upstream_oauth_provider().count(filter).await.unwrap(),
318            0
319        );
320
321        let mut ids = Vec::with_capacity(20);
322        // Create 20 providers
323        for idx in 0..20 {
324            let client_id = format!("client-{idx}");
325            let provider = repo
326                .upstream_oauth_provider()
327                .add(
328                    &mut rng,
329                    &clock,
330                    UpstreamOAuthProviderParams {
331                        issuer: None,
332                        human_name: None,
333                        brand_name: None,
334                        scope: scope.clone(),
335                        token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
336                        fetch_userinfo: false,
337                        userinfo_signed_response_alg: None,
338                        token_endpoint_signing_alg: None,
339                        id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
340                        client_id,
341                        encrypted_client_secret: None,
342                        claims_imports: UpstreamOAuthProviderClaimsImports::default(),
343                        token_endpoint_override: None,
344                        authorization_endpoint_override: None,
345                        userinfo_endpoint_override: None,
346                        jwks_uri_override: None,
347                        discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
348                        pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
349                        response_mode: None,
350                        additional_authorization_parameters: Vec::new(),
351                        forward_login_hint: false,
352                        ui_order: 0,
353                        on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
354                    },
355                )
356                .await
357                .unwrap();
358            ids.push(provider.id);
359            clock.advance(Duration::microseconds(10 * 1000 * 1000));
360        }
361
362        // Now we have 20 providers
363        assert_eq!(
364            repo.upstream_oauth_provider().count(filter).await.unwrap(),
365            20
366        );
367
368        // Lookup the first 10 items
369        let page = repo
370            .upstream_oauth_provider()
371            .list(filter, Pagination::first(10))
372            .await
373            .unwrap();
374
375        // It returned the first 10 items
376        assert!(page.has_next_page);
377        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
378        assert_eq!(&edge_ids, &ids[..10]);
379
380        // Getting the same page with the "enabled only" filter should return the same
381        // results
382        let other_page = repo
383            .upstream_oauth_provider()
384            .list(filter.enabled_only(), Pagination::first(10))
385            .await
386            .unwrap();
387
388        assert_eq!(page, other_page);
389
390        // Lookup the next 10 items
391        let page = repo
392            .upstream_oauth_provider()
393            .list(filter, Pagination::first(10).after(ids[9]))
394            .await
395            .unwrap();
396
397        // It returned the next 10 items
398        assert!(!page.has_next_page);
399        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
400        assert_eq!(&edge_ids, &ids[10..]);
401
402        // Lookup the last 10 items
403        let page = repo
404            .upstream_oauth_provider()
405            .list(filter, Pagination::last(10))
406            .await
407            .unwrap();
408
409        // It returned the last 10 items
410        assert!(page.has_previous_page);
411        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
412        assert_eq!(&edge_ids, &ids[10..]);
413
414        // Lookup the previous 10 items
415        let page = repo
416            .upstream_oauth_provider()
417            .list(filter, Pagination::last(10).before(ids[10]))
418            .await
419            .unwrap();
420
421        // It returned the previous 10 items
422        assert!(!page.has_previous_page);
423        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
424        assert_eq!(&edge_ids, &ids[..10]);
425
426        // Lookup 10 items between two IDs
427        let page = repo
428            .upstream_oauth_provider()
429            .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
430            .await
431            .unwrap();
432
433        // It returned the items in between
434        assert!(!page.has_next_page);
435        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
436        assert_eq!(&edge_ids, &ids[6..8]);
437
438        // There should not be any disabled providers
439        assert!(
440            repo.upstream_oauth_provider()
441                .list(
442                    UpstreamOAuthProviderFilter::new().disabled_only(),
443                    Pagination::first(1)
444                )
445                .await
446                .unwrap()
447                .edges
448                .is_empty()
449        );
450    }
451
452    /// Test that the pagination works as expected in the upstream OAuth
453    /// session repository
454    #[sqlx::test(migrator = "crate::MIGRATOR")]
455    async fn test_session_repository_pagination(pool: PgPool) {
456        let scope = Scope::from_iter([OPENID]);
457
458        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
459        let clock = MockClock::default();
460        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
461
462        // Create a provider
463        let provider = repo
464            .upstream_oauth_provider()
465            .add(
466                &mut rng,
467                &clock,
468                UpstreamOAuthProviderParams {
469                    issuer: Some("https://example.com/".to_owned()),
470                    human_name: None,
471                    brand_name: None,
472                    scope,
473                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
474                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
475                    fetch_userinfo: false,
476                    userinfo_signed_response_alg: None,
477                    token_endpoint_signing_alg: None,
478                    client_id: "client-id".to_owned(),
479                    encrypted_client_secret: None,
480                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
481                    token_endpoint_override: None,
482                    authorization_endpoint_override: None,
483                    userinfo_endpoint_override: None,
484                    jwks_uri_override: None,
485                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
486                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
487                    response_mode: None,
488                    additional_authorization_parameters: Vec::new(),
489                    forward_login_hint: false,
490                    ui_order: 0,
491                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
492                },
493            )
494            .await
495            .unwrap();
496
497        let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
498
499        // Count the number of sessions before we start
500        assert_eq!(
501            repo.upstream_oauth_session().count(filter).await.unwrap(),
502            0
503        );
504
505        let mut links = Vec::with_capacity(3);
506        for subject in ["alice", "bob", "charlie"] {
507            let link = repo
508                .upstream_oauth_link()
509                .add(&mut rng, &clock, &provider, subject.to_owned(), None)
510                .await
511                .unwrap();
512            links.push(link);
513        }
514
515        let mut ids = Vec::with_capacity(20);
516        let sids = ["one", "two"].into_iter().cycle();
517        // Create 20 sessions
518        for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
519            let state = format!("state-{idx}");
520            let session = repo
521                .upstream_oauth_session()
522                .add(&mut rng, &clock, &provider, state, None, None)
523                .await
524                .unwrap();
525            let id_token_claims = serde_json::json!({
526                "sub": link.subject,
527                "sid": sid,
528                "aud": provider.client_id,
529                "iss": "https://example.com/",
530            });
531            let session = repo
532                .upstream_oauth_session()
533                .complete_with_link(
534                    &clock,
535                    session,
536                    link,
537                    None,
538                    Some(id_token_claims),
539                    None,
540                    None,
541                )
542                .await
543                .unwrap();
544            ids.push(session.id);
545            clock.advance(Duration::microseconds(10 * 1000 * 1000));
546        }
547
548        // Now we have 20 sessions
549        assert_eq!(
550            repo.upstream_oauth_session().count(filter).await.unwrap(),
551            20
552        );
553
554        // Lookup the first 10 items
555        let page = repo
556            .upstream_oauth_session()
557            .list(filter, Pagination::first(10))
558            .await
559            .unwrap();
560
561        // It returned the first 10 items
562        assert!(page.has_next_page);
563        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
564        assert_eq!(&edge_ids, &ids[..10]);
565
566        // Lookup the next 10 items
567        let page = repo
568            .upstream_oauth_session()
569            .list(filter, Pagination::first(10).after(ids[9]))
570            .await
571            .unwrap();
572
573        // It returned the next 10 items
574        assert!(!page.has_next_page);
575        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
576        assert_eq!(&edge_ids, &ids[10..]);
577
578        // Lookup the last 10 items
579        let page = repo
580            .upstream_oauth_session()
581            .list(filter, Pagination::last(10))
582            .await
583            .unwrap();
584
585        // It returned the last 10 items
586        assert!(page.has_previous_page);
587        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
588        assert_eq!(&edge_ids, &ids[10..]);
589
590        // Lookup the previous 10 items
591        let page = repo
592            .upstream_oauth_session()
593            .list(filter, Pagination::last(10).before(ids[10]))
594            .await
595            .unwrap();
596
597        // It returned the previous 10 items
598        assert!(!page.has_previous_page);
599        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
600        assert_eq!(&edge_ids, &ids[..10]);
601
602        // Lookup 5 items between two IDs
603        let page = repo
604            .upstream_oauth_session()
605            .list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
606            .await
607            .unwrap();
608
609        // It returned the items in between
610        assert!(!page.has_next_page);
611        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
612        assert_eq!(&edge_ids, &ids[6..11]);
613
614        // Check the sub/sid filters
615        assert_eq!(
616            repo.upstream_oauth_session()
617                .count(filter.with_sub_claim("alice").with_sid_claim("one"))
618                .await
619                .unwrap(),
620            4
621        );
622        assert_eq!(
623            repo.upstream_oauth_session()
624                .count(filter.with_sub_claim("bob").with_sid_claim("two"))
625                .await
626                .unwrap(),
627            4
628        );
629
630        let page = repo
631            .upstream_oauth_session()
632            .list(
633                filter.with_sub_claim("alice").with_sid_claim("one"),
634                Pagination::first(10),
635            )
636            .await
637            .unwrap();
638        assert_eq!(page.edges.len(), 4);
639        for edge in page.edges {
640            assert_eq!(
641                edge.id_token_claims().unwrap().get("sub").unwrap().as_str(),
642                Some("alice")
643            );
644            assert_eq!(
645                edge.id_token_claims().unwrap().get("sid").unwrap().as_str(),
646                Some("one")
647            );
648        }
649    }
650}