Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions include/API/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "API/Capabilities.h"
#include "API/CommandBuffer.h"
#include "API/RenderPass.h"
#include "API/ShaderBindingTable.h"
#include "API/Texture.h"

#include "Support/Pipeline.h"
Expand Down Expand Up @@ -83,6 +84,21 @@ struct ShaderContainer {
llvm::SmallVector<SpecializationConstant> SpecializationConstants;
};

struct RayTracingShader {
Stages Stage;
std::string EntryPoint;
};

struct RayTracingPipelineCreateDesc {
// All RT shaders are compiled into a single DXIL library; every entry in
// `Shaders` references this same blob via the backend's library-loading
// path.
const llvm::MemoryBuffer *Library = nullptr;
llvm::SmallVector<RayTracingShader> Shaders;
llvm::SmallVector<HitGroup> HitGroups;
RayTracingPipelineConfig Config;
};

struct TraditionalRasterPipelineCreateDesc {
llvm::SmallVector<InputLayoutDesc> InputLayout;
llvm::SmallVector<Format> RTFormats;
Expand Down Expand Up @@ -120,6 +136,12 @@ struct TraditionalRasterPipelineCreateDesc {
case Stages::Compute:
case Stages::Amplification:
case Stages::Mesh:
case Stages::RayGeneration:
case Stages::Miss:
case Stages::ClosestHit:
case Stages::AnyHit:
case Stages::Intersection:
case Stages::Callable:
llvm_unreachable("Not a traditional raster pipeline stage.");
}
}
Expand Down Expand Up @@ -211,6 +233,14 @@ class Device {
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
const TraditionalRasterPipelineCreateDesc &Desc) = 0;

virtual llvm::Expected<std::unique_ptr<PipelineState>>
createPipelineRT(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
const RayTracingPipelineCreateDesc &Desc) = 0;

virtual llvm::Expected<std::unique_ptr<ShaderBindingTable>>
createShaderBindingTable(const PipelineState &PSO,
const ShaderBindingTableDesc &Desc) = 0;

virtual llvm::Expected<std::unique_ptr<Fence>>
createFence(llvm::StringRef Name) = 0;

Expand Down Expand Up @@ -281,6 +311,21 @@ createBufferWithData(Device &Dev, std::string Name,
size_t SizeInBytes, ComputeEncoder *Encoder,
std::unique_ptr<offloadtest::Buffer> *OutUploadBuffer);

// Builds all BLAS / TLAS objects defined in `P.AccelStructs` using the
// supplied compute encoder. Uploads each BLAS's vertex/index data, creates
// the BLASBuildRequest + AS object via `Dev`, and records the build via
// `Enc.batchBuildAS`. Then resolves TLAS instance references and records the
// TLAS batch with a separate call (so the AS-build-write barrier between
// BLAS and TLAS is automatic).
//
// Built AS objects are pushed to `OutAS` (in declaration order: BLASes first,
// then TLASes). Vertex/index buffers used as build inputs are pushed to
// `OutInputBuffers`; both must outlive command-buffer submission.
llvm::Error buildPipelineAccelerationStructures(
Device &Dev, ComputeEncoder &Enc, Pipeline &P,
llvm::SmallVectorImpl<std::unique_ptr<AccelerationStructure>> &OutAS,
llvm::SmallVectorImpl<std::unique_ptr<Buffer>> &OutInputBuffers);

} // namespace offloadtest

#endif // OFFLOADTEST_API_DEVICE_H
35 changes: 35 additions & 0 deletions include/API/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#include "API/API.h"

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"

Expand All @@ -21,6 +23,20 @@ namespace offloadtest {

class Buffer;
class PipelineState;
class AccelerationStructure;
class ShaderBindingTable;
struct BLASBuildRequest;
struct TLASBuildRequest;

/// One acceleration-structure build to record in a batch. The caller is
/// responsible for ensuring no item in a batch has a memory dependency on
/// another (e.g. a TLAS that reads a BLAS being built in the same batch must
/// be in a separate batch — that barrier is inserted between batchBuildAS
/// calls automatically).
struct ASBuildItem {
AccelerationStructure *AS;
llvm::PointerUnion<const BLASBuildRequest *, const TLASBuildRequest *> Req;
};

/// Base class for all command encoders. An encoder records commands into a
/// command buffer. Call endEncoding() when done recording. Barriers are
Expand Down Expand Up @@ -82,6 +98,25 @@ class ComputeEncoder : public CommandEncoder {
virtual llvm::Error copyBufferToBuffer(Buffer &Src, size_t SrcOffset,
Buffer &Dst, size_t DstOffset,
size_t Size) = 0;

/// Build a batch of acceleration structures in a single barrier slot. All
/// items in `Items` must be independent — no item may depend on another's
/// build output. Backends may issue this as one native batch call (Vulkan)
/// or as a sequence of single-AS calls without intermediate barriers (DX12,
/// Metal). A barrier covering AS-build writes is implicitly emitted before
/// any subsequent command that reads from the freshly-built structures.
virtual llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) = 0;

/// Trace rays from a RayTracing pipeline. \p PSO must have been created via
/// Device::createPipelineRT and \p SBT via Device::createShaderBindingTable
/// on that same PSO. \p Width, \p Height, \p Depth are the dispatch
/// dimensions passed through to the backend's DispatchRays equivalent
/// (D3D12 DispatchRays, Vulkan vkCmdTraceRaysKHR, Metal compute dispatch
/// after metal_irconverter lowering).
virtual llvm::Error dispatchRays(const PipelineState &PSO,
const ShaderBindingTable &SBT,
uint32_t Width, uint32_t Height,
uint32_t Depth) = 0;
};

