Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backends/aoti/slim/core/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
/// @param ptr Pointer to device memory to free.
static void free(void* ptr) {
// Get the current stream for the current device
// Currently all cuda slimtensors should be on the same device same stream,
// so we can just use the stream on current device.
// TODO(gasoonjia): add cuda stream as a member of MaybeOwningStorage to
// support multiple devices.
auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1);
if (stream_result.ok()) {
ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get()));
} else {
// Fallback to synchronous free if we can't get the stream
ET_CUDA_LOG_WARN(cudaFree(ptr));
}
ET_CHECK_MSG(stream_result.ok(), "Failed to get current CUDA stream");
ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get()));
}

/// Copies memory between CPU and CUDA or CUDA and CUDA.
Expand Down
26 changes: 16 additions & 10 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,17 +423,17 @@ class ET_EXPERIMENTAL CudaBackend final

const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);

// Synchronize CUDA stream to ensure kernel execution is complete
// before accessing output data (either for copy or skip-copy path)
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
cudaError_t sync_err = cudaStreamSynchronize(cuda_stream);
ET_CHECK_OR_RETURN_ERROR(
sync_err == cudaSuccess,
Internal,
"cudaStreamSynchronize failed: %s",
cudaGetErrorString(sync_err));

if (copy_outputs) {
// Synchronize CUDA stream before D2H copy. This is required because
// cudaMemcpy is not stream-ordered and needs the kernel to complete.
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
cudaError_t sync_err = cudaStreamSynchronize(cuda_stream);
ET_CHECK_OR_RETURN_ERROR(
sync_err == cudaSuccess,
Internal,
"cudaStreamSynchronize failed: %s",
cudaGetErrorString(sync_err));

// Deep copy GPU SlimTensor results back to CPU ETensors
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor());
Expand All @@ -448,6 +448,12 @@ class ET_EXPERIMENTAL CudaBackend final
// Skip-copy optimization: point ETensor directly to GPU data.
// The caller is responsible for handling GPU data directly.
//
// No cudaStreamSynchronize needed here because:
// 1. All operations (kernel, allocations, frees) are on the same stream
// 2. cudaFreeAsync is stream-ordered, so CUDA guarantees the kernel
// completes before any memory is freed
// 3. The next execution's operations will also be ordered on this stream
//
// Lifetime management: We cache the newly created GPU tensors and delete
// the previous round's tensors, since they are no longer needed.
{
Expand Down
11 changes: 10 additions & 1 deletion extension/asr/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ AsrRunner::AsrRunner(
}
}

AsrRunner::~AsrRunner() = default;

bool AsrRunner::is_loaded() const {
return module_ && encoder_method_loaded_ && decoder_method_loaded_ &&
(!sampler_method_present_ || sampler_method_loaded_) && tokenizer_ &&
Expand Down Expand Up @@ -121,13 +123,20 @@ Error AsrRunner::load() {
#ifdef CUDA_AVAILABLE
// Skip copying outputs to CPU. When a sampler exists, keep both encoder and
// decoder outputs on device and pass decoder logits directly into sampler.
executorch::runtime::BackendOptions<1> backend_options;
// The backend will automatically create a shared CUDA stream for all methods
// when skip-copy is enabled to ensure proper ordering.
executorch::runtime::BackendOptions<2> backend_options;
std::string skip_methods = kEncoderMethodName;
if (sampler_method_present_) {
skip_methods.append(",").append(kDecoderMethodName);
}
ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option(
"skip_copy_output_to_cpu_for_method", skip_methods.c_str()));
// Enable shared CUDA stream for all methods when skip-copy is used.
// This ensures proper ordering between encoder/decoder/sampler outputs.
ET_CHECK_OK_OR_RETURN_ERROR(
backend_options.set_option("use_shared_cuda_stream", true));

const auto opt_err =
executorch::runtime::set_option("CudaBackend", backend_options.view());
if (opt_err != ::executorch::runtime::Error::Ok) {
Expand Down
2 changes: 2 additions & 0 deletions extension/asr/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class ET_EXPERIMENTAL AsrRunner {
std::optional<std::string> data_path,
const std::string& tokenizer_path);

~AsrRunner();

/**
* Returns true when the module and tokenizer are ready for inference.
*/
Expand Down