mas_storage_pg/oauth2/
device_code_grant.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Clock, DeviceCodeGrant, DeviceCodeGrantState, Session};
12use mas_storage::oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository};
13use oauth2_types::scope::Scope;
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, ExecuteExt, errors::DatabaseInconsistencyError};
20
21/// An implementation of [`OAuth2DeviceCodeGrantRepository`] for a PostgreSQL
22/// connection
23pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
24    conn: &'c mut PgConnection,
25}
26
27impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
28    /// Create a new [`PgOAuth2DeviceCodeGrantRepository`] from an active
29    /// PostgreSQL connection
30    pub fn new(conn: &'c mut PgConnection) -> Self {
31        Self { conn }
32    }
33}
34
35struct OAuth2DeviceGrantLookup {
36    oauth2_device_code_grant_id: Uuid,
37    oauth2_client_id: Uuid,
38    scope: String,
39    device_code: String,
40    user_code: String,
41    created_at: DateTime<Utc>,
42    expires_at: DateTime<Utc>,
43    fulfilled_at: Option<DateTime<Utc>>,
44    rejected_at: Option<DateTime<Utc>>,
45    exchanged_at: Option<DateTime<Utc>>,
46    user_session_id: Option<Uuid>,
47    oauth2_session_id: Option<Uuid>,
48    ip_address: Option<IpAddr>,
49    user_agent: Option<String>,
50}
51
52impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
53    type Error = DatabaseInconsistencyError;
54
55    fn try_from(
56        OAuth2DeviceGrantLookup {
57            oauth2_device_code_grant_id,
58            oauth2_client_id,
59            scope,
60            device_code,
61            user_code,
62            created_at,
63            expires_at,
64            fulfilled_at,
65            rejected_at,
66            exchanged_at,
67            user_session_id,
68            oauth2_session_id,
69            ip_address,
70            user_agent,
71        }: OAuth2DeviceGrantLookup,
72    ) -> Result<Self, Self::Error> {
73        let id = Ulid::from(oauth2_device_code_grant_id);
74        let client_id = Ulid::from(oauth2_client_id);
75
76        let scope: Scope = scope.parse().map_err(|e| {
77            DatabaseInconsistencyError::on("oauth2_authorization_grants")
78                .column("scope")
79                .row(id)
80                .source(e)
81        })?;
82
83        let state = match (
84            fulfilled_at,
85            rejected_at,
86            exchanged_at,
87            user_session_id,
88            oauth2_session_id,
89        ) {
90            (None, None, None, None, None) => DeviceCodeGrantState::Pending,
91
92            (Some(fulfilled_at), None, None, Some(user_session_id), None) => {
93                DeviceCodeGrantState::Fulfilled {
94                    browser_session_id: Ulid::from(user_session_id),
95                    fulfilled_at,
96                }
97            }
98
99            (None, Some(rejected_at), None, Some(user_session_id), None) => {
100                DeviceCodeGrantState::Rejected {
101                    browser_session_id: Ulid::from(user_session_id),
102                    rejected_at,
103                }
104            }
105
106            (
107                Some(fulfilled_at),
108                None,
109                Some(exchanged_at),
110                Some(user_session_id),
111                Some(oauth2_session_id),
112            ) => DeviceCodeGrantState::Exchanged {
113                browser_session_id: Ulid::from(user_session_id),
114                session_id: Ulid::from(oauth2_session_id),
115                fulfilled_at,
116                exchanged_at,
117            },
118
119            _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
120        };
121
122        Ok(DeviceCodeGrant {
123            id,
124            state,
125            client_id,
126            scope,
127            user_code,
128            device_code,
129            created_at,
130            expires_at,
131            ip_address,
132            user_agent,
133        })
134    }
135}
136
137#[async_trait]
138impl OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'_> {
139    type Error = DatabaseError;
140
141    #[tracing::instrument(
142        name = "db.oauth2_device_code_grant.add",
143        skip_all,
144        fields(
145            db.query.text,
146            oauth2_device_code.id,
147            oauth2_device_code.scope = %params.scope,
148            oauth2_client.id = %params.client.id,
149        ),
150        err,
151    )]
152    async fn add(
153        &mut self,
154        rng: &mut (dyn RngCore + Send),
155        clock: &dyn Clock,
156        params: OAuth2DeviceCodeGrantParams<'_>,
157    ) -> Result<DeviceCodeGrant, Self::Error> {
158        let now = clock.now();
159        let id = Ulid::from_datetime_with_source(now.into(), rng);
160        tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
161
162        let created_at = now;
163        let expires_at = now + params.expires_in;
164        let client_id = params.client.id;
165
166        sqlx::query!(
167            r#"
168                INSERT INTO "oauth2_device_code_grant"
169                    ( oauth2_device_code_grant_id
170                    , oauth2_client_id
171                    , scope
172                    , device_code
173                    , user_code
174                    , created_at
175                    , expires_at
176                    , ip_address
177                    , user_agent
178                    )
179                VALUES
180                    ($1, $2, $3, $4, $5, $6, $7, $8, $9)
181            "#,
182            Uuid::from(id),
183            Uuid::from(client_id),
184            params.scope.to_string(),
185            &params.device_code,
186            &params.user_code,
187            created_at,
188            expires_at,
189            params.ip_address as Option<IpAddr>,
190            params.user_agent.as_deref(),
191        )
192        .traced()
193        .execute(&mut *self.conn)
194        .await?;
195
196        Ok(DeviceCodeGrant {
197            id,
198            state: DeviceCodeGrantState::Pending,
199            client_id,
200            scope: params.scope,
201            user_code: params.user_code,
202            device_code: params.device_code,
203            created_at,
204            expires_at,
205            ip_address: params.ip_address,
206            user_agent: params.user_agent,
207        })
208    }
209
210    #[tracing::instrument(
211        name = "db.oauth2_device_code_grant.lookup",
212        skip_all,
213        fields(
214            db.query.text,
215            oauth2_device_code.id = %id,
216        ),
217        err,
218    )]
219    async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
220        let res = sqlx::query_as!(
221            OAuth2DeviceGrantLookup,
222            r#"
223                SELECT oauth2_device_code_grant_id
224                     , oauth2_client_id
225                     , scope
226                     , device_code
227                     , user_code
228                     , created_at
229                     , expires_at
230                     , fulfilled_at
231                     , rejected_at
232                     , exchanged_at
233                     , user_session_id
234                     , oauth2_session_id
235                     , ip_address as "ip_address: IpAddr"
236                     , user_agent
237                FROM
238                    oauth2_device_code_grant
239
240                WHERE oauth2_device_code_grant_id = $1
241            "#,
242            Uuid::from(id),
243        )
244        .traced()
245        .fetch_optional(&mut *self.conn)
246        .await?;
247
248        let Some(res) = res else { return Ok(None) };
249
250        Ok(Some(res.try_into()?))
251    }
252
253    #[tracing::instrument(
254        name = "db.oauth2_device_code_grant.find_by_user_code",
255        skip_all,
256        fields(
257            db.query.text,
258            oauth2_device_code.user_code = %user_code,
259        ),
260        err,
261    )]
262    async fn find_by_user_code(
263        &mut self,
264        user_code: &str,
265    ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
266        let res = sqlx::query_as!(
267            OAuth2DeviceGrantLookup,
268            r#"
269                SELECT oauth2_device_code_grant_id
270                     , oauth2_client_id
271                     , scope
272                     , device_code
273                     , user_code
274                     , created_at
275                     , expires_at
276                     , fulfilled_at
277                     , rejected_at
278                     , exchanged_at
279                     , user_session_id
280                     , oauth2_session_id
281                     , ip_address as "ip_address: IpAddr"
282                     , user_agent
283                FROM
284                    oauth2_device_code_grant
285
286                WHERE user_code = $1
287            "#,
288            user_code,
289        )
290        .traced()
291        .fetch_optional(&mut *self.conn)
292        .await?;
293
294        let Some(res) = res else { return Ok(None) };
295
296        Ok(Some(res.try_into()?))
297    }
298
299    #[tracing::instrument(
300        name = "db.oauth2_device_code_grant.find_by_device_code",
301        skip_all,
302        fields(
303            db.query.text,
304            oauth2_device_code.device_code = %device_code,
305        ),
306        err,
307    )]
308    async fn find_by_device_code(
309        &mut self,
310        device_code: &str,
311    ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
312        let res = sqlx::query_as!(
313            OAuth2DeviceGrantLookup,
314            r#"
315                SELECT oauth2_device_code_grant_id
316                     , oauth2_client_id
317                     , scope
318                     , device_code
319                     , user_code
320                     , created_at
321                     , expires_at
322                     , fulfilled_at
323                     , rejected_at
324                     , exchanged_at
325                     , user_session_id
326                     , oauth2_session_id
327                     , ip_address as "ip_address: IpAddr"
328                     , user_agent
329                FROM
330                    oauth2_device_code_grant
331
332                WHERE device_code = $1
333            "#,
334            device_code,
335        )
336        .traced()
337        .fetch_optional(&mut *self.conn)
338        .await?;
339
340        let Some(res) = res else { return Ok(None) };
341
342        Ok(Some(res.try_into()?))
343    }
344
345    #[tracing::instrument(
346        name = "db.oauth2_device_code_grant.fulfill",
347        skip_all,
348        fields(
349            db.query.text,
350            oauth2_device_code.id = %device_code_grant.id,
351            oauth2_client.id = %device_code_grant.client_id,
352            browser_session.id = %browser_session.id,
353            user.id = %browser_session.user.id,
354        ),
355        err,
356    )]
357    async fn fulfill(
358        &mut self,
359        clock: &dyn Clock,
360        device_code_grant: DeviceCodeGrant,
361        browser_session: &BrowserSession,
362    ) -> Result<DeviceCodeGrant, Self::Error> {
363        let fulfilled_at = clock.now();
364        let device_code_grant = device_code_grant
365            .fulfill(browser_session, fulfilled_at)
366            .map_err(DatabaseError::to_invalid_operation)?;
367
368        let res = sqlx::query!(
369            r#"
370                UPDATE oauth2_device_code_grant
371                SET fulfilled_at = $1
372                  , user_session_id = $2
373                WHERE oauth2_device_code_grant_id = $3
374            "#,
375            fulfilled_at,
376            Uuid::from(browser_session.id),
377            Uuid::from(device_code_grant.id),
378        )
379        .traced()
380        .execute(&mut *self.conn)
381        .await?;
382
383        DatabaseError::ensure_affected_rows(&res, 1)?;
384
385        Ok(device_code_grant)
386    }
387
388    #[tracing::instrument(
389        name = "db.oauth2_device_code_grant.reject",
390        skip_all,
391        fields(
392            db.query.text,
393            oauth2_device_code.id = %device_code_grant.id,
394            oauth2_client.id = %device_code_grant.client_id,
395            browser_session.id = %browser_session.id,
396            user.id = %browser_session.user.id,
397        ),
398        err,
399    )]
400    async fn reject(
401        &mut self,
402        clock: &dyn Clock,
403        device_code_grant: DeviceCodeGrant,
404        browser_session: &BrowserSession,
405    ) -> Result<DeviceCodeGrant, Self::Error> {
406        let fulfilled_at = clock.now();
407        let device_code_grant = device_code_grant
408            .reject(browser_session, fulfilled_at)
409            .map_err(DatabaseError::to_invalid_operation)?;
410
411        let res = sqlx::query!(
412            r#"
413                UPDATE oauth2_device_code_grant
414                SET rejected_at = $1
415                  , user_session_id = $2
416                WHERE oauth2_device_code_grant_id = $3
417            "#,
418            fulfilled_at,
419            Uuid::from(browser_session.id),
420            Uuid::from(device_code_grant.id),
421        )
422        .traced()
423        .execute(&mut *self.conn)
424        .await?;
425
426        DatabaseError::ensure_affected_rows(&res, 1)?;
427
428        Ok(device_code_grant)
429    }
430
431    #[tracing::instrument(
432        name = "db.oauth2_device_code_grant.exchange",
433        skip_all,
434        fields(
435            db.query.text,
436            oauth2_device_code.id = %device_code_grant.id,
437            oauth2_client.id = %device_code_grant.client_id,
438            oauth2_session.id = %session.id,
439        ),
440        err,
441    )]
442    async fn exchange(
443        &mut self,
444        clock: &dyn Clock,
445        device_code_grant: DeviceCodeGrant,
446        session: &Session,
447    ) -> Result<DeviceCodeGrant, Self::Error> {
448        let exchanged_at = clock.now();
449        let device_code_grant = device_code_grant
450            .exchange(session, exchanged_at)
451            .map_err(DatabaseError::to_invalid_operation)?;
452
453        let res = sqlx::query!(
454            r#"
455                UPDATE oauth2_device_code_grant
456                SET exchanged_at = $1
457                  , oauth2_session_id = $2
458                WHERE oauth2_device_code_grant_id = $3
459            "#,
460            exchanged_at,
461            Uuid::from(session.id),
462            Uuid::from(device_code_grant.id),
463        )
464        .traced()
465        .execute(&mut *self.conn)
466        .await?;
467
468        DatabaseError::ensure_affected_rows(&res, 1)?;
469
470        Ok(device_code_grant)
471    }
472}