struct Viewport {
Expand Down
78 changes: 78 additions & 0 deletions include/API/ShaderBindingTable.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//===- ShaderBindingTable.h - Offload RT Shader Binding Table -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef OFFLOADTEST_API_SHADERBINDINGTABLE_H
#define OFFLOADTEST_API_SHADERBINDINGTABLE_H

#include "API/API.h"

#include <cstdint>

namespace offloadtest {

struct ShaderBindingTableDesc;

/// Runtime shader binding table built from a RayTracing PipelineState plus a
/// ShaderBindingTableDesc. Concrete subclasses (one per backend) hold the
/// device-side records and any address ranges needed by the backend's
/// DispatchRays call.
class ShaderBindingTable {
GPUAPI API;

public:
virtual ~ShaderBindingTable();
ShaderBindingTable(const ShaderBindingTable &) = delete;
ShaderBindingTable &operator=(const ShaderBindingTable &) = delete;

GPUAPI getAPI() const { return API; }

protected:
explicit ShaderBindingTable(GPUAPI API) : API(API) {}
};

/// Per-region SBT layout numbers.
///
/// Every backend lays an SBT out as four concatenated regions (raygen, miss,
/// hit-group, callable). Within a region every record is
/// `[shader-identifier][LocalRootData][padding-to-stride]`, where stride is
/// `align(identifierSize + max-LocalRootData-in-region, RecordAlign)` and
/// the region itself is aligned to `BaseAlign`. The numbers match the
/// alignment rules of both D3D12
/// (`D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES` / `…_RECORD_BYTE_ALIGNMENT` /
/// `…_TABLE_BYTE_ALIGNMENT`) and Vulkan
/// (`shaderGroupHandleSize` / `shaderGroupHandleAlignment` /
/// `shaderGroupBaseAlignment`); backend-specific constants are passed in.
struct SBTRegionLayout {
uint32_t Stride = 0;
uint32_t Size = 0;
uint32_t Offset = 0; // byte offset from the start of the SBT buffer
};

struct SBTLayout {
SBTRegionLayout RayGen;
SBTRegionLayout Miss;
SBTRegionLayout HitGroup;
SBTRegionLayout Callable;
uint32_t TotalSize = 0;
};

/// Compute the per-region layout for an SBT description.
///
/// \p IdentifierSize is the size of one shader identifier (32 bytes on
/// D3D12; `shaderGroupHandleSize` on Vulkan).
/// \p RecordAlign is the per-record alignment (D3D12: 32 bytes; Vulkan:
/// `shaderGroupHandleAlignment`).
/// \p BaseAlign is the per-region alignment (D3D12: 64 bytes; Vulkan:
/// `shaderGroupBaseAlignment`).
SBTLayout computeSBTLayout(uint32_t IdentifierSize, uint32_t RecordAlign,
uint32_t BaseAlign,
const ShaderBindingTableDesc &Desc);

} // namespace offloadtest

#endif // OFFLOADTEST_API_SHADERBINDINGTABLE_H
117 changes: 113 additions & 4 deletions include/Support/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,53 @@ enum class Stages {

// Mesh Shader Raster
Amplification,
Mesh
Mesh,

// Ray Tracing
RayGeneration,
Miss,
ClosestHit,
AnyHit,
Intersection,
Callable
};
inline constexpr std::array AllStages = {
Stages::Compute, Stages::Vertex, Stages::Hull, Stages::Domain,
Stages::Geometry, Stages::Pixel, Stages::Amplification, Stages::Mesh,
Stages::Compute, Stages::Vertex, Stages::Hull,
Stages::Domain, Stages::Geometry, Stages::Pixel,
Stages::Amplification, Stages::Mesh, Stages::RayGeneration,
Stages::Miss, Stages::ClosestHit, Stages::AnyHit,
Stages::Intersection, Stages::Callable,
};
inline constexpr size_t NumStages = AllStages.size();

enum class ShaderPipelineKind { Compute, TraditionalRaster, MeshShaderRaster };
inline constexpr bool isRayTracingStage(Stages S) {
switch (S) {
case Stages::RayGeneration:
case Stages::Miss:
case Stages::ClosestHit:
case Stages::AnyHit:
case Stages::Intersection:
case Stages::Callable:
return true;
case Stages::Compute:
case Stages::Vertex:
case Stages::Hull:
case Stages::Domain:
case Stages::Geometry:
case Stages::Pixel:
case Stages::Amplification:
case Stages::Mesh:
return false;
}
llvm_unreachable("All stages handled");
}

enum class ShaderPipelineKind {
Compute,
TraditionalRaster,
MeshShaderRaster,
RayTracing
};

enum class Rule { BufferExact, BufferFloatULP, BufferFloatEpsilon };

Expand Down Expand Up @@ -527,6 +565,40 @@ struct AccelerationStructureDescs {
llvm::SmallVector<TLASDesc, 1> TLAS;
};

enum class HitGroupType { Triangles, Procedural };

struct HitGroup {
std::string Name;
HitGroupType Type = HitGroupType::Triangles;
std::string ClosestHit;
std::optional<std::string> AnyHit;
std::optional<std::string> Intersection;
};

struct RayTracingPipelineConfig {
uint32_t MaxTraceRecursionDepth = 1;
uint32_t MaxPayloadSizeInBytes = 0;
uint32_t MaxAttributeSizeInBytes = 8;
std::optional<uint32_t> PipelineFlags;
};

struct SBTEntry {
// For RayGen / Miss / Callable entries: the shader's Entry name.
// For HitGroup entries: the HitGroup's Name.
std::string ShaderName;
// Optional per-record local-root data, laid out as the local root signature
// describes. Not used during PR1 bring-up; reserved here so the schema is
// stable when local root signatures land.
llvm::SmallVector<uint8_t> LocalRootData;
};

struct ShaderBindingTableDesc {
SBTEntry RayGen;
llvm::SmallVector<SBTEntry> Miss;
llvm::SmallVector<SBTEntry> HitGroup;
llvm::SmallVector<SBTEntry> Callable;
};

struct Pipeline {
ShaderPipelineKind Kind;
llvm::SmallVector<Shader> Shaders;
Expand All @@ -540,6 +612,9 @@ struct Pipeline {
llvm::SmallVector<DescriptorSet> Sets;
DispatchParametersSet DispatchParameters;
AccelerationStructureDescs AccelStructs;
std::optional<RayTracingPipelineConfig> RTConfig;
llvm::SmallVector<HitGroup> HitGroups;
std::optional<ShaderBindingTableDesc> SBT;

uint32_t getVertexCount() const {
if (DispatchParameters.VertexCount)
Expand Down Expand Up @@ -607,6 +682,7 @@ struct Pipeline {
bool isRaster() const {
return isTraditionalRaster() || isMeshShaderRaster();
}
bool isRayTracing() const { return Kind == ShaderPipelineKind::RayTracing; }
};
} // namespace offloadtest

Expand All @@ -627,6 +703,8 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::AABBGeometry)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::BLASDesc)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::InstanceDesc)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::TLASDesc)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::HitGroup)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::SBTEntry)

