Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/op-attrs/src/op-attrs/ops/element_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees(
ElementUnaryAttrs const &attrs,
ParallelTensorDimDegrees const &input_degrees) {
ASSERT(input_degrees.sum_degree.value == 1);
ASSERT(input_degrees.discard_copy_degree.value == 1);

return input_degrees;
}
Expand Down
8 changes: 0 additions & 8 deletions lib/op-attrs/test/src/op-attrs/ops/element_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,5 @@ TEST_SUITE(FF_TEST_SUITE) {
SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p)));
}

SUBCASE("discard copy degree > 1") {
positive_int degree = 2_p;

CHECK_THROWS(get_output_shape(
attrs,
make_input(
SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p)));
}
}
}
19 changes: 11 additions & 8 deletions lib/realm-execution/include/realm-execution/realm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,18 @@ struct RealmContext {
int priority = 0);
///\}

/** \name Data movement */
/** \name Data movement and reduction */
///\{
Realm::Event issue_copy(ParallelTensorShape const &src_shape,
Realm::RegionInstance src_inst,
ParallelTensorShape const &dst_shape,
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on = Realm::Event::NO_EVENT,
int priority = 0);
Realm::Event
issue_copy(ParallelTensorShape const &src_shape,
Realm::RegionInstance src_inst,
ParallelTensorShape const &dst_shape,
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on = Realm::Event::NO_EVENT,
int priority = 0,
std::optional<Realm::ReductionOpID> redop_id = std::nullopt,
bool exclusive = false);
///\}

/** \name Instance management */
Expand Down
154 changes: 154 additions & 0 deletions lib/realm-execution/include/realm-execution/tasks/realm_reduction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H
#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H
#include "op-attrs/datatype.dtg.h"
#include <realm.h>

namespace FlexFlow {

/**
* \brief Realm Sum Reduction for Float
* \see https://legion.stanford.edu/tutorial/realm/reductions.html
*/
struct SumReductionFloat {
using LHS = float;
using RHS = float;

/** \brief Identity element for addition (0.0) */
static constexpr RHS identity = 0.0f;

/**
* \brief Apply reduction: lhs += rhs
* \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop
* \param lhs Left-hand side accumulator (modified in place)
* \param rhs Value to add
*/
template <bool EXCLUSIVE>
static void apply(LHS &lhs, RHS rhs) {
if (EXCLUSIVE) {
lhs += rhs;
} else {
// Atomic float add via CAS loop
union {
float f;
int i;
} old_val, new_val;
do {
old_val.f = lhs;
new_val.f = old_val.f + rhs;
} while (
!__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i));
}
}

/**
* \brief Fold two RHS values: rhs1 += rhs2
* \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop
* \param rhs1 Accumulator (modified in place)
* \param rhs2 Value to fold in
*/
template <bool EXCLUSIVE>
static void fold(RHS &rhs1, RHS rhs2) {
if (EXCLUSIVE) {
rhs1 += rhs2;
} else {
// Atomic float add via CAS loop
union {
float f;
int i;
} old_val, new_val;
do {
old_val.f = rhs1;
new_val.f = old_val.f + rhs2;
} while (
!__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i));
}
}
};

/**
* \brief Realm Sum Reduction for Double
* \see https://legion.stanford.edu/tutorial/realm/reductions.html
*/
struct SumReductionDouble {
using LHS = double;
using RHS = double;

/** \brief Identity element for addition (0.0) */
static constexpr RHS identity = 0.0;

/**
* \brief Apply reduction: lhs += rhs
* \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop
* \param lhs Left-hand side accumulator (modified in place)
* \param rhs Value to add
*/
template <bool EXCLUSIVE>
static void apply(LHS &lhs, RHS rhs) {
if (EXCLUSIVE) {
lhs += rhs;
} else {
// Atomic double add via CAS loop using long long reinterpretation
union {
double d;
long long i;
} old_val, new_val;
do {
old_val.d = lhs;
new_val.d = old_val.d + rhs;
} while (!__sync_bool_compare_and_swap(
(long long *)&lhs, old_val.i, new_val.i));
}
}

/**
* \brief Fold two RHS values: rhs1 += rhs2
* \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop
* \param rhs1 Accumulator (modified in place)
* \param rhs2 Value to fold in
*/
template <bool EXCLUSIVE>
static void fold(RHS &rhs1, RHS rhs2) {
if (EXCLUSIVE) {
rhs1 += rhs2;
} else {
// Atomic double add via CAS loop using long long reinterpretation
union {
double d;
long long i;
} old_val, new_val;
do {
old_val.d = rhs1;
new_val.d = old_val.d + rhs2;
} while (!__sync_bool_compare_and_swap(
(long long *)&rhs1, old_val.i, new_val.i));
}
}
};

