From 456749186c97ababe664cafb64c9cb8b633acebe Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Tue, 28 Jan 2025 19:35:13 +0100 Subject: [PATCH] Fix transform iterator for non-copy-constructible types (#3542) Fixes: #3541 Co-authored-by: Michael Schellenberger Costa --- thrust/testing/cuda/transform_iterator.cmake | 12 ++++++++++ thrust/testing/cuda/transform_iterator.cu | 19 ++++++++++++++++ thrust/testing/transform_iterator.cu | 1 + thrust/thrust/iterator/transform_iterator.h | 24 +++++++++++++++++++- 4 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 thrust/testing/cuda/transform_iterator.cmake create mode 100644 thrust/testing/cuda/transform_iterator.cu diff --git a/thrust/testing/cuda/transform_iterator.cmake b/thrust/testing/cuda/transform_iterator.cmake new file mode 100644 index 00000000000..e02e34f5870 --- /dev/null +++ b/thrust/testing/cuda/transform_iterator.cmake @@ -0,0 +1,12 @@ +target_compile_options(${test_target} PRIVATE $<$: --extended-lambda>) + +# this check is actually not correct, because we must check the host compiler, not the CXX compiler. +# We rely on that those are usually the same ;) +if ("Clang" STREQUAL "${CMAKE_CXX_COMPILER_ID}" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13) + # When clang >= 13 is used as host compiler, we get the following warning: + # nvcc_internal_extended_lambda_implementation:312:22: error: definition of implicit copy constructor for '__nv_hdl_wrapper_t, int (const int &)>' is deprecated because it has a user-declared copy assignment operator [-Werror,-Wdeprecated-copy] + # 312 | __nv_hdl_wrapper_t & operator=(const __nv_hdl_wrapper_t &in) = delete; + # | ^ + # Let's suppress it until NVBug 4980157 is resolved. + target_compile_options(${test_target} PRIVATE $<$: -Wno-deprecated-copy>) +endif () diff --git a/thrust/testing/cuda/transform_iterator.cu b/thrust/testing/cuda/transform_iterator.cu new file mode 100644 index 00000000000..2973b5d569f --- /dev/null +++ b/thrust/testing/cuda/transform_iterator.cu @@ -0,0 +1,19 @@ +#include +#include +#include +#include + +#include + +// see also: https://github.com/NVIDIA/cccl/issues/3541 +void TestTransformWithLambda() +{ + auto l = [] __host__ __device__(int v) { return v < 4; }; + thrust::host_vector A{1, 2, 3, 4, 5, 6, 7}; + ASSERT_EQUAL(thrust::any_of(A.begin(), A.end(), l), true); + + thrust::device_vector B{1, 2, 3, 4, 5, 6, 7}; + ASSERT_EQUAL(thrust::any_of(B.begin(), B.end(), l), true); +} + +DECLARE_UNITTEST(TestTransformWithLambda); diff --git a/thrust/testing/transform_iterator.cu b/thrust/testing/transform_iterator.cu index 1c8cbb5b44f..b95012286d5 100644 --- a/thrust/testing/transform_iterator.cu +++ b/thrust/testing/transform_iterator.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include #include diff --git a/thrust/thrust/iterator/transform_iterator.h b/thrust/thrust/iterator/transform_iterator.h index 736678bce12..63df5356214 100644 --- a/thrust/thrust/iterator/transform_iterator.h +++ b/thrust/thrust/iterator/transform_iterator.h @@ -47,6 +47,9 @@ #include #include +#include +#include + THRUST_NAMESPACE_BEGIN /*! \addtogroup iterators @@ -238,7 +241,26 @@ class transform_iterator , m_f(other.functor()) {} - transform_iterator& operator=(const transform_iterator&) = default; + _CCCL_HOST_DEVICE transform_iterator& operator=(transform_iterator const& other) + { + super_t::operator=(other); + if constexpr (_CCCL_TRAIT(::cuda::std::is_copy_assignable, AdaptableUnaryFunction)) + { + m_f = other.m_f; + } + else if constexpr (_CCCL_TRAIT(::cuda::std::is_copy_constructible, AdaptableUnaryFunction)) + { + ::cuda::std::__destroy_at(&m_f); + ::cuda::std::__construct_at(&m_f, other.m_f); + } + else + { + static_assert(_CCCL_TRAIT(::cuda::std::is_copy_constructible, AdaptableUnaryFunction), + "Cannot use thrust::transform_iterator with a functor that is neither copy constructible nor " + "copy assignable"); + } + return *this; + } /*! This method returns a copy of this \p transform_iterator's \c AdaptableUnaryFunction. * \return A copy of this \p transform_iterator's \c AdaptableUnaryFunction.