diff --git a/lib/private/TaskProcessing/Manager.php b/lib/private/TaskProcessing/Manager.php index e288f2981a817..df804c33c8e57 100644 --- a/lib/private/TaskProcessing/Manager.php +++ b/lib/private/TaskProcessing/Manager.php @@ -81,6 +81,9 @@ class Manager implements IManager { public const MAX_TASK_AGE_SECONDS = 60 * 60 * 24 * 30 * 4; // 4 months + private const TASK_TYPES_CACHE_KEY = 'available_task_types_v2'; + private const TASK_TYPE_IDS_CACHE_KEY = 'available_task_type_ids'; + /** @var list|null */ private ?array $providers = null; @@ -89,6 +92,9 @@ class Manager implements IManager { */ private ?array $availableTaskTypes = null; + /** @var list|null */ + private ?array $availableTaskTypeIds = null; + private IAppData $appData; private ?array $preferences = null; private ?array $providersById = null; @@ -834,7 +840,7 @@ public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userI return []; } if ($this->availableTaskTypes === null) { - $cachedValue = $this->distributedCache->get('available_task_types_v2'); + $cachedValue = $this->distributedCache->get(self::TASK_TYPES_CACHE_KEY); if ($cachedValue !== null) { $this->availableTaskTypes = unserialize($cachedValue); } @@ -880,12 +886,53 @@ public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userI } $this->availableTaskTypes = $availableTaskTypes; - $this->distributedCache->set('available_task_types_v2', serialize($this->availableTaskTypes), 60); + $this->distributedCache->set(self::TASK_TYPES_CACHE_KEY, serialize($this->availableTaskTypes), 60); } return $this->availableTaskTypes; } + public function getAvailableTaskTypeIds(bool $showDisabled = false, ?string $userId = null): array { + // userId will be obtained from the session if left to null + if (!$this->checkGuestAccess($userId)) { + return []; + } + if ($this->availableTaskTypeIds === null) { + $cachedValue = $this->distributedCache->get(self::TASK_TYPE_IDS_CACHE_KEY); + if ($cachedValue !== null) { + $this->availableTaskTypeIds = $cachedValue; + } + } + // Either we have no cache or showDisabled is turned on, which we don't want to cache, ever. + if ($this->availableTaskTypeIds === null || $showDisabled) { + $taskTypes = $this->_getTaskTypes(); + $taskTypeSettings = $this->_getTaskTypeSettings(); + + $availableTaskTypeIds = []; + foreach ($taskTypes as $taskType) { + if ((!$showDisabled) && isset($taskTypeSettings[$taskType->getId()]) && !$taskTypeSettings[$taskType->getId()]) { + continue; + } + try { + $provider = $this->getPreferredProvider($taskType->getId()); + } catch (\OCP\TaskProcessing\Exception\Exception $e) { + continue; + } + $availableTaskTypeIds[] = $taskType->getId(); + } + + if ($showDisabled) { + // Do not cache showDisabled, ever. + return $availableTaskTypeIds; + } + + $this->availableTaskTypeIds = $availableTaskTypeIds; + $this->distributedCache->set(self::TASK_TYPE_IDS_CACHE_KEY, $this->availableTaskTypeIds, 60); + } + + + return $this->availableTaskTypeIds; + } public function canHandleTask(Task $task): bool { return isset($this->getAvailableTaskTypes()[$task->getTaskTypeId()]); diff --git a/lib/public/TaskProcessing/IManager.php b/lib/public/TaskProcessing/IManager.php index 731250d7aa1c2..28beeafc88474 100644 --- a/lib/public/TaskProcessing/IManager.php +++ b/lib/public/TaskProcessing/IManager.php @@ -47,7 +47,7 @@ public function getProviders(): array; public function getPreferredProvider(string $taskTypeId); /** - * @param bool $showDisabled if false, disabled task types will be filtered + * @param bool $showDisabled if false, disabled task types will be filtered out * @param ?string $userId to check if the user is a guest. Will be obtained from session if left to default * @return array, optionalInputShape: ShapeDescriptor[], optionalInputShapeEnumValues: ShapeEnumValue[][], optionalInputShapeDefaults: array, outputShape: ShapeDescriptor[], outputShapeEnumValues: ShapeEnumValue[][], optionalOutputShape: ShapeDescriptor[], optionalOutputShapeEnumValues: ShapeEnumValue[][]}> * @since 30.0.0 @@ -56,6 +56,14 @@ public function getPreferredProvider(string $taskTypeId); */ public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userId = null): array; + /** + * @param bool $showDisabled if false, disabled task types will be filtered out + * @param ?string $userId to check if the user is a guest. Will be obtained from session if left to default + * @return list + * @since 32.0.0 + */ + public function getAvailableTaskTypeIds(bool $showDisabled = false, ?string $userId = null): array; + /** * @param Task $task The task to run * @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called diff --git a/tests/lib/TaskProcessing/TaskProcessingTest.php b/tests/lib/TaskProcessing/TaskProcessingTest.php index d2f619da3495d..8b7dba22dea8a 100644 --- a/tests/lib/TaskProcessing/TaskProcessingTest.php +++ b/tests/lib/TaskProcessing/TaskProcessingTest.php @@ -632,6 +632,7 @@ private function getFile(string $name, string $content): File { public function testShouldNotHaveAnyProviders(): void { $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); self::assertCount(0, $this->manager->getAvailableTaskTypes()); + self::assertCount(0, $this->manager->getAvailableTaskTypeIds()); self::assertFalse($this->manager->hasProviders()); self::expectException(PreConditionNotMetException::class); $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null)); @@ -647,6 +648,8 @@ public function testProviderShouldBeRegisteredAndTaskTypeDisabled(): void { $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true); self::assertCount(0, $this->manager->getAvailableTaskTypes()); self::assertCount(1, $this->manager->getAvailableTaskTypes(true)); + self::assertCount(0, $this->manager->getAvailableTaskTypeIds()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds(true)); self::assertTrue($this->manager->hasProviders()); self::expectException(PreConditionNotMetException::class); $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null)); @@ -659,6 +662,7 @@ public function testProviderShouldBeRegisteredAndTaskFailValidation(): void { new ServiceRegistration('test', BrokenSyncProvider::class) ]); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToText::ID, ['wrongInputKey' => 'Hello'], 'test', null); self::assertNull($task->getId()); @@ -680,6 +684,7 @@ public function testProviderShouldBeRegisteredAndTaskWithFilesFailValidation(): $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue($this->manager->hasProviders()); $audioId = $this->getFile('audioInput', 'Hello')->getId(); @@ -695,6 +700,7 @@ public function testProviderShouldBeRegisteredAndFail(): void { new ServiceRegistration('test', FailingSyncProvider::class) ]); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); self::assertNull($task->getId()); @@ -723,6 +729,7 @@ public function testProviderShouldBeRegisteredAndFailOutputValidation(): void { new ServiceRegistration('test', BrokenSyncProvider::class) ]); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); self::assertNull($task->getId()); @@ -751,6 +758,7 @@ public function testProviderShouldBeRegisteredAndRun(): void { new ServiceRegistration('test', SuccessfulSyncProvider::class) ]); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); $taskTypeStruct = $this->manager->getAvailableTaskTypes()[array_keys($this->manager->getAvailableTaskTypes())[0]]; self::assertTrue(isset($taskTypeStruct['inputShape']['input'])); self::assertEquals(EShapeType::Text, $taskTypeStruct['inputShape']['input']->getShapeType()); @@ -803,6 +811,7 @@ public function testTaskTypeExplicitlyEnabled(): void { $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); @@ -843,6 +852,7 @@ public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningRawFi $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue($this->manager->hasProviders()); $audioId = $this->getFile('audioInput', 'Hello')->getId(); @@ -893,6 +903,7 @@ public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningFileI $mount->expects($this->any())->method('getUser')->willReturn($user); $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue($this->manager->hasProviders()); $audioId = $this->getFile('audioInput', 'Hello')->getId(); @@ -952,6 +963,7 @@ public function testOldTasksShouldBeCleanedUp(): void { new ServiceRegistration('test', SuccessfulSyncProvider::class) ]); self::assertCount(1, $this->manager->getAvailableTaskTypes()); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); $this->manager->scheduleTask($task); @@ -992,6 +1004,7 @@ public function testShouldTransparentlyHandleTextProcessingProviders(): void { ]); $taskTypes = $this->manager->getAvailableTaskTypes(); self::assertCount(1, $taskTypes); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); @@ -1023,6 +1036,7 @@ public function testShouldTransparentlyHandleFailingTextProcessingProviders(): v ]); $taskTypes = $this->manager->getAvailableTaskTypes(); self::assertCount(1, $taskTypes); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); @@ -1053,6 +1067,7 @@ public function testShouldTransparentlyHandleText2ImageProviders(): void { ]); $taskTypes = $this->manager->getAvailableTaskTypes(); self::assertCount(1, $taskTypes); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue(isset($taskTypes[TextToImage::ID])); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); @@ -1089,6 +1104,7 @@ public function testShouldTransparentlyHandleFailingText2ImageProviders(): void ]); $taskTypes = $this->manager->getAvailableTaskTypes(); self::assertCount(1, $taskTypes); + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); self::assertTrue(isset($taskTypes[TextToImage::ID])); self::assertTrue($this->manager->hasProviders()); $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); @@ -1178,6 +1194,7 @@ public function testGetAvailableTaskTypesIncludesExternalViaEvent() { // Assert self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes); + self::assertContains(ExternalTaskType::ID, $this->manager->getAvailableTaskTypeIds()); self::assertEquals(ExternalTaskType::ID, $externalProvider->getTaskTypeId(), 'Test Sanity: Provider must handle the Task Type'); self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']); // Check if shapes match the external type/provider @@ -1230,11 +1247,14 @@ public function testMergeTaskTypesLocalAndEvent() { // Act $availableTypes = $this->manager->getAvailableTaskTypes(); + $availableTypeIds = $this->manager->getAvailableTaskTypeIds(); // Assert: Both task types should be available + self::assertContains(AudioToImage::ID, $availableTypeIds); self::assertArrayHasKey(AudioToImage::ID, $availableTypes); self::assertEquals(AudioToImage::class, $availableTypes[AudioToImage::ID]['name']); + self::assertContains(ExternalTaskType::ID, $availableTypeIds); self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes); self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']);