Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor Worker to support parallel task execution with ThreadPoolExe…
…cutor and update related configurations
  • Loading branch information
rroblf01 committed Nov 14, 2025
commit 9537b1d1b2231d0a0c665e693da1632c8ea8ec1a
53 changes: 24 additions & 29 deletions django_tasks/backends/database/management/commands/db_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import random
import signal
import sys
import threading
import time
from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction
from concurrent.futures import ThreadPoolExecutor
from types import FrameType

from django.conf import settings
Expand All @@ -20,7 +20,7 @@
from django_tasks.backends.database.backend import DatabaseBackend
from django_tasks.backends.database.models import DBTaskResult
from django_tasks.backends.database.utils import exclusive_transaction
from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_WORKERS, TaskContext
from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_THREADS, TaskContext
from django_tasks.exceptions import InvalidTaskBackendError
from django_tasks.signals import task_finished, task_started
from django_tasks.utils import get_random_id
Expand All @@ -40,7 +40,7 @@ def __init__(
startup_delay: bool,
max_tasks: int | None,
worker_id: str,
max_workers: int,
max_threads: int = MAX_THREADS,
):
self.queue_names = queue_names
self.process_all_queues = "*" in queue_names
Expand All @@ -49,7 +49,7 @@ def __init__(
self.backend_name = backend_name
self.startup_delay = startup_delay
self.max_tasks = max_tasks
self.max_workers = max_workers
self.max_threads = max_threads

self.running = True
self.running_task = False
Expand Down Expand Up @@ -88,6 +88,12 @@ def reset_signals(self) -> None:
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal.SIG_DFL)

def run_parallel(self) -> None:
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = [executor.submit(self.run) for _ in range(self.max_threads)]
for future in futures:
future.result()

def run(self) -> None:
logger.info(
"Starting worker worker_id=%s queues=%s",
Expand All @@ -108,32 +114,23 @@ def run(self) -> None:
# it be as efficient as possible.
with exclusive_transaction(tasks.db):
try:
task_results = list(tasks.get_locked(self.max_workers))
task_result = tasks.get_locked()
except OperationalError as e:
# Ignore locked databases and keep trying.
# It should unlock eventually.
if "is locked" in e.args[0]:
task_results = None
task_result = None
else:
raise

if task_results is not None and len(task_results) > 0:
if task_result is not None:
# "claim" the task, so it isn't run by another worker process
for task_result in task_results:
task_result.claim(self.worker_id)

if task_results is not None and len(task_results) > 0:
threads = []
for task_result in task_results:
thread = threading.Thread(target=self.run_task, args=(task_result,))
thread.start()
threads.append(thread)
task_result.claim(self.worker_id)

# Wait for all threads to complete
for thread in threads:
thread.join()
if task_result is not None:
self.run_task(task_result)

if self.batch and (task_results is None or len(task_results) == 0):
if self.batch and task_result is None:
# If we're running in "batch" mode, terminate the loop (and thus the worker)
logger.info(
"No more tasks to run for worker_id=%s - exiting gracefully.",
Expand All @@ -155,7 +152,7 @@ def run(self) -> None:

# If ctrl-c has just interrupted a task, self.running was cleared,
# and we should not sleep, but rather exit immediately.
if self.running and not task_results:
if self.running and not task_result:
# Wait before checking for another task
time.sleep(self.interval)

Expand Down Expand Up @@ -295,11 +292,11 @@ def add_arguments(self, parser: ArgumentParser) -> None:
default=get_random_id(),
)
parser.add_argument(
"--max-workers",
"--max-threads",
nargs="?",
type=valid_max_tasks,
help="Maximum number of worker threads to process tasks concurrently (default: %(default)r)",
default=MAX_WORKERS,
default=MAX_THREADS,
type=int,
help=f"The maximum number of threads to use for processing tasks (default: {MAX_THREADS})",
)

def configure_logging(self, verbosity: int) -> None:
Expand Down Expand Up @@ -327,7 +324,6 @@ def handle(
reload: bool,
max_tasks: int | None,
worker_id: str,
max_workers: int,
**options: dict,
) -> None:
self.configure_logging(verbosity)
Expand All @@ -346,15 +342,14 @@ def handle(
startup_delay=startup_delay,
max_tasks=max_tasks,
worker_id=worker_id,
max_workers=max_workers,
)

if reload:
if os.environ.get(DJANGO_AUTORELOAD_ENV) == "true":
# Only the child process should configure its signals
worker.configure_signals()

run_with_reloader(worker.run)
run_with_reloader(worker.run_parallel)
else:
worker.configure_signals()
worker.run()
worker.run_parallel()
8 changes: 4 additions & 4 deletions django_tasks/backends/database/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime
import logging
import uuid
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar

import django
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.db import models
from django.db.models import F, Q, QuerySet
from django.db.models import F, Q
from django.db.models.constraints import CheckConstraint
from django.utils import timezone
from django.utils.module_loading import import_string
Expand Down Expand Up @@ -80,11 +80,11 @@ def finished(self) -> "DBTaskResultQuerySet":
return self.failed() | self.succeeded()

@retry()
def get_locked(self, size: int = 1) -> QuerySet["DBTaskResult"]:
def get_locked(self) -> Optional["DBTaskResult"]:
"""
Get a job, locking the row and accounting for deadlocks.
"""
return self.select_for_update(skip_locked=True)[:size]
return self.select_for_update(skip_locked=True).first()


class DBTaskResult(GenericBase[P, T], models.Model):
Expand Down
2 changes: 1 addition & 1 deletion django_tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TASK_MIN_PRIORITY = -100
TASK_MAX_PRIORITY = 100
TASK_DEFAULT_PRIORITY = 0
MAX_WORKERS = 1
MAX_THREADS = 1

TASK_REFRESH_ATTRS = {
"errors",
Expand Down
54 changes: 19 additions & 35 deletions tests/tests/test_database_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ class DatabaseBackendWorkerTestCase(TransactionTestCase):
interval=0,
startup_delay=False,
worker_id=worker_id,
max_threads=1,
)
)

