Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions lib/private/TaskProcessing/Manager.php
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class Manager implements IManager {

public const MAX_TASK_AGE_SECONDS = 60 * 60 * 24 * 30 * 4; // 4 months

private const TASK_TYPES_CACHE_KEY = 'available_task_types_v2';
private const TASK_TYPE_IDS_CACHE_KEY = 'available_task_type_ids';

/** @var list<IProvider>|null */
private ?array $providers = null;

Expand All @@ -89,6 +92,9 @@ class Manager implements IManager {
*/
private ?array $availableTaskTypes = null;

/** @var list<string>|null */
private ?array $availableTaskTypeIds = null;

private IAppData $appData;
private ?array $preferences = null;
private ?array $providersById = null;
Expand Down Expand Up @@ -834,7 +840,7 @@ public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userI
return [];
}
if ($this->availableTaskTypes === null) {
$cachedValue = $this->distributedCache->get('available_task_types_v2');
$cachedValue = $this->distributedCache->get(self::TASK_TYPES_CACHE_KEY);
if ($cachedValue !== null) {
$this->availableTaskTypes = unserialize($cachedValue);
}
Expand Down Expand Up @@ -880,12 +886,53 @@ public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userI
}

$this->availableTaskTypes = $availableTaskTypes;
$this->distributedCache->set('available_task_types_v2', serialize($this->availableTaskTypes), 60);
$this->distributedCache->set(self::TASK_TYPES_CACHE_KEY, serialize($this->availableTaskTypes), 60);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

}


return $this->availableTaskTypes;
}
public function getAvailableTaskTypeIds(bool $showDisabled = false, ?string $userId = null): array {
// userId will be obtained from the session if left to null
if (!$this->checkGuestAccess($userId)) {
return [];
}
if ($this->availableTaskTypeIds === null) {
$cachedValue = $this->distributedCache->get(self::TASK_TYPE_IDS_CACHE_KEY);
if ($cachedValue !== null) {
$this->availableTaskTypeIds = $cachedValue;
}
}
// Either we have no cache or showDisabled is turned on, which we don't want to cache, ever.
if ($this->availableTaskTypeIds === null || $showDisabled) {
$taskTypes = $this->_getTaskTypes();
$taskTypeSettings = $this->_getTaskTypeSettings();

$availableTaskTypeIds = [];
foreach ($taskTypes as $taskType) {
if ((!$showDisabled) && isset($taskTypeSettings[$taskType->getId()]) && !$taskTypeSettings[$taskType->getId()]) {
continue;
}
try {
$provider = $this->getPreferredProvider($taskType->getId());
} catch (\OCP\TaskProcessing\Exception\Exception $e) {
continue;
}
$availableTaskTypeIds[] = $taskType->getId();
}

if ($showDisabled) {
// Do not cache showDisabled, ever.
return $availableTaskTypeIds;
}

$this->availableTaskTypeIds = $availableTaskTypeIds;
$this->distributedCache->set(self::TASK_TYPE_IDS_CACHE_KEY, $this->availableTaskTypeIds, 60);
}


return $this->availableTaskTypeIds;
}

public function canHandleTask(Task $task): bool {
return isset($this->getAvailableTaskTypes()[$task->getTaskTypeId()]);
Expand Down
10 changes: 9 additions & 1 deletion lib/public/TaskProcessing/IManager.php
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public function getProviders(): array;
public function getPreferredProvider(string $taskTypeId);

/**
* @param bool $showDisabled if false, disabled task types will be filtered
* @param bool $showDisabled if false, disabled task types will be filtered out
* @param ?string $userId to check if the user is a guest. Will be obtained from session if left to default
* @return array<string, array{name: string, description: string, inputShape: ShapeDescriptor[], inputShapeEnumValues: ShapeEnumValue[][], inputShapeDefaults: array<array-key, numeric|string>, optionalInputShape: ShapeDescriptor[], optionalInputShapeEnumValues: ShapeEnumValue[][], optionalInputShapeDefaults: array<array-key, numeric|string>, outputShape: ShapeDescriptor[], outputShapeEnumValues: ShapeEnumValue[][], optionalOutputShape: ShapeDescriptor[], optionalOutputShapeEnumValues: ShapeEnumValue[][]}>
* @since 30.0.0
Expand All @@ -56,6 +56,14 @@ public function getPreferredProvider(string $taskTypeId);
*/
public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userId = null): array;

/**
* @param bool $showDisabled if false, disabled task types will be filtered out
* @param ?string $userId to check if the user is a guest. Will be obtained from session if left to default
* @return list<string>
* @since 32.0.0
*/
public function getAvailableTaskTypeIds(bool $showDisabled = false, ?string $userId = null): array;

