Skip to content

Commit def6d07

Browse files
authored
Make Timer safe on multiple clear (#30)
- Change all methods to receive FiniteDuration
1 parent 386823a commit def6d07

File tree

3 files changed

+49
-20
lines changed

3 files changed

+49
-20
lines changed

core/src/main/scala/scala/scalanative/loop/Timer.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import LibUV._, LibUVConstants._
77
import scala.scalanative.unsafe.Ptr
88
import internals.HandleUtils
99

10-
@inline class Timer private (private val ptr: Ptr[Byte]) extends AnyVal {
10+
@inline final class Timer private (private val ptr: Ptr[Byte]) extends AnyVal {
1111
def clear(): Unit = {
1212
uv_timer_stop(ptr)
1313
HandleUtils.close(ptr)
@@ -37,21 +37,30 @@ object Timer {
3737
val timerHandle = stdlib.malloc(uv_handle_size(UV_TIMER_T))
3838
uv_timer_init(EventLoop.loop, timerHandle)
3939
HandleUtils.setData(timerHandle, callback)
40+
val timer = new Timer(timerHandle)
41+
val withClearIfTimeout: () => Unit =
42+
if (repeat == 0L) { () =>
43+
{
44+
callback()
45+
timer.clear()
46+
}
47+
} else callback
4048
uv_timer_start(timerHandle, timeoutCB, timeout, repeat)
41-
new Timer(timerHandle)
49+
timer
4250
}
4351

4452
def delay(duration: FiniteDuration): Future[Unit] = {
4553
val promise = Promise[Unit]()
46-
timeout(duration.toMillis)(() => promise.success(()))
54+
timeout(duration)(() => promise.success(()))
4755
promise.future
4856
}
4957

50-
def timeout(millis: Long)(callback: () => Unit): Timer = {
51-
startTimer(millis, 0L, callback)
58+
def timeout(duration: FiniteDuration)(callback: () => Unit): Timer = {
59+
startTimer(duration.toMillis, 0L, callback)
5260
}
5361

54-
def repeat(millis: Long)(callback: () => Unit): Timer = {
62+
def repeat(duration: FiniteDuration)(callback: () => Unit): Timer = {
63+
val millis = duration.toMillis
5564
startTimer(millis, millis, callback)
5665
}
5766
}

core/src/main/scala/scala/scalanative/loop/internals/HandleUtils.scala

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,38 @@ private[loop] object HandleUtils {
1414
@inline def getData[T <: Object](handle: Ptr[Byte]): T = {
1515
// data is the first member of uv_loop_t
1616
val ptrOfPtr = handle.asInstanceOf[Ptr[Ptr[Byte]]]
17-
val rawptr = toRawPtr(!ptrOfPtr)
18-
castRawPtrToObject(rawptr).asInstanceOf[T]
17+
val dataPtr = !ptrOfPtr
18+
if (dataPtr == null) null.asInstanceOf[T]
19+
else {
20+
val rawptr = toRawPtr(dataPtr)
21+
castRawPtrToObject(rawptr).asInstanceOf[T]
22+
}
1923
}
2024
@inline def setData(handle: Ptr[Byte], obj: Object): Unit = {
21-
if (references.contains(obj)) references(obj) += 1
22-
else references(obj) = 1
23-
2425
// data is the first member of uv_loop_t
2526
val ptrOfPtr = handle.asInstanceOf[Ptr[Ptr[Byte]]]
26-
val rawptr = castObjectToRawPtr(obj)
27-
!ptrOfPtr = fromRawPtr[Byte](rawptr)
27+
if(obj != null) {
28+
if (references.contains(obj)) references(obj) += 1
29+
else references(obj) = 1
30+
val rawptr = castObjectToRawPtr(obj)
31+
!ptrOfPtr = fromRawPtr[Byte](rawptr)
32+
} else {
33+
!ptrOfPtr = null
34+
}
2835
}
2936
private val onCloseCB = new CloseCB {
3037
def apply(handle: UVHandle): Unit = {
3138
stdlib.free(handle)
3239
}
3340
}
3441
@inline def close(handle: Ptr[Byte]): Unit = {
35-
uv_close(handle, onCloseCB)
36-
val data = getData[Object](handle)
37-
val current = references(data)
38-
if (current > 1) references(data) -= 1
39-
else references.remove(data)
42+
if(getData(handle) != null) {
43+
uv_close(handle, onCloseCB)
44+
val data = getData[Object](handle)
45+
val current = references(data)
46+
if (current > 1) references(data) -= 1
47+
else references.remove(data)
48+
setData(handle, null)
49+
}
4050
}
4151
}

core/src/test/scala/scala/scalanative/loop/TimerTests.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ object TimerTests extends LoopTestSuite {
3030
val times = 3
3131
val p = Promise[Unit]()
3232
var timer: Timer = null.asInstanceOf[Timer]
33-
timer = Timer.repeat(d.toMillis) { () =>
33+
timer = Timer.repeat(d) { () =>
3434
if (i == times) {
3535
p.success(())
3636
timer.clear()
@@ -43,13 +43,23 @@ object TimerTests extends LoopTestSuite {
4343
}
4444
}
4545
test("clear timeout") {
46-
val handle = Timer.timeout(d.toMillis) { () =>
46+
val handle = Timer.timeout(d) { () =>
4747
throw new Exception("This timeout should have not triggered")
4848
}
4949
handle.clear()
5050
for {
5151
() <- Timer.delay(d * 2)
5252
} yield ()
5353
}
54+
test("close multiple times") {
55+
val timer = Timer.timeout(10.millis)(() => {})
56+
timer.clear()
57+
timer.clear()
58+
global.execute(new Runnable { def run(): Unit = timer.clear() })
59+
Timer.timeout(50.millis) { () =>
60+
timer.clear()
61+
timer.clear()
62+
}
63+
}
5464
}
5565
}

0 commit comments

Comments
 (0)