namespace llvm {
namespace yaml {
Expand Down Expand Up @@ -735,6 +813,22 @@ template <> struct MappingTraits<offloadtest::AccelerationStructureDescs> {
static void mapping(IO &I, offloadtest::AccelerationStructureDescs &D);
};

template <> struct MappingTraits<offloadtest::HitGroup> {
static void mapping(IO &I, offloadtest::HitGroup &G);
};

template <> struct MappingTraits<offloadtest::RayTracingPipelineConfig> {
static void mapping(IO &I, offloadtest::RayTracingPipelineConfig &C);
};

template <> struct MappingTraits<offloadtest::SBTEntry> {
static void mapping(IO &I, offloadtest::SBTEntry &E);
};

template <> struct MappingTraits<offloadtest::ShaderBindingTableDesc> {
static void mapping(IO &I, offloadtest::ShaderBindingTableDesc &S);
};

template <> struct ScalarEnumerationTraits<offloadtest::Rule> {
static void enumeration(IO &I, offloadtest::Rule &V) {
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::Rule::Val)
Expand Down Expand Up @@ -886,6 +980,21 @@ template <> struct ScalarEnumerationTraits<offloadtest::Stages> {
ENUM_CASE(Pixel);
ENUM_CASE(Amplification);
ENUM_CASE(Mesh);
ENUM_CASE(RayGeneration);
ENUM_CASE(Miss);
ENUM_CASE(ClosestHit);
ENUM_CASE(AnyHit);
ENUM_CASE(Intersection);
ENUM_CASE(Callable);
#undef ENUM_CASE
}
};

template <> struct ScalarEnumerationTraits<offloadtest::HitGroupType> {
static void enumeration(IO &I, offloadtest::HitGroupType &V) {
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::HitGroupType::Val)
ENUM_CASE(Triangles);
ENUM_CASE(Procedural);
#undef ENUM_CASE
}
};
Expand Down
Loading
Loading