1use async_trait::async_trait;
10use chrono::Duration;
11use mas_data_model::Clock;
12use mas_storage::queue::{QueueWorkerRepository, Worker};
13use rand::RngCore;
14use sqlx::PgConnection;
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, ExecuteExt};
19
20pub struct PgQueueWorkerRepository<'c> {
22 conn: &'c mut PgConnection,
23}
24
25impl<'c> PgQueueWorkerRepository<'c> {
26 #[must_use]
29 pub fn new(conn: &'c mut PgConnection) -> Self {
30 Self { conn }
31 }
32}
33
34#[async_trait]
35impl QueueWorkerRepository for PgQueueWorkerRepository<'_> {
36 type Error = DatabaseError;
37
38 #[tracing::instrument(
39 name = "db.queue_worker.register",
40 skip_all,
41 fields(
42 worker.id,
43 db.query.text,
44 ),
45 err,
46 )]
47 async fn register(
48 &mut self,
49 rng: &mut (dyn RngCore + Send),
50 clock: &dyn Clock,
51 ) -> Result<Worker, Self::Error> {
52 let now = clock.now();
53 let worker_id = Ulid::from_datetime_with_source(now.into(), rng);
54 tracing::Span::current().record("worker.id", tracing::field::display(worker_id));
55
56 sqlx::query!(
57 r#"
58 INSERT INTO queue_workers (queue_worker_id, registered_at, last_seen_at)
59 VALUES ($1, $2, $2)
60 "#,
61 Uuid::from(worker_id),
62 now,
63 )
64 .traced()
65 .execute(&mut *self.conn)
66 .await?;
67
68 Ok(Worker { id: worker_id })
69 }
70
71 #[tracing::instrument(
72 name = "db.queue_worker.heartbeat",
73 skip_all,
74 fields(
75 %worker.id,
76 db.query.text,
77 ),
78 err,
79 )]
80 async fn heartbeat(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
81 let now = clock.now();
82 let res = sqlx::query!(
83 r#"
84 UPDATE queue_workers
85 SET last_seen_at = $2
86 WHERE queue_worker_id = $1 AND shutdown_at IS NULL
87 "#,
88 Uuid::from(worker.id),
89 now,
90 )
91 .traced()
92 .execute(&mut *self.conn)
93 .await?;
94
95 DatabaseError::ensure_affected_rows(&res, 1)?;
97
98 Ok(())
99 }
100
101 #[tracing::instrument(
102 name = "db.queue_worker.shutdown",
103 skip_all,
104 fields(
105 %worker.id,
106 db.query.text,
107 ),
108 err,
109 )]
110 async fn shutdown(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
111 let now = clock.now();
112 let res = sqlx::query!(
113 r#"
114 UPDATE queue_workers
115 SET shutdown_at = $2
116 WHERE queue_worker_id = $1
117 "#,
118 Uuid::from(worker.id),
119 now,
120 )
121 .traced()
122 .execute(&mut *self.conn)
123 .await?;
124
125 DatabaseError::ensure_affected_rows(&res, 1)?;
126
127 let res = sqlx::query!(
129 r#"
130 DELETE FROM queue_leader
131 WHERE queue_worker_id = $1
132 "#,
133 Uuid::from(worker.id),
134 )
135 .traced()
136 .execute(&mut *self.conn)
137 .await?;
138
139 if res.rows_affected() > 0 {
141 sqlx::query!(
142 r#"
143 NOTIFY queue_leader_stepdown
144 "#,
145 )
146 .traced()
147 .execute(&mut *self.conn)
148 .await?;
149 }
150
151 Ok(())
152 }
153
154 #[tracing::instrument(
155 name = "db.queue_worker.shutdown_dead_workers",
156 skip_all,
157 fields(
158 db.query.text,
159 ),
160 err,
161 )]
162 async fn shutdown_dead_workers(
163 &mut self,
164 clock: &dyn Clock,
165 threshold: Duration,
166 ) -> Result<(), Self::Error> {
167 let now = clock.now();
171 sqlx::query!(
172 r#"
173 UPDATE queue_workers
174 SET shutdown_at = $1
175 WHERE shutdown_at IS NULL
176 AND last_seen_at < $2
177 "#,
178 now,
179 now - threshold,
180 )
181 .traced()
182 .execute(&mut *self.conn)
183 .await?;
184
185 Ok(())
186 }
187
188 #[tracing::instrument(
189 name = "db.queue_worker.remove_leader_lease_if_expired",
190 skip_all,
191 fields(
192 db.query.text,
193 ),
194 err,
195 )]
196 async fn remove_leader_lease_if_expired(
197 &mut self,
198 _clock: &dyn Clock,
199 ) -> Result<(), Self::Error> {
200 sqlx::query!(
203 r#"
204 DELETE FROM queue_leader
205 WHERE expires_at < NOW()
206 "#,
207 )
208 .traced()
209 .execute(&mut *self.conn)
210 .await?;
211
212 Ok(())
213 }
214
215 #[tracing::instrument(
216 name = "db.queue_worker.try_get_leader_lease",
217 skip_all,
218 fields(
219 %worker.id,
220 db.query.text,
221 ),
222 err,
223 )]
224 async fn try_get_leader_lease(
225 &mut self,
226 clock: &dyn Clock,
227 worker: &Worker,
228 ) -> Result<bool, Self::Error> {
229 let now = clock.now();
230 let res = sqlx::query!(
239 r#"
240 INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)
241 VALUES ($1, NOW() + INTERVAL '5 seconds', $2)
242 ON CONFLICT (active)
243 DO UPDATE SET expires_at = EXCLUDED.expires_at
244 WHERE queue_leader.queue_worker_id = $2
245 "#,
246 now,
247 Uuid::from(worker.id)
248 )
249 .traced()
250 .execute(&mut *self.conn)
251 .await?;
252
253 let am_i_the_leader = res.rows_affected() == 1;
256
257 Ok(am_i_the_leader)
258 }
259}