/**
* \brief Reduction op IDs for sum reductions
* \warning These IDs must not conflict with other registered reduction ops
*/
enum SumReductionOpIDs {
REDOP_SUM_FLOAT = 1, ///< Sum reduction op ID for float
REDOP_SUM_DOUBLE = 2, ///< Sum reduction op ID for double
};

/**
* \brief Returns the Realm reduction op ID for a sum reduction over the given datatype
* \param dtype The datatype to look up
* \return The corresponding Realm::ReductionOpID
* \throws PANIC if no sum reduction is registered for the given datatype
*/
inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) {
switch (dtype) {
case DataType::FLOAT:
return REDOP_SUM_FLOAT;
case DataType::DOUBLE:
return REDOP_SUM_DOUBLE;
default:
PANIC("no sum reduction registered for datatype {}", dtype);
}
}
} // namespace FlexFlow
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization(
std::unordered_map<DynamicNodeInvocation,
DeviceSpecificPtr<PerDeviceOpState> *>
device_state_map;
std::vector<Realm::Event> completion_events;
for (DynamicNodeInvocation const &invocation : dg.invocations) {
Realm::Processor target_proc = ctx.map_device_coord_to_processor(
assert_unwrap(invocation.node_attrs.device_coord));
Expand All @@ -56,14 +57,17 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization(
precondition);

if (completion_event.has_value()) {
completion_events.push_back(completion_event.value());
device_state_map.insert(std::pair{invocation, device_state_ptr});
} else {
// Task doesn't require initialization, clean up and don't store result
delete device_state_ptr;
}
}

ctx.get_outstanding_events().wait();
// wait for all init tasks — direct write to *result_ptr happens
// before each init task event fires so result is ready after this
Realm::Event::merge_events(completion_events).wait();

auto deref = [](DeviceSpecificPtr<PerDeviceOpState> *const &p) { return *p; };
std::unordered_map<DynamicNodeInvocation, DeviceSpecificPtr<PerDeviceOpState>>
Expand Down
49 changes: 49 additions & 0 deletions lib/realm-execution/src/realm-execution/pcg_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "realm-execution/instance_allocation.h"
#include "realm-execution/realm_context.h"
#include "realm-execution/tasks/impl/op_task.h"
#include "realm-execution/tasks/realm_reduction.h"
#include "realm-execution/tensor_instance_backing.h"
#include "task-spec/dynamic_graph/copy_insertion.h"
#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h"
Expand Down Expand Up @@ -215,13 +216,61 @@ static Realm::Event spawn_dynamic_node_invocation(
precondition);
};

// issue_replicate_bwd lambda
auto issue_replicate_bwd = [&]() {
std::optional<DynamicValueAttrs> output_grad_opt;
for (auto const &[slot, value] : invocation.inputs) {
if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) {
output_grad_opt = value;
}
}
DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt);
DynamicValueAttrs input_grad = get_only(invocation.outputs).second;
Realm::RegionInstance dst_inst =
tensor_instance_backing.backing.at(input_grad).first;

Realm::ReductionOpID redop_id = get_sum_reduction_op_id(
assert_unwrap(output_grad.parallel_tensor_shape).data_type);

