Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tokio/src/loom/mocked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ pub(crate) mod sync {
pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
self.0.try_lock().ok()
}

#[inline]
pub(crate) fn get_mut(&mut self) -> &mut T {
self.0.get_mut().unwrap()
}
}
pub(crate) use loom::sync::*;

Expand Down
8 changes: 8 additions & 0 deletions tokio/src/loom/std/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,12 @@ impl<T> Mutex<T> {
Err(TryLockError::WouldBlock) => None,
}
}

#[inline]
pub(crate) fn get_mut(&mut self) -> &mut T {
match self.0.get_mut() {
Ok(val) => val,
Err(p_err) => p_err.into_inner(),
}
}
}
114 changes: 80 additions & 34 deletions tokio/src/runtime/time/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::util::WakeList;

use crate::loom::sync::atomic::AtomicU64;
use std::fmt;
use std::sync::RwLock;
use std::{num::NonZeroU64, ptr::NonNull};

struct AtomicOptionNonZeroU64(AtomicU64);
Expand Down Expand Up @@ -115,7 +116,10 @@ struct Inner {
next_wake: AtomicOptionNonZeroU64,

/// Sharded Timer wheels.
wheels: Box<[Mutex<wheel::Wheel>]>,
wheels: RwLock<ShardedWheel>,

/// Number of entries in the sharded timer wheels.
wheels_len: u32,

/// True if the driver is being shutdown.
pub(super) is_shutdown: AtomicBool,
Expand All @@ -130,6 +134,9 @@ struct Inner {
did_wake: AtomicBool,
}

/// Wrapper around the sharded timer wheels.
struct ShardedWheel(Box<[Mutex<wheel::Wheel>]>);

// ===== impl Driver =====

impl Driver {
Expand All @@ -149,7 +156,8 @@ impl Driver {
time_source,
inner: Inner {
next_wake: AtomicOptionNonZeroU64::new(None),
wheels: wheels.into_boxed_slice(),
wheels: RwLock::new(ShardedWheel(wheels.into_boxed_slice())),
wheels_len: shards,
is_shutdown: AtomicBool::new(false),
#[cfg(feature = "test-util")]
did_wake: AtomicBool::new(false),
Expand Down Expand Up @@ -190,23 +198,28 @@ impl Driver {
assert!(!handle.is_shutdown());

// Finds out the min expiration time to park.
let locks = (0..rt_handle.time().inner.get_shard_size())
.map(|id| rt_handle.time().inner.lock_sharded_wheel(id))
.collect::<Vec<_>>();

let expiration_time = locks
.iter()
.filter_map(|lock| lock.next_expiration_time())
.min();

rt_handle
.time()
.inner
.next_wake
.store(next_wake_time(expiration_time));

// Safety: After updating the `next_wake`, we drop all the locks.
drop(locks);
let expiration_time = {
let mut wheels_lock = rt_handle
.time()
.inner
.wheels
.write()
.expect("Timer wheel shards poisoned");
let expiration_time = (0..rt_handle.time().inner.get_shard_size())
.filter_map(|id| {
let wheel = wheels_lock.get_sharded_wheel_mut(id);
wheel.next_expiration_time()
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could something like this make more sense?

wheels_lock.0.iter_mut()
    .filter_map(|wheel| wheel.get_mut().next_expiration_time())
    .min();

This way, we don't need to touch indexes at all.

Copy link
Contributor Author

@tglane tglane Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this. At least when calculating the min expiration time we indeed don't need the indexes anymore and we can get rid of one function.
Implemented it in 1610c12

.min();

rt_handle
.time()
.inner
.next_wake
.store(next_wake_time(expiration_time));

expiration_time
};

match expiration_time {
Some(when) => {
Expand Down Expand Up @@ -312,7 +325,12 @@ impl Handle {
// Returns the next wakeup time of this shard.
pub(self) fn process_at_sharded_time(&self, id: u32, mut now: u64) -> Option<u64> {
let mut waker_list = WakeList::new();
let mut lock = self.inner.lock_sharded_wheel(id);
let mut wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisoned");
let mut lock = wheels_lock.lock_sharded_wheel(id);

if now < lock.elapsed() {
// Time went backwards! This normally shouldn't happen as the Rust language
Expand All @@ -334,10 +352,16 @@ impl Handle {
if !waker_list.can_push() {
// Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped.
drop(lock);
drop(wheels_lock);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to drop wheels_lock after the loop, since there's also a call to wake_all there.

Copy link
Contributor Author

@tglane tglane Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Missed that... Implemented in 1610c12.


waker_list.wake_all();

lock = self.inner.lock_sharded_wheel(id);
wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisoned");
lock = wheels_lock.lock_sharded_wheel(id);
}
}
}
Expand All @@ -360,7 +384,12 @@ impl Handle {
/// `add_entry` must not be called concurrently.
pub(self) unsafe fn clear_entry(&self, entry: NonNull<TimerShared>) {
unsafe {
let mut lock = self.inner.lock_sharded_wheel(entry.as_ref().shard_id());
let wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisoned");
let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id());

if entry.as_ref().might_be_registered() {
lock.remove(entry);
Expand All @@ -383,7 +412,13 @@ impl Handle {
entry: NonNull<TimerShared>,
) {
let waker = unsafe {
let mut lock = self.inner.lock_sharded_wheel(entry.as_ref().shard_id());
let wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisoned");

let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id());

// We may have raced with a firing/deregistration, so check before
// deregistering.
Expand Down Expand Up @@ -443,24 +478,14 @@ impl Handle {
// ===== impl Inner =====

impl Inner {
/// Locks the driver's sharded wheel structure.
pub(super) fn lock_sharded_wheel(
&self,
shard_id: u32,
) -> crate::loom::sync::MutexGuard<'_, Wheel> {
let index = shard_id % (self.wheels.len() as u32);
// Safety: This modulo operation ensures that the index is not out of bounds.
unsafe { self.wheels.get_unchecked(index as usize).lock() }
}

// Check whether the driver has been shutdown
pub(super) fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::SeqCst)
}

// Gets the number of shards.
fn get_shard_size(&self) -> u32 {
self.wheels.len() as u32
self.wheels_len
}
}

Expand All @@ -470,5 +495,26 @@ impl fmt::Debug for Inner {
}
}

// ===== impl ShardedWheel =====

impl ShardedWheel {
/// Locks the driver's sharded wheel structure.
pub(super) fn lock_sharded_wheel(
&self,
shard_id: u32,
) -> crate::loom::sync::MutexGuard<'_, Wheel> {
let index = shard_id % (self.0.len() as u32);
// Safety: This modulo operation ensures that the index is not out of bounds.
unsafe { self.0.get_unchecked(index as usize) }.lock()
}

/// Gets a mutable reference to the sharded wheel with the given id.
pub(super) fn get_sharded_wheel_mut(&mut self, shard_id: u32) -> &mut wheel::Wheel {
let index = shard_id % (self.0.len() as u32);
// Safety: This modulo operation ensures that the index is not out of bounds.
unsafe { self.0.get_unchecked_mut(index as usize) }.get_mut()
}
}

#[cfg(test)]
mod tests;