@@ -69,6 +69,48 @@ object DifferentiableKernel {
6969 type OutputDelta >: OutputDelta0
7070 }
7171
72+ object Tapes {
73+ final class Fill [OutputElementData , OutputElementDelta ](inputParameterMap : Map [Any , Tape ],
74+ override val value : OpenCL .Buffer [OutputElementData ])
75+ extends CumulativeTape {
76+ override def isTrainable : Boolean = ???
77+
78+ private var deltaAccumulator : Future .Stateful [Option [Delta ]] = {
79+ new Constant (None )
80+ }
81+
82+ override def forceBackward (delta : Delta ) = ???
83+
84+ override type Delta = OpenCL .Buffer [OutputElementDelta ]
85+ override type Data = OpenCL .Buffer [OutputElementData ]
86+
87+ private def rawBackward (deltaFuture : Future .Stateful [Option [Delta ]]): Unit = {
88+ implicit def catcher = PartialFunction .empty
89+ for (deltaOption <- deltaFuture) {
90+ deltaOption match {
91+ case None =>
92+ case Some (delta) =>
93+ ???
94+ }
95+ }
96+ }
97+
98+ override protected def flush (): Unit = {
99+ rawBackward(synchronized {
100+ val delta = deltaAccumulator
101+ deltaAccumulator = new Constant (None )
102+ delta
103+ })
104+ }
105+
106+ override protected def closeUpstreams (): Unit = {
107+ for (upstream <- inputParameterMap.values) {
108+ upstream.close()
109+ }
110+ }
111+ }
112+ }
113+
72114 trait InputMetadata {
73115
74116 type Data
@@ -213,45 +255,7 @@ object DifferentiableKernel {
213255 } finally {
214256 event.close()
215257 }
216- new CumulativeTape {
217- override def isTrainable : Boolean = ???
218-
219- private var deltaAccumulator : Future .Stateful [Option [Delta ]] = {
220- new Constant (None )
221- }
222-
223- override val value = outputBuffer
224-
225- override def forceBackward (delta : Delta ) = ???
226-
227- override type Delta = OutputDelta
228- override type Data = OutputData
229-
230- private def rawBackward (deltaFuture : Future .Stateful [Option [Delta ]]): Unit = {
231- implicit def catcher = PartialFunction .empty
232- for (deltaOption <- deltaFuture) {
233- deltaOption match {
234- case None =>
235- case Some (delta) =>
236- ???
237- }
238- }
239- }
240-
241- override protected def flush (): Unit = {
242- rawBackward(synchronized {
243- val delta = deltaAccumulator
244- deltaAccumulator = new Constant (None )
245- delta
246- })
247- }
248-
249- override protected def closeUpstreams (): Unit = {
250- for (upstream <- inputParameterMap.values) {
251- upstream.close()
252- }
253- }
254- }
258+ new Tapes .Fill (inputParameterMap, outputBuffer)
255259 }
256260 }
257261 }
0 commit comments