mas_storage/upstream_oauth2/provider.rs
1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11 Clock, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
12 UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderOnBackchannelLogout,
13 UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
14 UpstreamOAuthProviderTokenAuthMethod,
15};
16use mas_iana::jose::JsonWebSignatureAlg;
17use oauth2_types::scope::Scope;
18use rand_core::RngCore;
19use ulid::Ulid;
20use url::Url;
21
22use crate::{Pagination, pagination::Page, repository_impl};
23
24/// Structure which holds parameters when inserting or updating an upstream
25/// OAuth 2.0 provider
26pub struct UpstreamOAuthProviderParams {
27 /// The OIDC issuer of the provider
28 pub issuer: Option<String>,
29
30 /// A human-readable name for the provider
31 pub human_name: Option<String>,
32
33 /// A brand identifier, e.g. "apple" or "google"
34 pub brand_name: Option<String>,
35
36 /// The scope to request during the authorization flow
37 pub scope: Scope,
38
39 /// The token endpoint authentication method
40 pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
41
42 /// The JWT signing algorithm to use when then `client_secret_jwt` or
43 /// `private_key_jwt` authentication methods are used
44 pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
45
46 /// Expected signature for the JWT payload returned by the token
47 /// authentication endpoint.
48 ///
49 /// Defaults to `RS256`.
50 pub id_token_signed_response_alg: JsonWebSignatureAlg,
51
52 /// Whether to fetch the user profile from the userinfo endpoint,
53 /// or to rely on the data returned in the `id_token` from the
54 /// `token_endpoint`.
55 pub fetch_userinfo: bool,
56
57 /// Expected signature for the JWT payload returned by the userinfo
58 /// endpoint.
59 ///
60 /// If not specified, the response is expected to be an unsigned JSON
61 /// payload. Defaults to `None`.
62 pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
63
64 /// The client ID to use when authenticating to the upstream
65 pub client_id: String,
66
67 /// The encrypted client secret to use when authenticating to the upstream
68 pub encrypted_client_secret: Option<String>,
69
70 /// How claims should be imported from the upstream provider
71 pub claims_imports: UpstreamOAuthProviderClaimsImports,
72
73 /// The URL to use as the authorization endpoint. If `None`, the URL will be
74 /// discovered
75 pub authorization_endpoint_override: Option<Url>,
76
77 /// The URL to use as the token endpoint. If `None`, the URL will be
78 /// discovered
79 pub token_endpoint_override: Option<Url>,
80
81 /// The URL to use as the userinfo endpoint. If `None`, the URL will be
82 /// discovered
83 pub userinfo_endpoint_override: Option<Url>,
84
85 /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
86 pub jwks_uri_override: Option<Url>,
87
88 /// How the provider metadata should be discovered
89 pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
90
91 /// How should PKCE be used
92 pub pkce_mode: UpstreamOAuthProviderPkceMode,
93
94 /// What response mode it should ask
95 pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
96
97 /// Additional parameters to include in the authorization request
98 pub additional_authorization_parameters: Vec<(String, String)>,
99
100 /// Whether to forward the login hint to the upstream provider.
101 pub forward_login_hint: bool,
102
103 /// The position of the provider in the UI
104 pub ui_order: i32,
105
106 /// The behavior when receiving a backchannel logout notification
107 pub on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout,
108}
109
110/// Filter parameters for listing upstream OAuth 2.0 providers
111#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
112pub struct UpstreamOAuthProviderFilter<'a> {
113 /// Filter by whether the provider is enabled
114 ///
115 /// If `None`, all providers are returned
116 enabled: Option<bool>,
117
118 _lifetime: PhantomData<&'a ()>,
119}
120
121impl UpstreamOAuthProviderFilter<'_> {
122 /// Create a new [`UpstreamOAuthProviderFilter`] with default values
123 #[must_use]
124 pub fn new() -> Self {
125 Self::default()
126 }
127
128 /// Return only enabled providers
129 #[must_use]
130 pub const fn enabled_only(mut self) -> Self {
131 self.enabled = Some(true);
132 self
133 }
134
135 /// Return only disabled providers
136 #[must_use]
137 pub const fn disabled_only(mut self) -> Self {
138 self.enabled = Some(false);
139 self
140 }
141
142 /// Get the enabled filter
143 ///
144 /// Returns `None` if the filter is not set
145 #[must_use]
146 pub const fn enabled(&self) -> Option<bool> {
147 self.enabled
148 }
149}
150
151/// An [`UpstreamOAuthProviderRepository`] helps interacting with
152/// [`UpstreamOAuthProvider`] saved in the storage backend
153#[async_trait]
154pub trait UpstreamOAuthProviderRepository: Send + Sync {
155 /// The error type returned by the repository
156 type Error;
157
158 /// Lookup an upstream OAuth provider by its ID
159 ///
160 /// Returns `None` if the provider was not found
161 ///
162 /// # Parameters
163 ///
164 /// * `id`: The ID of the provider to lookup
165 ///
166 /// # Errors
167 ///
168 /// Returns [`Self::Error`] if the underlying repository fails
169 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
170
171 /// Add a new upstream OAuth provider
172 ///
173 /// Returns the newly created provider
174 ///
175 /// # Parameters
176 ///
177 /// * `rng`: A random number generator
178 /// * `clock`: The clock used to generate timestamps
179 /// * `params`: The parameters of the provider to add
180 ///
181 /// # Errors
182 ///
183 /// Returns [`Self::Error`] if the underlying repository fails
184 async fn add(
185 &mut self,
186 rng: &mut (dyn RngCore + Send),
187 clock: &dyn Clock,
188 params: UpstreamOAuthProviderParams,
189 ) -> Result<UpstreamOAuthProvider, Self::Error>;
190
191 /// Delete an upstream OAuth provider
192 ///
193 /// # Parameters
194 ///
195 /// * `provider`: The provider to delete
196 ///
197 /// # Errors
198 ///
199 /// Returns [`Self::Error`] if the underlying repository fails
200 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
201 self.delete_by_id(provider.id).await
202 }
203
204 /// Delete an upstream OAuth provider by its ID
205 ///
206 /// # Parameters
207 ///
208 /// * `id`: The ID of the provider to delete
209 ///
210 /// # Errors
211 ///
212 /// Returns [`Self::Error`] if the underlying repository fails
213 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
214
215 /// Insert or update an upstream OAuth provider
216 ///
217 /// # Parameters
218 ///
219 /// * `clock`: The clock used to generate timestamps
220 /// * `id`: The ID of the provider to update
221 /// * `params`: The parameters of the provider to update
222 ///
223 /// # Errors
224 ///
225 /// Returns [`Self::Error`] if the underlying repository fails
226 async fn upsert(
227 &mut self,
228 clock: &dyn Clock,
229 id: Ulid,
230 params: UpstreamOAuthProviderParams,
231 ) -> Result<UpstreamOAuthProvider, Self::Error>;
232
233 /// Disable an upstream OAuth provider
234 ///
235 /// Returns the disabled provider
236 ///
237 /// # Parameters
238 ///
239 /// * `clock`: The clock used to generate timestamps
240 /// * `provider`: The provider to disable
241 ///
242 /// # Errors
243 ///
244 /// Returns [`Self::Error`] if the underlying repository fails
245 async fn disable(
246 &mut self,
247 clock: &dyn Clock,
248 provider: UpstreamOAuthProvider,
249 ) -> Result<UpstreamOAuthProvider, Self::Error>;
250
251 /// List [`UpstreamOAuthProvider`] with the given filter and pagination
252 ///
253 /// # Parameters
254 ///
255 /// * `filter`: The filter to apply
256 /// * `pagination`: The pagination parameters
257 ///
258 /// # Errors
259 ///
260 /// Returns [`Self::Error`] if the underlying repository fails
261 async fn list(
262 &mut self,
263 filter: UpstreamOAuthProviderFilter<'_>,
264 pagination: Pagination,
265 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
266
267 /// Count the number of [`UpstreamOAuthProvider`] with the given filter
268 ///
269 /// # Parameters
270 ///
271 /// * `filter`: The filter to apply
272 ///
273 /// # Errors
274 ///
275 /// Returns [`Self::Error`] if the underlying repository fails
276 async fn count(
277 &mut self,
278 filter: UpstreamOAuthProviderFilter<'_>,
279 ) -> Result<usize, Self::Error>;
280
281 /// Get all enabled upstream OAuth providers
282 ///
283 /// # Errors
284 ///
285 /// Returns [`Self::Error`] if the underlying repository fails
286 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
287}
288
289repository_impl!(UpstreamOAuthProviderRepository:
290 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
291
292 async fn add(
293 &mut self,
294 rng: &mut (dyn RngCore + Send),
295 clock: &dyn Clock,
296 params: UpstreamOAuthProviderParams
297 ) -> Result<UpstreamOAuthProvider, Self::Error>;
298
299 async fn upsert(
300 &mut self,
301 clock: &dyn Clock,
302 id: Ulid,
303 params: UpstreamOAuthProviderParams
304 ) -> Result<UpstreamOAuthProvider, Self::Error>;
305
306 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
307
308 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
309
310 async fn disable(
311 &mut self,
312 clock: &dyn Clock,
313 provider: UpstreamOAuthProvider
314 ) -> Result<UpstreamOAuthProvider, Self::Error>;
315
316 async fn list(
317 &mut self,
318 filter: UpstreamOAuthProviderFilter<'_>,
319 pagination: Pagination
320 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
321
322 async fn count(
323 &mut self,
324 filter: UpstreamOAuthProviderFilter<'_>
325 ) -> Result<usize, Self::Error>;
326
327 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
328);