1use std::{collections::HashMap, sync::Arc};
8
9use mas_context::LogContext;
10use mas_data_model::{
11 UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
12};
13use mas_iana::oauth::PkceCodeChallengeMethod;
14use mas_oidc_client::error::DiscoveryError;
15use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
16use oauth2_types::oidc::VerifiedProviderMetadata;
17use tokio::sync::RwLock;
18use url::Url;
19
20pub struct LazyProviderInfos<'a> {
23 cache: &'a MetadataCache,
24 provider: &'a UpstreamOAuthProvider,
25 client: &'a reqwest::Client,
26 loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
27}
28
29impl<'a> LazyProviderInfos<'a> {
30 pub fn new(
31 cache: &'a MetadataCache,
32 provider: &'a UpstreamOAuthProvider,
33 client: &'a reqwest::Client,
34 ) -> Self {
35 Self {
36 cache,
37 provider,
38 client,
39 loaded_metadata: None,
40 }
41 }
42
43 pub async fn maybe_discover(
46 &mut self,
47 ) -> Result<Option<&VerifiedProviderMetadata>, DiscoveryError> {
48 match self.load().await {
49 Ok(metadata) => Ok(Some(metadata)),
50 Err(DiscoveryError::Disabled) => Ok(None),
51 Err(e) => Err(e),
52 }
53 }
54
55 async fn load(&mut self) -> Result<&VerifiedProviderMetadata, DiscoveryError> {
56 if self.loaded_metadata.is_none() {
57 let verify = match self.provider.discovery_mode {
58 UpstreamOAuthProviderDiscoveryMode::Oidc => true,
59 UpstreamOAuthProviderDiscoveryMode::Insecure => false,
60 UpstreamOAuthProviderDiscoveryMode::Disabled => {
61 return Err(DiscoveryError::Disabled);
62 }
63 };
64
65 let Some(issuer) = &self.provider.issuer else {
66 return Err(DiscoveryError::MissingIssuer);
67 };
68
69 let metadata = self.cache.get(self.client, issuer, verify).await?;
70
71 self.loaded_metadata = Some(metadata);
72 }
73
74 Ok(self.loaded_metadata.as_ref().unwrap())
75 }
76
77 pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
82 if let Some(jwks_uri) = &self.provider.jwks_uri_override {
83 return Ok(jwks_uri);
84 }
85
86 Ok(self.load().await?.jwks_uri())
87 }
88
89 pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
94 if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
95 return Ok(authorization_endpoint);
96 }
97
98 Ok(self.load().await?.authorization_endpoint())
99 }
100
101 pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
106 if let Some(token_endpoint) = &self.provider.token_endpoint_override {
107 return Ok(token_endpoint);
108 }
109
110 Ok(self.load().await?.token_endpoint())
111 }
112
113 pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
118 if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
119 return Ok(userinfo_endpoint);
120 }
121
122 Ok(self.load().await?.userinfo_endpoint())
123 }
124
125 pub async fn pkce_methods(
130 &mut self,
131 ) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
132 let methods = match self.provider.pkce_mode {
133 UpstreamOAuthProviderPkceMode::Auto => self
134 .maybe_discover()
135 .await?
136 .and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
137 UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
138 UpstreamOAuthProviderPkceMode::Disabled => None,
139 };
140
141 Ok(methods)
142 }
143}
144
145#[allow(clippy::module_name_repetitions)]
151#[derive(Debug, Clone, Default)]
152pub struct MetadataCache {
153 cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
154 insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
155}
156
157impl MetadataCache {
158 #[must_use]
159 pub fn new() -> Self {
160 Self::default()
161 }
162
163 #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all)]
173 pub async fn warm_up_and_run<R: RepositoryAccess>(
174 &self,
175 client: &reqwest::Client,
176 interval: std::time::Duration,
177 repository: &mut R,
178 ) -> Result<tokio::task::JoinHandle<()>, R::Error> {
179 let providers = repository.upstream_oauth_provider().all_enabled().await?;
180
181 for provider in providers {
182 let verify = match provider.discovery_mode {
183 UpstreamOAuthProviderDiscoveryMode::Oidc => true,
184 UpstreamOAuthProviderDiscoveryMode::Insecure => false,
185 UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
186 };
187
188 let Some(issuer) = &provider.issuer else {
189 tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
190 continue;
191 };
192
193 if let Err(e) = self.fetch(client, issuer, verify).await {
194 tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
195 }
196 }
197
198 let cache = self.clone();
200 let client = client.clone();
201 Ok(tokio::spawn(async move {
202 loop {
203 tokio::time::sleep(interval).await;
205 LogContext::new("metadata-cache-refresh")
206 .run(|| cache.refresh_all(&client))
207 .await;
208 }
209 }))
210 }
211
212 #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all)]
213 async fn fetch(
214 &self,
215 client: &reqwest::Client,
216 issuer: &str,
217 verify: bool,
218 ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
219 if verify {
220 let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?;
221 let metadata = Arc::new(metadata);
222
223 self.cache
224 .write()
225 .await
226 .insert(issuer.to_owned(), metadata.clone());
227
228 Ok(metadata)
229 } else {
230 let metadata =
231 mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?;
232 let metadata = Arc::new(metadata);
233
234 self.insecure_cache
235 .write()
236 .await
237 .insert(issuer.to_owned(), metadata.clone());
238
239 Ok(metadata)
240 }
241 }
242
243 #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all)]
249 pub async fn get(
250 &self,
251 client: &reqwest::Client,
252 issuer: &str,
253 verify: bool,
254 ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
255 let cache = if verify {
256 self.cache.read().await
257 } else {
258 self.insecure_cache.read().await
259 };
260
261 if let Some(metadata) = cache.get(issuer) {
262 return Ok(Arc::clone(metadata));
263 }
264 drop(cache);
266
267 let metadata = self.fetch(client, issuer, verify).await?;
268 Ok(metadata)
269 }
270
271 #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
272 async fn refresh_all(&self, client: &reqwest::Client) {
273 let keys: Vec<String> = {
275 let cache = self.cache.read().await;
276 cache.keys().cloned().collect()
277 };
278
279 for issuer in keys {
280 if let Err(e) = self.fetch(client, &issuer, true).await {
281 tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
282 }
283 }
284
285 let keys: Vec<String> = {
287 let cache = self.insecure_cache.read().await;
288 cache.keys().cloned().collect()
289 };
290
291 for issuer in keys {
292 if let Err(e) = self.fetch(client, &issuer, false).await {
293 tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
294 }
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use mas_data_model::{
305 Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
306 UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
307 };
308 use mas_iana::jose::JsonWebSignatureAlg;
309 use oauth2_types::scope::{OPENID, Scope};
310 use ulid::Ulid;
311 use wiremock::{
312 Mock, MockServer, ResponseTemplate,
313 matchers::{method, path},
314 };
315
316 use super::*;
317 use crate::test_utils::setup;
318
319 #[tokio::test]
320 async fn test_metadata_cache() {
321 setup();
322 let mock_server = MockServer::start().await;
323 let http_client = mas_http::reqwest_client();
324
325 let cache = MetadataCache::new();
326
327 cache
329 .get(&http_client, &mock_server.uri(), false)
330 .await
331 .unwrap_err();
332
333 let expected_calls = 3;
334 let mut calls = 0;
335 let _mock_guard = Mock::given(method("GET"))
336 .and(path("/.well-known/openid-configuration"))
337 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
338 "issuer": mock_server.uri(),
339 "authorization_endpoint": "https://example.com/authorize",
340 "token_endpoint": "https://example.com/token",
341 "jwks_uri": "https://example.com/jwks",
342 "userinfo_endpoint": "https://example.com/userinfo",
343 "scopes_supported": ["openid"],
344 "response_types_supported": ["code"],
345 "response_modes_supported": ["query", "fragment"],
346 "grant_types_supported": ["authorization_code"],
347 "subject_types_supported": ["public"],
348 "id_token_signing_alg_values_supported": ["RS256"],
349 })))
350 .expect(expected_calls)
351 .mount(&mock_server)
352 .await;
353
354 cache
356 .get(&http_client, &mock_server.uri(), false)
357 .await
358 .unwrap();
359 calls += 1;
360
361 cache
363 .get(&http_client, &mock_server.uri(), false)
364 .await
365 .unwrap();
366 calls += 0;
367
368 cache
370 .get(&http_client, &mock_server.uri(), true)
371 .await
372 .unwrap_err();
373 calls += 1;
374
375 cache.refresh_all(&http_client).await;
377 calls += 1;
378
379 assert_eq!(calls, expected_calls);
380 }
381
382 #[tokio::test]
383 async fn test_lazy_provider_infos() {
384 setup();
385
386 let mock_server = MockServer::start().await;
387 let http_client = mas_http::reqwest_client();
388
389 let expected_calls = 2;
390 let mut calls = 0;
391 let _mock_guard = Mock::given(method("GET"))
392 .and(path("/.well-known/openid-configuration"))
393 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
394 "issuer": mock_server.uri(),
395 "authorization_endpoint": "https://example.com/authorize",
396 "token_endpoint": "https://example.com/token",
397 "jwks_uri": "https://example.com/jwks",
398 "userinfo_endpoint": "https://example.com/userinfo",
399 "scopes_supported": ["openid"],
400 "response_types_supported": ["code"],
401 "response_modes_supported": ["query", "fragment"],
402 "grant_types_supported": ["authorization_code"],
403 "subject_types_supported": ["public"],
404 "id_token_signing_alg_values_supported": ["RS256"],
405 })))
406 .expect(expected_calls)
407 .mount(&mock_server)
408 .await;
409
410 let clock = MockClock::default();
411 let provider = UpstreamOAuthProvider {
412 id: Ulid::nil(),
413 issuer: Some(mock_server.uri()),
414 human_name: Some("Example Ltd.".to_owned()),
415 brand_name: None,
416 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
417 pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
418 fetch_userinfo: false,
419 userinfo_signed_response_alg: None,
420 jwks_uri_override: None,
421 authorization_endpoint_override: None,
422 scope: Scope::from_iter([OPENID]),
423 userinfo_endpoint_override: None,
424 token_endpoint_override: None,
425 client_id: "client_id".to_owned(),
426 encrypted_client_secret: None,
427 token_endpoint_signing_alg: None,
428 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
429 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
430 response_mode: None,
431 created_at: clock.now(),
432 disabled_at: None,
433 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
434 additional_authorization_parameters: Vec::new(),
435 forward_login_hint: false,
436 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
437 };
438
439 {
441 let cache = MetadataCache::new();
442 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
443 lazy_metadata.maybe_discover().await.unwrap();
444 assert_eq!(
445 lazy_metadata
446 .authorization_endpoint()
447 .await
448 .unwrap()
449 .as_str(),
450 "https://example.com/authorize"
451 );
452 calls += 1;
453 }
454
455 {
457 let provider = UpstreamOAuthProvider {
458 jwks_uri_override: Some("https://example.com/jwks_override".parse().unwrap()),
459 authorization_endpoint_override: Some(
460 "https://example.com/authorize_override".parse().unwrap(),
461 ),
462 token_endpoint_override: Some(
463 "https://example.com/token_override".parse().unwrap(),
464 ),
465 ..provider.clone()
466 };
467 let cache = MetadataCache::new();
468 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
469 assert_eq!(
470 lazy_metadata.jwks_uri().await.unwrap().as_str(),
471 "https://example.com/jwks_override"
472 );
473 assert_eq!(
474 lazy_metadata
475 .authorization_endpoint()
476 .await
477 .unwrap()
478 .as_str(),
479 "https://example.com/authorize_override"
480 );
481 assert_eq!(
482 lazy_metadata.token_endpoint().await.unwrap().as_str(),
483 "https://example.com/token_override"
484 );
485 calls += 0;
487 }
488
489 {
491 let provider = UpstreamOAuthProvider {
492 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
493 ..provider.clone()
494 };
495 let cache = MetadataCache::new();
496 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
497 lazy_metadata.authorization_endpoint().await.unwrap_err();
498 calls += 1;
500 }
501
502 {
504 let provider = UpstreamOAuthProvider {
505 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
506 authorization_endpoint_override: Some(
507 Url::parse("https://example.com/authorize_override").unwrap(),
508 ),
509 token_endpoint_override: None,
510 ..provider.clone()
511 };
512 let cache = MetadataCache::new();
513 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
514 assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
516 assert_eq!(
517 lazy_metadata
518 .authorization_endpoint()
519 .await
520 .unwrap()
521 .as_str(),
522 "https://example.com/authorize_override"
523 );
524 assert!(matches!(
525 lazy_metadata.token_endpoint().await,
526 Err(DiscoveryError::Disabled),
527 ));
528 calls += 0;
530 }
531
532 assert_eq!(calls, expected_calls);
533 }
534}