99namespace Test \LanguageModel ;
1010
1111use OC \AppFramework \Bootstrap \Coordinator ;
12+ use OC \AppFramework \Bootstrap \RegistrationContext ;
13+ use OC \AppFramework \Bootstrap \ServiceRegistration ;
14+ use OC \EventDispatcher \EventDispatcher ;
15+ use OC \LanguageModel \Db \Task ;
1216use OC \LanguageModel \Db \TaskMapper ;
1317use OC \LanguageModel \LanguageModelManager ;
1418use OC \LanguageModel \TaskBackgroundJob ;
15- use OCP \BackgroundJob \IJobList ;
19+ use OCP \AppFramework \Db \DoesNotExistException ;
20+ use OCP \AppFramework \Utility \ITimeFactory ;
1621use OCP \Common \Exception \NotFoundException ;
1722use OCP \EventDispatcher \IEventDispatcher ;
1823use OCP \IServerContainer ;
@@ -82,16 +87,69 @@ class LanguageModelManagerTest extends \Test\TestCase {
8287 protected function setUp (): void {
8388 parent ::setUp ();
8489
90+ $ this ->providers = [
91+ TestVanillaLanguageModelProvider::class => new TestVanillaLanguageModelProvider (),
92+ TestFullLanguageModelProvider::class => new TestFullLanguageModelProvider (),
93+ TestFailingLanguageModelProvider::class => new TestFailingLanguageModelProvider (),
94+ ];
95+
96+ $ this ->serverContainer = $ this ->createMock (IServerContainer::class);
97+ $ this ->serverContainer ->expects ($ this ->any ())->method ('get ' )->willReturnCallback (function ($ class ) {
98+ return $ this ->providers [$ class ];
99+ });
100+
101+ $ this ->eventDispatcher = new EventDispatcher (
102+ new \Symfony \Component \EventDispatcher \EventDispatcher (),
103+ $ this ->serverContainer ,
104+ \OC ::$ server ->get (LoggerInterface::class),
105+ );
106+
107+ $ this ->registrationContext = $ this ->createMock (RegistrationContext::class);
108+ $ this ->coordinator = $ this ->createMock (Coordinator::class);
109+ $ this ->coordinator ->expects ($ this ->any ())->method ('getRegistrationContext ' )->willReturn ($ this ->registrationContext );
110+
111+ $ this ->taskMapper = $ this ->createMock (TaskMapper::class);
112+ $ this ->tasksDb = [];
113+ $ this ->taskMapper
114+ ->expects ($ this ->any ())
115+ ->method ('insert ' )
116+ ->willReturnCallback (function (Task $ task ) {
117+ $ task ->setId (count ($ this ->tasksDb ) ? max (array_keys ($ this ->tasksDb )) : 1 );
118+ $ this ->tasksDb [$ task ->getId ()] = $ task ->toRow ();
119+ return $ task ;
120+ });
121+ $ this ->taskMapper
122+ ->expects ($ this ->any ())
123+ ->method ('update ' )
124+ ->willReturnCallback (function (Task $ task ) {
125+ $ this ->tasksDb [$ task ->getId ()] = $ task ->toRow ();
126+ return $ task ;
127+ });
128+ $ this ->taskMapper
129+ ->expects ($ this ->any ())
130+ ->method ('find ' )
131+ ->willReturnCallback (function (int $ id ) {
132+ if (!isset ($ this ->tasksDb [$ id ])) {
133+ throw new DoesNotExistException ('Could not find it ' );
134+ }
135+ return Task::fromRow ($ this ->tasksDb [$ id ]);
136+ });
137+
138+ $ this ->jobList = $ this ->createPartialMock (DummyJobList::class, ['add ' ]);
139+ $ this ->jobList ->expects ($ this ->any ())->method ('add ' )->willReturnCallback (function () {
140+ });
141+
85142 $ this ->languageModelManager = new LanguageModelManager (
86- \ OC :: $ server -> get (IServerContainer::class) ,
87- $ this ->coordinator = \ OC :: $ server -> get (Coordinator::class) ,
143+ $ this -> serverContainer ,
144+ $ this ->coordinator ,
88145 \OC ::$ server ->get (LoggerInterface::class),
89- \ OC :: $ server -> get (IJobList::class) ,
90- \ OC :: $ server -> get (TaskMapper::class) ,
146+ $ this -> jobList ,
147+ $ this -> taskMapper ,
91148 );
92149 }
93150
94151 public function testShouldNotHaveAnyProviders () {
152+ $ this ->registrationContext ->expects ($ this ->any ())->method ('getLanguageModelProviders ' )->willReturn ([]);
95153 $ this ->assertCount (0 , $ this ->languageModelManager ->getAvailableTasks ());
96154 $ this ->assertCount (0 , $ this ->languageModelManager ->getAvailableTaskTypes ());
97155 $ this ->assertFalse ($ this ->languageModelManager ->hasProviders ());
@@ -100,7 +158,9 @@ public function testShouldNotHaveAnyProviders() {
100158 }
101159
102160 public function testProviderShouldBeRegisteredAndRun () {
103- $ this ->coordinator ->getRegistrationContext ()->registerLanguageModelProvider ('test ' , TestVanillaLanguageModelProvider::class);
161+ $ this ->registrationContext ->expects ($ this ->any ())->method ('getLanguageModelProviders ' )->willReturn ([
162+ new ServiceRegistration ('test ' , TestVanillaLanguageModelProvider::class)
163+ ]);
104164 $ this ->assertCount (1 , $ this ->languageModelManager ->getAvailableTasks ());
105165 $ this ->assertCount (1 , $ this ->languageModelManager ->getAvailableTaskTypes ());
106166 $ this ->assertTrue ($ this ->languageModelManager ->hasProviders ());
@@ -113,7 +173,9 @@ public function testProviderShouldBeRegisteredAndRun() {
113173
114174 public function testProviderShouldBeRegisteredAndScheduled () {
115175 // register provider
116- $ this ->coordinator ->getRegistrationContext ()->registerLanguageModelProvider ('test ' , TestVanillaLanguageModelProvider::class);
176+ $ this ->registrationContext ->expects ($ this ->any ())->method ('getLanguageModelProviders ' )->willReturn ([
177+ new ServiceRegistration ('test ' , TestVanillaLanguageModelProvider::class)
178+ ]);
117179 $ this ->assertCount (1 , $ this ->languageModelManager ->getAvailableTasks ());
118180 $ this ->assertCount (1 , $ this ->languageModelManager ->getAvailableTaskTypes ());
119181 $ this ->assertTrue ($ this ->languageModelManager ->hasProviders ());
@@ -139,18 +201,18 @@ public function testProviderShouldBeRegisteredAndScheduled() {
139201 $ this ->assertNull ($ task2 ->getOutput ());
140202 $ this ->assertEquals (ILanguageModelTask::STATUS_SCHEDULED , $ task2 ->getStatus ());
141203
142- /** @var IEventDispatcher $eventDispatcher */
143- $ eventDispatcher = \OC ::$ server ->get (IEventDispatcher::class);
204+ /** @var IEventDispatcher $this-> eventDispatcher */
205+ $ this -> eventDispatcher = \OC ::$ server ->get (IEventDispatcher::class);
144206 $ successfulEventFired = false ;
145- $ eventDispatcher ->addListener (TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $ event ) use (&$ successfulEventFired , $ task ) {
207+ $ this -> eventDispatcher ->addListener (TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $ event ) use (&$ successfulEventFired , $ task ) {
146208 $ successfulEventFired = true ;
147209 $ t = $ event ->getTask ();
148210 $ this ->assertEquals ($ task ->getId (), $ t ->getId ());
149211 $ this ->assertEquals (ILanguageModelTask::STATUS_SUCCESSFUL , $ t ->getStatus ());
150212 $ this ->assertEquals ('Hello Free Prompt ' , $ t ->getOutput ());
151213 });
152214 $ failedEventFired = false ;
153- $ eventDispatcher ->addListener (TaskFailedEvent::class, function (TaskFailedEvent $ event ) use (&$ failedEventFired , $ task ) {
215+ $ this -> eventDispatcher ->addListener (TaskFailedEvent::class, function (TaskFailedEvent $ event ) use (&$ failedEventFired , $ task ) {
154216 $ failedEventFired = true ;
155217 $ t = $ event ->getTask ();
156218 $ this ->assertEquals ($ task ->getId (), $ t ->getId ());
@@ -159,11 +221,14 @@ public function testProviderShouldBeRegisteredAndScheduled() {
159221 });
160222
161223 // run background job
162- /** @var TaskBackgroundJob $bgJob */
163- $ bgJob = \OC ::$ server ->get (TaskBackgroundJob::class);
224+ $ bgJob = new TaskBackgroundJob (
225+ \OC ::$ server ->get (ITimeFactory::class),
226+ $ this ->languageModelManager ,
227+ $ this ->eventDispatcher ,
228+ );
164229 $ bgJob ->setArgument (['taskId ' => $ task ->getId ()]);
165- $ bgJob ->start (new DummyJobList () );
166- $ provider = \ OC :: $ server -> get ( TestVanillaLanguageModelProvider::class) ;
230+ $ bgJob ->start ($ this -> jobList );
231+ $ provider = $ this -> providers [ TestVanillaLanguageModelProvider::class] ;
167232 $ this ->assertTrue ($ provider ->ran );
168233 $ this ->assertTrue ($ successfulEventFired );
169234 $ this ->assertFalse ($ failedEventFired );
@@ -173,12 +238,14 @@ public function testProviderShouldBeRegisteredAndScheduled() {
173238 $ this ->assertEquals ($ task ->getId (), $ task3 ->getId ());
174239 $ this ->assertEquals ('Hello ' , $ task3 ->getInput ());
175240 $ this ->assertEquals ('Hello Free Prompt ' , $ task3 ->getOutput ());
176- $ this ->assertEquals (ILanguageModelTask::STATUS_SUCCESSFUL , $ task2 ->getStatus ());
241+ $ this ->assertEquals (ILanguageModelTask::STATUS_SUCCESSFUL , $ task3 ->getStatus ());
177242 }
178243
179244 public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly () {
180- $ this ->coordinator ->getRegistrationContext ()->registerLanguageModelProvider ('test ' , TestVanillaLanguageModelProvider::class);
181- $ this ->coordinator ->getRegistrationContext ()->registerLanguageModelProvider ('test ' , TestFullLanguageModelProvider::class);
245+ $ this ->registrationContext ->expects ($ this ->any ())->method ('getLanguageModelProviders ' )->willReturn ([
246+ new ServiceRegistration ('test ' , TestVanillaLanguageModelProvider::class),
247+ new ServiceRegistration ('test ' , TestFullLanguageModelProvider::class),
248+ ]);
182249 $ this ->assertCount (3 , $ this ->languageModelManager ->getAvailableTasks ());
183250 $ this ->assertCount (3 , $ this ->languageModelManager ->getAvailableTaskTypes ());
184251 $ this ->assertTrue ($ this ->languageModelManager ->hasProviders ());
@@ -204,7 +271,9 @@ public function testNonexistentTask() {
204271
205272 public function testTaskFailure () {
206273 // register provider
207- $ this ->coordinator ->getRegistrationContext ()->registerLanguageModelProvider ('test ' , TestFailingLanguageModelProvider::class);
274+ $ this ->registrationContext ->expects ($ this ->any ())->method ('getLanguageModelProviders ' )->willReturn ([
275+ new ServiceRegistration ('test ' , TestFailingLanguageModelProvider::class),
276+ ]);
208277 $ this ->assertCount (1 , $ this ->languageModelManager ->getAvailableTasks ());
209278 $ this ->assertCount (1 , $ this ->languageModelManager ->getAvailableTaskTypes ());
210279 $ this ->assertTrue ($ this ->languageModelManager ->hasProviders ());
@@ -230,18 +299,16 @@ public function testTaskFailure() {
230299 $ this ->assertNull ($ task2 ->getOutput ());
231300 $ this ->assertEquals (ILanguageModelTask::STATUS_SCHEDULED , $ task2 ->getStatus ());
232301
233- /** @var IEventDispatcher $eventDispatcher */
234- $ eventDispatcher = \OC ::$ server ->get (IEventDispatcher::class);
235302 $ successfulEventFired = false ;
236- $ eventDispatcher ->addListener (TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $ event ) use (&$ successfulEventFired , $ task ) {
303+ $ this -> eventDispatcher ->addListener (TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $ event ) use (&$ successfulEventFired , $ task ) {
237304 $ successfulEventFired = true ;
238305 $ t = $ event ->getTask ();
239306 $ this ->assertEquals ($ task ->getId (), $ t ->getId ());
240307 $ this ->assertEquals (ILanguageModelTask::STATUS_SUCCESSFUL , $ t ->getStatus ());
241308 $ this ->assertEquals ('Hello Free Prompt ' , $ t ->getOutput ());
242309 });
243310 $ failedEventFired = false ;
244- $ eventDispatcher ->addListener (TaskFailedEvent::class, function (TaskFailedEvent $ event ) use (&$ failedEventFired , $ task ) {
311+ $ this -> eventDispatcher ->addListener (TaskFailedEvent::class, function (TaskFailedEvent $ event ) use (&$ failedEventFired , $ task ) {
245312 $ failedEventFired = true ;
246313 $ t = $ event ->getTask ();
247314 $ this ->assertEquals ($ task ->getId (), $ t ->getId ());
@@ -250,11 +317,14 @@ public function testTaskFailure() {
250317 });
251318
252319 // run background job
253- /** @var TaskBackgroundJob $bgJob */
254- $ bgJob = \OC ::$ server ->get (TaskBackgroundJob::class);
320+ $ bgJob = new TaskBackgroundJob (
321+ \OC ::$ server ->get (ITimeFactory::class),
322+ $ this ->languageModelManager ,
323+ $ this ->eventDispatcher ,
324+ );
255325 $ bgJob ->setArgument (['taskId ' => $ task ->getId ()]);
256- $ bgJob ->start (new DummyJobList () );
257- $ provider = \ OC :: $ server -> get ( TestFailingLanguageModelProvider::class) ;
326+ $ bgJob ->start ($ this -> jobList );
327+ $ provider = $ this -> providers [ TestFailingLanguageModelProvider::class] ;
258328 $ this ->assertTrue ($ provider ->ran );
259329 $ this ->assertTrue ($ failedEventFired );
260330 $ this ->assertFalse ($ successfulEventFired );
@@ -264,6 +334,6 @@ public function testTaskFailure() {
264334 $ this ->assertEquals ($ task ->getId (), $ task3 ->getId ());
265335 $ this ->assertEquals ('Hello ' , $ task3 ->getInput ());
266336 $ this ->assertNull ($ task3 ->getOutput ());
267- $ this ->assertEquals (ILanguageModelTask::STATUS_FAILED , $ task2 ->getStatus ());
337+ $ this ->assertEquals (ILanguageModelTask::STATUS_FAILED , $ task3 ->getStatus ());
268338 }
269339}
0 commit comments