/**
* @param Task $task The task to run
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called
Expand Down
20 changes: 20 additions & 0 deletions tests/lib/TaskProcessing/TaskProcessingTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ private function getFile(string $name, string $content): File {
public function testShouldNotHaveAnyProviders(): void {
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]);
self::assertCount(0, $this->manager->getAvailableTaskTypes());
self::assertCount(0, $this->manager->getAvailableTaskTypeIds());
self::assertFalse($this->manager->hasProviders());
self::expectException(PreConditionNotMetException::class);
$this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null));
Expand All @@ -647,6 +648,8 @@ public function testProviderShouldBeRegisteredAndTaskTypeDisabled(): void {
$this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true);
self::assertCount(0, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypes(true));
self::assertCount(0, $this->manager->getAvailableTaskTypeIds());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds(true));
self::assertTrue($this->manager->hasProviders());
self::expectException(PreConditionNotMetException::class);
$this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null));
Expand All @@ -659,6 +662,7 @@ public function testProviderShouldBeRegisteredAndTaskFailValidation(): void {
new ServiceRegistration('test', BrokenSyncProvider::class)
]);
self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToText::ID, ['wrongInputKey' => 'Hello'], 'test', null);
self::assertNull($task->getId());
Expand All @@ -680,6 +684,7 @@ public function testProviderShouldBeRegisteredAndTaskWithFilesFailValidation():
$this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);

self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
self::assertTrue($this->manager->hasProviders());

$audioId = $this->getFile('audioInput', 'Hello')->getId();
Expand All @@ -695,6 +700,7 @@ public function testProviderShouldBeRegisteredAndFail(): void {
new ServiceRegistration('test', FailingSyncProvider::class)
]);
self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
self::assertNull($task->getId());
Expand Down Expand Up @@ -723,6 +729,7 @@ public function testProviderShouldBeRegisteredAndFailOutputValidation(): void {
new ServiceRegistration('test', BrokenSyncProvider::class)
]);
self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woop woop, even tests! 😍

self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
self::assertNull($task->getId());
Expand Down Expand Up @@ -751,6 +758,7 @@ public function testProviderShouldBeRegisteredAndRun(): void {
new ServiceRegistration('test', SuccessfulSyncProvider::class)
]);
self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
$taskTypeStruct = $this->manager->getAvailableTaskTypes()[array_keys($this->manager->getAvailableTaskTypes())[0]];
self::assertTrue(isset($taskTypeStruct['inputShape']['input']));
self::assertEquals(EShapeType::Text, $taskTypeStruct['inputShape']['input']->getShapeType());
Expand Down Expand Up @@ -803,6 +811,7 @@ public function testTaskTypeExplicitlyEnabled(): void {
$this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true);

self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());

self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
Expand Down Expand Up @@ -843,6 +852,7 @@ public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningRawFi
$this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);

self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());

self::assertTrue($this->manager->hasProviders());
$audioId = $this->getFile('audioInput', 'Hello')->getId();
Expand Down Expand Up @@ -893,6 +903,7 @@ public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningFileI
$mount->expects($this->any())->method('getUser')->willReturn($user);
$this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);
self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());

self::assertTrue($this->manager->hasProviders());
$audioId = $this->getFile('audioInput', 'Hello')->getId();
Expand Down Expand Up @@ -952,6 +963,7 @@ public function testOldTasksShouldBeCleanedUp(): void {
new ServiceRegistration('test', SuccessfulSyncProvider::class)
]);
self::assertCount(1, $this->manager->getAvailableTaskTypes());
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
$this->manager->scheduleTask($task);
Expand Down Expand Up @@ -992,6 +1004,7 @@ public function testShouldTransparentlyHandleTextProcessingProviders(): void {
]);
$taskTypes = $this->manager->getAvailableTaskTypes();
self::assertCount(1, $taskTypes);
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
Expand Down Expand Up @@ -1023,6 +1036,7 @@ public function testShouldTransparentlyHandleFailingTextProcessingProviders(): v
]);
$taskTypes = $this->manager->getAvailableTaskTypes();
self::assertCount(1, $taskTypes);
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
Expand Down Expand Up @@ -1053,6 +1067,7 @@ public function testShouldTransparentlyHandleText2ImageProviders(): void {
]);
$taskTypes = $this->manager->getAvailableTaskTypes();
self::assertCount(1, $taskTypes);
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
self::assertTrue(isset($taskTypes[TextToImage::ID]));
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
Expand Down Expand Up @@ -1089,6 +1104,7 @@ public function testShouldTransparentlyHandleFailingText2ImageProviders(): void
]);
$taskTypes = $this->manager->getAvailableTaskTypes();
self::assertCount(1, $taskTypes);
self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
self::assertTrue(isset($taskTypes[TextToImage::ID]));
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
Expand Down Expand Up @@ -1178,6 +1194,7 @@ public function testGetAvailableTaskTypesIncludesExternalViaEvent() {

// Assert
self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes);
self::assertContains(ExternalTaskType::ID, $this->manager->getAvailableTaskTypeIds());
self::assertEquals(ExternalTaskType::ID, $externalProvider->getTaskTypeId(), 'Test Sanity: Provider must handle the Task Type');
self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']);
// Check if shapes match the external type/provider
Expand Down Expand Up @@ -1230,11 +1247,14 @@ public function testMergeTaskTypesLocalAndEvent() {

// Act
$availableTypes = $this->manager->getAvailableTaskTypes();
$availableTypeIds = $this->manager->getAvailableTaskTypeIds();

// Assert: Both task types should be available
self::assertContains(AudioToImage::ID, $availableTypeIds);
self::assertArrayHasKey(AudioToImage::ID, $availableTypes);
self::assertEquals(AudioToImage::class, $availableTypes[AudioToImage::ID]['name']);

self::assertContains(ExternalTaskType::ID, $availableTypeIds);
self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes);
self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']);

Expand Down
Loading