1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Clock, DeviceCodeGrant, DeviceCodeGrantState, Session};
12use mas_storage::oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository};
13use oauth2_types::scope::Scope;
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, ExecuteExt, errors::DatabaseInconsistencyError};
20
21pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
24 conn: &'c mut PgConnection,
25}
26
27impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
28 pub fn new(conn: &'c mut PgConnection) -> Self {
31 Self { conn }
32 }
33}
34
35struct OAuth2DeviceGrantLookup {
36 oauth2_device_code_grant_id: Uuid,
37 oauth2_client_id: Uuid,
38 scope: String,
39 device_code: String,
40 user_code: String,
41 created_at: DateTime<Utc>,
42 expires_at: DateTime<Utc>,
43 fulfilled_at: Option<DateTime<Utc>>,
44 rejected_at: Option<DateTime<Utc>>,
45 exchanged_at: Option<DateTime<Utc>>,
46 user_session_id: Option<Uuid>,
47 oauth2_session_id: Option<Uuid>,
48 ip_address: Option<IpAddr>,
49 user_agent: Option<String>,
50}
51
52impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
53 type Error = DatabaseInconsistencyError;
54
55 fn try_from(
56 OAuth2DeviceGrantLookup {
57 oauth2_device_code_grant_id,
58 oauth2_client_id,
59 scope,
60 device_code,
61 user_code,
62 created_at,
63 expires_at,
64 fulfilled_at,
65 rejected_at,
66 exchanged_at,
67 user_session_id,
68 oauth2_session_id,
69 ip_address,
70 user_agent,
71 }: OAuth2DeviceGrantLookup,
72 ) -> Result<Self, Self::Error> {
73 let id = Ulid::from(oauth2_device_code_grant_id);
74 let client_id = Ulid::from(oauth2_client_id);
75
76 let scope: Scope = scope.parse().map_err(|e| {
77 DatabaseInconsistencyError::on("oauth2_authorization_grants")
78 .column("scope")
79 .row(id)
80 .source(e)
81 })?;
82
83 let state = match (
84 fulfilled_at,
85 rejected_at,
86 exchanged_at,
87 user_session_id,
88 oauth2_session_id,
89 ) {
90 (None, None, None, None, None) => DeviceCodeGrantState::Pending,
91
92 (Some(fulfilled_at), None, None, Some(user_session_id), None) => {
93 DeviceCodeGrantState::Fulfilled {
94 browser_session_id: Ulid::from(user_session_id),
95 fulfilled_at,
96 }
97 }
98
99 (None, Some(rejected_at), None, Some(user_session_id), None) => {
100 DeviceCodeGrantState::Rejected {
101 browser_session_id: Ulid::from(user_session_id),
102 rejected_at,
103 }
104 }
105
106 (
107 Some(fulfilled_at),
108 None,
109 Some(exchanged_at),
110 Some(user_session_id),
111 Some(oauth2_session_id),
112 ) => DeviceCodeGrantState::Exchanged {
113 browser_session_id: Ulid::from(user_session_id),
114 session_id: Ulid::from(oauth2_session_id),
115 fulfilled_at,
116 exchanged_at,
117 },
118
119 _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
120 };
121
122 Ok(DeviceCodeGrant {
123 id,
124 state,
125 client_id,
126 scope,
127 user_code,
128 device_code,
129 created_at,
130 expires_at,
131 ip_address,
132 user_agent,
133 })
134 }
135}
136
137#[async_trait]
138impl OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'_> {
139 type Error = DatabaseError;
140
141 #[tracing::instrument(
142 name = "db.oauth2_device_code_grant.add",
143 skip_all,
144 fields(
145 db.query.text,
146 oauth2_device_code.id,
147 oauth2_device_code.scope = %params.scope,
148 oauth2_client.id = %params.client.id,
149 ),
150 err,
151 )]
152 async fn add(
153 &mut self,
154 rng: &mut (dyn RngCore + Send),
155 clock: &dyn Clock,
156 params: OAuth2DeviceCodeGrantParams<'_>,
157 ) -> Result<DeviceCodeGrant, Self::Error> {
158 let now = clock.now();
159 let id = Ulid::from_datetime_with_source(now.into(), rng);
160 tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
161
162 let created_at = now;
163 let expires_at = now + params.expires_in;
164 let client_id = params.client.id;
165
166 sqlx::query!(
167 r#"
168 INSERT INTO "oauth2_device_code_grant"
169 ( oauth2_device_code_grant_id
170 , oauth2_client_id
171 , scope
172 , device_code
173 , user_code
174 , created_at
175 , expires_at
176 , ip_address
177 , user_agent
178 )
179 VALUES
180 ($1, $2, $3, $4, $5, $6, $7, $8, $9)
181 "#,
182 Uuid::from(id),
183 Uuid::from(client_id),
184 params.scope.to_string(),
185 ¶ms.device_code,
186 ¶ms.user_code,
187 created_at,
188 expires_at,
189 params.ip_address as Option<IpAddr>,
190 params.user_agent.as_deref(),
191 )
192 .traced()
193 .execute(&mut *self.conn)
194 .await?;
195
196 Ok(DeviceCodeGrant {
197 id,
198 state: DeviceCodeGrantState::Pending,
199 client_id,
200 scope: params.scope,
201 user_code: params.user_code,
202 device_code: params.device_code,
203 created_at,
204 expires_at,
205 ip_address: params.ip_address,
206 user_agent: params.user_agent,
207 })
208 }
209
210 #[tracing::instrument(
211 name = "db.oauth2_device_code_grant.lookup",
212 skip_all,
213 fields(
214 db.query.text,
215 oauth2_device_code.id = %id,
216 ),
217 err,
218 )]
219 async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
220 let res = sqlx::query_as!(
221 OAuth2DeviceGrantLookup,
222 r#"
223 SELECT oauth2_device_code_grant_id
224 , oauth2_client_id
225 , scope
226 , device_code
227 , user_code
228 , created_at
229 , expires_at
230 , fulfilled_at
231 , rejected_at
232 , exchanged_at
233 , user_session_id
234 , oauth2_session_id
235 , ip_address as "ip_address: IpAddr"
236 , user_agent
237 FROM
238 oauth2_device_code_grant
239
240 WHERE oauth2_device_code_grant_id = $1
241 "#,
242 Uuid::from(id),
243 )
244 .traced()
245 .fetch_optional(&mut *self.conn)
246 .await?;
247
248 let Some(res) = res else { return Ok(None) };
249
250 Ok(Some(res.try_into()?))
251 }
252
253 #[tracing::instrument(
254 name = "db.oauth2_device_code_grant.find_by_user_code",
255 skip_all,
256 fields(
257 db.query.text,
258 oauth2_device_code.user_code = %user_code,
259 ),
260 err,
261 )]
262 async fn find_by_user_code(
263 &mut self,
264 user_code: &str,
265 ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
266 let res = sqlx::query_as!(
267 OAuth2DeviceGrantLookup,
268 r#"
269 SELECT oauth2_device_code_grant_id
270 , oauth2_client_id
271 , scope
272 , device_code
273 , user_code
274 , created_at
275 , expires_at
276 , fulfilled_at
277 , rejected_at
278 , exchanged_at
279 , user_session_id
280 , oauth2_session_id
281 , ip_address as "ip_address: IpAddr"
282 , user_agent
283 FROM
284 oauth2_device_code_grant
285
286 WHERE user_code = $1
287 "#,
288 user_code,
289 )
290 .traced()
291 .fetch_optional(&mut *self.conn)
292 .await?;
293
294 let Some(res) = res else { return Ok(None) };
295
296 Ok(Some(res.try_into()?))
297 }
298
299 #[tracing::instrument(
300 name = "db.oauth2_device_code_grant.find_by_device_code",
301 skip_all,
302 fields(
303 db.query.text,
304 oauth2_device_code.device_code = %device_code,
305 ),
306 err,
307 )]
308 async fn find_by_device_code(
309 &mut self,
310 device_code: &str,
311 ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
312 let res = sqlx::query_as!(
313 OAuth2DeviceGrantLookup,
314 r#"
315 SELECT oauth2_device_code_grant_id
316 , oauth2_client_id
317 , scope
318 , device_code
319 , user_code
320 , created_at
321 , expires_at
322 , fulfilled_at
323 , rejected_at
324 , exchanged_at
325 , user_session_id
326 , oauth2_session_id
327 , ip_address as "ip_address: IpAddr"
328 , user_agent
329 FROM
330 oauth2_device_code_grant
331
332 WHERE device_code = $1
333 "#,
334 device_code,
335 )
336 .traced()
337 .fetch_optional(&mut *self.conn)
338 .await?;
339
340 let Some(res) = res else { return Ok(None) };
341
342 Ok(Some(res.try_into()?))
343 }
344
345 #[tracing::instrument(
346 name = "db.oauth2_device_code_grant.fulfill",
347 skip_all,
348 fields(
349 db.query.text,
350 oauth2_device_code.id = %device_code_grant.id,
351 oauth2_client.id = %device_code_grant.client_id,
352 browser_session.id = %browser_session.id,
353 user.id = %browser_session.user.id,
354 ),
355 err,
356 )]
357 async fn fulfill(
358 &mut self,
359 clock: &dyn Clock,
360 device_code_grant: DeviceCodeGrant,
361 browser_session: &BrowserSession,
362 ) -> Result<DeviceCodeGrant, Self::Error> {
363 let fulfilled_at = clock.now();
364 let device_code_grant = device_code_grant
365 .fulfill(browser_session, fulfilled_at)
366 .map_err(DatabaseError::to_invalid_operation)?;
367
368 let res = sqlx::query!(
369 r#"
370 UPDATE oauth2_device_code_grant
371 SET fulfilled_at = $1
372 , user_session_id = $2
373 WHERE oauth2_device_code_grant_id = $3
374 "#,
375 fulfilled_at,
376 Uuid::from(browser_session.id),
377 Uuid::from(device_code_grant.id),
378 )
379 .traced()
380 .execute(&mut *self.conn)
381 .await?;
382
383 DatabaseError::ensure_affected_rows(&res, 1)?;
384
385 Ok(device_code_grant)
386 }
387
388 #[tracing::instrument(
389 name = "db.oauth2_device_code_grant.reject",
390 skip_all,
391 fields(
392 db.query.text,
393 oauth2_device_code.id = %device_code_grant.id,
394 oauth2_client.id = %device_code_grant.client_id,
395 browser_session.id = %browser_session.id,
396 user.id = %browser_session.user.id,
397 ),
398 err,
399 )]
400 async fn reject(
401 &mut self,
402 clock: &dyn Clock,
403 device_code_grant: DeviceCodeGrant,
404 browser_session: &BrowserSession,
405 ) -> Result<DeviceCodeGrant, Self::Error> {
406 let fulfilled_at = clock.now();
407 let device_code_grant = device_code_grant
408 .reject(browser_session, fulfilled_at)
409 .map_err(DatabaseError::to_invalid_operation)?;
410
411 let res = sqlx::query!(
412 r#"
413 UPDATE oauth2_device_code_grant
414 SET rejected_at = $1
415 , user_session_id = $2
416 WHERE oauth2_device_code_grant_id = $3
417 "#,
418 fulfilled_at,
419 Uuid::from(browser_session.id),
420 Uuid::from(device_code_grant.id),
421 )
422 .traced()
423 .execute(&mut *self.conn)
424 .await?;
425
426 DatabaseError::ensure_affected_rows(&res, 1)?;
427
428 Ok(device_code_grant)
429 }
430
431 #[tracing::instrument(
432 name = "db.oauth2_device_code_grant.exchange",
433 skip_all,
434 fields(
435 db.query.text,
436 oauth2_device_code.id = %device_code_grant.id,
437 oauth2_client.id = %device_code_grant.client_id,
438 oauth2_session.id = %session.id,
439 ),
440 err,
441 )]
442 async fn exchange(
443 &mut self,
444 clock: &dyn Clock,
445 device_code_grant: DeviceCodeGrant,
446 session: &Session,
447 ) -> Result<DeviceCodeGrant, Self::Error> {
448 let exchanged_at = clock.now();
449 let device_code_grant = device_code_grant
450 .exchange(session, exchanged_at)
451 .map_err(DatabaseError::to_invalid_operation)?;
452
453 let res = sqlx::query!(
454 r#"
455 UPDATE oauth2_device_code_grant
456 SET exchanged_at = $1
457 , oauth2_session_id = $2
458 WHERE oauth2_device_code_grant_id = $3
459 "#,
460 exchanged_at,
461 Uuid::from(session.id),
462 Uuid::from(device_code_grant.id),
463 )
464 .traced()
465 .execute(&mut *self.conn)
466 .await?;
467
468 DatabaseError::ensure_affected_rows(&res, 1)?;
469
470 Ok(device_code_grant)
471 }
472}