mas_storage_pg/queue/
job.rs

1// Copyright 2024, 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6//! A module containing the PostgreSQL implementation of the
7//! [`QueueJobRepository`].
8
9use async_trait::async_trait;
10use chrono::{DateTime, Duration, Utc};
11use mas_data_model::Clock;
12use mas_storage::queue::{Job, QueueJobRepository, Worker};
13use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
14use rand::RngCore;
15use sqlx::PgConnection;
16use tracing::Instrument;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt};
21
22/// An implementation of [`QueueJobRepository`] for a PostgreSQL connection.
23pub struct PgQueueJobRepository<'c> {
24    conn: &'c mut PgConnection,
25}
26
27impl<'c> PgQueueJobRepository<'c> {
28    /// Create a new [`PgQueueJobRepository`] from an active PostgreSQL
29    /// connection.
30    #[must_use]
31    pub fn new(conn: &'c mut PgConnection) -> Self {
32        Self { conn }
33    }
34}
35
36struct JobReservationResult {
37    queue_job_id: Uuid,
38    queue_name: String,
39    payload: serde_json::Value,
40    metadata: serde_json::Value,
41    attempt: i32,
42}
43
44impl TryFrom<JobReservationResult> for Job {
45    type Error = DatabaseInconsistencyError;
46
47    fn try_from(value: JobReservationResult) -> Result<Self, Self::Error> {
48        let id = value.queue_job_id.into();
49        let queue_name = value.queue_name;
50        let payload = value.payload;
51
52        let metadata = serde_json::from_value(value.metadata).map_err(|e| {
53            DatabaseInconsistencyError::on("queue_jobs")
54                .column("metadata")
55                .row(id)
56                .source(e)
57        })?;
58
59        let attempt = value.attempt.try_into().map_err(|e| {
60            DatabaseInconsistencyError::on("queue_jobs")
61                .column("attempt")
62                .row(id)
63                .source(e)
64        })?;
65
66        Ok(Self {
67            id,
68            queue_name,
69            payload,
70            metadata,
71            attempt,
72        })
73    }
74}
75
76#[async_trait]
77impl QueueJobRepository for PgQueueJobRepository<'_> {
78    type Error = DatabaseError;
79
80    #[tracing::instrument(
81        name = "db.queue_job.schedule",
82        fields(
83            queue_job.id,
84            queue_job.queue_name = queue_name,
85            db.query.text,
86        ),
87        skip_all,
88        err,
89    )]
90    async fn schedule(
91        &mut self,
92        rng: &mut (dyn RngCore + Send),
93        clock: &dyn Clock,
94        queue_name: &str,
95        payload: serde_json::Value,
96        metadata: serde_json::Value,
97    ) -> Result<(), Self::Error> {
98        let created_at = clock.now();
99        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
100        tracing::Span::current().record("queue_job.id", tracing::field::display(id));
101
102        sqlx::query!(
103            r#"
104                INSERT INTO queue_jobs
105                    (queue_job_id, queue_name, payload, metadata, created_at)
106                VALUES ($1, $2, $3, $4, $5)
107            "#,
108            Uuid::from(id),
109            queue_name,
110            payload,
111            metadata,
112            created_at,
113        )
114        .traced()
115        .execute(&mut *self.conn)
116        .await?;
117
118        Ok(())
119    }
120
121    #[tracing::instrument(
122        name = "db.queue_job.schedule_later",
123        fields(
124            queue_job.id,
125            queue_job.queue_name = queue_name,
126            queue_job.scheduled_at = %scheduled_at,
127            db.query.text,
128        ),
129        skip_all,
130        err,
131    )]
132    async fn schedule_later(
133        &mut self,
134        rng: &mut (dyn RngCore + Send),
135        clock: &dyn Clock,
136        queue_name: &str,
137        payload: serde_json::Value,
138        metadata: serde_json::Value,
139        scheduled_at: DateTime<Utc>,
140        schedule_name: Option<&str>,
141    ) -> Result<(), Self::Error> {
142        let created_at = clock.now();
143        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
144        tracing::Span::current().record("queue_job.id", tracing::field::display(id));
145
146        sqlx::query!(
147            r#"
148                INSERT INTO queue_jobs
149                    (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, schedule_name, status)
150                VALUES ($1, $2, $3, $4, $5, $6, $7, 'scheduled')
151            "#,
152            Uuid::from(id),
153            queue_name,
154            payload,
155            metadata,
156            created_at,
157            scheduled_at,
158            schedule_name,
159        )
160        .traced()
161        .execute(&mut *self.conn)
162        .await?;
163
164        // If there was a schedule name supplied, update the queue_schedules table
165        if let Some(schedule_name) = schedule_name {
166            let span = tracing::info_span!(
167                "db.queue_job.schedule_later.update_schedules",
168                { DB_QUERY_TEXT } = tracing::field::Empty,
169            );
170
171            let res = sqlx::query!(
172                r#"
173                    UPDATE queue_schedules
174                    SET last_scheduled_at = $1,
175                        last_scheduled_job_id = $2
176                    WHERE schedule_name = $3
177                "#,
178                scheduled_at,
179                Uuid::from(id),
180                schedule_name,
181            )
182            .record(&span)
183            .execute(&mut *self.conn)
184            .instrument(span)
185            .await?;
186
187            DatabaseError::ensure_affected_rows(&res, 1)?;
188        }
189
190        Ok(())
191    }
192
193    #[tracing::instrument(
194        name = "db.queue_job.reserve",
195        skip_all,
196        fields(
197            db.query.text,
198        ),
199        err,
200    )]
201    async fn reserve(
202        &mut self,
203        clock: &dyn Clock,
204        worker: &Worker,
205        queues: &[&str],
206        count: usize,
207    ) -> Result<Vec<Job>, Self::Error> {
208        let now = clock.now();
209        let max_count = i64::try_from(count).unwrap_or(i64::MAX);
210        let queues: Vec<String> = queues.iter().map(|&s| s.to_owned()).collect();
211        let results = sqlx::query_as!(
212            JobReservationResult,
213            r#"
214                -- We first grab a few jobs that are available,
215                -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently
216                -- and we don't get multiple workers grabbing the same jobs
217                WITH locked_jobs AS (
218                    SELECT queue_job_id
219                    FROM queue_jobs
220                    WHERE
221                        status = 'available'
222                        AND queue_name = ANY($1)
223                    ORDER BY queue_job_id ASC
224                    LIMIT $2
225                    FOR UPDATE
226                    SKIP LOCKED
227                )
228                -- then we update the status of those jobs to 'running', returning the job details
229                UPDATE queue_jobs
230                SET status = 'running', started_at = $3, started_by = $4
231                FROM locked_jobs
232                WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id
233                RETURNING
234                    queue_jobs.queue_job_id,
235                    queue_jobs.queue_name,
236                    queue_jobs.payload,
237                    queue_jobs.metadata,
238                    queue_jobs.attempt
239            "#,
240            &queues,
241            max_count,
242            now,
243            Uuid::from(worker.id),
244        )
245        .traced()
246        .fetch_all(&mut *self.conn)
247        .await?;
248
249        let jobs = results
250            .into_iter()
251            .map(TryFrom::try_from)
252            .collect::<Result<Vec<_>, _>>()?;
253
254        Ok(jobs)
255    }
256
257    #[tracing::instrument(
258        name = "db.queue_job.mark_as_completed",
259        skip_all,
260        fields(
261            db.query.text,
262            job.id = %id,
263        ),
264        err,
265    )]
266    async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error> {
267        let now = clock.now();
268        let res = sqlx::query!(
269            r#"
270                UPDATE queue_jobs
271                SET status = 'completed', completed_at = $1
272                WHERE queue_job_id = $2 AND status = 'running'
273            "#,
274            now,
275            Uuid::from(id),
276        )
277        .traced()
278        .execute(&mut *self.conn)
279        .await?;
280
281        DatabaseError::ensure_affected_rows(&res, 1)?;
282
283        Ok(())
284    }
285
286    #[tracing::instrument(
287        name = "db.queue_job.mark_as_failed",
288        skip_all,
289        fields(
290            db.query.text,
291            job.id = %id,
292        ),
293        err
294    )]
295    async fn mark_as_failed(
296        &mut self,
297        clock: &dyn Clock,
298        id: Ulid,
299        reason: &str,
300    ) -> Result<(), Self::Error> {
301        let now = clock.now();
302        let res = sqlx::query!(
303            r#"
304                UPDATE queue_jobs
305                SET
306                    status = 'failed',
307                    failed_at = $1,
308                    failed_reason = $2
309                WHERE
310                    queue_job_id = $3
311                    AND status = 'running'
312            "#,
313            now,
314            reason,
315            Uuid::from(id),
316        )
317        .traced()
318        .execute(&mut *self.conn)
319        .await?;
320
321        DatabaseError::ensure_affected_rows(&res, 1)?;
322
323        Ok(())
324    }
325
326    #[tracing::instrument(
327        name = "db.queue_job.retry",
328        skip_all,
329        fields(
330            db.query.text,
331            job.id = %id,
332        ),
333        err
334    )]
335    async fn retry(
336        &mut self,
337        rng: &mut (dyn RngCore + Send),
338        clock: &dyn Clock,
339        id: Ulid,
340        delay: Duration,
341    ) -> Result<(), Self::Error> {
342        let now = clock.now();
343        let scheduled_at = now + delay;
344        let new_id = Ulid::from_datetime_with_source(now.into(), rng);
345
346        let span = tracing::info_span!(
347            "db.queue_job.retry.insert_job",
348            { DB_QUERY_TEXT } = tracing::field::Empty
349        );
350        // Create a new job with the same payload and metadata, but a new ID and
351        // increment the attempt
352        // We make sure we do this only for 'failed' jobs
353        let res = sqlx::query!(
354            r#"
355                INSERT INTO queue_jobs
356                    (queue_job_id, queue_name, payload, metadata, created_at,
357                     attempt, scheduled_at, schedule_name, status)
358                SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, schedule_name, 'scheduled'
359                FROM queue_jobs
360                WHERE queue_job_id = $4
361                  AND status = 'failed'
362            "#,
363            Uuid::from(new_id),
364            now,
365            scheduled_at,
366            Uuid::from(id),
367        )
368        .record(&span)
369        .execute(&mut *self.conn)
370        .instrument(span)
371        .await?;
372
373        DatabaseError::ensure_affected_rows(&res, 1)?;
374
375        // If that job was referenced by a schedule, update the schedule
376        let span = tracing::info_span!(
377            "db.queue_job.retry.update_schedule",
378            { DB_QUERY_TEXT } = tracing::field::Empty
379        );
380        sqlx::query!(
381            r#"
382                UPDATE queue_schedules
383                SET last_scheduled_at = $1,
384                    last_scheduled_job_id = $2
385                WHERE last_scheduled_job_id = $3
386            "#,
387            scheduled_at,
388            Uuid::from(new_id),
389            Uuid::from(id),
390        )
391        .record(&span)
392        .execute(&mut *self.conn)
393        .instrument(span)
394        .await?;
395
396        // Update the old job to point to the new attempt
397        let span = tracing::info_span!(
398            "db.queue_job.retry.update_old_job",
399            { DB_QUERY_TEXT } = tracing::field::Empty
400        );
401        let res = sqlx::query!(
402            r#"
403                UPDATE queue_jobs
404                SET next_attempt_id = $1
405                WHERE queue_job_id = $2
406            "#,
407            Uuid::from(new_id),
408            Uuid::from(id),
409        )
410        .record(&span)
411        .execute(&mut *self.conn)
412        .instrument(span)
413        .await?;
414
415        DatabaseError::ensure_affected_rows(&res, 1)?;
416
417        Ok(())
418    }
419
420    #[tracing::instrument(
421        name = "db.queue_job.schedule_available_jobs",
422        skip_all,
423        fields(
424            db.query.text,
425        ),
426        err
427    )]
428    async fn schedule_available_jobs(&mut self, clock: &dyn Clock) -> Result<usize, Self::Error> {
429        let now = clock.now();
430        let res = sqlx::query!(
431            r#"
432                UPDATE queue_jobs
433                SET status = 'available'
434                WHERE
435                    status = 'scheduled'
436                    AND scheduled_at <= $1
437            "#,
438            now,
439        )
440        .traced()
441        .execute(&mut *self.conn)
442        .await?;
443
444        let count = res.rows_affected();
445        Ok(usize::try_from(count).unwrap_or(usize::MAX))
446    }
447}