diff --git a/dpctl_ext/tensor/CMakeLists.txt b/dpctl_ext/tensor/CMakeLists.txt index 056b7c42554..cf55035c23d 100644 --- a/dpctl_ext/tensor/CMakeLists.txt +++ b/dpctl_ext/tensor/CMakeLists.txt @@ -69,6 +69,19 @@ set(_accumulator_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp ) +set(_reduction_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/all.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/any.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/max.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/min.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/prod.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp +) set(_sorting_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp @@ -82,6 +95,10 @@ set(_tensor_accumulation_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp ${_accumulator_sources} ) +set(_tensor_reductions_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_reductions.cpp + ${_reduction_sources} +) set(_tensor_sorting_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp ${_sorting_sources} @@ -114,6 +131,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_accumulation_i target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_reductions_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_reductions_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_reductions_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + set(python_module_name _tensor_sorting_impl) pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources}) add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources}) @@ -135,7 +158,7 @@ set(_no_fast_math_sources list( APPEND _no_fast_math_sources # ${_elementwise_sources} - # ${_reduction_sources} + ${_reduction_sources} ${_sorting_sources} # ${_linalg_sources} ${_accumulator_sources} diff --git a/dpctl_ext/tensor/__init__.py b/dpctl_ext/tensor/__init__.py index cba7c417d55..ac24151bedf 100644 --- a/dpctl_ext/tensor/__init__.py +++ b/dpctl_ext/tensor/__init__.py @@ -78,6 +78,17 @@ tile, unstack, ) +from ._reduction import ( + argmax, + argmin, + count_nonzero, + logsumexp, + max, + min, + prod, + reduce_hypot, + sum, +) from ._reshape import reshape from ._search_functions import where from ._searchsorted import searchsorted @@ -90,9 +101,14 @@ ) from ._sorting import argsort, sort, top_k from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type +from ._utility_functions import all, any, diff __all__ = [ + "all", + "any", "arange", + "argmax", + "argmin", "argsort", "asarray", "asnumpy", @@ -102,10 +118,12 @@ "can_cast", "concat", "copy", + "count_nonzero", "clip", "cumulative_logsumexp", "cumulative_prod", "cumulative_sum", + "diff", "empty", "empty_like", "extract", @@ -120,15 +138,20 @@ "isdtype", "isin", "linspace", + "logsumexp", + "max", "meshgrid", + "min", "moveaxis", "permute_dims", "nonzero", "ones", "ones_like", "place", + "prod", "put", "put_along_axis", + "reduce_hypot", "repeat", "reshape", "result_type", @@ -137,6 +160,7 @@ "sort", "squeeze", "stack", + "sum", "swapaxes", "take", "take_along_axis", diff --git a/dpctl_ext/tensor/_manipulation_functions.py b/dpctl_ext/tensor/_manipulation_functions.py index 08459dcaea7..e2d55c533bc 100644 --- a/dpctl_ext/tensor/_manipulation_functions.py +++ b/dpctl_ext/tensor/_manipulation_functions.py @@ -624,7 +624,7 @@ def repeat(x, repeats, /, *, axis=None): "'repeats' array must be broadcastable to the size of " "the repeated axis" ) - if not dpt.all(repeats >= 0): + if not dpt_ext.all(repeats >= 0): raise ValueError("'repeats' elements must be positive") elif isinstance(repeats, (tuple, list, range)): @@ -646,7 +646,7 @@ def repeat(x, repeats, /, *, axis=None): repeats = dpt_ext.asarray( repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q ) - if not dpt.all(repeats >= 0): + if not dpt_ext.all(repeats >= 0): raise ValueError("`repeats` elements must be positive") else: raise TypeError( diff --git a/dpctl_ext/tensor/_reduction.py b/dpctl_ext/tensor/_reduction.py new file mode 100644 index 00000000000..b8fdcf4f37e --- /dev/null +++ b/dpctl_ext/tensor/_reduction.py @@ -0,0 +1,834 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import dpctl +import dpctl.tensor as dpt +from dpctl.utils import ExecutionPlacementError, SequentialOrderManager + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti +import dpctl_ext.tensor._tensor_reductions_impl as tri + +from ._numpy_helper import normalize_axis_tuple +from ._type_utils import ( + _default_accumulation_dtype, + _default_accumulation_dtype_fp_types, + _to_device_supported_dtype, +) + + +def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim + if axis is None: + axis = tuple(range(nd)) + perm = list(axis) + x_tmp = x + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt_ext.permute_dims(x, perm) + red_nd = len(axis) + if any([x_tmp.shape[i] == 0 for i in range(-red_nd, 0)]): + raise ValueError("reduction cannot be performed over zero-size axes") + res_shape = x_tmp.shape[: nd - red_nd] + exec_q = x.sycl_queue + res_dt = x.dtype + res_usm_type = x.usm_type + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape + else: + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) + ) + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt_ext.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out): + out = dpt_ext.empty_like(out) + else: + out = dpt_ext.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q + ) + + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if red_nd == 0: + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=x_tmp, dst=out, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[cpy_e] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + out = orig_out + return out + + hev, red_ev = _reduction_fn( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=out, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(hev, red_ev) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[red_ev] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + out = orig_out + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + out = dpt_ext.permute_dims(dpt_ext.reshape(out, res_shape), inv_perm) + return out + + +def _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + _reduction_fn, + _dtype_supported, + _default_reduction_type_fn, +): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + nd = x.ndim + if axis is None: + axis = tuple(range(nd)) + perm = list(axis) + arr = x + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + arr = dpt_ext.permute_dims(x, perm) + red_nd = len(axis) + res_shape = arr.shape[: nd - red_nd] + q = x.sycl_queue + inp_dt = x.dtype + if dtype is None: + res_dt = _default_reduction_type_fn(inp_dt, q) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, q.sycl_device) + + res_usm_type = x.usm_type + + implemented_types = _dtype_supported(inp_dt, res_dt, res_usm_type, q) + if dtype is None and not implemented_types: + raise RuntimeError( + "Automatically determined reduction data type does not " + "have direct implementation" + ) + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape + else: + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) + ) + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt_ext.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out) and implemented_types: + out = dpt_ext.empty_like(out) + else: + out = dpt_ext.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + + _manager = SequentialOrderManager[q] + dep_evs = _manager.submitted_events + if red_nd == 0: + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=out, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=q, depends=[cpy_e] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + out = orig_out + return out + + if implemented_types: + ht_e, red_e = _reduction_fn( + src=arr, + trailing_dims_to_reduce=red_nd, + dst=out, + sycl_queue=q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_e, red_e) + if not (orig_out is None or orig_out is out): + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=q, depends=[red_e] + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + out = orig_out + else: + if _dtype_supported(res_dt, res_dt, res_usm_type, q): + tmp = dpt_ext.empty( + arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + ht_e_red, red_ev = _reduction_fn( + src=tmp, + trailing_dims_to_reduce=red_nd, + dst=out, + sycl_queue=q, + depends=[cpy_e], + ) + _manager.add_event_pair(ht_e_red, red_ev) + else: + buf_dt = _default_reduction_type_fn(inp_dt, q) + tmp = dpt_ext.empty( + arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + tmp_res = dpt_ext.empty( + res_shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_red, r_e = _reduction_fn( + src=tmp, + trailing_dims_to_reduce=red_nd, + dst=tmp_res, + sycl_queue=q, + depends=[cpy_e], + ) + _manager.add_event_pair(ht_e_red, r_e) + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=tmp_res, dst=out, sycl_queue=q, depends=[r_e] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + out = dpt_ext.permute_dims(dpt_ext.reshape(out, res_shape), inv_perm) + return out + + +def _search_over_axis(x, axis, keepdims, out, _reduction_fn): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim + if axis is None: + axis = tuple(range(nd)) + perm = list(axis) + x_tmp = x + else: + if isinstance(axis, int): + axis = (axis,) + else: + raise TypeError( + f"'axis' argument expected to have type 'int' " + r"or be `None`, " + f"got type {type(axis)}" + ) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt_ext.permute_dims(x, perm) + axis = normalize_axis_tuple(axis, nd, "axis") + red_nd = len(axis) + if any([x_tmp.shape[i] == 0 for i in range(-red_nd, 0)]): + raise ValueError("reduction cannot be performed over zero-size axes") + res_shape = x_tmp.shape[: nd - red_nd] + exec_q = x.sycl_queue + res_dt = ti.default_device_index_type(exec_q.sycl_device) + res_usm_type = x.usm_type + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape + else: + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) + ) + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt_ext.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out) and red_nd > 0: + out = dpt_ext.empty_like(out) + else: + out = dpt_ext.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q + ) + + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if red_nd == 0: + ht_e_fill, fill_ev = ti._full_usm_ndarray( + fill_value=0, dst=out, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_fill, fill_ev) + return out + + hev, red_ev = _reduction_fn( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=out, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(hev, red_ev) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[red_ev] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + out = orig_out + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + out = dpt_ext.permute_dims(dpt_ext.reshape(out, res_shape), inv_perm) + return out + + +def argmax(x, /, *, axis=None, keepdims=False, out=None): + """ + Returns the indices of the maximum values of the input array ``x`` along a + specified axis. + + When the maximum value occurs multiple times, the indices corresponding to + the first occurrence are returned. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to search. If ``None``, returns the index of the + maximum value of the flattened array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the indices of the first occurrence of the + maximum values. If the entire array was searched, a + zero-dimensional array is returned. The returned array has the + default array index data type for the device of ``x``. + """ + return _search_over_axis(x, axis, keepdims, out, tri._argmax_over_axis) + + +def argmin(x, /, *, axis=None, keepdims=False, out=None): + """ + Returns the indices of the minimum values of the input array ``x`` along a + specified axis. + + When the minimum value occurs multiple times, the indices corresponding to + the first occurrence are returned. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to search. If ``None``, returns the index of the + minimum value of the flattened array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the indices of the first occurrence of the + minimum values. If the entire array was searched, a + zero-dimensional array is returned. The returned array has the + default array index data type for the device of ``x``. + """ + return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis) + + +def count_nonzero(x, /, *, axis=None, keepdims=False, out=None): + """ + Counts the number of elements in the input array ``x`` which are non-zero. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which to count. If a tuple of unique integers, + the number of non-zero values are computed over multiple axes. + If ``None``, the number of non-zero values is computed over the + entire array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and data + type. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the count of non-zero values. If the sum was + computed over the entire array, a zero-dimensional array is + returned. The returned array will have the default array index data + type. + """ + if x.dtype != dpt.bool: + x = dpt.astype(x, dpt.bool, copy=False) + return sum( + x, + axis=axis, + dtype=ti.default_device_index_type(x.sycl_device), + keepdims=keepdims, + out=out, + ) + + +def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + """ + Calculates the logarithm of the sum of exponentials of elements in the + input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which values must be computed. If a tuple + of unique integers, values are computed over multiple axes. + If ``None``, the result is computed over the entire array. + Default: ``None``. + dtype (Optional[dtype]): + data type of the returned array. If ``None``, the default data + type is inferred from the "kind" of the input array data type. + + * If ``x`` has a real-valued floating-point data type, the + returned array will have the same data type as ``x``. + * If ``x`` has a boolean or integral data type, the returned array + will have the default floating point data type for the device + where input array ``x`` is allocated. + * If ``x`` has a complex-valued floating-point data type, + an error is raised. + + If the data type (either specified or resolved) differs from the + data type of ``x``, the input array elements are cast to the + specified data type before computing the result. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the results. If the result was computed over + the entire array, a zero-dimensional array is returned. + The returned array has the data type as described in the + ``dtype`` parameter description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + tri._logsumexp_over_axis, + lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported( + inp_dt, res_dt + ), + _default_accumulation_dtype_fp_types, + ) + + +def max(x, /, *, axis=None, keepdims=False, out=None): + """ + Calculates the maximum value of the input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which maxima must be computed. If a tuple + of unique integers, the maxima are computed over multiple axes. + If ``None``, the max is computed over the entire array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the maxima. If the max was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the same data type as ``x``. + """ + return _comparison_over_axis(x, axis, keepdims, out, tri._max_over_axis) + + +def min(x, /, *, axis=None, keepdims=False, out=None): + """ + Calculates the minimum value of the input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which minima must be computed. If a tuple + of unique integers, the minima are computed over multiple axes. + If ``None``, the min is computed over the entire array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the minima. If the min was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the same data type as ``x``. + """ + return _comparison_over_axis(x, axis, keepdims, out, tri._min_over_axis) + + +def prod(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + """ + Calculates the product of elements in the input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which products must be computed. If a tuple + of unique integers, products are computed over multiple axes. + If ``None``, the product is computed over the entire array. + Default: ``None``. + dtype (Optional[dtype]): + data type of the returned array. If ``None``, the default data + type is inferred from the "kind" of the input array data type. + + * If ``x`` has a real- or complex-valued floating-point data + type, the returned array will have the same data type as + ``x``. + * If ``x`` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array ``x`` is allocated. + * If ``x`` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array ``x`` is allocated. + * If ``x`` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array ``x`` is allocated. + + If the data type (either specified or resolved) differs from the + data type of ``x``, the input array elements are cast to the + specified data type before computing the product. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the products. If the product was computed over + the entire array, a zero-dimensional array is returned. The + returned array has the data type as described in the ``dtype`` + parameter description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + tri._prod_over_axis, + tri._prod_over_axis_dtype_supported, + _default_accumulation_dtype, + ) + + +def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + """ + Calculates the square root of the sum of squares of elements in the input + array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which values must be computed. If a tuple + of unique integers, values are computed over multiple axes. + If ``None``, the result is computed over the entire array. + Default: ``None``. + dtype (Optional[dtype]): + data type of the returned array. If ``None``, the default data + type is inferred from the "kind" of the input array data type. + + * If ``x`` has a real-valued floating-point data type, the + returned array will have the same data type as ``x``. + * If ``x`` has a boolean or integral data type, the returned array + will have the default floating point data type for the device + where input array ``x`` is allocated. + * If ``x`` has a complex-valued floating-point data type, + an error is raised. + + If the data type (either specified or resolved) differs from the + data type of ``x``, the input array elements are cast to the + specified data type before computing the result. Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the results. If the result was computed over + the entire array, a zero-dimensional array is returned. The + returned array has the data type as described in the ``dtype`` + parameter description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + tri._hypot_over_axis, + lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported( + inp_dt, res_dt + ), + _default_accumulation_dtype_fp_types, + ) + + +def sum(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + """ + Calculates the sum of elements in the input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which sums must be computed. If a tuple + of unique integers, sums are computed over multiple axes. + If ``None``, the sum is computed over the entire array. + Default: ``None``. + dtype (Optional[dtype]): + data type of the returned array. If ``None``, the default data + type is inferred from the "kind" of the input array data type. + + * If ``x`` has a real- or complex-valued floating-point data + type, the returned array will have the same data type as + ``x``. + * If ``x`` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array ``x`` is allocated. + * If ``x`` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array ``x`` is allocated. + array ``x`` is allocated. + * If ``x`` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array ``x`` is allocated. + + If the data type (either specified or resolved) differs from the + data type of ``x``, the input array elements are cast to the + specified data type before computing the sum. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the sums. If the sum was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the data type as described in the ``dtype`` parameter + description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + tri._sum_over_axis, + tri._sum_over_axis_dtype_supported, + _default_accumulation_dtype, + ) diff --git a/dpctl_ext/tensor/_utility_functions.py b/dpctl_ext/tensor/_utility_functions.py new file mode 100644 index 00000000000..a122ac3d6ce --- /dev/null +++ b/dpctl_ext/tensor/_utility_functions.py @@ -0,0 +1,509 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import builtins +import operator + +import dpctl.tensor as dpt +import dpctl.utils as du + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti +import dpctl_ext.tensor._tensor_reductions_impl as tri + +from ._numpy_helper import normalize_axis_index, normalize_axis_tuple +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) +from ._type_utils import ( + _resolve_one_strong_one_weak_types, + _resolve_one_strong_two_weak_types, +) + + +def _boolean_reduction(x, axis, keepdims, func): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim + if axis is None: + red_nd = nd + # case of a scalar + if red_nd == 0: + return dpt_ext.astype(x, dpt.bool) + x_tmp = x + res_shape = () + perm = list(range(nd)) + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + + red_nd = len(axis) + # check for axis=() + if red_nd == 0: + return dpt_ext.astype(x, dpt.bool) + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt_ext.permute_dims(x, perm) + res_shape = x_tmp.shape[: nd - red_nd] + + exec_q = x.sycl_queue + res_usm_type = x.usm_type + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + # always allocate the temporary as + # int32 and usm-device to ensure that atomic updates + # are supported + res_tmp = dpt_ext.empty( + res_shape, + dtype=dpt.int32, + usm_type="device", + sycl_queue=exec_q, + ) + hev0, ev0 = func( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=res_tmp, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(hev0, ev0) + + # copy to boolean result array + res = dpt_ext.empty( + res_shape, + dtype=dpt.bool, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + hev1, ev1 = ti._copy_usm_ndarray_into_usm_ndarray( + src=res_tmp, dst=res, sycl_queue=exec_q, depends=[ev0] + ) + _manager.add_event_pair(hev1, ev1) + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt_ext.permute_dims(dpt_ext.reshape(res, res_shape), inv_perm) + return res + + +def all(x, /, *, axis=None, keepdims=False): + """ + all(x, axis=None, keepdims=False) + + Tests whether all input array elements evaluate to True along a given axis. + + Args: + x (usm_ndarray): Input array. + axis (Optional[Union[int, Tuple[int,...]]]): Axis (or axes) + along which to perform a logical AND reduction. + When `axis` is `None`, a logical AND reduction + is performed over all dimensions of `x`. + If `axis` is negative, the axis is counted from + the last dimension to the first. + Default: `None`. + keepdims (bool, optional): If `True`, the reduced axes are included + in the result as singleton dimensions, and the result is + broadcastable to the input array shape. + If `False`, the reduced axes are not included in the result. + Default: `False`. + + Returns: + usm_ndarray: + An array with a data type of `bool` + containing the results of the logical AND reduction. + """ + return _boolean_reduction(x, axis, keepdims, tri._all) + + +def any(x, /, *, axis=None, keepdims=False): + """ + any(x, axis=None, keepdims=False) + + Tests whether any input array elements evaluate to True along a given axis. + + Args: + x (usm_ndarray): Input array. + axis (Optional[Union[int, Tuple[int,...]]]): Axis (or axes) + along which to perform a logical OR reduction. + When `axis` is `None`, a logical OR reduction + is performed over all dimensions of `x`. + If `axis` is negative, the axis is counted from + the last dimension to the first. + Default: `None`. + keepdims (bool, optional): If `True`, the reduced axes are included + in the result as singleton dimensions, and the result is + broadcastable to the input array shape. + If `False`, the reduced axes are not included in the result. + Default: `False`. + + Returns: + usm_ndarray: + An array with a data type of `bool` + containing the results of the logical OR reduction. + """ + return _boolean_reduction(x, axis, keepdims, tri._any) + + +def _validate_diff_shape(sh1, sh2, axis): + """ + Utility for validating that two shapes `sh1` and `sh2` + are possible to concatenate along `axis`. + """ + if not sh2: + # scalars will always be accepted + return True + else: + sh1_ndim = len(sh1) + if sh1_ndim == len(sh2) and builtins.all( + sh1[i] == sh2[i] for i in range(sh1_ndim) if i != axis + ): + return True + else: + return False + + +def _concat_diff_input(arr, axis, prepend, append): + """ + Concatenates `arr`, `prepend` and, `append` along `axis`, + where `arr` is an array and `prepend` and `append` are + any mixture of arrays and scalars. + """ + if prepend is not None and append is not None: + q1, x_usm_type = arr.sycl_queue, arr.usm_type + q2, prepend_usm_type = _get_queue_usm_type(prepend) + q3, append_usm_type = _get_queue_usm_type(append) + if q2 is None and q3 is None: + exec_q = q1 + coerced_usm_type = x_usm_type + elif q3 is None: + exec_q = du.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + prepend_usm_type, + ) + ) + elif q2 is None: + exec_q = du.get_execution_queue((q1, q3)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + append_usm_type, + ) + ) + else: + exec_q = du.get_execution_queue((q1, q2, q3)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + prepend_usm_type, + append_usm_type, + ) + ) + du.validate_usm_type(coerced_usm_type, allow_none=False) + arr_shape = arr.shape + prepend_shape = _get_shape(prepend) + append_shape = _get_shape(append) + if not builtins.all( + isinstance(s, (tuple, list)) + for s in ( + prepend_shape, + append_shape, + ) + ): + raise TypeError( + "Shape of arguments can not be inferred. " + "Arguments are expected to be " + "lists, tuples, or both" + ) + valid_prepend_shape = _validate_diff_shape( + arr_shape, prepend_shape, axis + ) + if not valid_prepend_shape: + raise ValueError( + f"`diff` argument `prepend` with shape {prepend_shape} is " + f"invalid for first input with shape {arr_shape}" + ) + valid_append_shape = _validate_diff_shape(arr_shape, append_shape, axis) + if not valid_append_shape: + raise ValueError( + f"`diff` argument `append` with shape {append_shape} is invalid" + f" for first input with shape {arr_shape}" + ) + sycl_dev = exec_q.sycl_device + arr_dtype = arr.dtype + prepend_dtype = _get_dtype(prepend, sycl_dev) + append_dtype = _get_dtype(append, sycl_dev) + if not builtins.all( + _validate_dtype(o) for o in (prepend_dtype, append_dtype) + ): + raise ValueError("Operands have unsupported data types") + prepend_dtype, append_dtype = _resolve_one_strong_two_weak_types( + arr_dtype, prepend_dtype, append_dtype, sycl_dev + ) + if isinstance(prepend, dpt.usm_ndarray): + a_prepend = prepend + else: + a_prepend = dpt_ext.asarray( + prepend, + dtype=prepend_dtype, + usm_type=coerced_usm_type, + sycl_queue=exec_q, + ) + if isinstance(append, dpt.usm_ndarray): + a_append = append + else: + a_append = dpt_ext.asarray( + append, + dtype=append_dtype, + usm_type=coerced_usm_type, + sycl_queue=exec_q, + ) + if not prepend_shape: + prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :] + a_prepend = dpt_ext.broadcast_to(a_prepend, prepend_shape) + if not append_shape: + append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :] + a_append = dpt_ext.broadcast_to(a_append, append_shape) + return dpt_ext.concat((a_prepend, arr, a_append), axis=axis) + elif prepend is not None: + q1, x_usm_type = arr.sycl_queue, arr.usm_type + q2, prepend_usm_type = _get_queue_usm_type(prepend) + if q2 is None: + exec_q = q1 + coerced_usm_type = x_usm_type + else: + exec_q = du.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + prepend_usm_type, + ) + ) + du.validate_usm_type(coerced_usm_type, allow_none=False) + arr_shape = arr.shape + prepend_shape = _get_shape(prepend) + if not isinstance(prepend_shape, (tuple, list)): + raise TypeError( + "Shape of argument can not be inferred. " + "Argument is expected to be a " + "list or tuple" + ) + valid_prepend_shape = _validate_diff_shape( + arr_shape, prepend_shape, axis + ) + if not valid_prepend_shape: + raise ValueError( + f"`diff` argument `prepend` with shape {prepend_shape} is " + f"invalid for first input with shape {arr_shape}" + ) + sycl_dev = exec_q.sycl_device + arr_dtype = arr.dtype + prepend_dtype = _get_dtype(prepend, sycl_dev) + if not _validate_dtype(prepend_dtype): + raise ValueError("Operand has unsupported data type") + prepend_dtype = _resolve_one_strong_one_weak_types( + arr_dtype, prepend_dtype, sycl_dev + ) + if isinstance(prepend, dpt.usm_ndarray): + a_prepend = prepend + else: + a_prepend = dpt_ext.asarray( + prepend, + dtype=prepend_dtype, + usm_type=coerced_usm_type, + sycl_queue=exec_q, + ) + if not prepend_shape: + prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :] + a_prepend = dpt_ext.broadcast_to(a_prepend, prepend_shape) + return dpt_ext.concat((a_prepend, arr), axis=axis) + elif append is not None: + q1, x_usm_type = arr.sycl_queue, arr.usm_type + q2, append_usm_type = _get_queue_usm_type(append) + if q2 is None: + exec_q = q1 + coerced_usm_type = x_usm_type + else: + exec_q = du.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + append_usm_type, + ) + ) + du.validate_usm_type(coerced_usm_type, allow_none=False) + arr_shape = arr.shape + append_shape = _get_shape(append) + if not isinstance(append_shape, (tuple, list)): + raise TypeError( + "Shape of argument can not be inferred. " + "Argument is expected to be a " + "list or tuple" + ) + valid_append_shape = _validate_diff_shape(arr_shape, append_shape, axis) + if not valid_append_shape: + raise ValueError( + f"`diff` argument `append` with shape {append_shape} is invalid" + f" for first input with shape {arr_shape}" + ) + sycl_dev = exec_q.sycl_device + arr_dtype = arr.dtype + append_dtype = _get_dtype(append, sycl_dev) + if not _validate_dtype(append_dtype): + raise ValueError("Operand has unsupported data type") + append_dtype = _resolve_one_strong_one_weak_types( + arr_dtype, append_dtype, sycl_dev + ) + if isinstance(append, dpt.usm_ndarray): + a_append = append + else: + a_append = dpt_ext.asarray( + append, + dtype=append_dtype, + usm_type=coerced_usm_type, + sycl_queue=exec_q, + ) + if not append_shape: + append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :] + a_append = dpt_ext.broadcast_to(a_append, append_shape) + return dpt_ext.concat((arr, a_append), axis=axis) + else: + arr1 = arr + return arr1 + + +def diff(x, /, *, axis=-1, n=1, prepend=None, append=None): + """ + Calculates the `n`-th discrete forward difference of `x` along `axis`. + + Args: + x (usm_ndarray): + input array. + axis (int): + axis along which to compute the difference. A valid axis must be on + the interval `[-N, N)`, where `N` is the rank (number of + dimensions) of `x`. + Default: `-1` + n (int): + number of times to recursively compute the difference. + Default: `1`. + prepend (Union[usm_ndarray, bool, int, float, complex]): + value or values to prepend to the specified axis before taking the + difference. + Must have the same shape as `x` except along `axis`, which can have + any shape. + Default: `None`. + append (Union[usm_ndarray, bool, int, float, complex]): + value or values to append to the specified axis before taking the + difference. + Must have the same shape as `x` except along `axis`, which can have + any shape. + Default: `None`. + + Returns: + usm_ndarray: + an array containing the `n`-th differences. The array will have the + same shape as `x`, except along `axis`, which will have shape: + ``prepend.shape[axis] + x.shape[axis] + append.shape[axis] - n`` + + The data type of the returned array is determined by the Type + Promotion Rules. + """ + + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x)}" + ) + x_nd = x.ndim + axis = normalize_axis_index(operator.index(axis), x_nd) + n = operator.index(n) + if n < 0: + raise ValueError(f"`n` must be positive, got {n}") + arr = _concat_diff_input(x, axis, prepend, append) + if n == 0: + return arr + # form slices and recurse + sl0 = tuple( + slice(None) if i != axis else slice(1, None) for i in range(x_nd) + ) + sl1 = tuple( + slice(None) if i != axis else slice(None, -1) for i in range(x_nd) + ) + + diff_op = dpt.not_equal if x.dtype == dpt.bool else dpt.subtract + if n > 1: + arr_tmp0 = diff_op(arr[sl0], arr[sl1]) + arr_tmp1 = diff_op(arr_tmp0[sl0], arr_tmp0[sl1]) + n = n - 2 + if n > 0: + sl3 = tuple( + slice(None) if i != axis else slice(None, -2) + for i in range(x_nd) + ) + for _ in range(n): + arr_tmp0_sliced = arr_tmp0[sl3] + diff_op(arr_tmp1[sl0], arr_tmp1[sl1], out=arr_tmp0_sliced) + arr_tmp0, arr_tmp1 = arr_tmp1, arr_tmp0_sliced + arr = arr_tmp1 + else: + arr = diff_op(arr[sl0], arr[sl1]) + return arr diff --git a/dpctl_ext/tensor/libtensor/include/kernels/reductions.hpp b/dpctl_ext/tensor/libtensor/include/kernels/reductions.hpp new file mode 100644 index 00000000000..ee6431dec63 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/reductions.hpp @@ -0,0 +1,3323 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor reduction along axis. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "dpctl_tensor_types.hpp" +#include "utils/math_utils.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace reduction_detail +{ + +inline std::size_t get_work_group_size(const sycl::device &d) +{ + // prevents running out of resources on CPU + return std::min( + 2048, d.get_info() / 2); +} + +} // namespace reduction_detail + +template +struct needs_workaround +{ + static constexpr bool value = + (std::is_same_v> && + (std::is_same_v || + std::is_same_v)) || + (__LIBSYCL_MAJOR_VERSION < 7 && std::is_same_v && + std::is_same_v>); +}; + +template +struct can_use_reduce_over_group +{ + static constexpr bool value = + sycl::has_known_identity::value && + !needs_workaround::value; +}; + +template +struct SequentialReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + +public: + SequentialReduction(const argT *inp, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size) + : inp_(inp), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + const ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + outT red_val(identity_); + for (std::size_t m = 0; m < reduction_max_gid_; ++m) { + const ssize_t inp_reduction_offset = inp_reduced_dims_indexer_(m); + const ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + red_val = reduction_op_(red_val, val); + } + + out_[out_iter_offset] = red_val; + } +}; + +/* === Reduction, using sycl::reduce_over_group, and sycl::atomic_ref === */ + +/* + This kernel only works for outT with sizeof(outT) == 4, or sizeof(outT) == 8 + if the device has aspect atomic64 and only with those supported by + sycl::atomic_ref +*/ +template +struct ReductionOverGroupWithAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + ReductionOverGroupWithAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + std::size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (std::size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg; + if constexpr (su_ns::IsLogicalAnd::value) { + red_val_over_wg = static_cast( + sycl::all_of_group(work_group, local_red_val)); + } + else if constexpr (su_ns::IsLogicalOr::value) { + red_val_over_wg = static_cast( + sycl::any_of_group(work_group, local_red_val)); + } + else { + red_val_over_wg = sycl::reduce_over_group(work_group, local_red_val, + identity_, reduction_op_); + } + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_iter_offset]); + if constexpr (su_ns::IsPlus::value) { + res_ref += red_val_over_wg; + } + else if constexpr (su_ns::IsMaximum::value) { + res_ref.fetch_max(red_val_over_wg); + } + else if constexpr (su_ns::IsMinimum::value) { + res_ref.fetch_min(red_val_over_wg); + } + else if constexpr (su_ns::IsLogicalAnd::value) { + res_ref.fetch_and(red_val_over_wg); + } + else if constexpr (su_ns::IsLogicalOr::value) { + res_ref.fetch_or(red_val_over_wg); + } + else { + outT read_val = res_ref.load(); + outT new_val{}; + do { + new_val = reduction_op_(read_val, red_val_over_wg); + } while (!res_ref.compare_exchange_strong(read_val, new_val)); + } + } + } +}; + +/* === Reduction, using custom_reduce_over_group, and sycl::atomic_ref === */ + +template +struct CustomReductionOverGroupWithAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + CustomReductionOverGroupWithAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + std::size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (std::size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + + auto work_group = it.get_group(); + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_iter_offset]); + // retain these checks in case a reduce_over_group work-around is + // needed + if constexpr (su_ns::IsSyclPlus::value) { + res_ref += red_val_over_wg; + } + else if constexpr (su_ns::IsSyclMaximum::value) { + res_ref.fetch_max(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclMinimum::value) { + res_ref.fetch_min(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclLogicalAnd::value) { + res_ref.fetch_and(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclLogicalOr::value) + { + res_ref.fetch_or(red_val_over_wg); + } + else { + outT read_val = res_ref.load(); + outT new_val{}; + do { + new_val = reduction_op_(read_val, red_val_over_wg); + } while (!res_ref.compare_exchange_strong(read_val, new_val)); + } + } + } +}; + +template +struct ReductionOverGroupNoAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + ReductionOverGroupNoAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const std::size_t n_reduction_groups = + it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (std::size_t m = 0; m < reductions_per_wi; ++m) { + std::size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg; + if constexpr (su_ns::IsLogicalAnd::value) { + red_val_over_wg = sycl::all_of_group(work_group, local_red_val); + } + else if constexpr (su_ns::IsLogicalOr::value) { + red_val_over_wg = sycl::any_of_group(work_group, local_red_val); + } + else { + red_val_over_wg = sycl::reduce_over_group(work_group, local_red_val, + identity_, reduction_op_); + } + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +/* = Reduction, using custom_reduce_over_group and not using atomic_ref*/ + +template +struct CustomReductionOverGroupNoAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + CustomReductionOverGroupNoAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const std::size_t n_reduction_groups = + it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (std::size_t m = 0; m < reductions_per_wi; ++m) { + std::size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (std::is_same_v> || + std::is_same_v>) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template < + typename argTy, + typename resTy, + typename ReductionOpT, + typename InputOutputIterIndexerT, + typename ReductionIndexerT, + template + class kernel_name_token> +sycl::event + sequential_reduction(sycl::queue &exec_q, + const argTy *arg, + resTy *res, + resTy identity_val, + std::size_t iter_nelems, + std::size_t reduction_nelems, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + sycl::range<1>(iter_nelems), + SequentialReduction( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems)); + }); + + return red_ev; +} + +template +class custom_reduction_wrapper; + +template < + typename argTy, + typename resTy, + typename ReductionOpT, + typename InputOutputIterIndexerT, + typename ReductionIndexerT, + template + class kernel_name_token> +sycl::event + submit_atomic_reduction(sycl::queue &exec_q, + const argTy *arg, + resTy *res, + resTy identity_val, + std::size_t wg, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reductions_per_wi, + std::size_t reduction_groups, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + auto ndRange = sycl::nd_range<1>(globalRange, localRange); + + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + ndRange, + ReductionOverGroupWithAtomicFunctor( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, iter_nelems, + reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + + using KernelName = class custom_reduction_wrapper< + kernel_name_token>; + + cgh.parallel_for( + ndRange, + CustomReductionOverGroupWithAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + }); + return red_ev; +} + +template +class reduction_over_group_with_atomics_init_krn; + +template +class reduction_seq_krn; + +template +class reduction_over_group_with_atomics_krn; + +typedef sycl::event (*reduction_strided_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + const std::vector &); + +using dpctl::tensor::sycl_utils::choose_workgroup_size; + +template +sycl::event reduction_over_group_with_atomics_strided_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + static constexpr resTy identity_val = + su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + const IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, + res_strides); + using InitKernelName = + class reduction_over_group_with_atomics_init_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); + }); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + sycl::event comp_ev = + submit_atomic_reduction( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {res_init_ev}); + + return comp_ev; + } +} + +// Contig + +typedef sycl::event (*reduction_contig_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +/* @brief Reduce rows in a matrix */ +template +sycl::event reduction_axis1_over_group_with_atomics_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr resTy identity_val = + su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + else { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using RowsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + RowsIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const RowsIndexerT rows_indexer{/* size */ iter_nelems, + /* step */ reduction_nelems}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{rows_indexer, + result_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + sycl::event comp_ev = + submit_atomic_reduction( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {res_init_ev}); + + return comp_ev; + } +} + +/* @brief Reduce rows in a matrix */ +template +sycl::event reduction_axis0_over_group_with_atomics_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of cols in a + // matrix when reducing over cols) + std::size_t reduction_nelems, // size of each reduction (length of cols, + // i.e. number of rows) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr resTy identity_val = + su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + else { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + sycl::event comp_ev = + submit_atomic_reduction( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {res_init_ev}); + + return comp_ev; + } +} + +/* = Reduction, using sycl::reduce_over_group, but not using atomic_ref = */ + +template < + typename argTy, + typename resTy, + typename ReductionOpT, + typename InputOutputIterIndexerT, + typename ReductionIndexerT, + template + class kernel_name_token> +sycl::event submit_no_atomic_reduction( + sycl::queue &exec_q, + const argTy *arg, + resTy *res, + resTy identity_val, + std::size_t wg, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reductions_per_wi, + std::size_t reduction_groups, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + auto ndRange = sycl::nd_range<1>(globalRange, localRange); + + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + ndRange, + ReductionOverGroupNoAtomicFunctor( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, iter_nelems, + reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class custom_reduction_wrapper< + kernel_name_token>; + + cgh.parallel_for( + ndRange, + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + }); + return red_ev; +} + +template +class reduction_over_group_temps_krn; + +typedef sycl::event (*reduction_strided_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + const std::vector &); + +template +class reduction_over_group_temps_empty_krn; + +template +sycl::event reduction_over_group_temps_strided_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + static constexpr resTy identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + const IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, + res_strides); + using InitKernelName = + class reduction_over_group_temps_empty_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); + }); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + ; + + sycl::event first_reduction_ev; + { + using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + const InputIndexerT inp_indexer(iter_nd, iter_arg_offset, + iter_shape_and_strides); + static constexpr ResIndexerT noop_tmp_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + noop_tmp_indexer}; + const ReductionIndexerT reduction_indexer{ + red_nd, reduction_arg_offset, reduction_shape_stride}; + + first_reduction_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, partially_reduced_tmp, identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev; + { + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + partial_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + } + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + const ResIndexerT res_iter_indexer{ + iter_nd, iter_res_offset, + /* shape */ iter_shape_and_strides, + /* strides */ iter_shape_and_strides + 2 * iter_nd}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +template +sycl::event reduction_axis1_over_group_temps_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr resTy identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using RowsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + RowsIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const RowsIndexerT rows_indexer{/* size */ iter_nelems, + /* step */ reduction_nelems}; + static constexpr NoOpIndexerT noop_tmp_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{rows_indexer, + noop_tmp_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + first_reduction_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, partially_reduced_tmp, identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +template +sycl::event reduction_axis0_over_group_temps_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr resTy identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT noop_tmp_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + noop_tmp_indexer}; + const ReductionIndexerT reduction_indexer{ + /* size */ reduction_nelems, + /* step */ iter_nelems}; + + first_reduction_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, partially_reduced_tmp, identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +// Argmax and Argmin + +/* Sequential search reduction */ + +template +struct SequentialSearchReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + +public: + SequentialSearchReduction( + const argT *inp, + outT *res, + const ReductionOp &reduction_op, + const argT &identity_val, + const IdxReductionOp &idx_reduction_op, + const outT &idx_identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size) + : inp_(inp), out_(res), reduction_op_(reduction_op), + identity_(identity_val), idx_reduction_op_(idx_reduction_op), + idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + const ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + argT red_val(identity_); + outT idx_val(idx_identity_); + for (std::size_t m = 0; m < reduction_max_gid_; ++m) { + const ssize_t inp_reduction_offset = inp_reduced_dims_indexer_(m); + const ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == red_val) { + idx_val = idx_reduction_op_(idx_val, static_cast(m)); + } + else { + if constexpr (su_ns::IsMinimum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::less_complex; + // less_complex always returns false for NaNs, so check + if (less_complex(val, red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + red_val = val; + idx_val = static_cast(m); + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val < red_val || std::isnan(val)) { + red_val = val; + idx_val = static_cast(m); + } + } + else { + if (val < red_val) { + red_val = val; + idx_val = static_cast(m); + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::greater_complex; + if (greater_complex(val, red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + red_val = val; + idx_val = static_cast(m); + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val > red_val || std::isnan(val)) { + red_val = val; + idx_val = static_cast(m); + } + } + else { + if (val > red_val) { + red_val = val; + idx_val = static_cast(m); + } + } + } + } + } + out_[out_iter_offset] = idx_val; + } +}; + +/* = Search reduction using reduce_over_group*/ + +template +struct SearchReduction +{ +private: + const argT *inp_ = nullptr; + argT *vals_ = nullptr; + const outT *inds_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + SearchReduction(const argT *data, + argT *vals, + const outT *inds, + outT *res, + const ReductionOp &reduction_op, + const argT &identity_val, + const IdxReductionOp &idx_reduction_op, + const outT &idx_identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), vals_(vals), inds_(inds), out_(res), + reduction_op_(reduction_op), identity_(identity_val), + idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const std::size_t n_reduction_groups = + it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + argT local_red_val(identity_); + outT local_idx(idx_identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (std::size_t m = 0; m < reductions_per_wi; ++m) { + std::size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == local_red_val) { + if constexpr (!First) { + local_idx = + idx_reduction_op_(local_idx, inds_[inp_offset]); + } + else { + local_idx = idx_reduction_op_( + local_idx, static_cast(arg_reduce_gid)); + } + } + else { + if constexpr (su_ns::IsMinimum::value) { + if (val < local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = static_cast(arg_reduce_gid); + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + if (val > local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = static_cast(arg_reduce_gid); + } + } + } + } + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + argT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, identity_, reduction_op_); + + if constexpr (std::is_integral_v) { + local_idx = + (red_val_over_wg == local_red_val) ? local_idx : idx_identity_; + } + else { + local_idx = + (red_val_over_wg == local_red_val || + std::isnan(red_val_over_wg) || std::isnan(local_red_val)) + ? local_idx + : idx_identity_; + } + outT idx_over_wg = sycl::reduce_over_group( + work_group, local_idx, idx_identity_, idx_reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + if constexpr (!Last) { + // if not the final reduction, write value corresponding to + // an index to a temporary + vals_[out_iter_offset * n_reduction_groups + + reduction_batch_id] = red_val_over_wg; + } + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + idx_over_wg; + } + } +}; + +/* = Search reduction using custom_reduce_over_group*/ + +template +struct CustomSearchReduction +{ +private: + const argT *inp_ = nullptr; + argT *vals_ = nullptr; + const outT *inds_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + CustomSearchReduction(const argT *data, + argT *vals, + outT *inds, + outT *res, + const ReductionOp &reduction_op, + const argT &identity_val, + const IdxReductionOp &idx_reduction_op, + const outT &idx_identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), vals_(vals), inds_(inds), out_(res), + reduction_op_(reduction_op), identity_(identity_val), + idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const std::size_t n_reduction_groups = + it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + argT local_red_val(identity_); + outT local_idx(idx_identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (std::size_t m = 0; m < reductions_per_wi; ++m) { + std::size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == local_red_val) { + if constexpr (!First) { + local_idx = + idx_reduction_op_(local_idx, inds_[inp_offset]); + } + else { + local_idx = idx_reduction_op_( + local_idx, static_cast(arg_reduce_gid)); + } + } + else { + if constexpr (su_ns::IsMinimum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::less_complex; + // less_complex always returns false for NaNs, so + // check + if (less_complex(val, local_red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val < local_red_val || std::isnan(val)) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else { + if (val < local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::greater_complex; + if (greater_complex(val, local_red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val > local_red_val || std::isnan(val)) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else { + if (val > local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + } + } + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + argT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + // equality does not hold for NaNs, so check here + local_idx = (red_val_over_wg == local_red_val || + std::isnan(std::real(local_red_val)) || + std::isnan(std::imag(local_red_val))) + ? local_idx + : idx_identity_; + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + // equality does not hold for NaNs, so check here + local_idx = + (red_val_over_wg == local_red_val || std::isnan(local_red_val)) + ? local_idx + : idx_identity_; + } + else { + local_idx = + red_val_over_wg == local_red_val ? local_idx : idx_identity_; + } + outT idx_over_wg = sycl::reduce_over_group( + work_group, local_idx, idx_identity_, idx_reduction_op_); + if (work_group.leader()) { + // each group writes to a different memory location + if constexpr (!Last) { + // if not the final reduction, write value corresponding to + // an index to a temporary + vals_[out_iter_offset * n_reduction_groups + + reduction_batch_id] = red_val_over_wg; + } + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + idx_over_wg; + } + } +}; + +typedef sycl::event (*search_strided_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + const std::vector &); + +template +class search_seq_strided_krn; + +template +class search_seq_contig_krn; + +template +class search_over_group_krn; + +template +class custom_search_over_group_krn; + +template +class search_empty_krn; + +template +sycl::event + submit_search_reduction(sycl::queue &exec_q, + const argTy *arg, + argTy *arg_tmp, + resTy *res_tmp, + resTy *res, + argTy identity_val, + resTy idx_identity_val, + std::size_t wg, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reductions_per_wi, + std::size_t reduction_groups, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + auto ndRange = sycl::nd_range<1>(globalRange, localRange); + + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class search_over_group_krn; + cgh.parallel_for( + ndRange, SearchReduction( + arg, arg_tmp, res_tmp, res, ReductionOpT(), + identity_val, IndexOpT(), idx_identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class custom_search_over_group_krn< + argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT, First, Last>; + cgh.parallel_for( + ndRange, + CustomSearchReduction( + arg, arg_tmp, res_tmp, res, ReductionOpT(), identity_val, + IndexOpT(), idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + }); + return red_ev; +} + +template +sycl::event search_over_group_temps_strided_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + static constexpr argTy identity_val = + su_ns::Identity::value; + static constexpr resTy idx_identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + const IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, + res_strides); + using InitKernelName = + class search_empty_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = idx_identity_val; + }); + }); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 4; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = + submit_search_reduction( + exec_q, arg_tp, nullptr, nullptr, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + auto val_tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + argTy *partially_reduced_vals_tmp = val_tmp_owner.get(); + argTy *partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + // Only 2*iter_nd entries describing shape and strides of iterated + // dimensions of input array from iter_shape_and_strides are going + // to be accessed by inp_indexer + const InputIndexerT inp_indexer(iter_nd, iter_arg_offset, + iter_shape_and_strides); + static constexpr ResIndexerT noop_tmp_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + noop_tmp_indexer}; + const ReductionIndexerT reduction_indexer{ + red_nd, reduction_arg_offset, reduction_shape_stride}; + + first_reduction_ev = + submit_search_reduction( + exec_q, arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, identity_val, idx_identity_val, wg, + iter_nelems, reduction_nelems, reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + identity_val, idx_identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + const ResIndexerT res_iter_indexer{ + iter_nd, iter_res_offset, + /* shape */ iter_shape_and_strides, + /* strides */ iter_shape_and_strides + 2 * iter_nd}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, nullptr, temp_arg, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, remaining_reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner, val_tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +typedef sycl::event (*search_contig_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event search_axis1_over_group_temps_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr argTy identity_val = + su_ns::Identity::value; + static constexpr resTy idx_identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(idx_identity_val), iter_nelems, depends); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = + submit_search_reduction( + exec_q, arg_tp, nullptr, nullptr, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + auto val_tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + argTy *partially_reduced_vals_tmp = val_tmp_owner.get(); + argTy *partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + first_reduction_ev = + submit_search_reduction( + exec_q, arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, identity_val, idx_identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + identity_val, idx_identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, nullptr, temp_arg, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, remaining_reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner, val_tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +template +sycl::event search_axis0_over_group_temps_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr argTy identity_val = + su_ns::Identity::value; + static constexpr resTy idx_identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(idx_identity_val), iter_nelems, depends); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = + class search_seq_contig_krn; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for( + iter_range, + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = + submit_search_reduction( + exec_q, arg_tp, nullptr, nullptr, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + auto vals_tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + argTy *partially_reduced_vals_tmp = vals_tmp_owner.get(); + argTy *partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + const ReductionIndexerT reduction_indexer{ + /* size */ reduction_nelems, + /* step */ iter_nelems}; + + first_reduction_ev = + submit_search_reduction( + exec_q, arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, identity_val, idx_identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + identity_val, idx_identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, nullptr, temp_arg, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, remaining_reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner, vals_tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/source/reductions/all.cpp b/dpctl_ext/tensor/libtensor/source/reductions/all.cpp new file mode 100644 index 00000000000..a901b9e1d9a --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/all.cpp @@ -0,0 +1,164 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + all_reduction_strided_dispatch_vector[td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + all_reduction_axis1_contig_dispatch_vector[td_ns::num_types]; +static reduction_contig_impl_fn_ptr + all_reduction_axis0_contig_dispatch_vector[td_ns::num_types]; + +template +struct AllStridedFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } +}; + +template +struct AllAxis1ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl; + } +}; + +template +struct AllAxis0ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl; + } +}; + +void populate_all_dispatch_vectors(void) +{ + using td_ns::DispatchVectorBuilder; + + DispatchVectorBuilder + all_dvb1; + all_dvb1.populate_dispatch_vector(all_reduction_strided_dispatch_vector); + + DispatchVectorBuilder + all_dvb2; + all_dvb2.populate_dispatch_vector( + all_reduction_axis1_contig_dispatch_vector); + + DispatchVectorBuilder + all_dvb3; + all_dvb3.populate_dispatch_vector( + all_reduction_axis0_contig_dispatch_vector); +}; + +using atomic_support::atomic_support_fn_ptr_t; +using atomic_support::check_atomic_support; +static atomic_support_fn_ptr_t all_atomic_support = + check_atomic_support; + +} // namespace impl + +void init_all(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_all_dispatch_vectors(); + using impl::all_reduction_axis0_contig_dispatch_vector; + using impl::all_reduction_axis1_contig_dispatch_vector; + using impl::all_reduction_strided_dispatch_vector; + + using impl::all_atomic_support; + + auto all_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_boolean_reduction( + src, trailing_dims_to_reduce, dst, exec_q, depends, + all_reduction_axis1_contig_dispatch_vector, + all_reduction_axis0_contig_dispatch_vector, + all_reduction_strided_dispatch_vector, all_atomic_support); + }; + m.def("_all", all_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/all.hpp b/dpctl_ext/tensor/libtensor/source/reductions/all.hpp new file mode 100644 index 00000000000..5fb184e37c6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/all.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_all(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/any.cpp b/dpctl_ext/tensor/libtensor/source/reductions/any.cpp new file mode 100644 index 00000000000..6859e46cbc4 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/any.cpp @@ -0,0 +1,164 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + any_reduction_strided_dispatch_vector[td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + any_reduction_axis1_contig_dispatch_vector[td_ns::num_types]; +static reduction_contig_impl_fn_ptr + any_reduction_axis0_contig_dispatch_vector[td_ns::num_types]; + +template +struct AnyStridedFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } +}; + +template +struct AnyAxis1ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl; + } +}; + +template +struct AnyAxis0ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl; + } +}; + +void populate_any_dispatch_vectors(void) +{ + using td_ns::DispatchVectorBuilder; + + DispatchVectorBuilder + any_dvb1; + any_dvb1.populate_dispatch_vector(any_reduction_strided_dispatch_vector); + + DispatchVectorBuilder + any_dvb2; + any_dvb2.populate_dispatch_vector( + any_reduction_axis1_contig_dispatch_vector); + + DispatchVectorBuilder + any_dvb3; + any_dvb3.populate_dispatch_vector( + any_reduction_axis0_contig_dispatch_vector); +}; + +using atomic_support::atomic_support_fn_ptr_t; +using atomic_support::check_atomic_support; +static atomic_support_fn_ptr_t any_atomic_support = + check_atomic_support; + +} // namespace impl + +void init_any(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_any_dispatch_vectors(); + using impl::any_reduction_axis0_contig_dispatch_vector; + using impl::any_reduction_axis1_contig_dispatch_vector; + using impl::any_reduction_strided_dispatch_vector; + + using impl::any_atomic_support; + + auto any_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_boolean_reduction( + src, trailing_dims_to_reduce, dst, exec_q, depends, + any_reduction_axis1_contig_dispatch_vector, + any_reduction_axis0_contig_dispatch_vector, + any_reduction_strided_dispatch_vector, any_atomic_support); + }; + m.def("_any", any_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/any.hpp b/dpctl_ext/tensor/libtensor/source/reductions/any.hpp new file mode 100644 index 00000000000..4e368a67461 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/any.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_any(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/argmax.cpp b/dpctl_ext/tensor/libtensor/source/reductions/argmax.cpp new file mode 100644 index 00000000000..10fc4975916 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/argmax.cpp @@ -0,0 +1,279 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::search_strided_impl_fn_ptr; +static search_strided_impl_fn_ptr + argmax_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmax_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmax_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportForArgmaxReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ArgmaxOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgmaxOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgmaxOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +void populate_argmax_over_axis_dispatch_tables(void) +{ + using td_ns::DispatchTableBuilder; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(argmax_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(argmax_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_argmax(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_argmax_over_axis_dispatch_tables; + populate_argmax_over_axis_dispatch_tables(); + using impl::argmax_over_axis0_contig_temps_dispatch_table; + using impl::argmax_over_axis1_contig_temps_dispatch_table; + using impl::argmax_over_axis_strided_temps_dispatch_table; + + auto argmax_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmax_over_axis_strided_temps_dispatch_table, + argmax_over_axis0_contig_temps_dispatch_table, + argmax_over_axis1_contig_temps_dispatch_table); + }; + m.def("_argmax_over_axis", argmax_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/argmax.hpp b/dpctl_ext/tensor/libtensor/source/reductions/argmax.hpp new file mode 100644 index 00000000000..3274f8c7d0c --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/argmax.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_argmax(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/argmin.cpp b/dpctl_ext/tensor/libtensor/source/reductions/argmin.cpp new file mode 100644 index 00000000000..ec4637b62d4 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/argmin.cpp @@ -0,0 +1,279 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::search_strided_impl_fn_ptr; +static search_strided_impl_fn_ptr + argmin_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmin_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmin_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportForArgminReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ArgminOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +void populate_argmin_over_axis_dispatch_tables(void) +{ + using td_ns::DispatchTableBuilder; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmin_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(argmin_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(argmin_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_argmin(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_argmin_over_axis_dispatch_tables; + populate_argmin_over_axis_dispatch_tables(); + using impl::argmin_over_axis0_contig_temps_dispatch_table; + using impl::argmin_over_axis1_contig_temps_dispatch_table; + using impl::argmin_over_axis_strided_temps_dispatch_table; + + auto argmin_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmin_over_axis_strided_temps_dispatch_table, + argmin_over_axis0_contig_temps_dispatch_table, + argmin_over_axis1_contig_temps_dispatch_table); + }; + m.def("_argmin_over_axis", argmin_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/argmin.hpp b/dpctl_ext/tensor/libtensor/source/reductions/argmin.hpp new file mode 100644 index 00000000000..1865c258a52 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/argmin.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_argmin(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.cpp b/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.cpp new file mode 100644 index 00000000000..75e4010bfd5 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.cpp @@ -0,0 +1,258 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + logsumexp_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + logsumexp_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + logsumexp_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportDataForLogSumExpReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< +#if 1 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, +#endif + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct LogSumExpOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct LogSumExpOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct LogSumExpOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +void populate_logsumexp_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table( + logsumexp_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table( + logsumexp_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table( + logsumexp_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_logsumexp(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_logsumexp_over_axis_dispatch_tables; + populate_logsumexp_over_axis_dispatch_tables(); + using impl::logsumexp_over_axis0_contig_temps_dispatch_table; + using impl::logsumexp_over_axis1_contig_temps_dispatch_table; + using impl::logsumexp_over_axis_strided_temps_dispatch_table; + + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + + auto logsumexp_pyapi = [&](const arrayT &src, + int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_tree_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + logsumexp_over_axis_strided_temps_dispatch_table, + logsumexp_over_axis0_contig_temps_dispatch_table, + logsumexp_over_axis1_contig_temps_dispatch_table); + }; + m.def("_logsumexp_over_axis", logsumexp_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto logsumexp_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + return py_tree_reduction_dtype_supported( + input_dtype, output_dtype, + logsumexp_over_axis_strided_temps_dispatch_table); + }; + m.def("_logsumexp_over_axis_dtype_supported", logsumexp_dtype_supported, + "", py::arg("arg_dtype"), py::arg("out_dtype")); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.hpp b/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.hpp new file mode 100644 index 00000000000..2e2c19877db --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_logsumexp(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/max.cpp b/dpctl_ext/tensor/libtensor/source/reductions/max.cpp new file mode 100644 index 00000000000..d19ed226d3b --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/max.cpp @@ -0,0 +1,410 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + max_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +/* @brief Types supported by max reduction code based on atomic_ref */ +template +struct TypePairSupportDataForMaxReductionAtomic +{ + /* value is true if a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< + // input int32 + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForMaxReductionTemps +{ + static constexpr bool is_defined = std::disjunction< + // input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MaxOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +void populate_max_over_axis_dispatch_tables(void) +{ + using td_ns::DispatchTableBuilder; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(max_over_axis_strided_atomic_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(max_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(max_over_axis1_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(max_over_axis0_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(max_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(max_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t max_atomic_support_vector[td_ns::num_types]; + +void populate_max_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::MaxAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(max_atomic_support_vector); +} + +} // namespace impl + +void init_max(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_max_over_axis_dispatch_tables; + populate_max_over_axis_dispatch_tables(); + using impl::max_over_axis0_contig_atomic_dispatch_table; + using impl::max_over_axis0_contig_temps_dispatch_table; + using impl::max_over_axis1_contig_atomic_dispatch_table; + using impl::max_over_axis1_contig_temps_dispatch_table; + using impl::max_over_axis_strided_atomic_dispatch_table; + using impl::max_over_axis_strided_temps_dispatch_table; + + using impl::populate_max_atomic_support_dispatch_vector; + populate_max_atomic_support_dispatch_vector(); + using impl::max_atomic_support_vector; + + auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + max_over_axis_strided_atomic_dispatch_table, + max_over_axis0_contig_atomic_dispatch_table, + max_over_axis1_contig_atomic_dispatch_table, + max_over_axis_strided_temps_dispatch_table, + max_over_axis0_contig_temps_dispatch_table, + max_over_axis1_contig_temps_dispatch_table, + max_atomic_support_vector); + }; + m.def("_max_over_axis", max_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/max.hpp b/dpctl_ext/tensor/libtensor/source/reductions/max.hpp new file mode 100644 index 00000000000..bc242dc8d74 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/max.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_max(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/min.cpp b/dpctl_ext/tensor/libtensor/source/reductions/min.cpp new file mode 100644 index 00000000000..97d3432b13e --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/min.cpp @@ -0,0 +1,412 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + min_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +/* @brief Types supported by min reduction code based on atomic_ref */ +template +struct TypePairSupportDataForMinReductionAtomic +{ + /* value is true if a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< + // input int32 + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForMinReductionTemps +{ + static constexpr bool is_defined = std::disjunction< + // input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MinOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +void populate_min_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(min_over_axis_strided_atomic_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(min_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(min_over_axis1_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(min_over_axis0_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(min_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(min_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t min_atomic_support_vector[td_ns::num_types]; + +void populate_min_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::MinAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(min_atomic_support_vector); +} + +} // namespace impl + +void init_min(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_min_over_axis_dispatch_tables; + populate_min_over_axis_dispatch_tables(); + using impl::min_over_axis0_contig_atomic_dispatch_table; + using impl::min_over_axis0_contig_temps_dispatch_table; + using impl::min_over_axis1_contig_atomic_dispatch_table; + using impl::min_over_axis1_contig_temps_dispatch_table; + using impl::min_over_axis_strided_atomic_dispatch_table; + using impl::min_over_axis_strided_temps_dispatch_table; + + using impl::populate_min_atomic_support_dispatch_vector; + populate_min_atomic_support_dispatch_vector(); + using impl::min_atomic_support_vector; + + auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + min_over_axis_strided_atomic_dispatch_table, + min_over_axis0_contig_atomic_dispatch_table, + min_over_axis1_contig_atomic_dispatch_table, + min_over_axis_strided_temps_dispatch_table, + min_over_axis0_contig_temps_dispatch_table, + min_over_axis1_contig_temps_dispatch_table, + min_atomic_support_vector); + }; + m.def("_min_over_axis", min_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/min.hpp b/dpctl_ext/tensor/libtensor/source/reductions/min.hpp new file mode 100644 index 00000000000..e054f44539f --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/min.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_min(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/prod.cpp b/dpctl_ext/tensor/libtensor/source/reductions/prod.cpp new file mode 100644 index 00000000000..6cbb21dfe02 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/prod.cpp @@ -0,0 +1,466 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForProductReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForProductReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ProductOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +void populate_prod_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(prod_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(prod_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t prod_atomic_support_vector[td_ns::num_types]; + +void populate_prod_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::ProductAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(prod_atomic_support_vector); +} + +} // namespace impl + +void init_prod(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_prod_over_axis_dispatch_tables; + populate_prod_over_axis_dispatch_tables(); + using impl::prod_over_axis0_contig_atomic_dispatch_table; + using impl::prod_over_axis0_contig_temps_dispatch_table; + using impl::prod_over_axis1_contig_atomic_dispatch_table; + using impl::prod_over_axis1_contig_temps_dispatch_table; + using impl::prod_over_axis_strided_atomic_dispatch_table; + using impl::prod_over_axis_strided_temps_dispatch_table; + + using impl::populate_prod_atomic_support_dispatch_vector; + populate_prod_atomic_support_dispatch_vector(); + using impl::prod_atomic_support_vector; + + auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis0_contig_atomic_dispatch_table, + prod_over_axis1_contig_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_over_axis0_contig_temps_dispatch_table, + prod_over_axis1_contig_temps_dispatch_table, + prod_atomic_support_vector); + }; + m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto prod_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_atomic_support_vector); + }; + m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/prod.hpp b/dpctl_ext/tensor/libtensor/source/reductions/prod.hpp new file mode 100644 index 00000000000..15b1c07e5dd --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/prod.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_prod(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.cpp b/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.cpp new file mode 100644 index 00000000000..5279b4f6c27 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.cpp @@ -0,0 +1,254 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + hypot_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + hypot_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + hypot_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportDataForHypotReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct HypotOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct HypotOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct HypotOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +void populate_hypot_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(hypot_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(hypot_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(hypot_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_reduce_hypot(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_hypot_over_axis_dispatch_tables; + populate_hypot_over_axis_dispatch_tables(); + using impl::hypot_over_axis0_contig_temps_dispatch_table; + using impl::hypot_over_axis1_contig_temps_dispatch_table; + using impl::hypot_over_axis_strided_temps_dispatch_table; + + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + + auto hypot_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_tree_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + hypot_over_axis_strided_temps_dispatch_table, + hypot_over_axis0_contig_temps_dispatch_table, + hypot_over_axis1_contig_temps_dispatch_table); + }; + m.def("_hypot_over_axis", hypot_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto hypot_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + return py_tree_reduction_dtype_supported( + input_dtype, output_dtype, + hypot_over_axis_strided_temps_dispatch_table); + }; + m.def("_hypot_over_axis_dtype_supported", hypot_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype")); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.hpp b/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.hpp new file mode 100644 index 00000000000..c0a16345af7 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_reduce_hypot(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduction_atomic_support.hpp b/dpctl_ext/tensor/libtensor/source/reductions/reduction_atomic_support.hpp new file mode 100644 index 00000000000..5f9cc32f120 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduction_atomic_support.hpp @@ -0,0 +1,147 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +#include + +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::py_internal::atomic_support +{ + +typedef bool (*atomic_support_fn_ptr_t)(const sycl::queue &, sycl::usm::alloc); + +/*! @brief Function which returns a constant value for atomic support */ +template +bool fixed_decision(const sycl::queue &, sycl::usm::alloc) +{ + return return_value; +} + +/*! @brief Template for querying atomic support for a type on a device */ +template +bool check_atomic_support(const sycl::queue &exec_q, + sycl::usm::alloc usm_alloc_type) +{ + static constexpr bool atomic32 = (sizeof(T) == 4); + static constexpr bool atomic64 = (sizeof(T) == 8); + using dpctl::tensor::type_utils::is_complex; + if constexpr ((!atomic32 && !atomic64) || is_complex::value) { + return fixed_decision(exec_q, usm_alloc_type); + } + else { + bool supports_atomics = false; + const sycl::device &dev = exec_q.get_device(); + if constexpr (atomic64) { + if (!dev.has(sycl::aspect::atomic64)) { + return false; + } + } + switch (usm_alloc_type) { + case sycl::usm::alloc::shared: + supports_atomics = + dev.has(sycl::aspect::usm_atomic_shared_allocations); + break; + case sycl::usm::alloc::host: + supports_atomics = + dev.has(sycl::aspect::usm_atomic_host_allocations); + break; + case sycl::usm::alloc::device: + supports_atomics = true; + break; + default: + supports_atomics = false; + } + return supports_atomics; + } +} + +template +struct ArithmeticAtomicSupportFactory +{ + fnT get() + { + using dpctl::tensor::type_utils::is_complex; + if constexpr (std::is_floating_point_v || + std::is_same_v || is_complex::value) + { + // for real- and complex- floating point types, tree reduction has + // better round-off accumulation properties (round-off error is + // proportional to the log2(reduction_size), while naive elementwise + // summation used by atomic implementation has round-off error + // growing proportional to the reduction_size.), hence reduction + // over floating point types should always use tree_reduction + // algorithm, even though atomic implementation may be applicable + return fixed_decision; + } + else { + return check_atomic_support; + } + } +}; + +template +struct MinMaxAtomicSupportFactory +{ + fnT get() + { + return check_atomic_support; + } +}; + +template +struct MaxAtomicSupportFactory : public MinMaxAtomicSupportFactory +{ +}; + +template +struct MinAtomicSupportFactory : public MinMaxAtomicSupportFactory +{ +}; + +template +struct SumAtomicSupportFactory : public ArithmeticAtomicSupportFactory +{ +}; + +template +struct ProductAtomicSupportFactory + : public ArithmeticAtomicSupportFactory +{ +}; + +} // namespace dpctl::tensor::py_internal::atomic_support diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.cpp b/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.cpp new file mode 100644 index 00000000000..fca5e09e2fe --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.cpp @@ -0,0 +1,69 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include + +#include "all.hpp" +#include "any.hpp" +#include "argmax.hpp" +#include "argmin.hpp" +#include "logsumexp.hpp" +#include "max.hpp" +#include "min.hpp" +#include "prod.hpp" +#include "reduce_hypot.hpp" +#include "sum.hpp" + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +/*! @brief Add reduction functions to Python module */ +void init_reduction_functions(py::module_ m) +{ + init_all(m); + init_any(m); + init_argmax(m); + init_argmin(m); + init_logsumexp(m); + init_max(m); + init_min(m); + init_prod(m); + init_reduce_hypot(m); + init_sum(m); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.hpp b/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.hpp new file mode 100644 index 00000000000..4df67c16bc4 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_reduction_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduction_over_axis.hpp b/dpctl_ext/tensor/libtensor/source/reductions/reduction_over_axis.hpp new file mode 100644 index 00000000000..936c8dbe9b5 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduction_over_axis.hpp @@ -0,0 +1,1318 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension, specifically functions for reductions. +//===---------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +/* ====================== dtype supported ======================== */ + +/*! @brief Template implementing Python API for querying type support by + * reduction which may support atomics */ +template +bool py_reduction_dtype_supported( + const py::dtype &input_dtype, + const py::dtype &output_dtype, + const std::string &dst_usm_type, + sycl::queue &q, + const fnT &atomic_dispatch_table, + const fnT &temps_dispatch_table, + const CheckAtomicSupportFnT &check_atomic_support) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + // remove_all_extents gets underlying type of table + using fn_ptrT = typename std::remove_all_extents::type; + fn_ptrT fn = nullptr; + + sycl::usm::alloc kind = sycl::usm::alloc::unknown; + + if (dst_usm_type == "device") { + kind = sycl::usm::alloc::device; + } + else if (dst_usm_type == "shared") { + kind = sycl::usm::alloc::shared; + } + else if (dst_usm_type == "host") { + kind = sycl::usm::alloc::host; + } + else { + throw py::value_error("Unrecognized `dst_usm_type` argument."); + } + + bool supports_atomics = check_atomic_support[out_typeid](q, kind); + + if (supports_atomics) { + fn = atomic_dispatch_table[arg_typeid][out_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[arg_typeid][out_typeid]; + } + + return (fn != nullptr); +} + +/*! @brief Template implementing Python API for querying type support by tree + * reduction */ +template +bool py_tree_reduction_dtype_supported(const py::dtype &input_dtype, + const py::dtype &output_dtype, + const fnT &temps_dispatch_table) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + auto fn = temps_dispatch_table[arg_typeid][out_typeid]; + + return (fn != nullptr); +} + +/* ==================== Generic reductions ====================== */ + +/*! @brief Template implementing Python API for reduction over axis which may + * support atomics */ +template +std::pair py_reduction_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &atomic_dispatch_table, + const contig_fnT &axis0_atomic_dispatch_table, + const contig_fnT &axis1_atomic_dispatch_table, + const strided_fnT &temps_dispatch_table, + const contig_fnT &axis0_temps_dispatch_table, + const contig_fnT &axis1_temps_dispatch_table, + const SupportAtomicFnT &check_atomic_support) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t dst_nelems = dst.get_size(); + + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + std::size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + + bool supports_atomics = check_atomic_support[dst_typeid](exec_q, usm_type); + + // handle special case when both reduction and iteration are 1D contiguous + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + // remove_all_extents gets underlying type of table + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + // remove_all_extents gets underlying type of table + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis0_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + // TODO: not used anywhere + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_reduction_shape; + shT simplified_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + simplify_iteration_space_1( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + simplified_reduction_shape, simplified_reduction_src_strides, + reduction_src_offset); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + std::size_t iter_nelems = dst_nelems; + + if (simplified_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast( + simplified_iteration_src_strides[0]) == reduction_nelems); + } + else if (static_cast( + simplified_reduction_src_strides[0]) == iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis0_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + // remove_all_extents gets underlying type of table + using strided_fn_ptr_T = + typename std::remove_all_extents::type; + strided_fn_ptr_T fn = nullptr; + + if (supports_atomics) { + fn = atomic_dispatch_table[src_typeid][dst_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + } + + std::vector host_task_events{}; + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + simplified_reduction_shape, simplified_reduction_src_strides); + auto tmp_alloc_owner = + std::move(std::get<0>(arrays_metainfo_packing_triple_)); + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + const py::ssize_t *temp_allocation_ptr = tmp_alloc_owner.get(); + + const py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + const py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto reduction_ev = + fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), + iteration_nd, iter_shape_and_strides, iteration_src_offset, + iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {reduction_ev}, tmp_alloc_owner); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, reduction_ev); +} + +/* ================= No atomic reductions ====================== */ + +/*! @brief Template implementing Python API for reduction over axis without + * atomics */ +template +std::pair py_tree_reduction_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &temps_dispatch_table, + const contig_fnT &axis0_temps_dispatch_table, + const contig_fnT &axis1_temps_dispatch_table) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t dst_nelems = dst.get_size(); + + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + std::size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // handle special case when both reduction and iteration are 1D contiguous + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + auto fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_reduction_shape; + shT simplified_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + simplify_iteration_space_1( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + simplified_reduction_shape, simplified_reduction_src_strides, + reduction_src_offset); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + std::size_t iter_nelems = dst_nelems; + + if (simplified_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast( + simplified_iteration_src_strides[0]) == reduction_nelems); + } + else if (static_cast( + simplified_reduction_src_strides[0]) == iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + auto fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + auto fn = temps_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + std::vector host_task_events{}; + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + simplified_reduction_shape, simplified_reduction_src_strides); + auto tmp_owner = std::move(std::get<0>(arrays_metainfo_packing_triple_)); + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + const py::ssize_t *temp_allocation_ptr = tmp_owner.get(); + + const py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + const py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto reduction_ev = + fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), + iteration_nd, iter_shape_and_strides, iteration_src_offset, + iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {reduction_ev}, tmp_owner); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, reduction_ev); +} + +/*! @brief Template implementing Python API for searching over an axis */ +template +std::pair py_search_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &strided_dispatch_table, + const contig_fnT &axis0_contig_dispatch_table, + const contig_fnT &axis1_contig_dispatch_table) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t dst_nelems = dst.get_size(); + + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + std::size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // handle special case when both reduction and iteration are 1D contiguous + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if (is_src_c_contig && is_dst_c_contig) { + auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && dst_nd == 1) { + auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT compact_reduction_shape; + shT compact_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + // TODO: not used anywhere + compact_iteration_space( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + compact_reduction_shape, compact_reduction_src_strides); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + std::size_t iter_nelems = dst_nelems; + + if (compact_reduction_src_strides[0] == 1) { + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast( + simplified_iteration_src_strides[0]) == reduction_nelems); + } + else if (static_cast(compact_reduction_src_strides[0]) == + iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1) { + auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + auto fn = strided_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + std::vector host_task_events{}; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + auto arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + compact_reduction_shape, compact_reduction_src_strides); + auto tmp_owner = std::move(std::get<0>(arrays_metainfo_packing_triple_)); + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + const py::ssize_t *temp_allocation_ptr = tmp_owner.get(); + + const py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + const py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_nd, iter_shape_and_strides, + iteration_src_offset, iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {comp_ev}, tmp_owner); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, comp_ev); +} + +/* ================= Atomic only reductions ====================== */ + +/*! @brief Template implementing Python API for boolean reductions over an axis + */ +template +std::pair + py_boolean_reduction(const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const contig_dispatchT &axis1_contig_dispatch_vector, + const contig_dispatchT &axis0_contig_dispatch_vector, + const strided_dispatchT &strided_dispatch_vector, + const atomic_support_fnT check_atomic_support) +{ + int src_nd = src.get_ndim(); + int iter_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iter_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iter_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t dst_nelems = dst.get_size(); + + std::size_t red_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + red_nelems *= static_cast(src_shape_ptr[i]); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(dst, src)) { + throw py::value_error("Arrays are expected to have no memory overlap"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + static constexpr int int32_typeid = + static_cast(td_ns::typenum_t::INT32); + if (dst_typeid != int32_typeid) { + throw py::value_error( + "Unexpected data type of destination array, expecting 'int32'"); + } + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + + bool supports_atomics = check_atomic_support(exec_q, usm_type); + if (!supports_atomics) { + throw py::value_error( + "This reduction is not supported for this device and usm_type."); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + // TODO: should be dst_nelems == 0? + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 0)) + { + auto fn = axis1_contig_dispatch_vector[src_typeid]; + static constexpr py::ssize_t zero_offset = 0; + + sycl::event red_ev = + fn(exec_q, dst_nelems, red_nelems, src_data, dst_data, zero_offset, + zero_offset, zero_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + auto fn = axis0_contig_dispatch_vector[src_typeid]; + static constexpr py::ssize_t zero_offset = 0; + + sycl::event red_ev = + fn(exec_q, dst_nelems, red_nelems, src_data, dst_data, zero_offset, + zero_offset, zero_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + + auto src_shape_vecs = src.get_shape_vector(); + auto src_strides_vecs = src.get_strides_vector(); + auto dst_strides_vecs = dst.get_strides_vector(); + + int simplified_red_nd = trailing_dims_to_reduce; + + using shT = std::vector; + shT red_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_red_shape; + shT simplified_red_src_strides; + py::ssize_t red_src_offset(0); + + simplify_iteration_space_1( + simplified_red_nd, src_shape_ptr + dst_nd, red_src_strides, + // output + simplified_red_shape, simplified_red_src_strides, red_src_offset); + + shT iter_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iter_nd); + shT const &iter_dst_strides = dst_strides_vecs; + + shT simplified_iter_shape; + shT simplified_iter_src_strides; + shT simplified_iter_dst_strides; + py::ssize_t iter_src_offset(0); + py::ssize_t iter_dst_offset(0); + + if (iter_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iter_nd = 1; + simplified_iter_shape.push_back(1); + simplified_iter_src_strides.push_back(0); + simplified_iter_dst_strides.push_back(0); + } + else { + simplify_iteration_space( + iter_nd, src_shape_ptr, iter_src_strides, iter_dst_strides, + // output + simplified_iter_shape, simplified_iter_src_strides, + simplified_iter_dst_strides, iter_src_offset, iter_dst_offset); + } + + if (simplified_red_nd == 1 && iter_nd == 1) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + std::size_t iter_nelems = dst_nelems; + + if (simplified_red_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iter_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iter_dst_strides[0] == 1) && + (static_cast(simplified_iter_src_strides[0]) == + red_nelems); + } + else if (static_cast(simplified_red_src_strides[0]) == + iter_nelems) { + mat_reduce_over_axis0 = (simplified_iter_dst_strides[0] == 1) && + (simplified_iter_src_strides[0] == 1); + } + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + auto fn = axis1_contig_dispatch_vector[src_typeid]; + + sycl::event red_ev = + fn(exec_q, iter_nelems, red_nelems, src_data, dst_data, + iter_src_offset, iter_dst_offset, red_src_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + else if (mat_reduce_over_axis0) { + auto fn = axis0_contig_dispatch_vector[src_typeid]; + + sycl::event red_ev = + fn(exec_q, iter_nelems, red_nelems, src_data, dst_data, + iter_src_offset, iter_dst_offset, red_src_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + } + + auto fn = strided_dispatch_vector[src_typeid]; + + std::vector host_task_events{}; + auto iter_red_metadata_packing_triple_ = + dpctl::tensor::offset_utils::device_allocate_and_pack( + exec_q, host_task_events, simplified_iter_shape, + simplified_iter_src_strides, simplified_iter_dst_strides, + simplified_red_shape, simplified_red_src_strides); + auto packed_shapes_strides_owner = + std::move(std::get<0>(iter_red_metadata_packing_triple_)); + const auto ©_metadata_ev = + std::get<2>(iter_red_metadata_packing_triple_); + const py::ssize_t *packed_shapes_and_strides = + packed_shapes_strides_owner.get(); + + const py::ssize_t *iter_shape_and_strides = packed_shapes_and_strides; + const py::ssize_t *red_shape_stride = + packed_shapes_and_strides + 3 * simplified_iter_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto red_ev = + fn(exec_q, dst_nelems, red_nelems, src_data, dst_data, iter_nd, + iter_shape_and_strides, iter_src_offset, iter_dst_offset, + simplified_red_nd, red_shape_stride, red_src_offset, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {red_ev}, packed_shapes_strides_owner); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, red_ev); +} + +extern void init_reduction_functions(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/sum.cpp b/dpctl_ext/tensor/libtensor/source/reductions/sum.cpp new file mode 100644 index 00000000000..d7142477750 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/sum.cpp @@ -0,0 +1,463 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForSumReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForSumReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct SumOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +void populate_sum_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(sum_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(sum_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t sum_atomic_support_vector[td_ns::num_types]; + +void populate_sum_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::SumAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(sum_atomic_support_vector); +} + +} // namespace impl + +void init_sum(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_sum_over_axis_dispatch_tables; + populate_sum_over_axis_dispatch_tables(); + using impl::sum_over_axis0_contig_atomic_dispatch_table; + using impl::sum_over_axis0_contig_temps_dispatch_table; + using impl::sum_over_axis1_contig_atomic_dispatch_table; + using impl::sum_over_axis1_contig_temps_dispatch_table; + using impl::sum_over_axis_strided_atomic_dispatch_table; + using impl::sum_over_axis_strided_temps_dispatch_table; + + using impl::populate_sum_atomic_support_dispatch_vector; + populate_sum_atomic_support_dispatch_vector(); + using impl::sum_atomic_support_vector; + + auto sum_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis0_contig_atomic_dispatch_table, + sum_over_axis1_contig_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + sum_over_axis0_contig_temps_dispatch_table, + sum_over_axis1_contig_temps_dispatch_table, + sum_atomic_support_vector); + }; + m.def("_sum_over_axis", sum_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sum_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + sum_atomic_support_vector); + }; + m.def("_sum_over_axis_dtype_supported", sum_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/sum.hpp b/dpctl_ext/tensor/libtensor/source/reductions/sum.hpp new file mode 100644 index 00000000000..08add902a04 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/sum.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_sum(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/tensor_reductions.cpp b/dpctl_ext/tensor/libtensor/source/tensor_reductions.cpp new file mode 100644 index 00000000000..6e6a24f7b93 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/tensor_reductions.cpp @@ -0,0 +1,43 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include + +#include "reductions/reduction_common.hpp" + +PYBIND11_MODULE(_tensor_reductions_impl, m) +{ + dpctl::tensor::py_internal::init_reduction_functions(m); +} diff --git a/dpnp/dpnp_iface_counting.py b/dpnp/dpnp_iface_counting.py index a4b85aa8529..a8ebafbcead 100644 --- a/dpnp/dpnp_iface_counting.py +++ b/dpnp/dpnp_iface_counting.py @@ -39,8 +39,9 @@ """ -import dpctl.tensor as dpt - +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt import dpnp diff --git a/dpnp/dpnp_iface_logic.py b/dpnp/dpnp_iface_logic.py index 3e3501b14c7..a81416a28e4 100644 --- a/dpnp/dpnp_iface_logic.py +++ b/dpnp/dpnp_iface_logic.py @@ -44,14 +44,13 @@ # pylint: disable=no-name-in-module -import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti import dpctl.utils as dpu import numpy # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor as dpt import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc @@ -1276,7 +1275,7 @@ def isin( usm_element = dpnp.get_usm_ndarray(element) usm_test = dpnp.get_usm_ndarray(test_elements) return dpnp_array._create_from_usm_ndarray( - dpt_ext.isin( + dpt.isin( usm_element, usm_test, invert=invert, diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 7eb44f79ae3..b5afd9523d6 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -428,7 +428,9 @@ def _get_first_nan_index(usm_a): if first_nan is not None: # all NaNs are collapsed, so need to put a count of all NaNs # at the last index - dpt.sum(usm_res.counts[first_nan:], out=usm_res.counts[first_nan]) + dpt_ext.sum( + usm_res.counts[first_nan:], out=usm_res.counts[first_nan] + ) result += (usm_res.counts[: first_nan + 1],) else: result += (usm_res.counts,) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 000c343abdb..cdcdd3af92e 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -47,14 +47,13 @@ import builtins import warnings -import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti import dpctl.utils as dpu import numpy # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor as dpt import dpctl_ext.tensor._type_utils as dtu import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi @@ -730,7 +729,7 @@ def clip(a, /, min=None, max=None, *, out=None, order="K", **kwargs): usm_max = None if max is None else dpnp.get_usm_ndarray_or_scalar(max) usm_out = None if out is None else dpnp.get_usm_ndarray(out) - usm_res = dpt_ext.clip(usm_arr, usm_min, usm_max, out=usm_out, order=order) + usm_res = dpt.clip(usm_arr, usm_min, usm_max, out=usm_out, order=order) if out is not None and isinstance(out, dpnp_array): return out return dpnp_array._create_from_usm_ndarray(usm_res) @@ -1126,7 +1125,7 @@ def cumprod(a, axis=None, dtype=None, out=None): return dpnp_wrap_reduction_call( usm_a, out, - dpt_ext.cumulative_prod, + dpt.cumulative_prod, _get_reduction_res_dt(a, dtype), axis=axis, dtype=dtype, @@ -1218,7 +1217,7 @@ def cumsum(a, axis=None, dtype=None, out=None): return dpnp_wrap_reduction_call( usm_a, out, - dpt_ext.cumulative_sum, + dpt.cumulative_sum, _get_reduction_res_dt(a, dtype), axis=axis, dtype=dtype, @@ -1307,7 +1306,7 @@ def cumulative_prod( return dpnp_wrap_reduction_call( dpnp.get_usm_ndarray(x), out, - dpt_ext.cumulative_prod, + dpt.cumulative_prod, _get_reduction_res_dt(x, dtype), axis=axis, dtype=dtype, @@ -1403,7 +1402,7 @@ def cumulative_sum( return dpnp_wrap_reduction_call( dpnp.get_usm_ndarray(x), out, - dpt_ext.cumulative_sum, + dpt.cumulative_sum, _get_reduction_res_dt(x, dtype), axis=axis, dtype=dtype, diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 055aaa999c3..19279f81286 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -39,12 +39,10 @@ """ -import dpctl.tensor as dpt - # pylint: disable=no-name-in-module # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor as dpt import dpctl_ext.tensor._tensor_impl as dti import dpnp @@ -376,13 +374,13 @@ def searchsorted(a, v, side="left", sorter=None): usm_a = dpnp.get_usm_ndarray(a) if dpnp.isscalar(v): - usm_v = dpt_ext.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type) + usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type) else: usm_v = dpnp.get_usm_ndarray(v) usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter) return dpnp_array._create_from_usm_ndarray( - dpt_ext.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter) + dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter) ) @@ -474,7 +472,5 @@ def where(condition, x=None, y=None, /, *, order="K", out=None): usm_condition = dpnp.get_usm_ndarray(condition) usm_out = None if out is None else dpnp.get_usm_ndarray(out) - usm_res = dpt_ext.where( - usm_condition, usm_x, usm_y, order=order, out=usm_out - ) + usm_res = dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out) return dpnp.get_result_array(usm_res, out) diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index 9d3ccc40ecf..75fe215837b 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -1118,7 +1118,7 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True): return dpnp_wrap_reduction_call( usm_a, out, - dpt.max, + dpt_ext.max, a.dtype, axis=axis, keepdims=keepdims, @@ -1395,7 +1395,7 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True): return dpnp_wrap_reduction_call( usm_a, out, - dpt.min, + dpt_ext.min, a.dtype, axis=axis, keepdims=keepdims, diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index 460a0dc80f0..a17c7dfd9d9 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -42,13 +42,11 @@ # pylint: disable=protected-access # pylint: disable=no-name-in-module - -import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor as dpt import dpctl_ext.tensor._type_utils as dtu import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi @@ -935,7 +933,7 @@ def cumlogsumexp( return dpnp_wrap_reduction_call( usm_x, out, - dpt_ext.cumulative_logsumexp, + dpt.cumulative_logsumexp, _get_accumulation_res_dt(x, dtype), axis=axis, dtype=dtype,