mas_storage/upstream_oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use mas_data_model::{
9    Clock, UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider,
10};
11use rand_core::RngCore;
12use ulid::Ulid;
13
14use crate::{Pagination, pagination::Page, repository_impl};
15
16/// Filter parameters for listing upstream OAuth sessions
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
18pub struct UpstreamOAuthSessionFilter<'a> {
19    provider: Option<&'a UpstreamOAuthProvider>,
20    sub_claim: Option<&'a str>,
21    sid_claim: Option<&'a str>,
22}
23
24impl<'a> UpstreamOAuthSessionFilter<'a> {
25    /// Create a new [`UpstreamOAuthSessionFilter`] with default values
26    #[must_use]
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    /// Set the upstream OAuth provider for which to list sessions
32    #[must_use]
33    pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
34        self.provider = Some(provider);
35        self
36    }
37
38    /// Get the upstream OAuth provider filter
39    ///
40    /// Returns [`None`] if no filter was set
41    #[must_use]
42    pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
43        self.provider
44    }
45
46    /// Set the `sub` claim to filter by
47    #[must_use]
48    pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
49        self.sub_claim = Some(sub_claim);
50        self
51    }
52
53    /// Get the `sub` claim filter
54    ///
55    /// Returns [`None`] if no filter was set
56    #[must_use]
57    pub fn sub_claim(&self) -> Option<&str> {
58        self.sub_claim
59    }
60
61    /// Set the `sid` claim to filter by
62    #[must_use]
63    pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
64        self.sid_claim = Some(sid_claim);
65        self
66    }
67
68    /// Get the `sid` claim filter
69    ///
70    /// Returns [`None`] if no filter was set
71    #[must_use]
72    pub fn sid_claim(&self) -> Option<&str> {
73        self.sid_claim
74    }
75}
76
77/// An [`UpstreamOAuthSessionRepository`] helps interacting with
78/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend
79#[async_trait]
80pub trait UpstreamOAuthSessionRepository: Send + Sync {
81    /// The error type returned by the repository
82    type Error;
83
84    /// Lookup a session by its ID
85    ///
86    /// Returns `None` if the session does not exist
87    ///
88    /// # Parameters
89    ///
90    /// * `id`: the ID of the session to lookup
91    ///
92    /// # Errors
93    ///
94    /// Returns [`Self::Error`] if the underlying repository fails
95    async fn lookup(
96        &mut self,
97        id: Ulid,
98    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
99
100    /// Add a session to the database
101    ///
102    /// Returns the newly created session
103    ///
104    /// # Parameters
105    ///
106    /// * `rng`: the random number generator to use
107    /// * `clock`: the clock source
108    /// * `upstream_oauth_provider`: the upstream OAuth provider for which to
109    ///   create the session
110    /// * `state`: the authorization grant `state` parameter sent to the
111    ///   upstream OAuth provider
112    /// * `code_challenge_verifier`: the code challenge verifier used in this
113    ///   session, if PKCE is being used
114    /// * `nonce`: the `nonce` used in this session if in OIDC mode
115    ///
116    /// # Errors
117    ///
118    /// Returns [`Self::Error`] if the underlying repository fails
119    async fn add(
120        &mut self,
121        rng: &mut (dyn RngCore + Send),
122        clock: &dyn Clock,
123        upstream_oauth_provider: &UpstreamOAuthProvider,
124        state: String,
125        code_challenge_verifier: Option<String>,
126        nonce: Option<String>,
127    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
128
129    /// Mark a session as completed and associate the given link
130    ///
131    /// Returns the updated session
132    ///
133    /// # Parameters
134    ///
135    /// * `clock`: the clock source
136    /// * `upstream_oauth_authorization_session`: the session to update
137    /// * `upstream_oauth_link`: the link to associate with the session
138    /// * `id_token`: the ID token returned by the upstream OAuth provider, if
139    ///   present
140    /// * `id_token_claims`: the claims contained in the ID token, if present
141    /// * `extra_callback_parameters`: the extra query parameters returned in
142    ///   the callback, if any
143    /// * `userinfo`: the user info returned by the upstream OAuth provider, if
144    ///   requested
145    ///
146    /// # Errors
147    ///
148    /// Returns [`Self::Error`] if the underlying repository fails
149    #[expect(clippy::too_many_arguments)]
150    async fn complete_with_link(
151        &mut self,
152        clock: &dyn Clock,
153        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
154        upstream_oauth_link: &UpstreamOAuthLink,
155        id_token: Option<String>,
156        id_token_claims: Option<serde_json::Value>,
157        extra_callback_parameters: Option<serde_json::Value>,
158        userinfo: Option<serde_json::Value>,
159    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
160
161    /// Mark a session as consumed
162    ///
163    /// Returns the updated session
164    ///
165    /// # Parameters
166    ///
167    /// * `clock`: the clock source
168    /// * `upstream_oauth_authorization_session`: the session to consume
169    ///
170    /// # Errors
171    ///
172    /// Returns [`Self::Error`] if the underlying repository fails
173    async fn consume(
174        &mut self,
175        clock: &dyn Clock,
176        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
177    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
178
179    /// List [`UpstreamOAuthAuthorizationSession`] with the given filter and
180    /// pagination
181    ///
182    /// # Parameters
183    ///
184    /// * `filter`: The filter to apply
185    /// * `pagination`: The pagination parameters
186    ///
187    /// # Errors
188    ///
189    /// Returns [`Self::Error`] if the underlying repository fails
190    async fn list(
191        &mut self,
192        filter: UpstreamOAuthSessionFilter<'_>,
193        pagination: Pagination,
194    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
195
196    /// Count the number of [`UpstreamOAuthAuthorizationSession`] with the given
197    /// filter
198    ///
199    /// # Parameters
200    ///
201    /// * `filter`: The filter to apply
202    ///
203    /// # Errors
204    ///
205    /// Returns [`Self::Error`] if the underlying repository fails
206    async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>)
207    -> Result<usize, Self::Error>;
208}
209
210repository_impl!(UpstreamOAuthSessionRepository:
211    async fn lookup(
212        &mut self,
213        id: Ulid,
214    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
215
216    async fn add(
217        &mut self,
218        rng: &mut (dyn RngCore + Send),
219        clock: &dyn Clock,
220        upstream_oauth_provider: &UpstreamOAuthProvider,
221        state: String,
222        code_challenge_verifier: Option<String>,
223        nonce: Option<String>,
224    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
225
226    async fn complete_with_link(
227        &mut self,
228        clock: &dyn Clock,
229        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
230        upstream_oauth_link: &UpstreamOAuthLink,
231        id_token: Option<String>,
232        id_token_claims: Option<serde_json::Value>,
233        extra_callback_parameters: Option<serde_json::Value>,
234        userinfo: Option<serde_json::Value>,
235    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
236
237    async fn consume(
238        &mut self,
239        clock: &dyn Clock,
240        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
241    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
242
243    async fn list(
244        &mut self,
245        filter: UpstreamOAuthSessionFilter<'_>,
246        pagination: Pagination,
247    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
248
249    async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>) -> Result<usize, Self::Error>;
250);