|
1 | 1 | #pragma once |
2 | 2 |
|
3 | 3 | #include <thrust/iterator/iterator_traits.h> |
| 4 | +#include <thrust/iterator/iterator_categories.h> |
| 5 | +#include <type_traits> |
4 | 6 |
|
5 | 7 |
|
6 | 8 | // Wraps an existing iterator into a forward iterator, |
7 | 9 | // thus removing some of its functionality |
8 | 10 | template <typename Iterator> |
9 | 11 | struct forward_iterator_wrapper { |
10 | 12 | // LegacyIterator requirements |
| 13 | + using iterator_system_tag = typename thrust::iterator_system<Iterator>::type; |
11 | 14 | using reference = typename thrust::iterator_traits<Iterator>::reference; |
12 | 15 | using pointer = typename thrust::iterator_traits<Iterator>::pointer; |
13 | 16 | using value_type = typename thrust::iterator_traits<Iterator>::value_type; |
14 | 17 | using difference_type = typename thrust::iterator_traits<Iterator>::difference_type; |
15 | | - using iterator_category = std::forward_iterator_tag; |
| 18 | + using iterator_category = typename std::conditional< |
| 19 | + std::is_convertible<iterator_system_tag, thrust::device_system_tag>::value, |
| 20 | + thrust::forward_device_iterator_tag, |
| 21 | + typename std::conditional< |
| 22 | + std::is_convertible<iterator_system_tag, thrust::host_system_tag>::value, |
| 23 | + thrust::forward_host_iterator_tag, |
| 24 | + std::forward_iterator_tag>::type>::type; |
16 | 25 | using base_iterator_category = typename thrust::iterator_traits<Iterator>::iterator_category; |
17 | 26 | static_assert( |
18 | | - std::is_convertible<base_iterator_category, std::forward_iterator_tag>::value, |
| 27 | + std::is_convertible<base_iterator_category, std::forward_iterator_tag>::value, |
19 | 28 | "Cannot create forward_iterator_wrapper around an iterator that is not itself at least a forward iterator"); |
20 | 29 |
|
21 | 30 | __host__ __device__ reference operator*() const { |
|
0 commit comments