33#include < thrust/sort.h>
44#include < thrust/memory.h>
55#include < thrust/pair.h>
6+ #include < thrust/fill.h>
7+ #include < thrust/logical.h>
8+
9+
10+ struct my_system : thrust::device_system<my_system> {};
611
7- struct my_tag : thrust::device_system<my_tag> {};
812
913template <typename T1, typename T2>
1014bool is_same (const T1 &, const T2 &)
1115{
1216 return false ;
1317}
1418
19+
1520template <typename T>
1621bool is_same (const T &, const T &)
1722{
1823 return true ;
1924}
2025
26+
2127void TestSelectSystemDifferentTypes ()
2228{
2329 using thrust::system::detail::generic::select_system;
2430
25- // select_system(my_tag , device_system_tag) should return device_system_tag (the minimum tag)
26- bool is_device_system_tag = is_same (thrust::device_system_tag (), select_system (my_tag (), thrust::device_system_tag ()));
31+ // select_system(my_system , device_system_tag) should return device_system_tag (the minimum tag)
32+ bool is_device_system_tag = is_same (thrust::device_system_tag (), select_system (my_system (), thrust::device_system_tag ()));
2733 ASSERT_EQUAL (true , is_device_system_tag);
2834
2935 // select_system(device_system_tag, my_tag) should return device_system_tag (the minimum tag)
30- is_device_system_tag = is_same (thrust::device_system_tag (), select_system (thrust::device_system_tag (), my_tag ()));
36+ is_device_system_tag = is_same (thrust::device_system_tag (), select_system (thrust::device_system_tag (), my_system ()));
3137 ASSERT_EQUAL (true , is_device_system_tag);
3238}
3339DECLARE_UNITTEST (TestSelectSystemDifferentTypes);
@@ -45,42 +51,91 @@ void TestSelectSystemSameTypes()
4551 bool is_device_system_tag = is_same (thrust::device_system_tag (), select_system (thrust::device_system_tag (), thrust::device_system_tag ()));
4652 ASSERT_EQUAL (true , is_device_system_tag);
4753
48- // select_system(my_tag, my_tag ) should return my_tag
49- bool is_my_tag = is_same (my_tag (), select_system (my_tag (), my_tag ()));
50- ASSERT_EQUAL (true , is_my_tag );
54+ // select_system(my_system, my_system ) should return my_system
55+ bool is_my_system = is_same (my_system (), select_system (my_system (), my_system ()));
56+ ASSERT_EQUAL (true , is_my_system );
5157}
5258DECLARE_UNITTEST (TestSelectSystemSameTypes);
5359
5460
55- // template<typename T>
56- // thrust::pair<thrust::pointer<T,my_tag>, std::ptrdiff_t>
57- // get_temporary_buffer(my_tag, std::ptrdiff_t n)
58- // {
59- // // communicate that my version of get_temporary_buffer
60- // // was correctly dispatched
61- // throw my_tag();
62- // }
63- //
61+ void TestGetTemporaryBuffer ()
62+ {
63+ const size_t n = 9001 ;
64+
65+ thrust::device_system_tag dev_tag;
66+ typedef thrust::pointer<int , thrust::device_system_tag> pointer;
67+ thrust::pair<pointer, std::ptrdiff_t > ptr_and_sz = thrust::get_temporary_buffer<int >(dev_tag, n);
68+
69+ ASSERT_EQUAL (ptr_and_sz.second , n);
70+
71+ const int ref_val = 13 ;
72+ thrust::device_vector<int > ref (n, ref_val);
73+
74+ thrust::fill_n (ptr_and_sz.first , n, ref_val);
75+
76+ ASSERT_EQUAL (true , thrust::all_of (ptr_and_sz.first , ptr_and_sz.first + n, thrust::placeholders::_1 == ref_val));
77+
78+ thrust::return_temporary_buffer (dev_tag, ptr_and_sz.first );
79+ }
80+ DECLARE_UNITTEST (TestGetTemporaryBuffer);
81+
82+
83+ void TestMalloc ()
84+ {
85+ const size_t n = 9001 ;
86+
87+ thrust::device_system_tag dev_tag;
88+ typedef thrust::pointer<int , thrust::device_system_tag> pointer;
89+ pointer ptr = pointer (static_cast <int *>(thrust::malloc (dev_tag, sizeof (int ) * n).get ()));
90+
91+ const int ref_val = 13 ;
92+ thrust::device_vector<int > ref (n, ref_val);
93+
94+ thrust::fill_n (ptr, n, ref_val);
95+
96+ ASSERT_EQUAL (true , thrust::all_of (ptr, ptr + n, thrust::placeholders::_1 == ref_val));
97+
98+ thrust::free (dev_tag, ptr);
99+ }
100+ DECLARE_UNITTEST (TestMalloc);
101+
102+
103+ static bool g_correctly_dispatched;
104+
105+
106+ template <typename T>
107+ thrust::pair<thrust::pointer<T,my_system>, std::ptrdiff_t >
108+ get_temporary_buffer (my_system sys, std::ptrdiff_t n)
109+ {
110+ // communicate that my version of get_temporary_buffer
111+ // was correctly dispatched
112+ g_correctly_dispatched = true ;
113+
114+ thrust::device_system_tag device_sys;
115+ thrust::pair<thrust::pointer<T, thrust::device_system_tag>, std::ptrdiff_t > result = thrust::get_temporary_buffer<T>(device_sys, n);
116+ return thrust::make_pair (thrust::pointer<T,my_system>(result.first .get ()), result.second );
117+ }
118+
119+
64120void TestGetTemporaryBufferDispatchImplicit ()
65121{
66- KNOWN_FAILURE;
67-
68- // bool correctly_dispatched = false;
69- //
70- // try
71- // {
72- // thrust::device_vector<int> vec(2);
73- //
74- // // call something we know will invoke get_temporary_buffer
75- // thrust::sort(thrust::retag<my_tag>(vec.begin()),
76- // thrust::retag<my_tag>(vec.end()));
77- // }
78- // catch(my_tag)
79- // {
80- // correctly_dispatched = true;
81- // }
82- //
83- // ASSERT_EQUAL(true, correctly_dispatched);
122+ if (is_same (thrust::device_system_tag (), thrust::system::cpp::tag ()))
123+ {
124+ // XXX cpp uses the internal scalar backend, which currently elides user tags
125+ KNOWN_FAILURE;
126+ }
127+ else
128+ {
129+ g_correctly_dispatched = false ;
130+
131+ thrust::device_vector<int > vec (9001 );
132+
133+ // call something we know will invoke get_temporary_buffer
134+ my_system sys;
135+ thrust::sort (sys, vec.begin (), vec.end ());
136+
137+ ASSERT_EQUAL (true , g_correctly_dispatched);
138+ }
84139}
85140DECLARE_UNITTEST (TestGetTemporaryBufferDispatchImplicit);
86141
0 commit comments