Skip to content

Commit 0d554e6

Browse files
committed
Fix compilation error on Scala 2.11
1 parent fe5827e commit 0d554e6

File tree

1 file changed

+43
-39
lines changed

1 file changed

+43
-39
lines changed

DifferentiableKernel/src/main/scala/com/thoughtworks/deeplearning/DifferentiableKernel.scala

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,48 @@ object DifferentiableKernel {
6969
type OutputDelta >: OutputDelta0
7070
}
7171

72+
object Tapes {
73+
final class Fill[OutputElementData, OutputElementDelta](inputParameterMap: Map[Any, Tape],
74+
override val value: OpenCL.Buffer[OutputElementData])
75+
extends CumulativeTape {
76+
override def isTrainable: Boolean = ???
77+
78+
private var deltaAccumulator: Future.Stateful[Option[Delta]] = {
79+
new Constant(None)
80+
}
81+
82+
override def forceBackward(delta: Delta) = ???
83+
84+
override type Delta = OpenCL.Buffer[OutputElementDelta]
85+
override type Data = OpenCL.Buffer[OutputElementData]
86+
87+
private def rawBackward(deltaFuture: Future.Stateful[Option[Delta]]): Unit = {
88+
implicit def catcher = PartialFunction.empty
89+
for (deltaOption <- deltaFuture) {
90+
deltaOption match {
91+
case None =>
92+
case Some(delta) =>
93+
???
94+
}
95+
}
96+
}
97+
98+
override protected def flush(): Unit = {
99+
rawBackward(synchronized {
100+
val delta = deltaAccumulator
101+
deltaAccumulator = new Constant(None)
102+
delta
103+
})
104+
}
105+
106+
override protected def closeUpstreams(): Unit = {
107+
for (upstream <- inputParameterMap.values) {
108+
upstream.close()
109+
}
110+
}
111+
}
112+
}
113+
72114
trait InputMetadata {
73115

74116
type Data
@@ -213,45 +255,7 @@ object DifferentiableKernel {
213255
} finally {
214256
event.close()
215257
}
216-
new CumulativeTape {
217-
override def isTrainable: Boolean = ???
218-
219-
private var deltaAccumulator: Future.Stateful[Option[Delta]] = {
220-
new Constant(None)
221-
}
222-
223-
override val value = outputBuffer
224-
225-
override def forceBackward(delta: Delta) = ???
226-
227-
override type Delta = OutputDelta
228-
override type Data = OutputData
229-
230-
private def rawBackward(deltaFuture: Future.Stateful[Option[Delta]]): Unit = {
231-
implicit def catcher = PartialFunction.empty
232-
for (deltaOption <- deltaFuture) {
233-
deltaOption match {
234-
case None =>
235-
case Some(delta) =>
236-
???
237-
}
238-
}
239-
}
240-
241-
override protected def flush(): Unit = {
242-
rawBackward(synchronized {
243-
val delta = deltaAccumulator
244-
deltaAccumulator = new Constant(None)
245-
delta
246-
})
247-
}
248-
249-
override protected def closeUpstreams(): Unit = {
250-
for (upstream <- inputParameterMap.values) {
251-
upstream.close()
252-
}
253-
}
254-
}
258+
new Tapes.Fill(inputParameterMap, outputBuffer)
255259
}
256260
}
257261
}

0 commit comments

Comments
 (0)