1mod link;
11mod provider;
12mod session;
13
14pub use self::{
15 link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16 session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21 use chrono::Duration;
22 use mas_data_model::{
23 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
24 UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
25 };
26 use mas_iana::jose::JsonWebSignatureAlg;
27 use mas_storage::{
28 Pagination, RepositoryAccess,
29 upstream_oauth2::{
30 UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31 UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32 UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
33 },
34 user::UserRepository,
35 };
36 use oauth2_types::scope::{OPENID, Scope};
37 use rand::SeedableRng;
38 use sqlx::PgPool;
39
40 use crate::PgRepository;
41
42 #[sqlx::test(migrator = "crate::MIGRATOR")]
43 async fn test_repository(pool: PgPool) {
44 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45 let clock = MockClock::default();
46 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48 let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50 assert!(all_providers.is_empty());
51
52 let provider = repo
54 .upstream_oauth_provider()
55 .add(
56 &mut rng,
57 &clock,
58 UpstreamOAuthProviderParams {
59 issuer: Some("https://example.com/".to_owned()),
60 human_name: None,
61 brand_name: None,
62 scope: Scope::from_iter([OPENID]),
63 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65 fetch_userinfo: false,
66 userinfo_signed_response_alg: None,
67 token_endpoint_signing_alg: None,
68 client_id: "client-id".to_owned(),
69 encrypted_client_secret: None,
70 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71 token_endpoint_override: None,
72 authorization_endpoint_override: None,
73 userinfo_endpoint_override: None,
74 jwks_uri_override: None,
75 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77 response_mode: None,
78 additional_authorization_parameters: Vec::new(),
79 forward_login_hint: false,
80 ui_order: 0,
81 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
82 },
83 )
84 .await
85 .unwrap();
86
87 let provider = repo
89 .upstream_oauth_provider()
90 .lookup(provider.id)
91 .await
92 .unwrap()
93 .expect("provider to be found in the database");
94 assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
95 assert_eq!(provider.client_id, "client-id");
96
97 let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
99 assert_eq!(providers.len(), 1);
100 assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
101 assert_eq!(providers[0].client_id, "client-id");
102
103 let session = repo
105 .upstream_oauth_session()
106 .add(
107 &mut rng,
108 &clock,
109 &provider,
110 "some-state".to_owned(),
111 None,
112 Some("some-nonce".to_owned()),
113 )
114 .await
115 .unwrap();
116
117 let session = repo
119 .upstream_oauth_session()
120 .lookup(session.id)
121 .await
122 .unwrap()
123 .expect("session to be found in the database");
124 assert_eq!(session.provider_id, provider.id);
125 assert_eq!(session.link_id(), None);
126 assert!(session.is_pending());
127 assert!(!session.is_completed());
128 assert!(!session.is_consumed());
129
130 let link = repo
132 .upstream_oauth_link()
133 .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
134 .await
135 .unwrap();
136
137 repo.upstream_oauth_link()
139 .lookup(link.id)
140 .await
141 .unwrap()
142 .expect("link to be found in database");
143
144 let link = repo
146 .upstream_oauth_link()
147 .find_by_subject(&provider, "a-subject")
148 .await
149 .unwrap()
150 .expect("link to be found in database");
151 assert_eq!(link.subject, "a-subject");
152 assert_eq!(link.provider_id, provider.id);
153
154 let session = repo
155 .upstream_oauth_session()
156 .complete_with_link(&clock, session, &link, None, None, None, None)
157 .await
158 .unwrap();
159 let session = repo
161 .upstream_oauth_session()
162 .lookup(session.id)
163 .await
164 .unwrap()
165 .expect("session to be found in the database");
166 assert!(session.is_completed());
167 assert!(!session.is_consumed());
168 assert_eq!(session.link_id(), Some(link.id));
169
170 let session = repo
171 .upstream_oauth_session()
172 .consume(&clock, session)
173 .await
174 .unwrap();
175 let session = repo
177 .upstream_oauth_session()
178 .lookup(session.id)
179 .await
180 .unwrap()
181 .expect("session to be found in the database");
182 assert!(session.is_consumed());
183
184 let user = repo
185 .user()
186 .add(&mut rng, &clock, "john".to_owned())
187 .await
188 .unwrap();
189 repo.upstream_oauth_link()
190 .associate_to_user(&link, &user)
191 .await
192 .unwrap();
193
194 let filter = UpstreamOAuthLinkFilter::new()
196 .for_user(&user)
197 .for_provider(&provider)
198 .for_subject("a-subject")
199 .enabled_providers_only();
200
201 let links = repo
202 .upstream_oauth_link()
203 .list(filter, Pagination::first(10))
204 .await
205 .unwrap();
206 assert!(!links.has_previous_page);
207 assert!(!links.has_next_page);
208 assert_eq!(links.edges.len(), 1);
209 assert_eq!(links.edges[0].id, link.id);
210 assert_eq!(links.edges[0].user_id, Some(user.id));
211
212 assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
213
214 assert_eq!(
216 repo.upstream_oauth_provider()
217 .count(UpstreamOAuthProviderFilter::new())
218 .await
219 .unwrap(),
220 1
221 );
222 assert_eq!(
223 repo.upstream_oauth_provider()
224 .count(UpstreamOAuthProviderFilter::new().enabled_only())
225 .await
226 .unwrap(),
227 1
228 );
229 assert_eq!(
230 repo.upstream_oauth_provider()
231 .count(UpstreamOAuthProviderFilter::new().disabled_only())
232 .await
233 .unwrap(),
234 0
235 );
236
237 repo.upstream_oauth_provider()
239 .disable(&clock, provider.clone())
240 .await
241 .unwrap();
242
243 assert_eq!(
245 repo.upstream_oauth_provider()
246 .count(UpstreamOAuthProviderFilter::new())
247 .await
248 .unwrap(),
249 1
250 );
251 assert_eq!(
252 repo.upstream_oauth_provider()
253 .count(UpstreamOAuthProviderFilter::new().enabled_only())
254 .await
255 .unwrap(),
256 0
257 );
258 assert_eq!(
259 repo.upstream_oauth_provider()
260 .count(UpstreamOAuthProviderFilter::new().disabled_only())
261 .await
262 .unwrap(),
263 1
264 );
265
266 let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
268
269 let session_count = repo
271 .upstream_oauth_session()
272 .count(session_filter)
273 .await
274 .unwrap();
275 assert_eq!(session_count, 1);
276
277 let session_page = repo
279 .upstream_oauth_session()
280 .list(session_filter, Pagination::first(10))
281 .await
282 .unwrap();
283
284 assert_eq!(session_page.edges.len(), 1);
285 assert_eq!(session_page.edges[0].id, session.id);
286 assert!(!session_page.has_next_page);
287 assert!(!session_page.has_previous_page);
288
289 repo.upstream_oauth_provider()
291 .delete(provider)
292 .await
293 .unwrap();
294 assert_eq!(
295 repo.upstream_oauth_provider()
296 .count(UpstreamOAuthProviderFilter::new())
297 .await
298 .unwrap(),
299 0
300 );
301 }
302
303 #[sqlx::test(migrator = "crate::MIGRATOR")]
306 async fn test_provider_repository_pagination(pool: PgPool) {
307 let scope = Scope::from_iter([OPENID]);
308
309 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
310 let clock = MockClock::default();
311 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
312
313 let filter = UpstreamOAuthProviderFilter::new();
314
315 assert_eq!(
317 repo.upstream_oauth_provider().count(filter).await.unwrap(),
318 0
319 );
320
321 let mut ids = Vec::with_capacity(20);
322 for idx in 0..20 {
324 let client_id = format!("client-{idx}");
325 let provider = repo
326 .upstream_oauth_provider()
327 .add(
328 &mut rng,
329 &clock,
330 UpstreamOAuthProviderParams {
331 issuer: None,
332 human_name: None,
333 brand_name: None,
334 scope: scope.clone(),
335 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
336 fetch_userinfo: false,
337 userinfo_signed_response_alg: None,
338 token_endpoint_signing_alg: None,
339 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
340 client_id,
341 encrypted_client_secret: None,
342 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
343 token_endpoint_override: None,
344 authorization_endpoint_override: None,
345 userinfo_endpoint_override: None,
346 jwks_uri_override: None,
347 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
348 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
349 response_mode: None,
350 additional_authorization_parameters: Vec::new(),
351 forward_login_hint: false,
352 ui_order: 0,
353 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
354 },
355 )
356 .await
357 .unwrap();
358 ids.push(provider.id);
359 clock.advance(Duration::microseconds(10 * 1000 * 1000));
360 }
361
362 assert_eq!(
364 repo.upstream_oauth_provider().count(filter).await.unwrap(),
365 20
366 );
367
368 let page = repo
370 .upstream_oauth_provider()
371 .list(filter, Pagination::first(10))
372 .await
373 .unwrap();
374
375 assert!(page.has_next_page);
377 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
378 assert_eq!(&edge_ids, &ids[..10]);
379
380 let other_page = repo
383 .upstream_oauth_provider()
384 .list(filter.enabled_only(), Pagination::first(10))
385 .await
386 .unwrap();
387
388 assert_eq!(page, other_page);
389
390 let page = repo
392 .upstream_oauth_provider()
393 .list(filter, Pagination::first(10).after(ids[9]))
394 .await
395 .unwrap();
396
397 assert!(!page.has_next_page);
399 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
400 assert_eq!(&edge_ids, &ids[10..]);
401
402 let page = repo
404 .upstream_oauth_provider()
405 .list(filter, Pagination::last(10))
406 .await
407 .unwrap();
408
409 assert!(page.has_previous_page);
411 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
412 assert_eq!(&edge_ids, &ids[10..]);
413
414 let page = repo
416 .upstream_oauth_provider()
417 .list(filter, Pagination::last(10).before(ids[10]))
418 .await
419 .unwrap();
420
421 assert!(!page.has_previous_page);
423 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
424 assert_eq!(&edge_ids, &ids[..10]);
425
426 let page = repo
428 .upstream_oauth_provider()
429 .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
430 .await
431 .unwrap();
432
433 assert!(!page.has_next_page);
435 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
436 assert_eq!(&edge_ids, &ids[6..8]);
437
438 assert!(
440 repo.upstream_oauth_provider()
441 .list(
442 UpstreamOAuthProviderFilter::new().disabled_only(),
443 Pagination::first(1)
444 )
445 .await
446 .unwrap()
447 .edges
448 .is_empty()
449 );
450 }
451
452 #[sqlx::test(migrator = "crate::MIGRATOR")]
455 async fn test_session_repository_pagination(pool: PgPool) {
456 let scope = Scope::from_iter([OPENID]);
457
458 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
459 let clock = MockClock::default();
460 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
461
462 let provider = repo
464 .upstream_oauth_provider()
465 .add(
466 &mut rng,
467 &clock,
468 UpstreamOAuthProviderParams {
469 issuer: Some("https://example.com/".to_owned()),
470 human_name: None,
471 brand_name: None,
472 scope,
473 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
474 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
475 fetch_userinfo: false,
476 userinfo_signed_response_alg: None,
477 token_endpoint_signing_alg: None,
478 client_id: "client-id".to_owned(),
479 encrypted_client_secret: None,
480 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
481 token_endpoint_override: None,
482 authorization_endpoint_override: None,
483 userinfo_endpoint_override: None,
484 jwks_uri_override: None,
485 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
486 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
487 response_mode: None,
488 additional_authorization_parameters: Vec::new(),
489 forward_login_hint: false,
490 ui_order: 0,
491 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
492 },
493 )
494 .await
495 .unwrap();
496
497 let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
498
499 assert_eq!(
501 repo.upstream_oauth_session().count(filter).await.unwrap(),
502 0
503 );
504
505 let mut links = Vec::with_capacity(3);
506 for subject in ["alice", "bob", "charlie"] {
507 let link = repo
508 .upstream_oauth_link()
509 .add(&mut rng, &clock, &provider, subject.to_owned(), None)
510 .await
511 .unwrap();
512 links.push(link);
513 }
514
515 let mut ids = Vec::with_capacity(20);
516 let sids = ["one", "two"].into_iter().cycle();
517 for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
519 let state = format!("state-{idx}");
520 let session = repo
521 .upstream_oauth_session()
522 .add(&mut rng, &clock, &provider, state, None, None)
523 .await
524 .unwrap();
525 let id_token_claims = serde_json::json!({
526 "sub": link.subject,
527 "sid": sid,
528 "aud": provider.client_id,
529 "iss": "https://example.com/",
530 });
531 let session = repo
532 .upstream_oauth_session()
533 .complete_with_link(
534 &clock,
535 session,
536 link,
537 None,
538 Some(id_token_claims),
539 None,
540 None,
541 )
542 .await
543 .unwrap();
544 ids.push(session.id);
545 clock.advance(Duration::microseconds(10 * 1000 * 1000));
546 }
547
548 assert_eq!(
550 repo.upstream_oauth_session().count(filter).await.unwrap(),
551 20
552 );
553
554 let page = repo
556 .upstream_oauth_session()
557 .list(filter, Pagination::first(10))
558 .await
559 .unwrap();
560
561 assert!(page.has_next_page);
563 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
564 assert_eq!(&edge_ids, &ids[..10]);
565
566 let page = repo
568 .upstream_oauth_session()
569 .list(filter, Pagination::first(10).after(ids[9]))
570 .await
571 .unwrap();
572
573 assert!(!page.has_next_page);
575 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
576 assert_eq!(&edge_ids, &ids[10..]);
577
578 let page = repo
580 .upstream_oauth_session()
581 .list(filter, Pagination::last(10))
582 .await
583 .unwrap();
584
585 assert!(page.has_previous_page);
587 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
588 assert_eq!(&edge_ids, &ids[10..]);
589
590 let page = repo
592 .upstream_oauth_session()
593 .list(filter, Pagination::last(10).before(ids[10]))
594 .await
595 .unwrap();
596
597 assert!(!page.has_previous_page);
599 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
600 assert_eq!(&edge_ids, &ids[..10]);
601
602 let page = repo
604 .upstream_oauth_session()
605 .list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
606 .await
607 .unwrap();
608
609 assert!(!page.has_next_page);
611 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
612 assert_eq!(&edge_ids, &ids[6..11]);
613
614 assert_eq!(
616 repo.upstream_oauth_session()
617 .count(filter.with_sub_claim("alice").with_sid_claim("one"))
618 .await
619 .unwrap(),
620 4
621 );
622 assert_eq!(
623 repo.upstream_oauth_session()
624 .count(filter.with_sub_claim("bob").with_sid_claim("two"))
625 .await
626 .unwrap(),
627 4
628 );
629
630 let page = repo
631 .upstream_oauth_session()
632 .list(
633 filter.with_sub_claim("alice").with_sid_claim("one"),
634 Pagination::first(10),
635 )
636 .await
637 .unwrap();
638 assert_eq!(page.edges.len(), 4);
639 for edge in page.edges {
640 assert_eq!(
641 edge.id_token_claims().unwrap().get("sub").unwrap().as_str(),
642 Some("alice")
643 );
644 assert_eq!(
645 edge.id_token_claims().unwrap().get("sid").unwrap().as_str(),
646 Some("one")
647 );
648 }
649 }
650}