// chain reductions sequentially to avoid write races on dst
Realm::Event e = precondition;
for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) {
DynamicValueAttrs replica_key = output_grad;
replica_key.mapping =
bidict<ParallelTensorSpaceCoordinate, MachineSpaceCoordinate>{{p, m}};
replica_key.shard_coord = p;

Realm::RegionInstance src_inst =
tensor_instance_backing.backing.at(replica_key).first;

e = ctx.issue_copy(assert_unwrap(output_grad.parallel_tensor_shape),
src_inst,
assert_unwrap(input_grad.parallel_tensor_shape),
dst_inst,
Realm::ProfilingRequestSet{},
e,
0,
redop_id,
false);
}
return e;
};

TrainingOperationAttrs op_attrs =
assert_unwrap(invocation.node_attrs.op_attrs);
return op_attrs.visit<Realm::Event>(overload{
[&](PCGOperatorAttrs const &pcg_op_attrs) {
return pcg_op_attrs.visit<Realm::Event>(overload{
[&](InputAttrs const &) { return Realm::Event::NO_EVENT; },
[&](WeightAttrs const &) { return Realm::Event::NO_EVENT; },
[&](ReplicateAttrs const &) {
if (invocation.node_attrs.task_type.has_value() &&
invocation.node_attrs.task_type.value() ==
DynamicTaskType::BWD) {
return issue_replicate_bwd();
}
return issue_copy(); // forward
},
[&](auto const &) { return spawn_task(); },
});
},
Expand Down
9 changes: 8 additions & 1 deletion lib/realm-execution/src/realm-execution/realm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ Realm::Event
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on,
int priority) {
int priority,
std::optional<Realm::ReductionOpID> redop_id,
bool exclusive) {
TensorShape src_piece_shape = get_piece_shape(src_shape);
TensorShape dst_piece_shape = get_piece_shape(dst_shape);
ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match
Expand All @@ -183,6 +185,11 @@ Realm::Event
size_of_datatype(src_piece_shape.data_type).int_from_positive_int()),
/*subfield_offset=*/0);

// set reduction op on dst field if provided
if (redop_id.has_value()) {
dst_field.set_redop(redop_id.value(), /*is_fold=*/false, exclusive);
}

Realm::Event result;
switch (src_piece_shape.dims.ff_ordered.num_dims()) {
#if REALM_MAX_DIM >= 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@ void per_device_op_state_init_task_body(void const *args,
result_state, ctx.get_current_device_idx())};
DeviceSpecificPtr<PerDeviceOpState> result_device_specific{
ctx.get_current_device_idx(), result_state_ptr};
spawn_per_device_op_state_init_return_task(ctx,
task_args.origin_proc,
result_device_specific,
task_args.origin_result_ptr,
Realm::Event::NO_EVENT);

// replace spawn_per_device_op_state_init_return_task with:
// NOTE: SM/TODO: direct write assumes single-node shared address space
// For multi-node, replace with UserEvent trigger pattern
*task_args.origin_result_ptr = result_device_specific;

// spawn_per_device_op_state_init_return_task(ctx,
// task_args.origin_proc,
// result_device_specific,
// task_args.origin_result_ptr,
// Realm::Event::NO_EVENT);
}

std::optional<Realm::Event> spawn_per_device_op_state_init_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "realm-execution/tasks/impl/op_task.h"
#include "realm-execution/tasks/impl/per_device_op_state_init_return_task.h"
#include "realm-execution/tasks/impl/per_device_op_state_init_task.h"
#include "realm-execution/tasks/realm_reduction.h"
#include "realm-execution/tasks/task_id_t.h"
#include "utils/exception.h"

Expand All @@ -30,9 +31,18 @@ Realm::Event register_task(Realm::Processor::Kind target_kind,
Realm::ProfilingRequestSet());
}

static void register_reductions() {
// register sum reduction ops
Realm::Runtime rt = Realm::Runtime::get_runtime();
rt.register_reduction<SumReductionFloat>(REDOP_SUM_FLOAT);
rt.register_reduction<SumReductionDouble>(REDOP_SUM_DOUBLE);
// register_reduction is synchronous — no event returned
}

Realm::Event register_all_tasks() {
std::vector<Realm::Event> pending_registrations;

register_reductions();
std::vector<task_id_t> init_task_ids = {
// Init tasks
task_id_t::BATCHNORM_INIT_TASK_ID,
Expand Down
Loading