mas_storage/upstream_oauth2/session.rs
1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use mas_data_model::{
9 Clock, UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider,
10};
11use rand_core::RngCore;
12use ulid::Ulid;
13
14use crate::{Pagination, pagination::Page, repository_impl};
15
16/// Filter parameters for listing upstream OAuth sessions
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
18pub struct UpstreamOAuthSessionFilter<'a> {
19 provider: Option<&'a UpstreamOAuthProvider>,
20 sub_claim: Option<&'a str>,
21 sid_claim: Option<&'a str>,
22}
23
24impl<'a> UpstreamOAuthSessionFilter<'a> {
25 /// Create a new [`UpstreamOAuthSessionFilter`] with default values
26 #[must_use]
27 pub fn new() -> Self {
28 Self::default()
29 }
30
31 /// Set the upstream OAuth provider for which to list sessions
32 #[must_use]
33 pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
34 self.provider = Some(provider);
35 self
36 }
37
38 /// Get the upstream OAuth provider filter
39 ///
40 /// Returns [`None`] if no filter was set
41 #[must_use]
42 pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
43 self.provider
44 }
45
46 /// Set the `sub` claim to filter by
47 #[must_use]
48 pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
49 self.sub_claim = Some(sub_claim);
50 self
51 }
52
53 /// Get the `sub` claim filter
54 ///
55 /// Returns [`None`] if no filter was set
56 #[must_use]
57 pub fn sub_claim(&self) -> Option<&str> {
58 self.sub_claim
59 }
60
61 /// Set the `sid` claim to filter by
62 #[must_use]
63 pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
64 self.sid_claim = Some(sid_claim);
65 self
66 }
67
68 /// Get the `sid` claim filter
69 ///
70 /// Returns [`None`] if no filter was set
71 #[must_use]
72 pub fn sid_claim(&self) -> Option<&str> {
73 self.sid_claim
74 }
75}
76
77/// An [`UpstreamOAuthSessionRepository`] helps interacting with
78/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend
79#[async_trait]
80pub trait UpstreamOAuthSessionRepository: Send + Sync {
81 /// The error type returned by the repository
82 type Error;
83
84 /// Lookup a session by its ID
85 ///
86 /// Returns `None` if the session does not exist
87 ///
88 /// # Parameters
89 ///
90 /// * `id`: the ID of the session to lookup
91 ///
92 /// # Errors
93 ///
94 /// Returns [`Self::Error`] if the underlying repository fails
95 async fn lookup(
96 &mut self,
97 id: Ulid,
98 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
99
100 /// Add a session to the database
101 ///
102 /// Returns the newly created session
103 ///
104 /// # Parameters
105 ///
106 /// * `rng`: the random number generator to use
107 /// * `clock`: the clock source
108 /// * `upstream_oauth_provider`: the upstream OAuth provider for which to
109 /// create the session
110 /// * `state`: the authorization grant `state` parameter sent to the
111 /// upstream OAuth provider
112 /// * `code_challenge_verifier`: the code challenge verifier used in this
113 /// session, if PKCE is being used
114 /// * `nonce`: the `nonce` used in this session if in OIDC mode
115 ///
116 /// # Errors
117 ///
118 /// Returns [`Self::Error`] if the underlying repository fails
119 async fn add(
120 &mut self,
121 rng: &mut (dyn RngCore + Send),
122 clock: &dyn Clock,
123 upstream_oauth_provider: &UpstreamOAuthProvider,
124 state: String,
125 code_challenge_verifier: Option<String>,
126 nonce: Option<String>,
127 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
128
129 /// Mark a session as completed and associate the given link
130 ///
131 /// Returns the updated session
132 ///
133 /// # Parameters
134 ///
135 /// * `clock`: the clock source
136 /// * `upstream_oauth_authorization_session`: the session to update
137 /// * `upstream_oauth_link`: the link to associate with the session
138 /// * `id_token`: the ID token returned by the upstream OAuth provider, if
139 /// present
140 /// * `id_token_claims`: the claims contained in the ID token, if present
141 /// * `extra_callback_parameters`: the extra query parameters returned in
142 /// the callback, if any
143 /// * `userinfo`: the user info returned by the upstream OAuth provider, if
144 /// requested
145 ///
146 /// # Errors
147 ///
148 /// Returns [`Self::Error`] if the underlying repository fails
149 #[expect(clippy::too_many_arguments)]
150 async fn complete_with_link(
151 &mut self,
152 clock: &dyn Clock,
153 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
154 upstream_oauth_link: &UpstreamOAuthLink,
155 id_token: Option<String>,
156 id_token_claims: Option<serde_json::Value>,
157 extra_callback_parameters: Option<serde_json::Value>,
158 userinfo: Option<serde_json::Value>,
159 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
160
161 /// Mark a session as consumed
162 ///
163 /// Returns the updated session
164 ///
165 /// # Parameters
166 ///
167 /// * `clock`: the clock source
168 /// * `upstream_oauth_authorization_session`: the session to consume
169 ///
170 /// # Errors
171 ///
172 /// Returns [`Self::Error`] if the underlying repository fails
173 async fn consume(
174 &mut self,
175 clock: &dyn Clock,
176 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
177 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
178
179 /// List [`UpstreamOAuthAuthorizationSession`] with the given filter and
180 /// pagination
181 ///
182 /// # Parameters
183 ///
184 /// * `filter`: The filter to apply
185 /// * `pagination`: The pagination parameters
186 ///
187 /// # Errors
188 ///
189 /// Returns [`Self::Error`] if the underlying repository fails
190 async fn list(
191 &mut self,
192 filter: UpstreamOAuthSessionFilter<'_>,
193 pagination: Pagination,
194 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
195
196 /// Count the number of [`UpstreamOAuthAuthorizationSession`] with the given
197 /// filter
198 ///
199 /// # Parameters
200 ///
201 /// * `filter`: The filter to apply
202 ///
203 /// # Errors
204 ///
205 /// Returns [`Self::Error`] if the underlying repository fails
206 async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>)
207 -> Result<usize, Self::Error>;
208}
209
210repository_impl!(UpstreamOAuthSessionRepository:
211 async fn lookup(
212 &mut self,
213 id: Ulid,
214 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
215
216 async fn add(
217 &mut self,
218 rng: &mut (dyn RngCore + Send),
219 clock: &dyn Clock,
220 upstream_oauth_provider: &UpstreamOAuthProvider,
221 state: String,
222 code_challenge_verifier: Option<String>,
223 nonce: Option<String>,
224 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
225
226 async fn complete_with_link(
227 &mut self,
228 clock: &dyn Clock,
229 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
230 upstream_oauth_link: &UpstreamOAuthLink,
231 id_token: Option<String>,
232 id_token_claims: Option<serde_json::Value>,
233 extra_callback_parameters: Option<serde_json::Value>,
234 userinfo: Option<serde_json::Value>,
235 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
236
237 async fn consume(
238 &mut self,
239 clock: &dyn Clock,
240 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
241 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
242
243 async fn list(
244 &mut self,
245 filter: UpstreamOAuthSessionFilter<'_>,
246 pagination: Pagination,
247 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
248
249 async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>) -> Result<usize, Self::Error>;
250);