mas_storage_pg/queue/
worker.rs

1// Copyright 2024, 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6//! A module containing the PostgreSQL implementation of the
7//! [`QueueWorkerRepository`].
8
9use async_trait::async_trait;
10use chrono::Duration;
11use mas_data_model::Clock;
12use mas_storage::queue::{QueueWorkerRepository, Worker};
13use rand::RngCore;
14use sqlx::PgConnection;
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, ExecuteExt};
19
20/// An implementation of [`QueueWorkerRepository`] for a PostgreSQL connection.
21pub struct PgQueueWorkerRepository<'c> {
22    conn: &'c mut PgConnection,
23}
24
25impl<'c> PgQueueWorkerRepository<'c> {
26    /// Create a new [`PgQueueWorkerRepository`] from an active PostgreSQL
27    /// connection.
28    #[must_use]
29    pub fn new(conn: &'c mut PgConnection) -> Self {
30        Self { conn }
31    }
32}
33
34#[async_trait]
35impl QueueWorkerRepository for PgQueueWorkerRepository<'_> {
36    type Error = DatabaseError;
37
38    #[tracing::instrument(
39        name = "db.queue_worker.register",
40        skip_all,
41        fields(
42            worker.id,
43            db.query.text,
44        ),
45        err,
46    )]
47    async fn register(
48        &mut self,
49        rng: &mut (dyn RngCore + Send),
50        clock: &dyn Clock,
51    ) -> Result<Worker, Self::Error> {
52        let now = clock.now();
53        let worker_id = Ulid::from_datetime_with_source(now.into(), rng);
54        tracing::Span::current().record("worker.id", tracing::field::display(worker_id));
55
56        sqlx::query!(
57            r#"
58                INSERT INTO queue_workers (queue_worker_id, registered_at, last_seen_at)
59                VALUES ($1, $2, $2)
60            "#,
61            Uuid::from(worker_id),
62            now,
63        )
64        .traced()
65        .execute(&mut *self.conn)
66        .await?;
67
68        Ok(Worker { id: worker_id })
69    }
70
71    #[tracing::instrument(
72        name = "db.queue_worker.heartbeat",
73        skip_all,
74        fields(
75            %worker.id,
76            db.query.text,
77        ),
78        err,
79    )]
80    async fn heartbeat(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
81        let now = clock.now();
82        let res = sqlx::query!(
83            r#"
84                UPDATE queue_workers
85                SET last_seen_at = $2
86                WHERE queue_worker_id = $1 AND shutdown_at IS NULL
87            "#,
88            Uuid::from(worker.id),
89            now,
90        )
91        .traced()
92        .execute(&mut *self.conn)
93        .await?;
94
95        // If no row was updated, the worker was shutdown so we return an error
96        DatabaseError::ensure_affected_rows(&res, 1)?;
97
98        Ok(())
99    }
100
101    #[tracing::instrument(
102        name = "db.queue_worker.shutdown",
103        skip_all,
104        fields(
105            %worker.id,
106            db.query.text,
107        ),
108        err,
109    )]
110    async fn shutdown(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
111        let now = clock.now();
112        let res = sqlx::query!(
113            r#"
114                UPDATE queue_workers
115                SET shutdown_at = $2
116                WHERE queue_worker_id = $1
117            "#,
118            Uuid::from(worker.id),
119            now,
120        )
121        .traced()
122        .execute(&mut *self.conn)
123        .await?;
124
125        DatabaseError::ensure_affected_rows(&res, 1)?;
126
127        // Remove the leader lease if we were holding it
128        let res = sqlx::query!(
129            r#"
130                DELETE FROM queue_leader
131                WHERE queue_worker_id = $1
132            "#,
133            Uuid::from(worker.id),
134        )
135        .traced()
136        .execute(&mut *self.conn)
137        .await?;
138
139        // If we were holding the leader lease, notify workers
140        if res.rows_affected() > 0 {
141            sqlx::query!(
142                r#"
143                    NOTIFY queue_leader_stepdown
144                "#,
145            )
146            .traced()
147            .execute(&mut *self.conn)
148            .await?;
149        }
150
151        Ok(())
152    }
153
154    #[tracing::instrument(
155        name = "db.queue_worker.shutdown_dead_workers",
156        skip_all,
157        fields(
158            db.query.text,
159        ),
160        err,
161    )]
162    async fn shutdown_dead_workers(
163        &mut self,
164        clock: &dyn Clock,
165        threshold: Duration,
166    ) -> Result<(), Self::Error> {
167        // Here the threshold is usually set to a few minutes, so we don't need to use
168        // the database time, as we can assume worker clocks have less than a minute
169        // skew between each other, else other things would break
170        let now = clock.now();
171        sqlx::query!(
172            r#"
173                UPDATE queue_workers
174                SET shutdown_at = $1
175                WHERE shutdown_at IS NULL
176                  AND last_seen_at < $2
177            "#,
178            now,
179            now - threshold,
180        )
181        .traced()
182        .execute(&mut *self.conn)
183        .await?;
184
185        Ok(())
186    }
187
188    #[tracing::instrument(
189        name = "db.queue_worker.remove_leader_lease_if_expired",
190        skip_all,
191        fields(
192            db.query.text,
193        ),
194        err,
195    )]
196    async fn remove_leader_lease_if_expired(
197        &mut self,
198        _clock: &dyn Clock,
199    ) -> Result<(), Self::Error> {
200        // `expires_at` is a rare exception where we use the database time, as this
201        // would be very sensitive to clock skew between workers
202        sqlx::query!(
203            r#"
204                DELETE FROM queue_leader
205                WHERE expires_at < NOW()
206            "#,
207        )
208        .traced()
209        .execute(&mut *self.conn)
210        .await?;
211
212        Ok(())
213    }
214
215    #[tracing::instrument(
216        name = "db.queue_worker.try_get_leader_lease",
217        skip_all,
218        fields(
219            %worker.id,
220            db.query.text,
221        ),
222        err,
223    )]
224    async fn try_get_leader_lease(
225        &mut self,
226        clock: &dyn Clock,
227        worker: &Worker,
228    ) -> Result<bool, Self::Error> {
229        let now = clock.now();
230        // The queue_leader table is meant to only have a single row, which conflicts on
231        // the `active` column
232
233        // If there is a conflict, we update the `expires_at` column ONLY IF the current
234        // leader is ourselves.
235
236        // `expires_at` is a rare exception where we use the database time, as this
237        // would be very sensitive to clock skew between workers
238        let res = sqlx::query!(
239            r#"
240                INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)
241                VALUES ($1, NOW() + INTERVAL '5 seconds', $2)
242                ON CONFLICT (active)
243                DO UPDATE SET expires_at = EXCLUDED.expires_at
244                WHERE queue_leader.queue_worker_id = $2
245            "#,
246            now,
247            Uuid::from(worker.id)
248        )
249        .traced()
250        .execute(&mut *self.conn)
251        .await?;
252
253        // We can then detect whether we are the leader or not by checking how many rows
254        // were affected by the upsert
255        let am_i_the_leader = res.rows_affected() == 1;
256
257        Ok(am_i_the_leader)
258    }
259}