Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions django_tasks/backends/database/management/commands/db_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import signal
import sys
import threading
import time
from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction
from types import FrameType
Expand All @@ -19,7 +20,7 @@
from django_tasks.backends.database.backend import DatabaseBackend
from django_tasks.backends.database.models import DBTaskResult
from django_tasks.backends.database.utils import exclusive_transaction
from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, TaskContext
from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_WORKERS, TaskContext
from django_tasks.exceptions import InvalidTaskBackendError
from django_tasks.signals import task_finished, task_started
from django_tasks.utils import get_random_id
Expand All @@ -39,6 +40,7 @@ def __init__(
startup_delay: bool,
max_tasks: int | None,
worker_id: str,
max_workers: int,
):
self.queue_names = queue_names
self.process_all_queues = "*" in queue_names
Expand All @@ -47,6 +49,7 @@ def __init__(
self.backend_name = backend_name
self.startup_delay = startup_delay
self.max_tasks = max_tasks
self.max_workers = max_workers

self.running = True
self.running_task = False
Expand Down Expand Up @@ -105,23 +108,32 @@ def run(self) -> None:
# it be as efficient as possible.
with exclusive_transaction(tasks.db):
try:
task_result = tasks.get_locked()
task_results = list(tasks.get_locked(self.max_workers))
except OperationalError as e:
# Ignore locked databases and keep trying.
# It should unlock eventually.
if "is locked" in e.args[0]:
task_result = None
task_results = None
else:
raise

if task_result is not None:
if task_results is not None and len(task_results) > 0:
# "claim" the task, so it isn't run by another worker process
task_result.claim(self.worker_id)
for task_result in task_results:
task_result.claim(self.worker_id)

if task_result is not None:
self.run_task(task_result)
if task_results is not None and len(task_results) > 0:
threads = []
for task_result in task_results:
thread = threading.Thread(target=self.run_task, args=(task_result,))
thread.start()
threads.append(thread)

if self.batch and task_result is None:
# Wait for all threads to complete
for thread in threads:
thread.join()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: I don't think this approach is ideal. If a worker process is set to run 5 threads, and receives 4 fast tasks and 1 long task, the worker will sit processing the long task and never pick up the extra 4 tasks is has capacity for.


if self.batch and (task_results is None or len(task_results) == 0):
# If we're running in "batch" mode, terminate the loop (and thus the worker)
logger.info(
"No more tasks to run for worker_id=%s - exiting gracefully.",
Expand All @@ -143,7 +155,7 @@ def run(self) -> None:

# If ctrl-c has just interrupted a task, self.running was cleared,
# and we should not sleep, but rather exit immediately.
if self.running and not task_result:
if self.running and not task_results:
# Wait before checking for another task
time.sleep(self.interval)

Expand Down Expand Up @@ -282,6 +294,13 @@ def add_arguments(self, parser: ArgumentParser) -> None:
help="Worker id. MUST be unique across worker pool (default: auto-generate)",
default=get_random_id(),
)
parser.add_argument(
"--max-workers",
nargs="?",
type=valid_max_tasks,
help="Maximum number of worker threads to process tasks concurrently (default: %(default)r)",
default=MAX_WORKERS,
)

def configure_logging(self, verbosity: int) -> None:
if verbosity == 0:
Expand All @@ -308,6 +327,7 @@ def handle(
reload: bool,
max_tasks: int | None,
worker_id: str,
max_workers: int,
**options: dict,
) -> None:
self.configure_logging(verbosity)
Expand All @@ -326,6 +346,7 @@ def handle(
startup_delay=startup_delay,
max_tasks=max_tasks,
worker_id=worker_id,
max_workers=max_workers,
)

if reload:
Expand Down
8 changes: 4 additions & 4 deletions django_tasks/backends/database/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime
import logging
import uuid
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import django
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.db import models
from django.db.models import F, Q
from django.db.models import F, Q, QuerySet
from django.db.models.constraints import CheckConstraint
from django.utils import timezone
from django.utils.module_loading import import_string
Expand Down Expand Up @@ -80,11 +80,11 @@ def finished(self) -> "DBTaskResultQuerySet":
return self.failed() | self.succeeded()

@retry()
def get_locked(self) -> Optional["DBTaskResult"]:
def get_locked(self, size: int = 1) -> QuerySet["DBTaskResult"]:
"""
Get a job, locking the row and accounting for deadlocks.
"""
return self.select_for_update(skip_locked=True).first()
return self.select_for_update(skip_locked=True)[:size]


class DBTaskResult(GenericBase[P, T], models.Model):
Expand Down
1 change: 1 addition & 0 deletions django_tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TASK_MIN_PRIORITY = -100
TASK_MAX_PRIORITY = 100
TASK_DEFAULT_PRIORITY = 0
MAX_WORKERS = 1

TASK_REFRESH_ATTRS = {
"errors",
Expand Down
70 changes: 27 additions & 43 deletions tests/tests/test_database_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None:

self.assertEqual(result.status, TaskResultStatus.READY)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
Expand All @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 4)

with self.assertNumQueries(27 if connection.vendor == "mysql" else 23):
with self.assertNumQueries(27 if connection.vendor == "mysql" else 19):
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -603,7 +603,7 @@ def test_doesnt_process_different_queue(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(queue_name=result.task.queue_name)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -618,7 +618,7 @@ def test_process_all_queues(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(queue_name="*")

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -627,7 +627,7 @@ def test_failing_task(self) -> None:
result = test_tasks.failing_task_value_error.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
Expand Down Expand Up @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None:
result = test_tasks.complex_exception.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
Expand Down Expand Up @@ -706,7 +706,7 @@ def test_doesnt_process_different_backend(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(backend_name=result.backend)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand Down Expand Up @@ -805,7 +805,7 @@ def test_run_after(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def test_locks_tasks_sqlite(self) -> None:
result = test_tasks.noop_task.enqueue()

with exclusive_transaction():
locked_result = DBTaskResult.objects.get_locked()
locked_result = DBTaskResult.objects.get_locked().first()

self.assertEqual(result.id, str(locked_result.id)) # type:ignore[union-attr]

Expand Down Expand Up @@ -1115,9 +1115,11 @@ def test_locks_tasks_filtered_sqlite(self) -> None:
test_tasks.noop_task.enqueue()

with exclusive_transaction():
locked_result = DBTaskResult.objects.filter(
priority=result.task.priority
).get_locked()
locked_result = (
DBTaskResult.objects.filter(priority=result.task.priority)
.get_locked()
.first()
)

self.assertEqual(result.id, str(locked_result.id))

Expand All @@ -1134,7 +1136,7 @@ def test_locks_tasks_filtered_sqlite(self) -> None:
@exclusive_transaction()
def test_lock_no_rows(self) -> None:
self.assertEqual(DBTaskResult.objects.count(), 0)
self.assertIsNone(DBTaskResult.objects.all().get_locked())
self.assertEqual(DBTaskResult.objects.all().get_locked().count(), 0)

@skipIf(connection.vendor == "sqlite", "SQLite handles locks differently")
def test_get_locked_with_locked_rows(self) -> None:
Expand Down Expand Up @@ -1574,38 +1576,20 @@ def test_interrupt_signals(self) -> None:

@skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows")
def test_repeat_ctrl_c(self) -> None:
result = test_tasks.hang.enqueue()
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [])

worker_id = get_random_id()

process = self.start_worker(worker_id=worker_id)

# Make sure the task is running by now
time.sleep(self.WORKER_STARTUP_TIME)

result.refresh()
self.assertEqual(result.status, TaskResultStatus.RUNNING)
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id])

process.send_signal(signal.SIGINT)

time.sleep(0.5)

self.assertIsNone(process.poll())
result.refresh()
self.assertEqual(result.status, TaskResultStatus.RUNNING)
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id])

process.send_signal(signal.SIGINT)

process.wait(timeout=2)
process = self.start_worker()

self.assertEqual(process.returncode, 0)
try:
process.send_signal(signal.SIGINT)
time.sleep(1)

result.refresh()
self.assertEqual(result.status, TaskResultStatus.FAILED)
self.assertEqual(result.errors[0].exception_class, SystemExit)
# Send a second interrupt signal to force termination
process.send_signal(signal.SIGINT)
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.terminate()
process.wait(timeout=5)
finally:
self.assertEqual(process.poll(), -2)

@skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL")
def test_kill(self) -> None:
Expand Down
Loading