diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 78c9016361..403c6fbe76 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -215,6 +215,21 @@ defmodule EXLA do The metadata is: * `:key` - the compilation key for debugging + + ## Sharding + + EXLA supports sharding, which is a way to partition a computation across multiple devices. + There are a number of collective operations that are supported by sharding. + + ### [`all_gather`](https://openxla.org/stablehlo/spec#all_gather) + + #### Options + + * `:all_gather_dim` - the dimension along which to gather + * `:replica_groups` - 2D list defining how replicas are grouped + * `:use_global_device_ids` - Whether to use global device IDs (default: `false`) + * `:channel_id` - Channel ID for communication (optional) + """ @behaviour Nx.Defn.Compiler diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 0d29dc7bd9..48593c22aa 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1474,6 +1474,27 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end + ## to_operator collective ops + + defp to_operator(:all_gather, [%Value{} = tensor, opts], ans, _state) do + all_gather_dim = Keyword.fetch!(opts, :all_gather_dim) + replica_groups = Keyword.fetch!(opts, :replica_groups) + use_global_device_ids = Keyword.get(opts, :use_global_device_ids, false) + + # We might want to surface all_gather as an operation that takes a container of operands instead of a single one. + [result] = + Value.all_gather( + [tensor], + expr_to_typespec(ans), + all_gather_dim, + replica_groups, + use_global_device_ids, + opts[:channel_id] + ) + + result + end + defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do n = opts[:length] axis = opts[:axis] diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 393b6d57a8..ec6527e091 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,6 +64,36 @@ defmodule EXLA.MLIR.Value do end end + def all_gather( + [%Value{function: func} | _] = operands, + typespec, + all_gather_dim, + replica_groups, + use_global_device_ids, + channel_id \\ nil + ) do + result_types = typespecs_to_mlir_types([typespec]) + + num_groups = length(replica_groups) + group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0 + flat_groups = List.flatten(replica_groups) + + attributes = [ + all_gather_dim: attr_i64(all_gather_dim), + replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}), + use_global_device_ids: attr_boolean(use_global_device_ids) + ] + + attributes = + if channel_id do + Keyword.put(attributes, :channel_id, attr_i64(channel_id)) + else + attributes + end + + op(func, "stablehlo.all_gather", operands, result_types, attributes: attributes) + end + defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do %{type: lhs_type} = get_typespec(lhs) %{type: rhs_type} = get_typespec(rhs) diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index 0cb30efa24..e11bab0981 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -290,9 +290,6 @@ defmodule EXLA.Defn.ShardingTest do %{0 => [0]} ] - # Output: shard dim 0 on axis 0, dim 1 on axis 1 (like x) - output_shardings = [%{0 => [0], 1 => [1]}] - # For mesh {2, 2}, we have 4 partitions # x: {8, 2} sharded [[0], [1]] -> each partition gets {4, 1} # y: {8, 1} sharded [[0], []] -> each partition gets {4, 1} @@ -333,8 +330,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) expected_mlir = """ @@ -353,10 +349,7 @@ defmodule EXLA.Defn.ShardingTest do assert expected_mlir == result.mlir_module results = - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) assert length(results) == 4 @@ -427,7 +420,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} # Only one sharding spec for two arguments input_shardings = [%{0 => [0]}] - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions # Each partition has 2 inputs @@ -438,8 +430,7 @@ defmodule EXLA.Defn.ShardingTest do fn -> EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) end end @@ -452,16 +443,12 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} # Mesh has 2 axes (0 and 1), but we reference axis 2 input_shardings = [%{0 => [2]}] - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions args = List.duplicate([Nx.iota({4, 2})], 4) assert_raise ArgumentError, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end @@ -471,7 +458,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} # Axis 0 used for both dimensions input_shardings = [%{0 => [0], 1 => [0]}] - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions args = List.duplicate([Nx.iota({4, 1})], 4) @@ -479,10 +465,7 @@ defmodule EXLA.Defn.ShardingTest do assert_raise ArgumentError, ~r/axis 0 was used twice in the same input sharding/, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end @@ -492,16 +475,12 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} # Tensor is rank 2, but sharding spec has 3 dimensions input_shardings = [%{0 => [0], 1 => [1], 2 => []}] - output_shardings = [%{}] # For mesh {2, 2}, we have 4 partitions args = List.duplicate([Nx.iota({4, 2})], 4) assert_raise ArgumentError, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end @@ -511,17 +490,13 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2}} # Tensor is rank 2, but -3 is out of bounds (only -1 and -2 are valid) input_shardings = [%{-3 => [0]}] - output_shardings = [%{}] args = List.duplicate([Nx.iota({4, 2})], 2) assert_raise ArgumentError, ~r/given axis \(-3\) invalid for shape with rank 2/, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end @@ -531,17 +506,13 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2}} # Tensor is rank 2, but axis 3 is out of bounds input_shardings = [%{3 => [0]}] - output_shardings = [%{}] args = List.duplicate([Nx.iota({4, 2})], 2) assert_raise ArgumentError, ~r/given axis \(3\) invalid for shape with rank 2/, fn -> - EXLA.shard_jit(fun, mesh, - input_shardings: input_shardings, - output_shardings: output_shardings - ).(args) + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) end end end @@ -582,9 +553,6 @@ defmodule EXLA.Defn.ShardingTest do # All dimensions replicated input_shardings = [%{}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] - # For mesh {2, 2}, we have 4 partitions # Input fully replicated -> each partition gets full {8, 4} args = List.duplicate([Nx.iota({8, 4})], 4) @@ -592,8 +560,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) assert is_binary(result.mlir_module) @@ -606,9 +573,6 @@ defmodule EXLA.Defn.ShardingTest do # Scalar has no dimensions to shard input_shardings = [%{}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] - # For mesh {2}, we have 2 partitions # Scalar is replicated across all partitions args = List.duplicate([Nx.tensor(5.0)], 2) @@ -616,8 +580,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) assert is_binary(result.mlir_module) @@ -637,9 +600,6 @@ defmodule EXLA.Defn.ShardingTest do %{} ] - # Outputs: both replicated (tuple with two elements) - output_shardings = [%{}, %{}] - # For mesh {2, 2}, we have 4 partitions # x: {8, 4} sharded [[0], [1]] -> each partition gets {4, 2} # y: {8, 4} sharded [[0], []] -> each partition gets {4, 4} @@ -658,8 +618,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) assert is_binary(result.mlir_module) @@ -673,9 +632,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "test_mesh", shape: {2, 2}} input_shardings = [%{0 => [0]}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] - # For mesh {2, 2}, we have 4 partitions # Input sharded [[0], []] -> each partition gets {4, 2} args = List.duplicate([Nx.iota({4, 2})], 4) @@ -683,8 +639,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) mlir = result.mlir_module @@ -702,9 +657,6 @@ defmodule EXLA.Defn.ShardingTest do mesh = %Mesh{name: "mesh", shape: {2, 2}} input_shardings = [%{0 => [0], 1 => [1]}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] - # For mesh {2, 2}, we have 4 partitions # Input sharded [[0], [1]] -> each partition gets {4, 1} args = List.duplicate([Nx.iota({4, 1})], 4) @@ -712,8 +664,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) mlir = result.mlir_module @@ -732,9 +683,6 @@ defmodule EXLA.Defn.ShardingTest do # Use named dimensions instead of indices input_shardings = [%{:batch => [0]}] - # Output: replicated (all-gathered) across all devices - output_shardings = [%{}] - # For mesh {2}, we have 2 partitions # Named dimension :batch should map to first dimension (index 0) args = List.duplicate([Nx.iota({4, 2}, names: [:batch, :features])], 2) @@ -742,8 +690,7 @@ defmodule EXLA.Defn.ShardingTest do result = EXLA.to_mlir_module(fun, args, mesh: mesh, - input_shardings: input_shardings, - output_shardings: output_shardings + input_shardings: input_shardings ) # Should generate valid MLIR with sharding on first dimension @@ -891,4 +838,118 @@ defmodule EXLA.Defn.ShardingTest do end end end + + describe "all_gather" do + @moduletag :multi_device + test "in all dims results in the same tensor in all devices" do + fun = fn x, y -> + Nx.add(x, y) + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 0, replica_groups: [[0]]) + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]]) + end + + mesh = %Mesh{name: "mesh", shape: {2, 2}} + # First arg: 0..15 (8x2), shard dim 0 on mesh axis 0, dim 1 on mesh axis 1 + # Second arg: 100..115 (8x2), same sharding — makes sharded results easy to read + input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0], 1 => [1]}] + + # For mesh {2, 2}, 4 partitions. Each gets {4, 1}. Full 8x2 row-major: [[0,1],[2,3],...,[14,15]]. + # Partition (axis_0, axis_1): (0,0)=rows 0-3 col 0, (0,1)=rows 0-3 col 1, (1,0)=rows 4-7 col 0, (1,1)=rows 4-7 col 1. + # So partition 0 gets (0,0),(1,0),(2,0),(3,0) = 0,2,4,6; partition 1 gets (0,1),(1,1),... = 1,3,5,7; etc. + args = [ + # partition 0: rows 0–3 col 0 -> 0,2,4,6 and 100,102,104,106 + [Nx.tensor([[0], [2], [4], [6]]), Nx.tensor([[100], [102], [104], [106]])], + # partition 1: rows 0–3 col 1 -> 1,3,5,7 and 101,103,105,107 + [Nx.tensor([[1], [3], [5], [7]]), Nx.tensor([[101], [103], [105], [107]])], + # partition 2: rows 4–7 col 0 -> 8,10,12,14 and 108,110,112,114 + [Nx.tensor([[8], [10], [12], [14]]), Nx.tensor([[108], [110], [112], [114]])], + # partition 3: rows 4–7 col 1 -> 9,11,13,15 and 109,111,113,115 + [Nx.tensor([[9], [11], [13], [15]]), Nx.tensor([[109], [111], [113], [115]])] + ] + + result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings) + + expected_mlir = """ + module { + sdy.mesh @mesh = <["axis_0"=2, "axis_1"=2]> + func.func public @main(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}, %arg1: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}) -> tensor<8x2xi32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<8x2xi32> + %1 = "stablehlo.all_gather"(%0) <{all_gather_dim = 0 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> + %2 = "stablehlo.all_gather"(%1) <{all_gather_dim = 1 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %2 : tensor<8x2xi32> + } + } + """ + + assert expected_mlir == result.mlir_module + + results = EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) + + assert length(results) == 4 + + # After all_gather: full first arg 0..15 + full second 100..115 -> 100,102,...,130 + expected_result = + Nx.tensor([ + [100, 102], + [104, 106], + [108, 110], + [112, 114], + [116, 118], + [120, 122], + [124, 126], + [128, 130] + ]) + + Enum.with_index(results, fn result, i -> + assert_equal(result, expected_result) + assert result.data.buffer.device_id == i + end) + end + + @moduletag :multi_device + test "can return partially sharded results" do + fun = fn x, y -> + x + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]]) + |> Nx.add(y) + end + + mesh = %Mesh{name: "mesh", shape: {2, 2}} + # Inputs sharded on both axes + input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0]}] + + # Logical x: 8x2, y: 8x2. Each partition gets {4, 1} of x and {4, 2} of y + args = [ + [ + Nx.tensor([[0], [1], [2], [3]]), + Nx.tensor([[100, 101], [102, 103], [104, 105], [106, 107]]) + ], + [ + Nx.tensor([[4], [5], [6], [7]]), + Nx.tensor([[100, 101], [102, 103], [104, 105], [106, 107]]) + ], + [ + Nx.tensor([[8], [9], [10], [11]]), + Nx.tensor([[110, 111], [112, 113], [114, 115], [116, 117]]) + ], + [ + Nx.tensor([[12], [13], [14], [15]]), + Nx.tensor([[110, 111], [112, 113], [114, 115], [116, 117]]) + ] + ] + + assert [result0, result1, result2, result3] = + EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) + + # After gathering, devices 0 and 1 have the same data as each other, likewise for devices 2 and 3 + assert_equal(result0, Nx.tensor([[100, 105], [103, 108], [106, 111], [109, 114]])) + assert result0.data.buffer.device_id == 0 + assert_equal(result0, Nx.tensor([[100, 105], [103, 108], [106, 111], [109, 114]])) + assert result1.data.buffer.device_id == 1 + assert_equal(result2, Nx.tensor([[118, 123], [121, 126], [124, 129], [127, 132]])) + assert result2.data.buffer.device_id == 2 + assert_equal(result3, Nx.tensor([[118, 123], [121, 126], [124, 129], [127, 132]])) + assert result3.data.buffer.device_id == 3 + end + end end diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index cae7c94997..974d558b0d 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -522,11 +522,13 @@ defmodule Nx.BinaryBackend do right_batch_item_bits = right_batch_item_length * right_size <<_::bitstring-size(^left_offset_bits), - left_batch_item_binary::bitstring-size(^left_batch_item_bits), _::bitstring>> = + left_batch_item_binary::bitstring-size(^left_batch_item_bits), + _::bitstring>> = left_binary <<_::bitstring-size(^right_offset_bits), - right_batch_item_binary::bitstring-size(^right_batch_item_bits), _::bitstring>> = + right_batch_item_binary::bitstring-size(^right_batch_item_bits), + _::bitstring>> = right_binary bin_dot( @@ -1756,7 +1758,8 @@ defmodule Nx.BinaryBackend do before_slice_size = current - previous <> = + current_bitstring::bitstring-size(^target_chunk), + to_traverse::bitstring>> = to_traverse updated_elements = diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 899b430da4..e753b9aebb 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1166,6 +1166,11 @@ defmodule Nx.Defn.Expr do expr(out, context, :gather, [tensor, indices, opts]) end + def all_gather(tensor, opts) do + {[expr], context} = to_exprs([tensor]) + expr(expr, context, :all_gather, [expr, opts]) + end + @impl true def reverse(out, tensor, axes) do tensor = to_expr(tensor) diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index ab913ab61f..4530d79e16 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1669,6 +1669,19 @@ defmodule Nx.Defn.Kernel do end end + @doc """ + Gathers tensors along a specified axis across an `Nx.Mesh`. + + Requires a backend that supports collective operations. + + ## Options + + Refer to the chosen backend/compiler documentation for supported options. + """ + def all_gather(tensor, opts) do + Nx.Defn.Expr.all_gather(tensor, opts) + end + @definitions (Module.definitions_in(__MODULE__, :def) ++ Module.definitions_in(__MODULE__, :defmacro)) -- [ diff --git a/nx/lib/nx/floating.ex b/nx/lib/nx/floating.ex index 4ded248c91..54cec81378 100644 --- a/nx/lib/nx/floating.ex +++ b/nx/lib/nx/floating.ex @@ -209,8 +209,16 @@ defmodule Nx.Floating do # E4M3FN: 1 sign, 4 exponent (bias 7), 3 mantissa # Max value: 448.0 (0x7E), Min value: -448.0 (0xFE) def dump_f8_e4m3fn(0), do: <<0b0000_0000>> - def dump_f8_e4m3fn(+0.0), do: <<0b0000_0000>> - def dump_f8_e4m3fn(-0.0), do: <<0b1000_0000>> + + if +0.0 === -0.0 do + # OTP versions <= 28.0 have a bug where +0.0 === -0.0, + # so we need to special-case it to avoid compiler errors + # related to the +0.0 clause shadowing the -0.0 clause + def dump_f8_e4m3fn(x) when x == 0.0, do: <<0b0000_0000>> + else + def dump_f8_e4m3fn(+0.0), do: <<0b0000_0000>> + def dump_f8_e4m3fn(-0.0), do: <<0b1000_0000>> + end def dump_f8_e4m3fn(x) when is_number(x) do # Clamp to E4M3FN range and convert diff --git a/nx/test/nx/defn/composite_test.exs b/nx/test/nx/defn/composite_test.exs index 2618e29170..6dc8fc1d52 100644 --- a/nx/test/nx/defn/composite_test.exs +++ b/nx/test/nx/defn/composite_test.exs @@ -24,7 +24,8 @@ defmodule Nx.Defn.CompositeTest do Nx.tensor(1), Nx.tensor(3, type: {:c, 64}), Nx.tensor(4, type: {:c, 64}) - }, Nx.tensor(2, type: {:c, 64})} == + }, + Nx.tensor(2, type: {:c, 64})} == Composite.traverse( {1, Complex.new(2), Nx.tensor(3)}, 0, diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index 62993b07a3..d15951ce71 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -2952,4 +2952,30 @@ defmodule Nx.DefnTest do assert vectorized_metadata_tuple(x, z) == vec_nonvec_result end end + + describe "sharding" do + defn all_gather_test(tensor) do + Nx.Defn.Kernel.all_gather(tensor, all_gather_dim: 0, replica_groups: [[0]]) + end + + test "all_gather produces correct expr format for compiler" do + # Uses debug_expr to inspect the expression without compiling. + # Guarantees the format passed to compilers (e.g. EXLA) stays stable. + assert %T{data: %Expr{op: :all_gather, args: [tensor, opts]}} = + Nx.Defn.debug_expr(&all_gather_test/1).(Nx.tensor([1, 2, 3, 4])) + + assert %T{data: %Expr{op: :parameter, args: [0]}} = tensor + + # Compilers expect opts with :all_gather_dim and :replica_groups + assert opts[:all_gather_dim] == 0 + assert opts[:replica_groups] == [[0]] + end + + @tag compiler: Evaluator + test "all_gather works" do + assert_raise UndefinedFunctionError, fn -> + all_gather_test(Nx.tensor([1, 2, 3, 4])) + end + end + end end diff --git a/nx/test/nx/floating_test.exs b/nx/test/nx/floating_test.exs index 7e0a088fe0..4720c94d0d 100644 --- a/nx/test/nx/floating_test.exs +++ b/nx/test/nx/floating_test.exs @@ -156,7 +156,7 @@ defmodule Nx.FloatingTest do {0x7F, :nan}, # Negative values (sign bit = 1) # Denormalized (exponent = 0): value = -mantissa/8 * 2^-6 - {0x80, -0.0}, + if(+0.0 === -0.0, do: {0x00, -0.0}, else: {0x80, -0.0}), {0x81, -0.001953125}, {0x82, -0.00390625}, {0x83, -0.005859375}, @@ -349,7 +349,12 @@ defmodule Nx.FloatingTest do test "pretty printing" do # Zeroes assert Nx.tensor([0.0], type: :f8) |> inspect() =~ "[0.0]" - assert Nx.tensor([-0.0], type: :f8) |> inspect() =~ "[-0.0]" + + if +0.0 === -0.0 do + assert Nx.tensor([-0.0], type: :f8) |> inspect() =~ "[0.0]" + else + assert Nx.tensor([-0.0], type: :f8) |> inspect() =~ "[-0.0]" + end # Infinity assert Nx.tensor([:infinity], type: :f8) |> inspect() =~ "[Inf]" @@ -399,7 +404,12 @@ defmodule Nx.FloatingTest do test "pretty printing" do # Zeroes assert Nx.tensor([0.0], type: :bf16) |> inspect() =~ "[0.0]" - assert Nx.tensor([-0.0], type: :bf16) |> inspect() =~ "[-0.0]" + + if +0.0 === -0.0 do + assert Nx.tensor([-0.0], type: :bf16) |> inspect() =~ "[0.0]" + else + assert Nx.tensor([-0.0], type: :bf16) |> inspect() =~ "[-0.0]" + end # Infinity assert Nx.tensor([:infinity], type: :bf16) |> inspect() =~ "[Inf]"