mas_storage/upstream_oauth2/
provider.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11    Clock, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
12    UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderOnBackchannelLogout,
13    UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
14    UpstreamOAuthProviderTokenAuthMethod,
15};
16use mas_iana::jose::JsonWebSignatureAlg;
17use oauth2_types::scope::Scope;
18use rand_core::RngCore;
19use ulid::Ulid;
20use url::Url;
21
22use crate::{Pagination, pagination::Page, repository_impl};
23
24/// Structure which holds parameters when inserting or updating an upstream
25/// OAuth 2.0 provider
26pub struct UpstreamOAuthProviderParams {
27    /// The OIDC issuer of the provider
28    pub issuer: Option<String>,
29
30    /// A human-readable name for the provider
31    pub human_name: Option<String>,
32
33    /// A brand identifier, e.g. "apple" or "google"
34    pub brand_name: Option<String>,
35
36    /// The scope to request during the authorization flow
37    pub scope: Scope,
38
39    /// The token endpoint authentication method
40    pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
41
42    /// The JWT signing algorithm to use when then `client_secret_jwt` or
43    /// `private_key_jwt` authentication methods are used
44    pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
45
46    /// Expected signature for the JWT payload returned by the token
47    /// authentication endpoint.
48    ///
49    /// Defaults to `RS256`.
50    pub id_token_signed_response_alg: JsonWebSignatureAlg,
51
52    /// Whether to fetch the user profile from the userinfo endpoint,
53    /// or to rely on the data returned in the `id_token` from the
54    /// `token_endpoint`.
55    pub fetch_userinfo: bool,
56
57    /// Expected signature for the JWT payload returned by the userinfo
58    /// endpoint.
59    ///
60    /// If not specified, the response is expected to be an unsigned JSON
61    /// payload. Defaults to `None`.
62    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
63
64    /// The client ID to use when authenticating to the upstream
65    pub client_id: String,
66
67    /// The encrypted client secret to use when authenticating to the upstream
68    pub encrypted_client_secret: Option<String>,
69
70    /// How claims should be imported from the upstream provider
71    pub claims_imports: UpstreamOAuthProviderClaimsImports,
72
73    /// The URL to use as the authorization endpoint. If `None`, the URL will be
74    /// discovered
75    pub authorization_endpoint_override: Option<Url>,
76
77    /// The URL to use as the token endpoint. If `None`, the URL will be
78    /// discovered
79    pub token_endpoint_override: Option<Url>,
80
81    /// The URL to use as the userinfo endpoint. If `None`, the URL will be
82    /// discovered
83    pub userinfo_endpoint_override: Option<Url>,
84
85    /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
86    pub jwks_uri_override: Option<Url>,
87
88    /// How the provider metadata should be discovered
89    pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
90
91    /// How should PKCE be used
92    pub pkce_mode: UpstreamOAuthProviderPkceMode,
93
94    /// What response mode it should ask
95    pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
96
97    /// Additional parameters to include in the authorization request
98    pub additional_authorization_parameters: Vec<(String, String)>,
99
100    /// Whether to forward the login hint to the upstream provider.
101    pub forward_login_hint: bool,
102
103    /// The position of the provider in the UI
104    pub ui_order: i32,
105
106    /// The behavior when receiving a backchannel logout notification
107    pub on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout,
108}
109
110/// Filter parameters for listing upstream OAuth 2.0 providers
111#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
112pub struct UpstreamOAuthProviderFilter<'a> {
113    /// Filter by whether the provider is enabled
114    ///
115    /// If `None`, all providers are returned
116    enabled: Option<bool>,
117
118    _lifetime: PhantomData<&'a ()>,
119}
120
121impl UpstreamOAuthProviderFilter<'_> {
122    /// Create a new [`UpstreamOAuthProviderFilter`] with default values
123    #[must_use]
124    pub fn new() -> Self {
125        Self::default()
126    }
127
128    /// Return only enabled providers
129    #[must_use]
130    pub const fn enabled_only(mut self) -> Self {
131        self.enabled = Some(true);
132        self
133    }
134
135    /// Return only disabled providers
136    #[must_use]
137    pub const fn disabled_only(mut self) -> Self {
138        self.enabled = Some(false);
139        self
140    }
141
142    /// Get the enabled filter
143    ///
144    /// Returns `None` if the filter is not set
145    #[must_use]
146    pub const fn enabled(&self) -> Option<bool> {
147        self.enabled
148    }
149}
150
151/// An [`UpstreamOAuthProviderRepository`] helps interacting with
152/// [`UpstreamOAuthProvider`] saved in the storage backend
153#[async_trait]
154pub trait UpstreamOAuthProviderRepository: Send + Sync {
155    /// The error type returned by the repository
156    type Error;
157
158    /// Lookup an upstream OAuth provider by its ID
159    ///
160    /// Returns `None` if the provider was not found
161    ///
162    /// # Parameters
163    ///
164    /// * `id`: The ID of the provider to lookup
165    ///
166    /// # Errors
167    ///
168    /// Returns [`Self::Error`] if the underlying repository fails
169    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
170
171    /// Add a new upstream OAuth provider
172    ///
173    /// Returns the newly created provider
174    ///
175    /// # Parameters
176    ///
177    /// * `rng`: A random number generator
178    /// * `clock`: The clock used to generate timestamps
179    /// * `params`: The parameters of the provider to add
180    ///
181    /// # Errors
182    ///
183    /// Returns [`Self::Error`] if the underlying repository fails
184    async fn add(
185        &mut self,
186        rng: &mut (dyn RngCore + Send),
187        clock: &dyn Clock,
188        params: UpstreamOAuthProviderParams,
189    ) -> Result<UpstreamOAuthProvider, Self::Error>;
190
191    /// Delete an upstream OAuth provider
192    ///
193    /// # Parameters
194    ///
195    /// * `provider`: The provider to delete
196    ///
197    /// # Errors
198    ///
199    /// Returns [`Self::Error`] if the underlying repository fails
200    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
201        self.delete_by_id(provider.id).await
202    }
203
204    /// Delete an upstream OAuth provider by its ID
205    ///
206    /// # Parameters
207    ///
208    /// * `id`: The ID of the provider to delete
209    ///
210    /// # Errors
211    ///
212    /// Returns [`Self::Error`] if the underlying repository fails
213    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
214
215    /// Insert or update an upstream OAuth provider
216    ///
217    /// # Parameters
218    ///
219    /// * `clock`: The clock used to generate timestamps
220    /// * `id`: The ID of the provider to update
221    /// * `params`: The parameters of the provider to update
222    ///
223    /// # Errors
224    ///
225    /// Returns [`Self::Error`] if the underlying repository fails
226    async fn upsert(
227        &mut self,
228        clock: &dyn Clock,
229        id: Ulid,
230        params: UpstreamOAuthProviderParams,
231    ) -> Result<UpstreamOAuthProvider, Self::Error>;
232
233    /// Disable an upstream OAuth provider
234    ///
235    /// Returns the disabled provider
236    ///
237    /// # Parameters
238    ///
239    /// * `clock`: The clock used to generate timestamps
240    /// * `provider`: The provider to disable
241    ///
242    /// # Errors
243    ///
244    /// Returns [`Self::Error`] if the underlying repository fails
245    async fn disable(
246        &mut self,
247        clock: &dyn Clock,
248        provider: UpstreamOAuthProvider,
249    ) -> Result<UpstreamOAuthProvider, Self::Error>;
250
251    /// List [`UpstreamOAuthProvider`] with the given filter and pagination
252    ///
253    /// # Parameters
254    ///
255    /// * `filter`: The filter to apply
256    /// * `pagination`: The pagination parameters
257    ///
258    /// # Errors
259    ///
260    /// Returns [`Self::Error`] if the underlying repository fails
261    async fn list(
262        &mut self,
263        filter: UpstreamOAuthProviderFilter<'_>,
264        pagination: Pagination,
265    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
266
267    /// Count the number of [`UpstreamOAuthProvider`] with the given filter
268    ///
269    /// # Parameters
270    ///
271    /// * `filter`: The filter to apply
272    ///
273    /// # Errors
274    ///
275    /// Returns [`Self::Error`] if the underlying repository fails
276    async fn count(
277        &mut self,
278        filter: UpstreamOAuthProviderFilter<'_>,
279    ) -> Result<usize, Self::Error>;
280
281    /// Get all enabled upstream OAuth providers
282    ///
283    /// # Errors
284    ///
285    /// Returns [`Self::Error`] if the underlying repository fails
286    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
287}
288
289repository_impl!(UpstreamOAuthProviderRepository:
290    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
291
292    async fn add(
293        &mut self,
294        rng: &mut (dyn RngCore + Send),
295        clock: &dyn Clock,
296        params: UpstreamOAuthProviderParams
297    ) -> Result<UpstreamOAuthProvider, Self::Error>;
298
299    async fn upsert(
300        &mut self,
301        clock: &dyn Clock,
302        id: Ulid,
303        params: UpstreamOAuthProviderParams
304    ) -> Result<UpstreamOAuthProvider, Self::Error>;
305
306    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
307
308    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
309
310    async fn disable(
311        &mut self,
312        clock: &dyn Clock,
313        provider: UpstreamOAuthProvider
314    ) -> Result<UpstreamOAuthProvider, Self::Error>;
315
316    async fn list(
317        &mut self,
318        filter: UpstreamOAuthProviderFilter<'_>,
319        pagination: Pagination
320    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
321
322    async fn count(
323        &mut self,
324        filter: UpstreamOAuthProviderFilter<'_>
325    ) -> Result<usize, Self::Error>;
326
327    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
328);