Expand All @@ -559,8 +560,7 @@ def test_run_enqueued_task(self) -> None:

self.assertEqual(result.status, TaskResultStatus.READY)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
self.assertEqual(result.attempts, 0)
Expand All @@ -582,29 +582,25 @@ def test_batch_processes_all_tasks(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 4)

with self.assertNumQueries(27 if connection.vendor == "mysql" else 19):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
self.assertEqual(DBTaskResult.objects.succeeded().count(), 3)
self.assertEqual(DBTaskResult.objects.failed().count(), 1)

def test_no_tasks(self) -> None:
with self.assertNumQueries(3):
self.run_worker()
self.run_worker()

def test_doesnt_process_different_queue(self) -> None:
result = test_tasks.noop_task.using(queue_name="queue-1").enqueue()

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(3):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(queue_name=result.task.queue_name)
self.run_worker(queue_name=result.task.queue_name)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)

Expand All @@ -613,22 +609,19 @@ def test_process_all_queues(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(3):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(queue_name="*")
self.run_worker(queue_name="*")

self.assertEqual(DBTaskResult.objects.ready().count(), 0)

def test_failing_task(self) -> None:
result = test_tasks.failing_task_value_error.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
result.refresh()
Expand Down Expand Up @@ -656,8 +649,7 @@ def test_complex_exception(self) -> None:
result = test_tasks.complex_exception.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()
self.run_worker(max_threads=1)

self.assertEqual(result.status, TaskResultStatus.READY)
result.refresh()
Expand Down Expand Up @@ -701,13 +693,11 @@ def test_doesnt_process_different_backend(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(3):
self.run_worker(backend_name="dummy")
self.run_worker(backend_name="dummy")

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(backend_name=result.backend)
self.run_worker(backend_name=result.backend)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)

Expand Down Expand Up @@ -794,8 +784,7 @@ def test_run_after(self) -> None:
self.assertEqual(DBTaskResult.objects.count(), 1)
self.assertEqual(DBTaskResult.objects.ready().count(), 0)

with self.assertNumQueries(3):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.count(), 1)
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -805,8 +794,7 @@ def test_run_after(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
self.assertEqual(DBTaskResult.objects.succeeded().count(), 1)
Expand Down Expand Up @@ -1055,7 +1043,7 @@ def test_locks_tasks_sqlite(self) -> None:
result = test_tasks.noop_task.enqueue()

with exclusive_transaction():
locked_result = DBTaskResult.objects.get_locked().first()
locked_result = DBTaskResult.objects.get_locked()

self.assertEqual(result.id, str(locked_result.id)) # type:ignore[union-attr]

Expand Down Expand Up @@ -1115,11 +1103,9 @@ def test_locks_tasks_filtered_sqlite(self) -> None:
test_tasks.noop_task.enqueue()

with exclusive_transaction():
locked_result = (
DBTaskResult.objects.filter(priority=result.task.priority)
.get_locked()
.first()
)
locked_result = DBTaskResult.objects.filter(
priority=result.task.priority
).get_locked()

self.assertEqual(result.id, str(locked_result.id))

Expand All @@ -1136,7 +1122,7 @@ def test_locks_tasks_filtered_sqlite(self) -> None:
@exclusive_transaction()
def test_lock_no_rows(self) -> None:
self.assertEqual(DBTaskResult.objects.count(), 0)
self.assertEqual(DBTaskResult.objects.all().get_locked().count(), 0)
self.assertIsNone(DBTaskResult.objects.all().get_locked())

@skipIf(connection.vendor == "sqlite", "SQLite handles locks differently")
def test_get_locked_with_locked_rows(self) -> None:
Expand Down Expand Up @@ -1577,11 +1563,9 @@ def test_interrupt_signals(self) -> None:
@skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows")
def test_repeat_ctrl_c(self) -> None:
process = self.start_worker()

try:
process.send_signal(signal.SIGINT)
time.sleep(1)

# Send a second interrupt signal to force termination
process.send_signal(signal.SIGINT)
process.wait(timeout=5)
Expand Down