mas_data_model/oauth2/
authorization_grant.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::str::FromStr as _;
8
9use chrono::{DateTime, Utc};
10use mas_iana::oauth::PkceCodeChallengeMethod;
11use oauth2_types::{
12    pkce::{CodeChallengeError, CodeChallengeMethodExt},
13    requests::ResponseMode,
14    scope::{OPENID, PROFILE, Scope},
15};
16use rand::{
17    RngCore,
18    distributions::{Alphanumeric, DistString},
19};
20use ruma_common::UserId;
21use serde::Serialize;
22use ulid::Ulid;
23use url::Url;
24
25use super::session::Session;
26use crate::InvalidTransitionError;
27
28#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
29pub struct Pkce {
30    pub challenge_method: PkceCodeChallengeMethod,
31    pub challenge: String,
32}
33
34impl Pkce {
35    /// Create a new PKCE challenge, with the given method and challenge.
36    #[must_use]
37    pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
38        Pkce {
39            challenge_method,
40            challenge,
41        }
42    }
43
44    /// Verify the PKCE challenge.
45    ///
46    /// # Errors
47    ///
48    /// Returns an error if the verifier is invalid.
49    pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
50        self.challenge_method.verify(&self.challenge, verifier)
51    }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
55pub struct AuthorizationCode {
56    pub code: String,
57    pub pkce: Option<Pkce>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
61#[serde(tag = "stage", rename_all = "lowercase")]
62pub enum AuthorizationGrantStage {
63    #[default]
64    Pending,
65    Fulfilled {
66        session_id: Ulid,
67        fulfilled_at: DateTime<Utc>,
68    },
69    Exchanged {
70        session_id: Ulid,
71        fulfilled_at: DateTime<Utc>,
72        exchanged_at: DateTime<Utc>,
73    },
74    Cancelled {
75        cancelled_at: DateTime<Utc>,
76    },
77}
78
79impl AuthorizationGrantStage {
80    #[must_use]
81    pub fn new() -> Self {
82        Self::Pending
83    }
84
85    fn fulfill(
86        self,
87        fulfilled_at: DateTime<Utc>,
88        session: &Session,
89    ) -> Result<Self, InvalidTransitionError> {
90        match self {
91            Self::Pending => Ok(Self::Fulfilled {
92                fulfilled_at,
93                session_id: session.id,
94            }),
95            _ => Err(InvalidTransitionError),
96        }
97    }
98
99    fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
100        match self {
101            Self::Fulfilled {
102                fulfilled_at,
103                session_id,
104            } => Ok(Self::Exchanged {
105                fulfilled_at,
106                exchanged_at,
107                session_id,
108            }),
109            _ => Err(InvalidTransitionError),
110        }
111    }
112
113    fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
114        match self {
115            Self::Pending => Ok(Self::Cancelled { cancelled_at }),
116            _ => Err(InvalidTransitionError),
117        }
118    }
119
120    /// Returns `true` if the authorization grant stage is [`Pending`].
121    ///
122    /// [`Pending`]: AuthorizationGrantStage::Pending
123    #[must_use]
124    pub fn is_pending(&self) -> bool {
125        matches!(self, Self::Pending)
126    }
127
128    /// Returns `true` if the authorization grant stage is [`Fulfilled`].
129    ///
130    /// [`Fulfilled`]: AuthorizationGrantStage::Fulfilled
131    #[must_use]
132    pub fn is_fulfilled(&self) -> bool {
133        matches!(self, Self::Fulfilled { .. })
134    }
135
136    /// Returns `true` if the authorization grant stage is [`Exchanged`].
137    ///
138    /// [`Exchanged`]: AuthorizationGrantStage::Exchanged
139    #[must_use]
140    pub fn is_exchanged(&self) -> bool {
141        matches!(self, Self::Exchanged { .. })
142    }
143}
144
145pub enum LoginHint<'a> {
146    MXID(&'a UserId),
147    Email(lettre::Address),
148    None,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
152pub struct AuthorizationGrant {
153    pub id: Ulid,
154    #[serde(flatten)]
155    pub stage: AuthorizationGrantStage,
156    pub code: Option<AuthorizationCode>,
157    pub client_id: Ulid,
158    pub redirect_uri: Url,
159    pub scope: Scope,
160    pub state: Option<String>,
161    pub nonce: Option<String>,
162    pub response_mode: ResponseMode,
163    pub response_type_id_token: bool,
164    pub created_at: DateTime<Utc>,
165    pub login_hint: Option<String>,
166    pub locale: Option<String>,
167}
168
169impl std::ops::Deref for AuthorizationGrant {
170    type Target = AuthorizationGrantStage;
171
172    fn deref(&self) -> &Self::Target {
173        &self.stage
174    }
175}
176
177impl AuthorizationGrant {
178    /// Parse a `login_hint`
179    ///
180    /// Returns `LoginHint::MXID` for valid mxid 'mxid:@john.doe:example.com'
181    ///
182    /// Returns `LoginHint::Email` for valid email 'john.doe@example.com'
183    ///
184    /// Otherwise returns `LoginHint::None`
185    #[must_use]
186    pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint<'_> {
187        let Some(login_hint) = &self.login_hint else {
188            return LoginHint::None;
189        };
190
191        if let Some(value) = login_hint.strip_prefix("mxid:")
192            && let Ok(mxid) = <&UserId>::try_from(value)
193            && mxid.server_name() == homeserver
194        {
195            LoginHint::MXID(mxid)
196        } else if let Ok(email) = lettre::Address::from_str(login_hint) {
197            LoginHint::Email(email)
198        } else {
199            LoginHint::None
200        }
201    }
202
203    /// Mark the authorization grant as exchanged.
204    ///
205    /// # Errors
206    ///
207    /// Returns an error if the authorization grant is not [`Fulfilled`].
208    ///
209    /// [`Fulfilled`]: AuthorizationGrantStage::Fulfilled
210    pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
211        self.stage = self.stage.exchange(exchanged_at)?;
212        Ok(self)
213    }
214
215    /// Mark the authorization grant as fulfilled.
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if the authorization grant is not [`Pending`].
220    ///
221    /// [`Pending`]: AuthorizationGrantStage::Pending
222    pub fn fulfill(
223        mut self,
224        fulfilled_at: DateTime<Utc>,
225        session: &Session,
226    ) -> Result<Self, InvalidTransitionError> {
227        self.stage = self.stage.fulfill(fulfilled_at, session)?;
228        Ok(self)
229    }
230
231    /// Mark the authorization grant as cancelled.
232    ///
233    /// # Errors
234    ///
235    /// Returns an error if the authorization grant is not [`Pending`].
236    ///
237    /// [`Pending`]: AuthorizationGrantStage::Pending
238    ///
239    /// # TODO
240    ///
241    /// This appears to be unused
242    pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
243        self.stage = self.stage.cancel(canceld_at)?;
244        Ok(self)
245    }
246
247    #[doc(hidden)]
248    pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
249        Self {
250            id: Ulid::from_datetime_with_source(now.into(), rng),
251            stage: AuthorizationGrantStage::Pending,
252            code: Some(AuthorizationCode {
253                code: Alphanumeric.sample_string(rng, 10),
254                pkce: None,
255            }),
256            client_id: Ulid::from_datetime_with_source(now.into(), rng),
257            redirect_uri: Url::parse("http://localhost:8080").unwrap(),
258            scope: Scope::from_iter([OPENID, PROFILE]),
259            state: Some(Alphanumeric.sample_string(rng, 10)),
260            nonce: Some(Alphanumeric.sample_string(rng, 10)),
261            response_mode: ResponseMode::Query,
262            response_type_id_token: false,
263            created_at: now,
264            login_hint: Some(String::from("mxid:@example-user:example.com")),
265            locale: Some(String::from("fr")),
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use rand::SeedableRng;
273
274    use super::*;
275    use crate::clock::{Clock, MockClock};
276
277    #[test]
278    fn no_login_hint() {
279        let now = MockClock::default().now();
280        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
281
282        let grant = AuthorizationGrant {
283            login_hint: None,
284            ..AuthorizationGrant::sample(now, &mut rng)
285        };
286
287        let hint = grant.parse_login_hint("example.com");
288
289        assert!(matches!(hint, LoginHint::None));
290    }
291
292    #[test]
293    fn valid_login_hint() {
294        let now = MockClock::default().now();
295        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
296
297        let grant = AuthorizationGrant {
298            login_hint: Some(String::from("mxid:@example-user:example.com")),
299            ..AuthorizationGrant::sample(now, &mut rng)
300        };
301
302        let hint = grant.parse_login_hint("example.com");
303
304        assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
305    }
306
307    #[test]
308    fn valid_login_hint_with_email() {
309        let now = MockClock::default().now();
310        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
311
312        let grant = AuthorizationGrant {
313            login_hint: Some(String::from("example@user")),
314            ..AuthorizationGrant::sample(now, &mut rng)
315        };
316
317        let hint = grant.parse_login_hint("example.com");
318
319        assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user"));
320    }
321
322    #[test]
323    fn invalid_login_hint() {
324        let now = MockClock::default().now();
325        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
326
327        let grant = AuthorizationGrant {
328            login_hint: Some(String::from("example-user")),
329            ..AuthorizationGrant::sample(now, &mut rng)
330        };
331
332        let hint = grant.parse_login_hint("example.com");
333
334        assert!(matches!(hint, LoginHint::None));
335    }
336
337    #[test]
338    fn valid_login_hint_for_wrong_homeserver() {
339        let now = MockClock::default().now();
340        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
341
342        let grant = AuthorizationGrant {
343            login_hint: Some(String::from("mxid:@example-user:matrix.org")),
344            ..AuthorizationGrant::sample(now, &mut rng)
345        };
346
347        let hint = grant.parse_login_hint("example.com");
348
349        assert!(matches!(hint, LoginHint::None));
350    }
351
352    #[test]
353    fn unknown_login_hint_type() {
354        let now = MockClock::default().now();
355        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
356
357        let grant = AuthorizationGrant {
358            login_hint: Some(String::from("something:anything")),
359            ..AuthorizationGrant::sample(now, &mut rng)
360        };
361
362        let hint = grant.parse_login_hint("example.com");
363
364        assert!(matches!(hint, LoginHint::None));
365    }
366}