Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,21 @@ defmodule EXLA do
The metadata is:

* `:key` - the compilation key for debugging

## Sharding

EXLA supports sharding, which is a way to partition a computation across multiple devices.
There are a number of collective operations that are supported by sharding.

### [`all_gather`](https://openxla.org/stablehlo/spec#all_gather)

#### Options

* `:all_gather_dim` - the dimension along which to gather
* `:replica_groups` - 2D list defining how replicas are grouped
* `:use_global_device_ids` - Whether to use global device IDs (default: `false`)
* `:channel_id` - Channel ID for communication (optional)

"""

@behaviour Nx.Defn.Compiler
Expand Down
21 changes: 21 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,27 @@ defmodule EXLA.Defn do
EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type)
end

## to_operator collective ops

defp to_operator(:all_gather, [%Value{} = tensor, opts], ans, _state) do
all_gather_dim = Keyword.fetch!(opts, :all_gather_dim)
replica_groups = Keyword.fetch!(opts, :replica_groups)
use_global_device_ids = Keyword.get(opts, :use_global_device_ids, false)

# We might want to surface all_gather as an operation that takes a container of operands instead of a single one.
[result] =
Value.all_gather(
[tensor],
expr_to_typespec(ans),
all_gather_dim,
replica_groups,
use_global_device_ids,
opts[:channel_id]
)

result
end

defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do
n = opts[:length]
axis = opts[:axis]
Expand Down
30 changes: 30 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,36 @@ defmodule EXLA.MLIR.Value do
end
end

def all_gather(
[%Value{function: func} | _] = operands,
typespec,
all_gather_dim,
replica_groups,
use_global_device_ids,
channel_id \\ nil
) do
result_types = typespecs_to_mlir_types([typespec])

num_groups = length(replica_groups)
group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0
flat_groups = List.flatten(replica_groups)

attributes = [
all_gather_dim: attr_i64(all_gather_dim),
replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}),
use_global_device_ids: attr_boolean(use_global_device_ids)
]

attributes =
if channel_id do
Keyword.put(attributes, :channel_id, attr_i64(channel_id))
else
attributes
end

op(func, "stablehlo.all_gather", operands, result_types, attributes: attributes)
end

defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do
%{type: lhs_type} = get_typespec(lhs)
%{type: rhs_type} = get_typespec(rhs)
Expand Down
Loading