1use std::str::FromStr as _;
8
9use chrono::{DateTime, Utc};
10use mas_iana::oauth::PkceCodeChallengeMethod;
11use oauth2_types::{
12 pkce::{CodeChallengeError, CodeChallengeMethodExt},
13 requests::ResponseMode,
14 scope::{OPENID, PROFILE, Scope},
15};
16use rand::{
17 RngCore,
18 distributions::{Alphanumeric, DistString},
19};
20use ruma_common::UserId;
21use serde::Serialize;
22use ulid::Ulid;
23use url::Url;
24
25use super::session::Session;
26use crate::InvalidTransitionError;
27
28#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
29pub struct Pkce {
30 pub challenge_method: PkceCodeChallengeMethod,
31 pub challenge: String,
32}
33
34impl Pkce {
35 #[must_use]
37 pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
38 Pkce {
39 challenge_method,
40 challenge,
41 }
42 }
43
44 pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
50 self.challenge_method.verify(&self.challenge, verifier)
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
55pub struct AuthorizationCode {
56 pub code: String,
57 pub pkce: Option<Pkce>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
61#[serde(tag = "stage", rename_all = "lowercase")]
62pub enum AuthorizationGrantStage {
63 #[default]
64 Pending,
65 Fulfilled {
66 session_id: Ulid,
67 fulfilled_at: DateTime<Utc>,
68 },
69 Exchanged {
70 session_id: Ulid,
71 fulfilled_at: DateTime<Utc>,
72 exchanged_at: DateTime<Utc>,
73 },
74 Cancelled {
75 cancelled_at: DateTime<Utc>,
76 },
77}
78
79impl AuthorizationGrantStage {
80 #[must_use]
81 pub fn new() -> Self {
82 Self::Pending
83 }
84
85 fn fulfill(
86 self,
87 fulfilled_at: DateTime<Utc>,
88 session: &Session,
89 ) -> Result<Self, InvalidTransitionError> {
90 match self {
91 Self::Pending => Ok(Self::Fulfilled {
92 fulfilled_at,
93 session_id: session.id,
94 }),
95 _ => Err(InvalidTransitionError),
96 }
97 }
98
99 fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
100 match self {
101 Self::Fulfilled {
102 fulfilled_at,
103 session_id,
104 } => Ok(Self::Exchanged {
105 fulfilled_at,
106 exchanged_at,
107 session_id,
108 }),
109 _ => Err(InvalidTransitionError),
110 }
111 }
112
113 fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
114 match self {
115 Self::Pending => Ok(Self::Cancelled { cancelled_at }),
116 _ => Err(InvalidTransitionError),
117 }
118 }
119
120 #[must_use]
124 pub fn is_pending(&self) -> bool {
125 matches!(self, Self::Pending)
126 }
127
128 #[must_use]
132 pub fn is_fulfilled(&self) -> bool {
133 matches!(self, Self::Fulfilled { .. })
134 }
135
136 #[must_use]
140 pub fn is_exchanged(&self) -> bool {
141 matches!(self, Self::Exchanged { .. })
142 }
143}
144
145pub enum LoginHint<'a> {
146 MXID(&'a UserId),
147 Email(lettre::Address),
148 None,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
152pub struct AuthorizationGrant {
153 pub id: Ulid,
154 #[serde(flatten)]
155 pub stage: AuthorizationGrantStage,
156 pub code: Option<AuthorizationCode>,
157 pub client_id: Ulid,
158 pub redirect_uri: Url,
159 pub scope: Scope,
160 pub state: Option<String>,
161 pub nonce: Option<String>,
162 pub response_mode: ResponseMode,
163 pub response_type_id_token: bool,
164 pub created_at: DateTime<Utc>,
165 pub login_hint: Option<String>,
166 pub locale: Option<String>,
167}
168
169impl std::ops::Deref for AuthorizationGrant {
170 type Target = AuthorizationGrantStage;
171
172 fn deref(&self) -> &Self::Target {
173 &self.stage
174 }
175}
176
177impl AuthorizationGrant {
178 #[must_use]
186 pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint<'_> {
187 let Some(login_hint) = &self.login_hint else {
188 return LoginHint::None;
189 };
190
191 if let Some(value) = login_hint.strip_prefix("mxid:")
192 && let Ok(mxid) = <&UserId>::try_from(value)
193 && mxid.server_name() == homeserver
194 {
195 LoginHint::MXID(mxid)
196 } else if let Ok(email) = lettre::Address::from_str(login_hint) {
197 LoginHint::Email(email)
198 } else {
199 LoginHint::None
200 }
201 }
202
203 pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
211 self.stage = self.stage.exchange(exchanged_at)?;
212 Ok(self)
213 }
214
215 pub fn fulfill(
223 mut self,
224 fulfilled_at: DateTime<Utc>,
225 session: &Session,
226 ) -> Result<Self, InvalidTransitionError> {
227 self.stage = self.stage.fulfill(fulfilled_at, session)?;
228 Ok(self)
229 }
230
231 pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
243 self.stage = self.stage.cancel(canceld_at)?;
244 Ok(self)
245 }
246
247 #[doc(hidden)]
248 pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
249 Self {
250 id: Ulid::from_datetime_with_source(now.into(), rng),
251 stage: AuthorizationGrantStage::Pending,
252 code: Some(AuthorizationCode {
253 code: Alphanumeric.sample_string(rng, 10),
254 pkce: None,
255 }),
256 client_id: Ulid::from_datetime_with_source(now.into(), rng),
257 redirect_uri: Url::parse("http://localhost:8080").unwrap(),
258 scope: Scope::from_iter([OPENID, PROFILE]),
259 state: Some(Alphanumeric.sample_string(rng, 10)),
260 nonce: Some(Alphanumeric.sample_string(rng, 10)),
261 response_mode: ResponseMode::Query,
262 response_type_id_token: false,
263 created_at: now,
264 login_hint: Some(String::from("mxid:@example-user:example.com")),
265 locale: Some(String::from("fr")),
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use rand::SeedableRng;
273
274 use super::*;
275 use crate::clock::{Clock, MockClock};
276
277 #[test]
278 fn no_login_hint() {
279 let now = MockClock::default().now();
280 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
281
282 let grant = AuthorizationGrant {
283 login_hint: None,
284 ..AuthorizationGrant::sample(now, &mut rng)
285 };
286
287 let hint = grant.parse_login_hint("example.com");
288
289 assert!(matches!(hint, LoginHint::None));
290 }
291
292 #[test]
293 fn valid_login_hint() {
294 let now = MockClock::default().now();
295 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
296
297 let grant = AuthorizationGrant {
298 login_hint: Some(String::from("mxid:@example-user:example.com")),
299 ..AuthorizationGrant::sample(now, &mut rng)
300 };
301
302 let hint = grant.parse_login_hint("example.com");
303
304 assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
305 }
306
307 #[test]
308 fn valid_login_hint_with_email() {
309 let now = MockClock::default().now();
310 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
311
312 let grant = AuthorizationGrant {
313 login_hint: Some(String::from("example@user")),
314 ..AuthorizationGrant::sample(now, &mut rng)
315 };
316
317 let hint = grant.parse_login_hint("example.com");
318
319 assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user"));
320 }
321
322 #[test]
323 fn invalid_login_hint() {
324 let now = MockClock::default().now();
325 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
326
327 let grant = AuthorizationGrant {
328 login_hint: Some(String::from("example-user")),
329 ..AuthorizationGrant::sample(now, &mut rng)
330 };
331
332 let hint = grant.parse_login_hint("example.com");
333
334 assert!(matches!(hint, LoginHint::None));
335 }
336
337 #[test]
338 fn valid_login_hint_for_wrong_homeserver() {
339 let now = MockClock::default().now();
340 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
341
342 let grant = AuthorizationGrant {
343 login_hint: Some(String::from("mxid:@example-user:matrix.org")),
344 ..AuthorizationGrant::sample(now, &mut rng)
345 };
346
347 let hint = grant.parse_login_hint("example.com");
348
349 assert!(matches!(hint, LoginHint::None));
350 }
351
352 #[test]
353 fn unknown_login_hint_type() {
354 let now = MockClock::default().now();
355 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
356
357 let grant = AuthorizationGrant {
358 login_hint: Some(String::from("something:anything")),
359 ..AuthorizationGrant::sample(now, &mut rng)
360 };
361
362 let hint = grant.parse_login_hint("example.com");
363
364 assert!(matches!(hint, LoginHint::None));
365 }
366}