mas_tasks/
new_queue.rs

1// Copyright 2024, 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use std::{collections::HashMap, sync::Arc};
7
8use async_trait::async_trait;
9use chrono::{DateTime, Duration, Utc};
10use cron::Schedule;
11use mas_context::LogContext;
12use mas_data_model::Clock;
13use mas_storage::{
14    RepositoryAccess, RepositoryError,
15    queue::{InsertableJob, Job, JobMetadata, Worker},
16};
17use mas_storage_pg::{DatabaseError, PgRepository};
18use opentelemetry::{
19    KeyValue,
20    metrics::{Counter, Histogram, UpDownCounter},
21};
22use rand::{Rng, RngCore, distributions::Uniform};
23use serde::de::DeserializeOwned;
24use sqlx::{
25    Acquire, Either,
26    postgres::{PgAdvisoryLock, PgListener},
27};
28use thiserror::Error;
29use tokio::{task::JoinSet, time::Instant};
30use tokio_util::sync::CancellationToken;
31use tracing::{Instrument as _, Span};
32use tracing_opentelemetry::OpenTelemetrySpanExt as _;
33use ulid::Ulid;
34
35use crate::{METER, State};
36
37type JobPayload = serde_json::Value;
38
39#[derive(Clone)]
40pub struct JobContext {
41    pub id: Ulid,
42    pub metadata: JobMetadata,
43    pub queue_name: String,
44    pub attempt: usize,
45    pub start: Instant,
46
47    #[expect(
48        dead_code,
49        reason = "we're not yet using this, but will be in the future"
50    )]
51    pub cancellation_token: CancellationToken,
52}
53
54impl JobContext {
55    pub fn span(&self) -> Span {
56        let span = tracing::info_span!(
57            parent: Span::none(),
58            "job.run",
59            job.id = %self.id,
60            job.queue.name = self.queue_name,
61            job.attempt = self.attempt,
62        );
63
64        span.add_link(self.metadata.span_context());
65
66        span
67    }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
71pub enum JobErrorDecision {
72    Retry,
73
74    #[default]
75    Fail,
76}
77
78impl std::fmt::Display for JobErrorDecision {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            Self::Retry => f.write_str("retry"),
82            Self::Fail => f.write_str("fail"),
83        }
84    }
85}
86
87#[derive(Debug, Error)]
88#[error("Job failed to run, will {decision}")]
89pub struct JobError {
90    decision: JobErrorDecision,
91    #[source]
92    error: anyhow::Error,
93}
94
95impl JobError {
96    pub fn retry<T: Into<anyhow::Error>>(error: T) -> Self {
97        Self {
98            decision: JobErrorDecision::Retry,
99            error: error.into(),
100        }
101    }
102
103    pub fn fail<T: Into<anyhow::Error>>(error: T) -> Self {
104        Self {
105            decision: JobErrorDecision::Fail,
106            error: error.into(),
107        }
108    }
109}
110
111pub trait FromJob {
112    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error>
113    where
114        Self: Sized;
115}
116
117impl<T> FromJob for T
118where
119    T: DeserializeOwned,
120{
121    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error> {
122        serde_json::from_value(payload).map_err(Into::into)
123    }
124}
125
126#[async_trait]
127pub trait RunnableJob: FromJob + Send + 'static {
128    async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError>;
129}
130
131fn box_runnable_job<T: RunnableJob + 'static>(job: T) -> Box<dyn RunnableJob> {
132    Box::new(job)
133}
134
135#[derive(Debug, Error)]
136pub enum QueueRunnerError {
137    #[error("Failed to setup listener")]
138    SetupListener(#[source] sqlx::Error),
139
140    #[error("Failed to start transaction")]
141    StartTransaction(#[source] sqlx::Error),
142
143    #[error("Failed to commit transaction")]
144    CommitTransaction(#[source] sqlx::Error),
145
146    #[error("Failed to acquire leader lock")]
147    LeaderLock(#[source] sqlx::Error),
148
149    #[error(transparent)]
150    Repository(#[from] RepositoryError),
151
152    #[error(transparent)]
153    Database(#[from] DatabaseError),
154
155    #[error("Invalid schedule expression")]
156    InvalidSchedule(#[from] cron::error::Error),
157
158    #[error("Worker is not the leader")]
159    NotLeader,
160}
161
162// When the worker waits for a notification, we still want to wake it up every
163// second. Because we don't want all the workers to wake up at the same time, we
164// add a random jitter to the sleep duration, so they effectively sleep between
165// 0.9 and 1.1 seconds.
166const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
167const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
168
169// How many jobs can we run concurrently
170const MAX_CONCURRENT_JOBS: usize = 10;
171
172// How many jobs can we fetch at once
173const MAX_JOBS_TO_FETCH: usize = 5;
174
175// How many attempts a job should be retried
176const MAX_ATTEMPTS: usize = 10;
177
178/// Returns the delay to wait before retrying a job
179///
180/// Uses an exponential backoff: 5s, 10s, 20s, 40s, 1m20s, 2m40s, 5m20s, 10m50s,
181/// 21m40s, 43m20s
182fn retry_delay(attempt: usize) -> Duration {
183    let attempt = u32::try_from(attempt).unwrap_or(u32::MAX);
184    Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
185}
186
187type JobResult = (std::time::Duration, Result<(), JobError>);
188type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
189
190struct ScheduleDefinition {
191    schedule_name: &'static str,
192    expression: Schedule,
193    queue_name: &'static str,
194    payload: serde_json::Value,
195}
196
197pub struct QueueWorker {
198    listener: PgListener,
199    registration: Worker,
200    am_i_leader: bool,
201    last_heartbeat: DateTime<Utc>,
202    cancellation_token: CancellationToken,
203    #[expect(dead_code, reason = "This is used on Drop")]
204    cancellation_guard: tokio_util::sync::DropGuard,
205    state: State,
206    schedules: Vec<ScheduleDefinition>,
207    tracker: JobTracker,
208    wakeup_reason: Counter<u64>,
209    tick_time: Histogram<u64>,
210}
211
212impl QueueWorker {
213    #[tracing::instrument(
214        name = "worker.init",
215        skip_all,
216        fields(worker.id)
217    )]
218    pub(crate) async fn new(
219        state: State,
220        cancellation_token: CancellationToken,
221    ) -> Result<Self, QueueRunnerError> {
222        let mut rng = state.rng();
223        let clock = state.clock();
224
225        let mut listener = PgListener::connect_with(&state.pool())
226            .await
227            .map_err(QueueRunnerError::SetupListener)?;
228
229        // We get notifications of leader stepping down on this channel
230        listener
231            .listen("queue_leader_stepdown")
232            .await
233            .map_err(QueueRunnerError::SetupListener)?;
234
235        // We get notifications when a job is available on this channel
236        listener
237            .listen("queue_available")
238            .await
239            .map_err(QueueRunnerError::SetupListener)?;
240
241        let txn = listener
242            .begin()
243            .await
244            .map_err(QueueRunnerError::StartTransaction)?;
245        let mut repo = PgRepository::from_conn(txn);
246
247        let registration = repo.queue_worker().register(&mut rng, clock).await?;
248        tracing::Span::current().record("worker.id", tracing::field::display(registration.id));
249        repo.into_inner()
250            .commit()
251            .await
252            .map_err(QueueRunnerError::CommitTransaction)?;
253
254        tracing::info!(worker.id = %registration.id, "Registered worker");
255        let now = clock.now();
256
257        let wakeup_reason = METER
258            .u64_counter("job.worker.wakeups")
259            .with_description("Counts how many time the worker has been woken up, for which reason")
260            .build();
261
262        // Pre-create the reasons on the counter
263        wakeup_reason.add(0, &[KeyValue::new("reason", "sleep")]);
264        wakeup_reason.add(0, &[KeyValue::new("reason", "task")]);
265        wakeup_reason.add(0, &[KeyValue::new("reason", "notification")]);
266
267        let tick_time = METER
268            .u64_histogram("job.worker.tick_duration")
269            .with_description(
270                "How much time the worker took to tick, including performing leader duties",
271            )
272            .build();
273
274        // We put a cancellation drop guard in the structure, so that when it gets
275        // dropped, we're sure to cancel the token
276        let cancellation_guard = cancellation_token.clone().drop_guard();
277
278        Ok(Self {
279            listener,
280            registration,
281            am_i_leader: false,
282            last_heartbeat: now,
283            cancellation_token,
284            cancellation_guard,
285            state,
286            schedules: Vec::new(),
287            tracker: JobTracker::new(),
288            wakeup_reason,
289            tick_time,
290        })
291    }
292
293    pub(crate) fn register_handler<T: RunnableJob + InsertableJob>(&mut self) -> &mut Self {
294        // There is a potential panic here, which is fine as it's going to be caught
295        // within the job task
296        let factory = |payload: JobPayload| {
297            box_runnable_job(T::from_job(payload).expect("Failed to deserialize job"))
298        };
299
300        self.tracker
301            .factories
302            .insert(T::QUEUE_NAME, Arc::new(factory));
303        self
304    }
305
306    pub(crate) fn add_schedule<T: InsertableJob>(
307        &mut self,
308        schedule_name: &'static str,
309        expression: Schedule,
310        job: T,
311    ) -> &mut Self {
312        let payload = serde_json::to_value(job).expect("failed to serialize job payload");
313
314        self.schedules.push(ScheduleDefinition {
315            schedule_name,
316            expression,
317            queue_name: T::QUEUE_NAME,
318            payload,
319        });
320
321        self
322    }
323
324    pub(crate) async fn run(mut self) {
325        if let Err(e) = self.run_inner().await {
326            tracing::error!(
327                error = &e as &dyn std::error::Error,
328                "Failed to run new queue"
329            );
330        }
331    }
332
333    async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
334        self.setup_schedules().await?;
335
336        while !self.cancellation_token.is_cancelled() {
337            LogContext::new("worker-run-loop")
338                .run(|| self.run_loop())
339                .await?;
340        }
341
342        self.shutdown().await?;
343
344        Ok(())
345    }
346
347    #[tracing::instrument(name = "worker.setup_schedules", skip_all)]
348    pub(crate) async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> {
349        let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect();
350
351        // Start a transaction on the existing PgListener connection
352        let txn = self
353            .listener
354            .begin()
355            .await
356            .map_err(QueueRunnerError::StartTransaction)?;
357
358        let mut repo = PgRepository::from_conn(txn);
359
360        // Setup the entries in the queue_schedules table
361        repo.queue_schedule().setup(&schedules).await?;
362
363        repo.into_inner()
364            .commit()
365            .await
366            .map_err(QueueRunnerError::CommitTransaction)?;
367
368        Ok(())
369    }
370
371    #[tracing::instrument(name = "worker.run_loop", skip_all)]
372    async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
373        self.wait_until_wakeup().await?;
374
375        if self.cancellation_token.is_cancelled() {
376            return Ok(());
377        }
378
379        let start = Instant::now();
380        self.tick().await?;
381
382        if self.am_i_leader {
383            self.perform_leader_duties().await?;
384        }
385
386        let elapsed = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
387        self.tick_time.record(elapsed, &[]);
388
389        Ok(())
390    }
391
392    #[tracing::instrument(name = "worker.shutdown", skip_all)]
393    async fn shutdown(&mut self) -> Result<(), QueueRunnerError> {
394        tracing::info!("Shutting down worker");
395
396        let clock = self.state.clock();
397        let mut rng = self.state.rng();
398
399        // Start a transaction on the existing PgListener connection
400        let txn = self
401            .listener
402            .begin()
403            .await
404            .map_err(QueueRunnerError::StartTransaction)?;
405
406        let mut repo = PgRepository::from_conn(txn);
407
408        // Log about any job still running
409        match self.tracker.running_jobs() {
410            0 => {}
411            1 => tracing::warn!("There is one job still running, waiting for it to finish"),
412            n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"),
413        }
414
415        // TODO: we may want to introduce a timeout here, and abort the tasks if they
416        // take too long. It's fine for now, as we don't have long-running
417        // tasks, most of them are idempotent, and the only effect might be that
418        // the worker would 'dirtily' shutdown, meaning that its tasks would be
419        // considered, later retried by another worker
420
421        // Wait for all the jobs to finish
422        self.tracker
423            .process_jobs(&mut rng, clock, &mut repo, true)
424            .await?;
425
426        // Tell the other workers we're shutting down
427        // This also releases the leader election lease
428        repo.queue_worker()
429            .shutdown(clock, &self.registration)
430            .await?;
431
432        repo.into_inner()
433            .commit()
434            .await
435            .map_err(QueueRunnerError::CommitTransaction)?;
436
437        Ok(())
438    }
439
440    #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all)]
441    async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
442        let mut rng = self.state.rng();
443
444        // This is to make sure we wake up every second to do the maintenance tasks
445        // We add a little bit of random jitter to the duration, so that we don't get
446        // fully synced workers waking up at the same time after each notification
447        let sleep_duration = rng.sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
448        let wakeup_sleep = tokio::time::sleep(sleep_duration);
449
450        tokio::select! {
451            () = self.cancellation_token.cancelled() => {
452                tracing::debug!("Woke up from cancellation");
453            },
454
455            () = wakeup_sleep => {
456                tracing::debug!("Woke up from sleep");
457                self.wakeup_reason.add(1, &[KeyValue::new("reason", "sleep")]);
458            },
459
460            () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => {
461                tracing::debug!("Joined job task");
462                self.wakeup_reason.add(1, &[KeyValue::new("reason", "task")]);
463            },
464
465            notification = self.listener.recv() => {
466                self.wakeup_reason.add(1, &[KeyValue::new("reason", "notification")]);
467                match notification {
468                    Ok(notification) => {
469                        tracing::debug!(
470                            notification.channel = notification.channel(),
471                            notification.payload = notification.payload(),
472                            "Woke up from notification"
473                        );
474                    },
475                    Err(e) => {
476                        tracing::error!(error = &e as &dyn std::error::Error, "Failed to receive notification");
477                    },
478                }
479            },
480        }
481
482        Ok(())
483    }
484
485    #[tracing::instrument(
486        name = "worker.tick",
487        skip_all,
488        fields(worker.id = %self.registration.id),
489    )]
490    async fn tick(&mut self) -> Result<(), QueueRunnerError> {
491        tracing::debug!("Tick");
492        let clock = self.state.clock();
493        let mut rng = self.state.rng();
494        let now = clock.now();
495
496        // Start a transaction on the existing PgListener connection
497        let txn = self
498            .listener
499            .begin()
500            .await
501            .map_err(QueueRunnerError::StartTransaction)?;
502        let mut repo = PgRepository::from_conn(txn);
503
504        // We send a heartbeat every minute, to avoid writing to the database too often
505        // on a logged table
506        if now - self.last_heartbeat >= chrono::Duration::minutes(1) {
507            tracing::info!("Sending heartbeat");
508            repo.queue_worker()
509                .heartbeat(clock, &self.registration)
510                .await?;
511            self.last_heartbeat = now;
512        }
513
514        // Remove any dead worker leader leases
515        repo.queue_worker()
516            .remove_leader_lease_if_expired(clock)
517            .await?;
518
519        // Try to become (or stay) the leader
520        let leader = repo
521            .queue_worker()
522            .try_get_leader_lease(clock, &self.registration)
523            .await?;
524
525        // Process any job task which finished
526        self.tracker
527            .process_jobs(&mut rng, clock, &mut repo, false)
528            .await?;
529
530        // Compute how many jobs we should fetch at most
531        let max_jobs_to_fetch = MAX_CONCURRENT_JOBS
532            .saturating_sub(self.tracker.running_jobs())
533            .max(MAX_JOBS_TO_FETCH);
534
535        if max_jobs_to_fetch == 0 {
536            tracing::warn!("Internal job queue is full, not fetching any new jobs");
537        } else {
538            // Grab a few jobs in the queue
539            let queues = self.tracker.queues();
540            let jobs = repo
541                .queue_job()
542                .reserve(clock, &self.registration, &queues, max_jobs_to_fetch)
543                .await?;
544
545            for Job {
546                id,
547                queue_name,
548                payload,
549                metadata,
550                attempt,
551            } in jobs
552            {
553                let cancellation_token = self.cancellation_token.child_token();
554                let start = Instant::now();
555                let context = JobContext {
556                    id,
557                    metadata,
558                    queue_name,
559                    attempt,
560                    start,
561                    cancellation_token,
562                };
563
564                self.tracker.spawn_job(self.state.clone(), context, payload);
565            }
566        }
567
568        // After this point, we are locking the leader table, so it's important that we
569        // commit as soon as possible to not block the other workers for too long
570        repo.into_inner()
571            .commit()
572            .await
573            .map_err(QueueRunnerError::CommitTransaction)?;
574
575        // Save the new leader state to log any change
576        if leader != self.am_i_leader {
577            // If we flipped state, log it
578            self.am_i_leader = leader;
579            if self.am_i_leader {
580                tracing::info!("I'm the leader now");
581            } else {
582                tracing::warn!("I am no longer the leader");
583            }
584        }
585
586        Ok(())
587    }
588
589    #[tracing::instrument(name = "worker.perform_leader_duties", skip_all)]
590    async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
591        // This should have been checked by the caller, but better safe than sorry
592        if !self.am_i_leader {
593            return Err(QueueRunnerError::NotLeader);
594        }
595
596        let clock = self.state.clock();
597        let mut rng = self.state.rng();
598
599        // Start a transaction on the existing PgListener connection
600        let txn = self
601            .listener
602            .begin()
603            .await
604            .map_err(QueueRunnerError::StartTransaction)?;
605
606        // The thing with the leader election is that it locks the table during the
607        // election, preventing other workers from going through the loop.
608        //
609        // Ideally, we would do the leader duties in the same transaction so that we
610        // make sure only one worker is doing the leader duties, but that
611        // would mean we would lock all the workers for the duration of the
612        // duties, which is not ideal.
613        //
614        // So we do the duties in a separate transaction, in which we take an advisory
615        // lock, so that in the very rare case where two workers think they are the
616        // leader, we still don't have two workers doing the duties at the same time.
617        let lock = PgAdvisoryLock::new("leader-duties");
618
619        let locked = lock
620            .try_acquire(txn)
621            .await
622            .map_err(QueueRunnerError::LeaderLock)?;
623
624        let locked = match locked {
625            Either::Left(locked) => locked,
626            Either::Right(txn) => {
627                tracing::error!("Another worker has the leader lock, aborting");
628                txn.rollback()
629                    .await
630                    .map_err(QueueRunnerError::CommitTransaction)?;
631                return Ok(());
632            }
633        };
634
635        let mut repo = PgRepository::from_conn(locked);
636
637        // Look at the state of schedules in the database
638        let schedules_status = repo.queue_schedule().list().await?;
639
640        let now = clock.now();
641        for schedule in &self.schedules {
642            // Find the schedule status from the database
643            let Some(status) = schedules_status
644                .iter()
645                .find(|s| s.schedule_name == schedule.schedule_name)
646            else {
647                tracing::error!(
648                    "Schedule {} was not found in the database",
649                    schedule.schedule_name
650                );
651                continue;
652            };
653
654            // Figure out if we should schedule a new job
655            if let Some(next_time) = status.last_scheduled_at {
656                if next_time > now {
657                    // We already have a job scheduled in the future, skip
658                    continue;
659                }
660
661                if status.last_scheduled_job_completed == Some(false) {
662                    // The last scheduled job has not completed yet, skip
663                    continue;
664                }
665            }
666
667            let next_tick = schedule.expression.after(&now).next().unwrap();
668
669            tracing::info!(
670                "Scheduling job for {}, next run at {}",
671                schedule.schedule_name,
672                next_tick
673            );
674
675            repo.queue_job()
676                .schedule_later(
677                    &mut rng,
678                    clock,
679                    schedule.queue_name,
680                    schedule.payload.clone(),
681                    serde_json::json!({}),
682                    next_tick,
683                    Some(schedule.schedule_name),
684                )
685                .await?;
686        }
687
688        // We also check if the worker is dead, and if so, we shutdown all the dead
689        // workers that haven't checked in the last two minutes
690        repo.queue_worker()
691            .shutdown_dead_workers(clock, Duration::minutes(2))
692            .await?;
693
694        // TODO: mark tasks those workers had as lost
695
696        // Mark all the scheduled jobs as available
697        let scheduled = repo.queue_job().schedule_available_jobs(clock).await?;
698        match scheduled {
699            0 => {}
700            1 => tracing::info!("One scheduled job marked as available"),
701            n => tracing::info!("{n} scheduled jobs marked as available"),
702        }
703
704        // Release the leader lock
705        let txn = repo
706            .into_inner()
707            .release_now()
708            .await
709            .map_err(QueueRunnerError::LeaderLock)?;
710
711        txn.commit()
712            .await
713            .map_err(QueueRunnerError::CommitTransaction)?;
714
715        Ok(())
716    }
717
718    /// Process all the pending jobs in the queue.
719    /// This should only be called in tests!
720    ///
721    /// # Errors
722    ///
723    /// This function can fail if the database connection fails.
724    pub async fn process_all_jobs_in_tests(&mut self) -> Result<(), QueueRunnerError> {
725        // I swear, I'm the leader!
726        self.am_i_leader = true;
727
728        // First, perform the leader duties. This will make sure that we schedule
729        // recurring jobs.
730        self.perform_leader_duties().await?;
731
732        let clock = self.state.clock();
733        let mut rng = self.state.rng();
734
735        // Grab the connection from the PgListener
736        let txn = self
737            .listener
738            .begin()
739            .await
740            .map_err(QueueRunnerError::StartTransaction)?;
741        let mut repo = PgRepository::from_conn(txn);
742
743        // Spawn all the jobs in the database
744        let queues = self.tracker.queues();
745        let jobs = repo
746            .queue_job()
747            // I really hope that we don't spawn more than 10k jobs in tests
748            .reserve(clock, &self.registration, &queues, 10_000)
749            .await?;
750
751        for Job {
752            id,
753            queue_name,
754            payload,
755            metadata,
756            attempt,
757        } in jobs
758        {
759            let cancellation_token = self.cancellation_token.child_token();
760            let start = Instant::now();
761            let context = JobContext {
762                id,
763                metadata,
764                queue_name,
765                attempt,
766                start,
767                cancellation_token,
768            };
769
770            self.tracker.spawn_job(self.state.clone(), context, payload);
771        }
772
773        self.tracker
774            .process_jobs(&mut rng, clock, &mut repo, true)
775            .await?;
776
777        repo.into_inner()
778            .commit()
779            .await
780            .map_err(QueueRunnerError::CommitTransaction)?;
781
782        Ok(())
783    }
784}
785
786/// Tracks running jobs
787///
788/// This is a separate structure to be able to borrow it mutably at the same
789/// time as the connection to the database is borrowed
790struct JobTracker {
791    /// Stores a mapping from the job queue name to the job factory
792    factories: HashMap<&'static str, JobFactory>,
793
794    /// A join set of all the currently running jobs
795    running_jobs: JoinSet<JobResult>,
796
797    /// Stores a mapping from the Tokio task ID to the job context
798    job_contexts: HashMap<tokio::task::Id, JobContext>,
799
800    /// Stores the last `join_next_with_id` result for processing, in case we
801    /// got woken up in `collect_next_job`
802    last_join_result: Option<Result<(tokio::task::Id, JobResult), tokio::task::JoinError>>,
803
804    /// An histogram which records the time it takes to process a job
805    job_processing_time: Histogram<u64>,
806
807    /// A counter which records the number of jobs currently in flight
808    in_flight_jobs: UpDownCounter<i64>,
809}
810
811impl JobTracker {
812    fn new() -> Self {
813        let job_processing_time = METER
814            .u64_histogram("job.process.duration")
815            .with_description("The time it takes to process a job in milliseconds")
816            .with_unit("ms")
817            .build();
818
819        let in_flight_jobs = METER
820            .i64_up_down_counter("job.active_tasks")
821            .with_description("The number of jobs currently in flight")
822            .with_unit("{job}")
823            .build();
824
825        Self {
826            factories: HashMap::new(),
827            running_jobs: JoinSet::new(),
828            job_contexts: HashMap::new(),
829            last_join_result: None,
830            job_processing_time,
831            in_flight_jobs,
832        }
833    }
834
835    /// Returns the queue names that are currently being tracked
836    fn queues(&self) -> Vec<&'static str> {
837        self.factories.keys().copied().collect()
838    }
839
840    /// Spawn a job on the job tracker
841    fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) {
842        let factory = self.factories.get(context.queue_name.as_str()).cloned();
843        let task = {
844            let log_context = LogContext::new(format!("job-{}", context.queue_name));
845            let context = context.clone();
846            let span = context.span();
847            log_context
848                .run(async move || {
849                    // We should never crash, but in case we do, we do that in the task and
850                    // don't crash the worker
851                    let job = factory.expect("unknown job factory")(payload);
852                    tracing::info!(
853                        job.id = %context.id,
854                        job.queue.name = %context.queue_name,
855                        job.attempt = %context.attempt,
856                        "Running job"
857                    );
858                    let result = job.run(&state, context.clone()).await;
859
860                    let Some(context_stats) =
861                        LogContext::maybe_with(mas_context::LogContext::stats)
862                    else {
863                        // This should never happen, but if it does it's fine: we're recovering fine
864                        // from panics in those tasks
865                        panic!("Missing log context, this should never happen");
866                    };
867
868                    // We log the result here so that it's attached to the right span & log context
869                    match &result {
870                        Ok(()) => {
871                            tracing::info!(
872                                job.id = %context.id,
873                                job.queue.name = %context.queue_name,
874                                job.attempt = %context.attempt,
875                                "Job completed [{context_stats}]"
876                            );
877                        }
878
879                        Err(JobError {
880                            decision: JobErrorDecision::Fail,
881                            error,
882                        }) => {
883                            tracing::error!(
884                                error = &**error as &dyn std::error::Error,
885                                job.id = %context.id,
886                                job.queue.name = %context.queue_name,
887                                job.attempt = %context.attempt,
888                                "Job failed, not retrying [{context_stats}]"
889                            );
890                        }
891
892                        Err(JobError {
893                            decision: JobErrorDecision::Retry,
894                            error,
895                        }) if context.attempt < MAX_ATTEMPTS => {
896                            let delay = retry_delay(context.attempt);
897                            tracing::warn!(
898                                error = &**error as &dyn std::error::Error,
899                                job.id = %context.id,
900                                job.queue.name = %context.queue_name,
901                                job.attempt = %context.attempt,
902                                "Job failed, will retry in {}s [{context_stats}]",
903                                delay.num_seconds()
904                            );
905                        }
906
907                        Err(JobError {
908                            decision: JobErrorDecision::Retry,
909                            error,
910                        }) => {
911                            tracing::error!(
912                                error = &**error as &dyn std::error::Error,
913                                job.id = %context.id,
914                                job.queue.name = %context.queue_name,
915                                job.attempt = %context.attempt,
916                                "Job failed too many times, abandonning [{context_stats}]"
917                            );
918                        }
919                    }
920
921                    (context_stats.elapsed, result)
922                })
923                .instrument(span)
924        };
925
926        self.in_flight_jobs.add(
927            1,
928            &[KeyValue::new("job.queue.name", context.queue_name.clone())],
929        );
930
931        let handle = self.running_jobs.spawn(task);
932        self.job_contexts.insert(handle.id(), context);
933    }
934
935    /// Returns `true` if there are currently running jobs
936    fn has_jobs(&self) -> bool {
937        !self.running_jobs.is_empty()
938    }
939
940    /// Returns the number of currently running jobs
941    ///
942    /// This also includes the job result which may be stored for processing
943    fn running_jobs(&self) -> usize {
944        self.running_jobs.len() + usize::from(self.last_join_result.is_some())
945    }
946
947    async fn collect_next_job(&mut self) {
948        // Double-check that we don't have a job result stored
949        if self.last_join_result.is_some() {
950            tracing::error!(
951                "Job tracker already had a job result stored, this should never happen!"
952            );
953            return;
954        }
955
956        self.last_join_result = self.running_jobs.join_next_with_id().await;
957    }
958
959    /// Process all the jobs which are currently running
960    ///
961    /// If `blocking` is `true`, this function will block until all the jobs
962    /// are finished. Otherwise, it will return as soon as it processed the
963    /// already finished jobs.
964    async fn process_jobs<E: std::error::Error + Send + Sync + 'static>(
965        &mut self,
966        rng: &mut (dyn RngCore + Send),
967        clock: &dyn Clock,
968        repo: &mut dyn RepositoryAccess<Error = E>,
969        blocking: bool,
970    ) -> Result<(), E> {
971        if self.last_join_result.is_none() {
972            if blocking {
973                self.last_join_result = self.running_jobs.join_next_with_id().await;
974            } else {
975                self.last_join_result = self.running_jobs.try_join_next_with_id();
976            }
977        }
978
979        while let Some(result) = self.last_join_result.take() {
980            match result {
981                // The job succeeded. The logging and time measurement is already done in the task
982                Ok((id, (elapsed, Ok(())))) => {
983                    let context = self
984                        .job_contexts
985                        .remove(&id)
986                        .expect("Job context not found");
987
988                    self.in_flight_jobs.add(
989                        -1,
990                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
991                    );
992
993                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
994                    self.job_processing_time.record(
995                        elapsed_ms,
996                        &[
997                            KeyValue::new("job.queue.name", context.queue_name),
998                            KeyValue::new("job.result", "success"),
999                        ],
1000                    );
1001
1002                    repo.queue_job()
1003                        .mark_as_completed(clock, context.id)
1004                        .await?;
1005                }
1006
1007                // The job failed. The logging and time measurement is already done in the task
1008                Ok((id, (elapsed, Err(e)))) => {
1009                    let context = self
1010                        .job_contexts
1011                        .remove(&id)
1012                        .expect("Job context not found");
1013
1014                    self.in_flight_jobs.add(
1015                        -1,
1016                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1017                    );
1018
1019                    let reason = format!("{:?}", e.error);
1020                    repo.queue_job()
1021                        .mark_as_failed(clock, context.id, &reason)
1022                        .await?;
1023
1024                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
1025                    match e.decision {
1026                        JobErrorDecision::Fail => {
1027                            self.job_processing_time.record(
1028                                elapsed_ms,
1029                                &[
1030                                    KeyValue::new("job.queue.name", context.queue_name),
1031                                    KeyValue::new("job.result", "failed"),
1032                                    KeyValue::new("job.decision", "fail"),
1033                                ],
1034                            );
1035                        }
1036
1037                        JobErrorDecision::Retry if context.attempt < MAX_ATTEMPTS => {
1038                            self.job_processing_time.record(
1039                                elapsed_ms,
1040                                &[
1041                                    KeyValue::new("job.queue.name", context.queue_name),
1042                                    KeyValue::new("job.result", "failed"),
1043                                    KeyValue::new("job.decision", "retry"),
1044                                ],
1045                            );
1046
1047                            let delay = retry_delay(context.attempt);
1048                            repo.queue_job()
1049                                .retry(&mut *rng, clock, context.id, delay)
1050                                .await?;
1051                        }
1052
1053                        JobErrorDecision::Retry => {
1054                            self.job_processing_time.record(
1055                                elapsed_ms,
1056                                &[
1057                                    KeyValue::new("job.queue.name", context.queue_name),
1058                                    KeyValue::new("job.result", "failed"),
1059                                    KeyValue::new("job.decision", "abandon"),
1060                                ],
1061                            );
1062                        }
1063                    }
1064                }
1065
1066                // The job crashed (or was aborted)
1067                Err(e) => {
1068                    let id = e.id();
1069                    let context = self
1070                        .job_contexts
1071                        .remove(&id)
1072                        .expect("Job context not found");
1073
1074                    self.in_flight_jobs.add(
1075                        -1,
1076                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1077                    );
1078
1079                    // This measurement is not accurate as it includes the time processing the jobs,
1080                    // but it's fine, it's only for panicked tasks
1081                    let elapsed = context
1082                        .start
1083                        .elapsed()
1084                        .as_millis()
1085                        .try_into()
1086                        .unwrap_or(u64::MAX);
1087
1088                    let reason = e.to_string();
1089                    repo.queue_job()
1090                        .mark_as_failed(clock, context.id, &reason)
1091                        .await?;
1092
1093                    if context.attempt < MAX_ATTEMPTS {
1094                        let delay = retry_delay(context.attempt);
1095                        tracing::error!(
1096                            error = &e as &dyn std::error::Error,
1097                            job.id = %context.id,
1098                            job.queue.name = %context.queue_name,
1099                            job.attempt = %context.attempt,
1100                            job.elapsed = format!("{elapsed}ms"),
1101                            "Job crashed, will retry in {}s",
1102                            delay.num_seconds()
1103                        );
1104
1105                        self.job_processing_time.record(
1106                            elapsed,
1107                            &[
1108                                KeyValue::new("job.queue.name", context.queue_name),
1109                                KeyValue::new("job.result", "crashed"),
1110                                KeyValue::new("job.decision", "retry"),
1111                            ],
1112                        );
1113
1114                        repo.queue_job()
1115                            .retry(&mut *rng, clock, context.id, delay)
1116                            .await?;
1117                    } else {
1118                        tracing::error!(
1119                            error = &e as &dyn std::error::Error,
1120                            job.id = %context.id,
1121                            job.queue.name = %context.queue_name,
1122                            job.attempt = %context.attempt,
1123                            job.elapsed = format!("{elapsed}ms"),
1124                            "Job crashed too many times, abandonning"
1125                        );
1126
1127                        self.job_processing_time.record(
1128                            elapsed,
1129                            &[
1130                                KeyValue::new("job.queue.name", context.queue_name),
1131                                KeyValue::new("job.result", "crashed"),
1132                                KeyValue::new("job.decision", "abandon"),
1133                            ],
1134                        );
1135                    }
1136                }
1137            }
1138
1139            if blocking {
1140                self.last_join_result = self.running_jobs.join_next_with_id().await;
1141            } else {
1142                self.last_join_result = self.running_jobs.try_join_next_with_id();
1143            }
1144        }
1145
1146        Ok(())
1147    }
1148}