diff --git a/backend/compiler.py b/backend/compiler.py index 804bdfdf..b5d0397f 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -120,7 +120,12 @@ class DICPBackend(BaseBackend): def __init__(self, target: str) -> None: super().__init__(target) self.driver = DICPDriver(target) - if self.driver.target == "dicp": + if self.driver.is_cpu_verify: + from .cpu_backend import CPUBackend + + self._cpu_backend = CPUBackend(target) + self.binary_ext = "obj" + elif self.driver.target == "dicp": self.binary_ext = "ttlinalgdir" elif self.driver.target == "mlu": self.capability = target.arch @@ -136,7 +141,7 @@ def __init__(self, target: str) -> None: @staticmethod def supports_target(target: GPUTarget): - return target.backend in ["dicp", "mlu", "maca", "ascend"] + return target.backend in ["dicp", "mlu", "maca", "ascend", "cpu"] @staticmethod def make_ttir(mod, metadata, opt): @@ -260,13 +265,31 @@ def add_stages(self, stages, options, language=None): ) ) stages["linkedir"] = lambda src, metadata: ttsharedir_to_linkedir( - src, metadata, options, named_ops=True + src, + metadata, + options, + named_ops=True, + cpu_verify=self.driver.is_cpu_verify, ) - stages["npubin"] = ( - lambda src, metadata: linalg_to_bin_enable_npu_compile( - src, metadata, options + if self.driver.is_cpu_verify: + from .cpu_backend import ( + _ttsharedir_to_llir, + _llir_to_bin, + _optimize_llir, + ) + + stages["llir"] = lambda src, metadata: _optimize_llir( + _ttsharedir_to_llir(src, metadata) + ) + stages["obj"] = lambda src, metadata: _llir_to_bin( + src, metadata + ) + else: + stages["npubin"] = ( + lambda src, metadata: linalg_to_bin_enable_npu_compile( + src, metadata, options + ) ) - ) else: raise RuntimeError("backend not supported") @@ -286,6 +309,8 @@ def get_driver(self): # parse add_kernel[(16,)](x, y, output, n_elements, BLOCK_SIZE=1024) def parse_options(self, options: dict) -> Any: + if self.driver.is_cpu_verify: + return self._cpu_backend.parse_options(options) if self.target.backend == "ascend": from triton.backends.dicp_triton.npu import NPUOptions @@ -360,7 +385,9 @@ def parse_options(self, options: dict) -> Any: def get_codegen_implementation(self, options=None): codegen_fns = dict() - if self.target.backend == "ascend": + if self.driver.is_cpu_verify: + return self._cpu_backend.get_codegen_implementation(options) + elif self.target.backend == "ascend": from triton.backends.dicp_triton.npu import min_dot_size codegen_fns = {"min_dot_size": min_dot_size(self.target)} @@ -384,6 +411,8 @@ def get_codegen_implementation(self, options=None): return codegen_fns def pack_metadata(self, metadata): + if self.driver.is_cpu_verify: + return self._cpu_backend.pack_metadata(metadata) if self.target.backend == "ascend": from triton.backends.dicp_triton.npu import TRITON_PROFILER_REGISTERED diff --git a/backend/cpu_backend.py b/backend/cpu_backend.py new file mode 100644 index 00000000..c8a1a80b --- /dev/null +++ b/backend/cpu_backend.py @@ -0,0 +1,780 @@ +# CPU Backend for verification +# Merged from triton_shared backend/compiler.py and backend/driver.py + +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes +from dataclasses import dataclass +from typing import Any, Dict, Tuple +from types import ModuleType +import hashlib +import tempfile +import os +import re +import shutil +import subprocess +import functools +import triton +from pathlib import Path +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase +import sysconfig +import importlib.util +import sys +import platform +import triton.backends.dicp_triton.utils as dicp_utils + +dump_ir = os.environ.get("DLC_DUMP_IR", "0") == "1" + + +def _get_triton_shared_opt_path() -> str: + path = os.getenv("TRITON_SHARED_OPT_PATH", "") + if path == "": + raise Exception("TRITON_SHARED_OPT_PATH is not set.") + return path + + +def _get_llvm_bin_path(bin_name: str) -> str: + path = os.getenv("LLVM_BINARY_DIR", "") + if path == "": + raise Exception("LLVM_BINARY_DIR is not set.") + return os.path.join(path, bin_name) + + +def _dump_ir_if_needed(files): + path = os.getenv("TRITON_SHARED_DUMP_PATH", "") + if not path: + return + for f in files: + shutil.copy(f, os.path.join(path, os.path.basename(f))) + + +def _get_sanitizer_type(): + sanitizer_type = os.getenv("TRITON_SHARED_SANITIZER_TYPE", "") + if sanitizer_type != "" and sanitizer_type != "asan" and sanitizer_type != "tsan": + raise Exception(f"TRITON_SHARED_SANITIZER_TYPE {sanitizer_type} is invalid.") + return sanitizer_type + + +def _ttir_to_ttsharedir(mod, metadata): + ttir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "tt.mlir") + dst_path = os.path.join(tmpdir, "ttshared.mlir") + Path(src_path).write_text(ttir_code) + triton_shared_opt_path = _get_triton_shared_opt_path() + subprocess_args = [ + triton_shared_opt_path, + src_path, + "--triton-to-linalg-experimental", + "--mlir-print-debuginfo", + "-o", + dst_path, + ] + if _get_sanitizer_type() != "": + print("Building with sanitizer support...") + subprocess_args.insert(2, "--add-llvm-debug-info") + subprocess.check_call(subprocess_args) + result = Path(dst_path).read_text() + if dump_ir: + dicp_utils._dump_stage_ir( + result, metadata["hash"], "kernel.ttsharedir.mlir" + ) + return result + + +def _optimize_ttsharedir(ttsharedir: str): + return ttsharedir + + +def _ttsharedir_to_llir(ttsharedir: str, metadata): + with tempfile.TemporaryDirectory() as tmpdir: + ttshared_path = os.path.join(tmpdir, "ttshared.mlir") + llmlir_path = os.path.join(tmpdir, "ll.mlir") + llir_path = os.path.join(tmpdir, "ll.ir") + Path(ttshared_path).write_text(ttsharedir) + mlir_opt_path = _get_llvm_bin_path("mlir-opt") + subprocess.check_call( + [ + mlir_opt_path, + ttshared_path, + "--convert-elementwise-to-linalg", + "--convert-linalg-to-affine-loops", + "--empty-tensor-to-alloc-tensor", + "--one-shot-bufferize=allow-return-allocs-from-loops=true", + "--lower-affine", + "--convert-linalg-to-loops", + "--expand-strided-metadata", + "--convert-scf-to-cf", + "--convert-arith-to-llvm", + "--convert-math-to-llvm", + "--convert-complex-to-llvm", + "--convert-vector-to-llvm", + "--convert-index-to-llvm", + "--memref-expand", + "--finalize-memref-to-llvm", + "--convert-func-to-llvm", + "--convert-cf-to-llvm", + "--lower-affine", + "--convert-arith-to-llvm", + "--reconcile-unrealized-casts", + "--mlir-print-debuginfo", + "-o", + llmlir_path, + ] + ) + mlir_translate_path = _get_llvm_bin_path("mlir-translate") + subprocess.check_call( + [mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path] + ) + result = Path(llir_path).read_text() + if dump_ir: + dicp_utils._dump_stage_ir(result, metadata["hash"], "kernel.llir.mlir") + return result + + +def _optimize_llir(llir: str): + return llir + + +def _llir_to_bin(llir: str, metadata): + pattern = r"define void @(\w+)\(.+" + matches = re.findall(pattern, llir) + assert len(matches) == 1 + metadata["name"] = matches[0] + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ll") + dst_path = os.path.join(tmpdir, "kernel.o") + Path(src_path).write_text(llir) + sanitizer_type = _get_sanitizer_type() + if sanitizer_type != "": + instrumented_src_path = os.path.join(tmpdir, "kernel-instrumented.ll") + opt_path = _get_llvm_bin_path("opt") + top_level_triton_path = os.path.dirname(triton.__file__) + sanitizer_attributes_pass_path = str( + next( + Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None + ) + ) + if not sanitizer_attributes_pass_path: + raise Exception(f"libSanitizerAttributes.so does not exist.") + subprocess.check_call( + [ + opt_path, + "-load-pass-plugin", + sanitizer_attributes_pass_path, + "-passes=sanitizer-attributes", + f"-sanitizer-type={sanitizer_type}", + "-S", + src_path, + "-o", + instrumented_src_path, + ] + ) + clang_path = _get_llvm_bin_path("clang++") + subprocess_args = [clang_path, "-c", instrumented_src_path, "-o", dst_path] + if sanitizer_type == "asan": + subprocess_args.extend( + ["-g", "-fsanitize=address", "-mllvm", "-asan-stack=0"] + ) + elif sanitizer_type == "tsan": + subprocess_args.extend(["-g", "-fsanitize=thread"]) + subprocess.check_call(subprocess_args) + else: + llc_path = _get_llvm_bin_path("llc") + subprocess.check_call( + [ + llc_path, + src_path, + "-filetype=obj", + "-relocation-model=pic", + "-o", + dst_path, + ] + ) + return Path(dst_path).read_bytes() + + +# -------------------- Compiler -------------------- + + +@dataclass(frozen=True) +class CPUOptions: + debug: bool = False + arch: str = None + + num_warps: int = -1 + num_ctas: int = -1 + num_stages: int = 2 + + enable_warp_specialization: bool = False + enable_fp_fusion: bool = True + + extern_libs: dict = None + + cluster_dims: tuple = (1, 1, 1) + shared: bool = False + + supported_fp8_dtypes: Tuple[str] = () + allow_fp8e4nv: bool = False + + allowed_dot_input_precisions: Tuple[str] = ("ieee", "hf32") + + sanitize_overflow: bool = True + enable_npu_compile: bool = True + + kernel_name: str = "triton_" + + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 + + enable_nd2nz_on_vector: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + + max_num_imprecise_acc_default: bool = None + multibuffer: bool = True + + inject_barrier_all: bool = False + disable_auto_inject_block_sync: bool = False + unit_flag: bool = False + + disable_auto_cv_work_space_manage: bool = False + enable_auto_bind_sub_block: bool = True + + tile_mix_vector_loop: int = None + tile_mix_cube_loop: int = None + + limit_auto_multi_buffer_only_for_local_buffer: bool = None + set_workspace_multibuffer: int = None + + stream: int = None + + def __post_init__(self): + pass + + def hash(self): + key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class CPUBackend(BaseBackend): + binary_ext = "obj" + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == "cpu" + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + + def parse_options(self, opts) -> Any: + args = {"arch": self.target.arch} + args.update( + {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} + ) + return CPUOptions(**args) + + def get_codegen_implementation(self, options): + codegen_fns = {"min_dot_size": lambda lhsType, rhsType: (1, 1, 1)} + return codegen_fns + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + metadata.name, + ) + + def load_dialects(self, ctx): + return + + @staticmethod + def make_ttir(mod, metadata, options): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.ttir.add_triton_licm(pm) + passes.common.add_symbol_dce(pm) + passes.ttir.add_loop_unroll(pm) + passes.common.add_cse(pm) + pm.run(mod) + if dump_ir: + dicp_utils._dump_stage_ir(str(mod), metadata["hash"], "kernel.ttir.mlir") + return mod + + def add_stages(self, stages, options, language): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttsharedir"] = lambda src, metadata: _optimize_ttsharedir( + _ttir_to_ttsharedir(src, metadata) + ) + stages["llir"] = lambda src, metadata: _optimize_llir( + _ttsharedir_to_llir(src, metadata) + ) + stages["obj"] = lambda src, metadata: _llir_to_bin(src, metadata) + + @functools.lru_cache() + def hash(self): + return self.target + + def get_module_map(self) -> Dict[str, ModuleType]: + return {} + + +# -------------------- Driver -------------------- + + +def _ty_to_cpp(ty): + if ty[0] == "*": + return "void*" + if ty == "constexpr": + return "PyObject*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def _extracted_type(ty): + if ty[0] == "*": + return "PyObject*" + if ty == "constexpr": + return "PyObject*" + return _ty_to_cpp(ty) + + +def _format_of(ty): + return { + "PyObject*": "O", + "constexpr": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + +def _generate_launcher(constants, signature, kernel_name): + arg_decls = ", ".join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + args_format = "".join( + [_format_of(_extracted_type(ty)) for ty in signature.values()] + ) + format = "iiiOOOO" + args_format + args_list = ( + ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) + if len(signature) > 0 + else "" + ) + + kernel_arg_decls = ", ".join( + _ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" + for i, ty in signature.items() + if ty != "constexpr" + ) + kernel_arg_decls += ", " if kernel_arg_decls else "" + + kernel_parameters = ", ".join( + f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" + for i, ty in signature.items() + if ty != "constexpr" + ) + kernel_parameters += ", " if kernel_parameters else "" + + return f""" +#include +#include +#include +#include "ExecutionEngine/CRunnerUtils.h" +#include "ExecutionEngine/CRunnerUtils.cpp" + +extern "C" {{ + // Pointer type (=Memref) becomes int64_t + MemRef struct + // FIXME: understand what this int64_t is used for. + void {kernel_name}({kernel_arg_decls} + int, int, int, int, int, int); +}} + +static void _launch(int gridX, int gridY, int gridZ, {arg_decls}) {{ + if (gridX*gridY*gridZ > 0) {{ + // Cast "function" to the real function type. + // apply parallelization to the triton grid when using ThreadSanitizer (TSan) + // to help detect potential data races across program instances during kernel execution + {"#pragma omp parallel for collapse(3)" if _get_sanitizer_type() == "tsan" else ""} + for(int x = 0; x < gridX; x++) {{ + for(int y = 0; y < gridY; y++) {{ + for(int z = 0; z < gridZ; z++) {{ + // Use some random type "char" here. + {' '.join(f'StridedMemRefType ptr_arg{i} = {{static_cast(arg{i}), static_cast(arg{i}), 0}};' for i, ty in signature.items() if i not in constants and ty[0] == "*")} + {kernel_name}({kernel_parameters} + gridX, gridY, gridZ, x, y, z); + }} + }} + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + // [CPULauncher-specific]: We don't need the metadata below but just put them + // here anyway to be consistent with others. + // This will make updating the driver easier in the future. + + // int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + // if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + // PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + // return NULL; + // }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + + if (PyErr_Occurred()) {{ + return NULL; + }} + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_shared_ref_cpu_kernel_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_shared_ref_cpu_kernel_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + + +def compile_module(launcher_src, kernel_placeholder_name): + py_version = sys.version_info + if platform.system() == "Windows": + py_include_dir = os.path.join(sys.base_prefix, "include") + py_lib_dir = os.path.join(sys.base_prefix, "libs") + py_lib = "{name}{major}{minor}.lib".format( + name="python", major=py_version.major, minor=py_version.minor + ) + else: + py_include_dir = os.path.join( + sys.base_prefix, + "include", + f"python{sys.version_info.major}.{sys.version_info.minor}", + ) + py_lib_dir = os.path.join(sys.base_prefix, "lib") + py_lib = "{name}{major}.{minor}".format( + name="python", major=py_version.major, minor=py_version.minor + ) + cpu_backend_path = Path(__file__).resolve().parent + include_dir = os.path.join(cpu_backend_path, "include") + + def launch( + gridX, + gridY, + gridZ, + stream, + cu_function, + kernel_metadata, + launch_metadata, + launch_enter_hook, + launch_exit_hook, + *args, + ): + kernel_obj = cu_function + kernel_name = kernel_metadata[6] + src = launcher_src.replace(kernel_placeholder_name, kernel_name) + key = hashlib.sha256(src.encode("utf-8") + kernel_obj).hexdigest() + cache = get_cache_manager(key) + name = "__triton_shared_ref_cpu_kernel_launcher" + if platform.system() == "Windows": + filename = f"{name}.pyd" + else: + filename = f"{name}.so" + cache_path = cache.get_file(filename) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + sanitizer_type = _get_sanitizer_type() + if platform.system() == "Windows": + if sanitizer_type != "": + raise Exception( + "Sanitizers are not supported on Windows with triton-shared." + ) + obj_path = os.path.join(tmpdir, "kernel.obj") + launcher_src_path = os.path.join(tmpdir, "main.cxx") + so_path = os.path.join(tmpdir, "kernel.pyd") + Path(obj_path).write_bytes(kernel_obj) + Path(launcher_src_path).write_text(src) + subprocess.check_call( + [ + "cl", + "/LD", + "/std:c++17", + launcher_src_path, + obj_path, + f"-I{py_include_dir}", + f"-I{include_dir}", + "/link", + f"/LIBPATH:{py_lib_dir}", + "/link", + f"{py_lib}", + f"/OUT:{so_path}", + ] + ) + else: + obj_path = os.path.join(tmpdir, "kernel.o") + launcher_src_path = os.path.join(tmpdir, "main.cxx") + so_path = os.path.join(tmpdir, "kernel.so") + Path(obj_path).write_bytes(kernel_obj) + Path(launcher_src_path).write_text(src) + if sanitizer_type != "": + clang_path = _get_llvm_bin_path("clang++") + subprocess_args = [ + clang_path, + "-std=c++17", + launcher_src_path, + obj_path, + f"-I{py_include_dir}", + f"-I{include_dir}", + f"-L{py_lib_dir}", + "-shared", + f"-l{py_lib}", + "-fPIC", + "-o", + so_path, + ] + if sanitizer_type == "asan": + subprocess_args.extend( + ["-g", "-fsanitize=address", "-mllvm", "-asan-stack=0"] + ) + elif sanitizer_type == "tsan": + subprocess_args.extend(["-g", "-fsanitize=thread"]) + subprocess.check_call(subprocess_args) + else: + subprocess.check_call( + [ + "g++", + "-std=c++17", + launcher_src_path, + obj_path, + f"-I{py_include_dir}", + f"-I{include_dir}", + f"-L{py_lib_dir}", + "-shared", + f"-l{py_lib}", + "-fPIC", + "-o", + so_path, + ] + ) + with open(so_path, "rb") as f: + cache_path = cache.put(f.read(), filename, binary=True) + spec = importlib.util.spec_from_file_location(name, cache_path) + if spec is None: + raise RuntimeError(f"Cannot find {name} module in {cache_path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.launch( + gridX, + gridY, + gridZ, + kernel_metadata, + launch_metadata, + launch_enter_hook, + launch_exit_hook, + *args, + ) + + return launch + + +class CPULauncher(object): + def __init__(self, src, metadata): + kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER" + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name) + self.launch = compile_module(launcher_src, kernel_placeholder_name) + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class CPUUtils(object): + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CPUUtils, cls).__new__(cls) + return cls.instance + + @staticmethod + def get_device_properties(device): + return { + "max_shared_mem": 2**20, + "multiprocessor_count": None, + "sm_clock_rate": None, + "mem_clock_rate": None, + "mem_bus_width": None, + } + + @staticmethod + def load_binary(name, kernel_obj, shared, device): + return ( + None, + kernel_obj, + None, + None, + ) + + +class CPUDriver(DriverBase): + def __init__(self): + super().__init__() + self.utils = CPUUtils() + self.launcher_cls = CPULauncher + self.binary_ext = "obj" + + @staticmethod + def is_active(): + return False + + def get_benchmarker(self): + from triton.testing import do_bench + + return do_bench + + def get_device_capability(self): + return ("cpu", 0) + + def get_current_stream(self, device): + return None + + def get_current_device(self): + return "cpu" + + def set_current_device(self, device): + assert device == "cpu" + return + + def get_current_target(self): + return GPUTarget("cpu", 0, 0) + + def get_active_torch_device(self): + import torch + + return torch.device("cpu") + + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args + + def map_python_to_cpp_type(self, ty: str) -> str: + return _ty_to_cpp(ty) diff --git a/backend/driver.py b/backend/driver.py index ae564af7..4bf69b20 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -129,6 +129,8 @@ def __init__(self, target=None): return self.__initialized = True super().__init__() + self.is_cpu_verify = os.environ.get("DLC_CPU_VERIFY", "0") == "1" + if target == "mlu": from triton.backends.dicp_triton.mlu import BangLauncher, BangUtils @@ -166,6 +168,12 @@ def __init__(self, target=None): self.launcher_cls = CudaLauncher else: self.target = "dicp" + if self.is_cpu_verify: + from .cpu_backend import CPUUtils, CPULauncher, CPUDriver + + self.utils = CPUUtils() + self.launcher_cls = CPULauncher + self._cpu_driver = CPUDriver() def __new__(cls, target=None): if not hasattr(cls, "instance"): @@ -231,6 +239,8 @@ def launch_as_union_task(self, device, grid): ) def get_device_capability(self): + if self.is_cpu_verify: + return self._cpu_driver.get_device_capability() if self.target == "mlu": return ("mlu", 0) elif self.target == "maca": @@ -303,6 +313,8 @@ def set_current_device(self, device): return def get_current_target(self): + if self.is_cpu_verify: + return self._cpu_driver.get_current_target() if self.target == "mlu": device = self.get_current_device() capability = self.utils.get_device_properties(device).get("isa_version") @@ -357,9 +369,10 @@ def get_empty_cache_for_benchmark(self): assert False, f"Not implemented for {self.target}" def get_active_torch_device(self): - # todo: fix it. import torch + if self.is_cpu_verify: + return self._cpu_driver.get_active_torch_device() return torch.device("cpu") def map_python_to_cpp_type(self, ty: str) -> str: diff --git a/backend/include/ExecutionEngine/CRunnerUtils.cpp b/backend/include/ExecutionEngine/CRunnerUtils.cpp new file mode 100644 index 00000000..87e47027 --- /dev/null +++ b/backend/include/ExecutionEngine/CRunnerUtils.cpp @@ -0,0 +1,189 @@ +//===- CRunnerUtils.cpp - Utils for MLIR execution ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic functions to manipulate structured MLIR types at +// runtime. Entities in this file are meant to be retargetable, including on +// targets without a C++ runtime, and must be kept C compatible. +// +//===----------------------------------------------------------------------===// + +#include "CRunnerUtils.h" +#include "Msan.h" + +#ifndef _WIN32 +#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + defined(__DragonFly__) +#include +#else +#include +#endif +#include +#else +#include "malloc.h" +#endif // _WIN32 + +#include +#include +#include +#include +#include +#include + +#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS + +namespace { +template void stdSort(uint64_t n, V *p) { std::sort(p, p + n); } + +} // namespace + +// Small runtime support "lib" for vector.print lowering. +// By providing elementary printing methods only, this +// library can remain fully unaware of low-level implementation +// details of our vectors. Also useful for direct LLVM IR output. +extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); } +extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); } +extern "C" void printF32(float f) { fprintf(stdout, "%g", f); } +extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); } +extern "C" void printString(char const *s) { fputs(s, stdout); } +extern "C" void printOpen() { fputs("( ", stdout); } +extern "C" void printClose() { fputs(" )", stdout); } +extern "C" void printComma() { fputs(", ", stdout); } +extern "C" void printNewline() { fputc('\n', stdout); } + +extern "C" void memrefCopy(int64_t elemSize, UnrankedMemRefType *srcArg, + UnrankedMemRefType *dstArg) { + DynamicMemRefType src(*srcArg); + DynamicMemRefType dst(*dstArg); + + int64_t rank = src.rank; + MLIR_MSAN_MEMORY_IS_INITIALIZED(src.sizes, rank * sizeof(int64_t)); + + // Handle empty shapes -> nothing to copy. + for (int rankp = 0; rankp < rank; ++rankp) + if (src.sizes[rankp] == 0) + return; + + char *srcPtr = src.data + src.offset * elemSize; + char *dstPtr = dst.data + dst.offset * elemSize; + + if (rank == 0) { + memcpy(dstPtr, srcPtr, elemSize); + return; + } + + int64_t *indices = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *srcStrides = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *dstStrides = static_cast(alloca(sizeof(int64_t) * rank)); + + // Initialize index and scale strides. + for (int rankp = 0; rankp < rank; ++rankp) { + indices[rankp] = 0; + srcStrides[rankp] = src.strides[rankp] * elemSize; + dstStrides[rankp] = dst.strides[rankp] * elemSize; + } + + int64_t readIndex = 0, writeIndex = 0; + for (;;) { + // Copy over the element, byte by byte. + memcpy(dstPtr + writeIndex, srcPtr + readIndex, elemSize); + // Advance index and read position. + for (int64_t axis = rank - 1; axis >= 0; --axis) { + // Advance at current axis. + auto newIndex = ++indices[axis]; + readIndex += srcStrides[axis]; + writeIndex += dstStrides[axis]; + // If this is a valid index, we have our next index, so continue copying. + if (src.sizes[axis] != newIndex) + break; + // We reached the end of this axis. If this is axis 0, we are done. + if (axis == 0) + return; + // Else, reset to 0 and undo the advancement of the linear index that + // this axis had. Then continue with the axis one outer. + indices[axis] = 0; + readIndex -= src.sizes[axis] * srcStrides[axis]; + writeIndex -= dst.sizes[axis] * dstStrides[axis]; + } + } +} + +/// Prints GFLOPS rating. +extern "C" void printFlops(double flops) { + fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9); +} + +/// Returns the number of seconds since Epoch 1970-01-01 00:00:00 +0000 (UTC). +extern "C" double rtclock() { +#ifndef _WIN32 + struct timeval tp; + int stat = gettimeofday(&tp, nullptr); + if (stat != 0) + fprintf(stderr, "Error returning time from gettimeofday: %d\n", stat); + return (tp.tv_sec + tp.tv_usec * 1.0e-6); +#else + fprintf(stderr, "Timing utility not implemented on Windows\n"); + return 0.0; +#endif // _WIN32 +} + +extern "C" void *mlirAlloc(uint64_t size) { return malloc(size); } + +extern "C" void *mlirAlignedAlloc(uint64_t alignment, uint64_t size) { +#ifdef _WIN32 + return _aligned_malloc(size, alignment); +#elif defined(__APPLE__) + // aligned_alloc was added in MacOS 10.15. Fall back to posix_memalign to also + // support older versions. + void *result = nullptr; + (void)::posix_memalign(&result, alignment, size); + return result; +#else + return aligned_alloc(alignment, size); +#endif +} + +extern "C" void mlirFree(void *ptr) { free(ptr); } + +extern "C" void mlirAlignedFree(void *ptr) { +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +extern "C" void *rtsrand(uint64_t s) { + // Standard mersenne_twister_engine seeded with s. + return new std::mt19937(s); +} + +extern "C" uint64_t rtrand(void *g, uint64_t m) { + std::mt19937 *generator = static_cast(g); + std::uniform_int_distribution distrib(0, m); + return distrib(*generator); +} + +extern "C" void rtdrand(void *g) { + std::mt19937 *generator = static_cast(g); + delete generator; +} + +#define IMPL_STDSORT(VNAME, V) \ + extern "C" void _mlir_ciface_stdSort##VNAME(uint64_t n, \ + StridedMemRefType *vref) { \ + assert(vref); \ + assert(vref->strides[0] == 1); \ + V *values = vref->data + vref->offset; \ + stdSort(n, values); \ + } +IMPL_STDSORT(I64, int64_t) +IMPL_STDSORT(F64, double) +IMPL_STDSORT(F32, float) +#undef IMPL_STDSORT + +#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS diff --git a/backend/include/ExecutionEngine/CRunnerUtils.h b/backend/include/ExecutionEngine/CRunnerUtils.h new file mode 100644 index 00000000..1e55ca92 --- /dev/null +++ b/backend/include/ExecutionEngine/CRunnerUtils.h @@ -0,0 +1,482 @@ +//===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic classes and functions to manipulate structured MLIR +// types at runtime. Entities in this file must be compliant with C++11 and be +// retargetable, including on targets without a C++ runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_CRUNNERUTILS_H +#define MLIR_EXECUTIONENGINE_CRUNNERUTILS_H + +#ifdef _WIN32 +#ifndef MLIR_CRUNNERUTILS_EXPORT +#ifdef mlir_c_runner_utils_EXPORTS +// We are building this library +#define MLIR_CRUNNERUTILS_EXPORT __declspec(dllexport) +#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#else +// We are using this library +#define MLIR_CRUNNERUTILS_EXPORT __declspec(dllimport) +#endif // mlir_c_runner_utils_EXPORTS +#endif // MLIR_CRUNNERUTILS_EXPORT +#else // _WIN32 +// Non-windows: use visibility attributes. +#define MLIR_CRUNNERUTILS_EXPORT __attribute__((visibility("default"))) +#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#endif // _WIN32 + +#include +#include +#include +#include +#include + +//===----------------------------------------------------------------------===// +// Codegen-compatible structures for Vector type. +//===----------------------------------------------------------------------===// +namespace mlir { +namespace detail { + +constexpr bool isPowerOf2(int n) { return (!(n & (n - 1))); } + +constexpr unsigned nextPowerOf2(int n) { + return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2))); +} + +template struct Vector1D; + +template struct Vector1D { + Vector1D() { + static_assert(detail::nextPowerOf2(sizeof(T[Dim])) == sizeof(T[Dim]), + "size error"); + } + inline T &operator[](unsigned i) { return vector[i]; } + inline const T &operator[](unsigned i) const { return vector[i]; } + +private: + T vector[Dim]; +}; + +// 1-D vector, padded to the next power of 2 allocation. +// Specialization occurs to avoid zero size arrays (which fail in -Werror). +template struct Vector1D { + Vector1D() { + static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error"); + static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), + "size error"); + } + inline T &operator[](unsigned i) { return vector[i]; } + inline const T &operator[](unsigned i) const { return vector[i]; } + +private: + T vector[Dim]; + char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; +}; +} // namespace detail +} // namespace mlir + +// N-D vectors recurse down to 1-D. +template struct Vector { + inline Vector &operator[](unsigned i) { return vector[i]; } + inline const Vector &operator[](unsigned i) const { + return vector[i]; + } + +private: + Vector vector[Dim]; +}; + +// 1-D vectors in LLVM are automatically padded to the next power of 2. +// We insert explicit padding in to account for this. +template +struct Vector + : public mlir::detail::Vector1D { +}; + +template using Vector1D = Vector; +template using Vector2D = Vector; +template +using Vector3D = Vector; +template +using Vector4D = Vector; + +template void dropFront(int64_t arr[N], int64_t *res) { + for (unsigned i = 1; i < N; ++i) + *(res + i - 1) = arr[i]; +} + +//===----------------------------------------------------------------------===// +// Codegen-compatible structures for StridedMemRef type. +//===----------------------------------------------------------------------===// +template class StridedMemrefIterator; + +/// StridedMemRef descriptor type with static rank. +template struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[N]; + int64_t strides[N]; + + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == N && + "indices should match rank in memref subscript"); + int64_t curOffset = offset; + for (int dim = N - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + StridedMemRefType operator[](int64_t idx) { + StridedMemRefType res; + res.basePtr = basePtr; + res.data = data; + res.offset = offset + idx * strides[0]; + dropFront(sizes, res.sizes); + dropFront(strides, res.strides); + return res; + } +}; + +/// StridedMemRef descriptor type specialized for rank 1. +template struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[1]; + int64_t strides[1]; + + template ().begin())> + T &operator[](Range indices) { + assert(indices.size() == 1 && + "indices should match rank in memref subscript"); + return (*this)[*indices.begin()]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, -1}; } + + T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } +}; + +/// StridedMemRef descriptor type specialized for rank 0. +template struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + + template ().begin())> + T &operator[](Range indices) { + assert((indices.size() == 0) && + "Expect empty indices for 0-rank memref subscript"); + return data[offset]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, offset + 1}; } +}; + +/// Iterate over all elements in a strided memref. +template class StridedMemrefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + StridedMemrefIterator(StridedMemRefType &descriptor, + int64_t offset = 0) + : offset(offset), descriptor(&descriptor) {} + StridedMemrefIterator &operator++() { + int dim = Rank - 1; + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + if (dim < 0) { + offset = -1; + return *this; + } + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::array &getIndices() { return indices; } + + bool operator==(const StridedMemrefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::array indices = {}; + + /// Descriptor for the strided memref. + StridedMemRefType *descriptor; +}; + +/// Iterate over all elements in a 0-ranked strided memref. +template class StridedMemrefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + StridedMemrefIterator(StridedMemRefType &descriptor, int64_t offset = 0) + : elt(descriptor.data + offset) {} + + StridedMemrefIterator &operator++() { + ++elt; + return *this; + } + + reference operator*() { return *elt; } + pointer operator->() { return elt; } + + // There are no indices for a 0-ranked memref, but this API is provided for + // consistency with the general case. + const std::array &getIndices() { + // Since this is a 0-array of indices we can keep a single global const + // copy. + static const std::array indices = {}; + return indices; + } + + bool operator==(const StridedMemrefIterator &other) const { + return other.elt == elt; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Pointer to the single element in the zero-ranked memref. + T *elt; +}; + +//===----------------------------------------------------------------------===// +// Codegen-compatible structure for UnrankedMemRef type. +//===----------------------------------------------------------------------===// +// Unranked MemRef +template struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + +//===----------------------------------------------------------------------===// +// DynamicMemRefType type. +//===----------------------------------------------------------------------===// +template class DynamicMemRefIterator; + +// A reference to one of the StridedMemRef types. +template class DynamicMemRefType { +public: + int64_t rank; + T *basePtr; + T *data; + int64_t offset; + const int64_t *sizes; + const int64_t *strides; + + explicit DynamicMemRefType(const StridedMemRefType &memRef) + : rank(0), basePtr(memRef.basePtr), data(memRef.data), + offset(memRef.offset), sizes(nullptr), strides(nullptr) {} + template + explicit DynamicMemRefType(const StridedMemRefType &memRef) + : rank(N), basePtr(memRef.basePtr), data(memRef.data), + offset(memRef.offset), sizes(memRef.sizes), strides(memRef.strides) {} + explicit DynamicMemRefType(const ::UnrankedMemRefType &memRef) + : rank(memRef.rank) { + auto *desc = static_cast *>(memRef.descriptor); + basePtr = desc->basePtr; + data = desc->data; + offset = desc->offset; + sizes = rank == 0 ? nullptr : desc->sizes; + strides = sizes + rank; + } + + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == rank && + "indices should match rank in memref subscript"); + if (rank == 0) + return data[offset]; + + int64_t curOffset = offset; + for (int dim = rank - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + DynamicMemRefIterator begin() { return {*this, offset}; } + DynamicMemRefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + DynamicMemRefType operator[](int64_t idx) { + assert(rank > 0 && "can't make a subscript of a zero ranked array"); + + DynamicMemRefType res(*this); + --res.rank; + res.offset += idx * res.strides[0]; + ++res.sizes; + ++res.strides; + return res; + } + + // This operator* can be used in conjunction with the previous operator[] in + // order to access the underlying value in case of zero-ranked memref. + T &operator*() { + assert(rank == 0 && "not a zero-ranked memRef"); + return data[offset]; + } +}; + +/// Iterate over all elements in a dynamic memref. +template class DynamicMemRefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + DynamicMemRefIterator(DynamicMemRefType &descriptor, int64_t offset = 0) + : offset(offset), descriptor(&descriptor) { + indices.resize(descriptor.rank, 0); + } + + DynamicMemRefIterator &operator++() { + if (descriptor->rank == 0) { + offset = -1; + return *this; + } + + int dim = descriptor->rank - 1; + + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + + if (dim < 0) { + offset = -1; + return *this; + } + + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::vector &getIndices() { return indices; } + + bool operator==(const DynamicMemRefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const DynamicMemRefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::vector indices = {}; + + /// Descriptor for the dynamic memref. + DynamicMemRefType *descriptor; +}; + +//===----------------------------------------------------------------------===// +// Small runtime support library for memref.copy lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +memrefCopy(int64_t elemSize, ::UnrankedMemRefType *src, + ::UnrankedMemRefType *dst); + +//===----------------------------------------------------------------------===// +// Small runtime support library for vector.print lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline(); + +//===----------------------------------------------------------------------===// +// Small runtime support library for timing execution and printing GFLOPS +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops); +extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock(); + +//===----------------------------------------------------------------------===// +// Runtime support library for random number generation. +//===----------------------------------------------------------------------===// +// Uses a seed to initialize a random generator and returns the generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void *rtsrand(uint64_t s); +// Returns a random number in the range of [0, m). +extern "C" MLIR_CRUNNERUTILS_EXPORT uint64_t rtrand(void *, uint64_t m); +// Deletes the random number generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void rtdrand(void *); + +//===----------------------------------------------------------------------===// +// Runtime support library to allow the use of std::sort in MLIR program. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortI64(uint64_t n, StridedMemRefType *vref); +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortF64(uint64_t n, StridedMemRefType *vref); +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortF32(uint64_t n, StridedMemRefType *vref); +#endif // MLIR_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/backend/include/ExecutionEngine/Msan.h b/backend/include/ExecutionEngine/Msan.h new file mode 100644 index 00000000..ee94660a --- /dev/null +++ b/backend/include/ExecutionEngine/Msan.h @@ -0,0 +1,35 @@ +//===- Msan.h - Utils related to the memory sanitizer ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares and defines macros related to msan. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_MSAN_H +#define MLIR_EXECUTIONENGINE_MSAN_H + +// Memory sanitizer currently can't be enabled for the jit-compiled code, and +// to suppress msan warnings we need to unpoison pointers and pointed-to +// datastructures before they can be accessed. + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if __has_feature(memory_sanitizer) && !defined(MLIR_MEMORY_SANITIZER) +#define MLIR_MEMORY_SANITIZER +#endif + +#if defined(MLIR_MEMORY_SANITIZER) +#include +#define MLIR_MSAN_MEMORY_IS_INITIALIZED(p, s) __msan_unpoison((p), (s)) +#else // Memory sanitizer: OFF +#define MLIR_MSAN_MEMORY_IS_INITIALIZED(p, s) +#endif // MLIR_MEMORY_SANITIZER + +#endif // MLIR_EXECUTIONENGINE_MSAN_H diff --git a/backend/include/ExecutionEngine/version.txt b/backend/include/ExecutionEngine/version.txt new file mode 100644 index 00000000..c3f15e55 --- /dev/null +++ b/backend/include/ExecutionEngine/version.txt @@ -0,0 +1 @@ +https://github.com/llvm/llvm-project/commit/3be3883e6d67bf908fd12b51219075293ebb3dff diff --git a/backend/npu.py b/backend/npu.py index 91c10110..38de5735 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -497,19 +497,22 @@ def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False): return content -def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): +def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False, cpu_verify=False): pm = ir.pass_manager(mod.context) dicp_triton.passes.linked_npu.add_lower_affine(pm) dicp_triton.passes.linked_npu.add_normalize_slice_ops(pm) dicp_triton.passes.linked_npu.add_linalg_if_to_select(pm) dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm) dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) - dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) + dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, named_ops, True, cpu_verify) dicp_triton.passes.linked_npu.add_linked_to_hivm(pm) - pm.run(mod) - + if cpu_verify: + dicp_triton.passes.linked_npu.add_debug_cpu_verify(pm) # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。 + pm.run(mod) content = str(mod) + if cpu_verify: + return content # 将"*xfxxx"替换成"?xfxxx" content = content.replace("*xf", "?xf") content = content.replace("*xi", "?xi") diff --git a/compiler/include/dicp/Conversion/LinalgToLinked/LinalgToLinked.h b/compiler/include/dicp/Conversion/LinalgToLinked/LinalgToLinked.h index adbeef75..743cb567 100644 --- a/compiler/include/dicp/Conversion/LinalgToLinked/LinalgToLinked.h +++ b/compiler/include/dicp/Conversion/LinalgToLinked/LinalgToLinked.h @@ -6,6 +6,7 @@ namespace mlir::dicp::linked { std::unique_ptr> -createLinalgToLinkedPass(bool globalKernel = true, bool namedOps = true); +createLinalgToLinkedPass(bool globalKernel = true, bool namedOps = true, + bool cpuVerify = false); } // namespace mlir::dicp::linked diff --git a/compiler/include/dicp/Conversion/LinalgToLinked/Passes.h b/compiler/include/dicp/Conversion/LinalgToLinked/Passes.h index 7f39d427..3f351e27 100644 --- a/compiler/include/dicp/Conversion/LinalgToLinked/Passes.h +++ b/compiler/include/dicp/Conversion/LinalgToLinked/Passes.h @@ -4,6 +4,8 @@ namespace mlir::dicp::linked { +std::unique_ptr> createDebugCPUVerifyPass(); + #define GEN_PASS_REGISTRATION #include "dicp/Conversion/LinalgToLinked/Passes.h.inc" diff --git a/compiler/include/dicp/Conversion/LinalgToLinked/Passes.td b/compiler/include/dicp/Conversion/LinalgToLinked/Passes.td index cbe1bed7..a3d59ca7 100644 --- a/compiler/include/dicp/Conversion/LinalgToLinked/Passes.td +++ b/compiler/include/dicp/Conversion/LinalgToLinked/Passes.td @@ -5,15 +5,30 @@ include "mlir/Pass/PassBase.td" def LinalgToLinked : Pass<"linalg-to-linked", "mlir::ModuleOp"> { let summary = "Convert Linalg to Linked dialect"; - let constructor = "linked::createLinalgToLinkedPass(true,true)"; + let constructor = "linked::createLinalgToLinkedPass()"; let options = [ Option<"globalKernel", "global-kernel", "bool", /*default*/"true", - "generate a global kernel">, + "Generate a global kernel">, Option<"namedOps", "named-ops", "bool", /*default*/"true", - "use linalg named ops instead of linalg.generic"> + "Use linalg named ops instead of linalg.generic">, + Option<"cpuVerify", "cpu-verify", + "bool", /*default*/"false", + "Skip NPU workspace args for CPU verification"> ]; } +def DebugCPUVerify : Pass<"debug-cpu-verify", "mlir::ModuleOp"> { + let summary = "Verify that only MLIR built-in dialects remain for CPU runner"; + let description = [{ + Verification pass that scans the module for operations belonging to + non-MLIR-upstream (external) dialects. Any such operation is reported + as an error, ensuring the IR has been fully lowered to standard MLIR + constructs before being handed off to a CPU runner for correctness + validation. + }]; + let constructor = "linked::createDebugCPUVerifyPass()"; +} + #endif diff --git a/compiler/lib/Conversion/LinalgToLinked/CMakeLists.txt b/compiler/lib/Conversion/LinalgToLinked/CMakeLists.txt index 649dce58..7bc99c81 100644 --- a/compiler/lib/Conversion/LinalgToLinked/CMakeLists.txt +++ b/compiler/lib/Conversion/LinalgToLinked/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(LinalgToLinked LinalgToLinkedPass.cpp + DebugCPUVerifyPass.cpp VerifyNoLinalgGenericPass.cpp TritonOpConverter.cpp @@ -9,6 +10,7 @@ add_triton_library(LinalgToLinked LINK_LIBS PUBLIC BiShengIRAnnotationDialect + BiShengIRHIVMDialect TritonTilingExtIR MLIRArithDialect MLIRDialectUtils diff --git a/compiler/lib/Conversion/LinalgToLinked/DebugCPUVerifyPass.cpp b/compiler/lib/Conversion/LinalgToLinked/DebugCPUVerifyPass.cpp new file mode 100644 index 00000000..8b45cf09 --- /dev/null +++ b/compiler/lib/Conversion/LinalgToLinked/DebugCPUVerifyPass.cpp @@ -0,0 +1,73 @@ +#include "dicp/Conversion/LinalgToLinked/Passes.h" + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "debug-cpu-verify" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "dicp/Conversion/LinalgToLinked/Passes.h.inc" + +namespace { + +/// Returns true if the operation belongs to an external (non-MLIR-upstream) +/// dialect that should have been lowered away before CPU verification. +static bool isExternalDialectOp(Operation *op) { + Dialect *dialect = op->getDialect(); + if (!dialect) + return false; + return isa( + dialect); +} + +/// Returns true if the operation is a hivm.hir.sync_block_* op that should be +/// removed before CPU verification. +static bool isSyncBlockOp(Operation *op) { + return isa(op); +} + +class DebugCPUVerifyPass : public DebugCPUVerifyBase { +public: + void runOnOperation() override { + // First pass: remove all hivm.hir.sync_block_* operations + SmallVector syncBlockOpsToErase; + getOperation()->walk([&](Operation *op) { + if (isSyncBlockOp(op)) + syncBlockOpsToErase.push_back(op); + }); + for (Operation *op : syncBlockOpsToErase) + op->erase(); + + // Second pass: verify no external dialect operations remain + bool failed = false; + getOperation()->walk([&](Operation *op) { + if (!isExternalDialectOp(op)) + return; + op->emitError() << "external dialect op '" << op->getName() + << "' must be lowered before CPU verification"; + failed = true; + }); + + if (failed) + return signalPassFailure(); + + LLVM_DEBUG(llvm::dbgs() << "[debug-cpu-verify] PASSED — no external " + "dialect operations found\n"); + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::linked::createDebugCPUVerifyPass() { + return std::make_unique(); +} diff --git a/compiler/lib/Conversion/LinalgToLinked/LinalgToLinkedPass.cpp b/compiler/lib/Conversion/LinalgToLinked/LinalgToLinkedPass.cpp index d7492ec3..bf459391 100644 --- a/compiler/lib/Conversion/LinalgToLinked/LinalgToLinkedPass.cpp +++ b/compiler/lib/Conversion/LinalgToLinked/LinalgToLinkedPass.cpp @@ -425,9 +425,11 @@ class ExternElementwiseClOpConverter class LinalgToLinkedPass : public LinalgToLinkedBase { public: - explicit LinalgToLinkedPass(bool globalKernel, bool namedOps) { + explicit LinalgToLinkedPass(bool globalKernel, bool namedOps, + bool cpuVerify) { this->globalKernel = globalKernel; this->namedOps = namedOps; + this->cpuVerify = cpuVerify; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -513,43 +515,10 @@ class LinalgToLinkedPass : public LinalgToLinkedBase { signalPassFailure(); } - // 强制在函数参数开头添加一个参数,代表工作空间的占位参数 - for (auto func : getOperation().getOps()) { - if (!func->hasAttr("global_kernel")) - continue; - - auto context = func.getContext(); - constexpr int64_t syncBlockLockArgIdx = 0; - NamedAttribute syncBlockLockArgAttr( - StringAttr::get(context, "syncBlockLock"), UnitAttr::get(context)); - MemRefType syncBlockLockArgType = - MemRefType::get(SmallVector(1, ShapedType::kDynamic), - IntegerType::get(context, 8)); - if (failed(func.insertArgument(syncBlockLockArgIdx, // argIndex - syncBlockLockArgType, // argType - nullptr, func->getLoc()))) { - signalPassFailure(); - return; - } - func->setAttr("SyncBlockLockArgIdx", - IntegerAttr::get(IntegerType::get(&getContext(), 64), - 0)); // 64: 64位整型 - - constexpr int64_t workspaceArgIdx = 1; - MemRefType workspaceArgType = - MemRefType::get(SmallVector(1, ShapedType::kDynamic), - IntegerType::get(context, 8)); - NamedAttribute workspaceArgAttr(StringAttr::get(context, "workspace"), - UnitAttr::get(context)); - - if (failed(func.insertArgument(/*argIndex*/ workspaceArgIdx, - /*argType*/ workspaceArgType, - /*dicAttr*/ nullptr, func->getLoc()))) { - signalPassFailure(); - return; - } - func->setAttr("WorkspaceArgIdx", - IntegerAttr::get(IntegerType::get(&getContext(), 64), 1)); + // Insert NPU workspace args unless in CPU verify mode + if (failed(insertWorkspaceArgs(moduleOp))) { + signalPassFailure(); + return; } target.addIllegalOp(); @@ -559,11 +528,40 @@ class LinalgToLinkedPass : public LinalgToLinkedBase { signalPassFailure(); } } + +private: + /// Inserts syncBlockLock and workspace args when not in CPU verify mode. + LogicalResult insertWorkspaceArgs(ModuleOp module) { + if (cpuVerify) + return success(); + + MLIRContext *ctx = module.getContext(); + auto memrefType = + MemRefType::get({ShapedType::kDynamic}, IntegerType::get(ctx, 8)); + + for (auto func : module.getOps()) { + if (!func->hasAttr(globalKernelAttr)) + continue; + + if (failed(func.insertArgument(0, memrefType, nullptr, func.getLoc()))) + return func.emitError("failed to insert syncBlockLock"); + func->setAttr("SyncBlockLockArgIdx", + IntegerAttr::get(IntegerType::get(ctx, 64), 0)); + + if (failed(func.insertArgument(1, memrefType, nullptr, func.getLoc()))) + return func.emitError("failed to insert workspace"); + func->setAttr("WorkspaceArgIdx", + IntegerAttr::get(IntegerType::get(ctx, 64), 1)); + } + return success(); + } }; } // namespace std::unique_ptr> -linked::createLinalgToLinkedPass(bool globalKernel, bool namedOps) { - return std::make_unique(globalKernel, namedOps); +linked::createLinalgToLinkedPass(bool globalKernel, bool namedOps, + bool cpuVerify) { + return std::make_unique(globalKernel, namedOps, + cpuVerify); } diff --git a/compiler/lib/Conversion/LinalgToLinked/TritonOpConverter.cpp b/compiler/lib/Conversion/LinalgToLinked/TritonOpConverter.cpp index b62ba68d..f69ea288 100644 --- a/compiler/lib/Conversion/LinalgToLinked/TritonOpConverter.cpp +++ b/compiler/lib/Conversion/LinalgToLinked/TritonOpConverter.cpp @@ -405,12 +405,6 @@ ScanConverter::convertToTargetOp(triton::ScanOp op, Value scanInput = op.getOperand(0); - scanInput.dump(); - - for (Value operand : op->getOperands()) { - operand.dump(); - } - auto srcType = mlir::dyn_cast(scanInput.getType()); if (!srcType) { return rewriter.notifyMatchFailure( diff --git a/test/ascend/cpu_verify/test_bare_matmul.py b/test/ascend/cpu_verify/test_bare_matmul.py new file mode 100644 index 00000000..c2236cb4 --- /dev/null +++ b/test/ascend/cpu_verify/test_bare_matmul.py @@ -0,0 +1,37 @@ +"""Triton bare matmul kernel test on CPU verify mode.""" + +import os + +# NOTE: Must set BEFORE importing triton, as triton reads this during import +os.environ.setdefault("DLC_CPU_VERIFY", "1") +os.environ.setdefault( + "LLVM_BINARY_DIR", + "/mnt/data01/kezengxiang/work/third_party/llvm-project/build_064f02dac0c81c19350a74415b3245f42fed09dc/bin", +) + +import torch +import triton +import triton.language as tl + + +@triton.jit +def bare_matmul(X, Y, Z, M, N, K, BLOCK_SIZE: tl.constexpr): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + offs_x = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_y = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + offs_x[:, None] * K + offs_y[None, :]) + y = tl.load(Y + offs_x[:, None] * N + offs_y[None, :]) + z = tl.dot(x, y) + tl.store(Z + offs_x[:, None] * N + offs_y[None, :], z) + + +def test_bare_matmul(): + n = 128 + a = torch.randn((n, n), dtype=torch.float32) + b = torch.randn((n, n), dtype=torch.float32) + c = torch.empty((n, n), dtype=torch.float32) + + bare_matmul[(1,)](a, b, c, n, n, n, BLOCK_SIZE=n) + + assert torch.allclose(torch.matmul(a, b), c, atol=1e-2, rtol=0) diff --git a/test/ascend/cpu_verify/test_fa.py b/test/ascend/cpu_verify/test_fa.py new file mode 100644 index 00000000..6de70bb0 --- /dev/null +++ b/test/ascend/cpu_verify/test_fa.py @@ -0,0 +1,724 @@ +""" +Flash Attention Test Suite + +Compares Triton attention implementations against references: +- CPU flash attention (PyTorch reference) +- NPU torch_npu.npu_fusion_attention (NPU reference) + +Usage: + pytest test_fa.py -v + python test_fa.py # Uses DLC_CPU_VERIFY=1 by default +""" + +import os + +os.environ.setdefault("DLC_CPU_VERIFY", "1") +os.environ.setdefault( + "LLVM_BINARY_DIR", + "/mnt/data01/kezengxiang/work/third_party/llvm-project/build_064f02dac0c81c19350a74415b3245f42fed09dc/bin", +) + +import pytest +import torch +import triton +import triton.language as tl +import triton.language.extra.deeplink as dl +import torch_npu + +CPU_VERIFY = os.environ.get("DLC_CPU_VERIFY", "0") == "1" +DEVICE = "cpu" if CPU_VERIFY else "npu" +ATOL = 1e-3 +RTOL = 0.0 + +# ============================================================================= +# Triton Kernels +# ============================================================================= + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + qk_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + N_CTX: tl.constexpr, + fp8_v: tl.constexpr, +): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + else: + lo, hi = 0, N_CTX + + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load(K_block_ptr) + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + + p = tl.math.exp(qk) + p_cast = p.to(tl.float16) + v = tl.load(V_block_ptr) + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + pv + m_i = m_ij + + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + M, + Out, + sm_scale: tl.constexpr, + stride_qz: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_qk: tl.constexpr, + stride_kz: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kk: tl.constexpr, + stride_vz: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vk: tl.constexpr, + stride_oz: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + stride_on: tl.constexpr, + Z: tl.constexpr, + H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + NUM_BLOCKS_PER_CORE: tl.constexpr, + NUM_BLOCKS: tl.constexpr, + NUM_BLOCKS_M: tl.constexpr, +): + pid = tl.program_id(0) + for block_idx in range(pid, NUM_BLOCKS, 24): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + q = tl.load(Q_block_ptr) + + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + task_m_idx, + sm_scale, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + N_CTX, + V.dtype.element_ty == tl.float8e5, + ) + + if STAGE & 2: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + task_m_idx, + sm_scale, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 2, + offs_m, + offs_n, + N_CTX, + V.dtype.element_ty == tl.float8e5, + ) + + m_i += tl.math.log(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_split_cv( + Q, + K, + V, + M, + Out, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + stride_qz: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_qk: tl.constexpr, + stride_kz: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kk: tl.constexpr, + stride_vz: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vk: tl.constexpr, + stride_oz: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + stride_on: tl.constexpr, + w1_stride_nb: tl.constexpr, + w1_stride_bm: tl.constexpr, + w1_stride_bn: tl.constexpr, + w2_stride_nb: tl.constexpr, + w2_stride_bm: tl.constexpr, + w2_stride_bn: tl.constexpr, + w3_stride_nb: tl.constexpr, + w3_stride_bm: tl.constexpr, + w3_stride_dm: tl.constexpr, + Z: tl.constexpr, + H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_CORES: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + NUM_BLOCKS_M = N_CTX // BLOCK_M + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + pid = tl.program_id(0) + + for block_idx in tl.range(pid, NUM_BLOCKS, NUM_CORES): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + q = tl.load(Q_block_ptr) + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + for start_n in range(0, N_CTX, BLOCK_N * NUM_STAGES): + for i in tl.range(0, NUM_STAGES, num_stages=NUM_STAGES): + ws1_ptr = tl.make_block_ptr( + base=workspace_1 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w1_stride_nb, + shape=(BLOCK_M, BLOCK_N), + strides=(w1_stride_bm, w1_stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + ws2_ptr = tl.make_block_ptr( + base=workspace_2 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w2_stride_nb, + shape=(BLOCK_M, BLOCK_N), + strides=(w2_stride_bm, w2_stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + ws3_ptr = tl.make_block_ptr( + base=workspace_3 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w3_stride_nb, + shape=(BLOCK_M, HEAD_DIM), + strides=(w3_stride_bm, w3_stride_dm), + offsets=(0, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + with dl.async_task(scope=dl.async_task.cube): + k = tl.load(K_block_ptr) + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + tl.store(ws1_ptr, qk) + dl.set_cross_flag(dl.SyncFlag.C2V, 0) + + with dl.async_task(scope=dl.async_task.vector): + dl.wait_cross_flag(dl.SyncFlag.C2V, 0) + qk = tl.load(ws1_ptr) + qk = qk * sm_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp(qk) + p_cast = p.to(Q.type.element_ty) + tl.store(ws2_ptr, p_cast) + dl.set_cross_flag(dl.SyncFlag.V2C, 1) + dl.wait_cross_flag(dl.SyncFlag.V2C, 1) + with dl.async_task(scope=dl.async_task.cube): + p_cast = tl.load(ws2_ptr) + v = tl.load(V_block_ptr) + acc_l0c = tl.dot(p_cast, v) + tl.store(ws3_ptr, acc_l0c) + dl.set_cross_flag(dl.SyncFlag.C2V, 2) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + with dl.async_task(scope=dl.async_task.vector): + l_ij = tl.sum(p, 1) + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + dl.wait_cross_flag(dl.SyncFlag.C2V, 2) + acc_ptr = acc_ptr * alpha[:, None] + acc_o_ub = tl.load(ws3_ptr) + acc_ptr = acc_ptr + acc_o_ub + m_i = m_ij + + m_i += tl.math.log(l_i) + accumulator = acc_ptr / l_i[:, None] + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +# ============================================================================= +# Python Wrappers +# ============================================================================= + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale, BM, BN, causal=False): + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + o = torch.empty_like(q) + stage = 3 if causal else 1 + num_cores = 24 + NUM_BLOCKS_M = triton.cdiv(q.shape[2], BM) + NUM_BLOCKS = NUM_BLOCKS_M * q.shape[0] * q.shape[1] + NUM_BLOCKS_PER_CORE = triton.cdiv(NUM_BLOCKS, num_cores) + + M = torch.empty( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + + _attn_fwd[(num_cores,)]( + q, + k, + v, + M, + o, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + STAGE=stage, + NUM_BLOCKS_PER_CORE=NUM_BLOCKS_PER_CORE, + NUM_BLOCKS=NUM_BLOCKS, + NUM_BLOCKS_M=NUM_BLOCKS_M, + multibuffer=True, + unit_flag=True, + debug=False, + ) + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + +class AttentionSplitCV(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale, BM, BN, causal=False): + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + o = torch.empty_like(q) + N_CTX = q.shape[2] + Z, H = q.shape[0], q.shape[1] + NUM_BLOCKS_M = N_CTX // BM + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + DIM = q.shape[-1] + NUM_CORES = 24 + NUM_STAGES = 4 + + acc = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), + dtype=torch.float32, + device=q.device, + ) + M = torch.empty( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + workspace_1 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, BN), device=q.device, dtype=torch.float32 + ) + workspace_2 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, BN), device=q.device, dtype=q.dtype + ) + workspace_3 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, DIM), device=q.device, dtype=torch.float32 + ) + + _attn_fwd_split_cv[(NUM_CORES,)]( + q, + k, + v, + M, + o, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + stride_qz=q.stride(0), + stride_qh=q.stride(1), + stride_qm=q.stride(2), + stride_qk=q.stride(3), + stride_kz=k.stride(0), + stride_kh=k.stride(1), + stride_kn=k.stride(2), + stride_kk=k.stride(3), + stride_vz=v.stride(0), + stride_vh=v.stride(1), + stride_vn=v.stride(2), + stride_vk=v.stride(3), + stride_oz=o.stride(0), + stride_oh=o.stride(1), + stride_om=o.stride(2), + stride_on=o.stride(3), + w1_stride_nb=workspace_1.stride(1), + w1_stride_bm=workspace_1.stride(2), + w1_stride_bn=workspace_1.stride(3), + w2_stride_nb=workspace_2.stride(1), + w2_stride_bm=workspace_2.stride(2), + w2_stride_bn=workspace_2.stride(3), + w3_stride_nb=workspace_3.stride(1), + w3_stride_bm=workspace_3.stride(2), + w3_stride_dm=workspace_3.stride(3), + Z=q.shape[0], + H=q.shape[1], + N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + NUM_CORES=NUM_CORES, + NUM_STAGES=NUM_STAGES, + disable_auto_inject_block_sync=True, + disable_auto_cv_work_space_manage=True, + ) + + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + return o + + +attention_base = _attention.apply +attention_split_cv = AttentionSplitCV.apply + + +# ============================================================================= +# Reference Implementations +# ============================================================================= + + +def flash_attention_cpu(q, k, v, sm_scale, causal=False): + """Pure PyTorch CPU flash attention for reference.""" + q_fp32, k_fp32, v_fp32 = q.float(), k.float(), v.float() + scores = torch.matmul(q_fp32, k_fp32.transpose(-2, -1)) * sm_scale + + if causal: + seq_len = q.shape[-2] + mask = torch.triu( + torch.ones(seq_len, seq_len, device=q.device), diagonal=1 + ).bool() + scores = scores.masked_fill(mask, float("-inf")) + + max_score = torch.max(scores, dim=-1, keepdim=True).values + exp_scores = torch.exp(scores - max_score) + attn_weights = exp_scores / torch.sum(exp_scores, dim=-1, keepdim=True) + out = torch.matmul(attn_weights, v_fp32) + + return out.to(q.dtype) + + +def torch_npu_fusion_attention(q, k, v, num_heads, scale): + """Wrapper for torch_npu.npu_fusion_attention.""" + return torch_npu.npu_fusion_attention( + q, + k, + v, + num_heads, + padding_mask=None, + atten_mask=None, + scale=scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + +# ============================================================================= +# Error Metrics +# ============================================================================= + + +def compute_error_metrics(ref, out, name): + """Compute and print error metrics.""" + diff = (ref - out).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + ref_max = ref.abs().max().item() + rel_max = max_diff / ref_max if ref_max > 0 else float("inf") + print( + f" [{name}] max_abs={max_diff:.6e}, mean_abs={mean_diff:.6e}, max_rel={rel_max:.6e}" + ) + return max_diff, mean_diff, rel_max + + +def compare_results(ref_cpu, ref_npu, base, cv, test_name): + """ + Compare results according to the rules: + - base is compared against both ref_cpu and ref_npu + - cv is only compared against base (not against references) + """ + print(f"\n[{test_name}] Precision Comparison:") + print("-" * 80) + + # base vs CPU reference + compute_error_metrics(ref_cpu, base, "base_vs_cpu_ref") + + # base vs NPU reference + compute_error_metrics(ref_npu, base, "base_vs_npu_ref") + + # cv vs base (cv is only compared to base, not to references) + compute_error_metrics(base, cv, "cv_vs_base") + + +# ============================================================================= +# Test Data Generation +# ============================================================================= + + +def generate_test_data(Z, H, N_CTX, HEAD_DIM, dtype=torch.float16): + """Generate test tensors on CPU with fixed seed.""" + torch.manual_seed(20) + shape = (Z, H, N_CTX, HEAD_DIM) + q = torch.empty(shape, dtype=dtype, device="cpu").normal_(mean=0.0, std=0.5) + k = torch.empty(shape, dtype=dtype, device="cpu").normal_(mean=0.0, std=0.5) + v = torch.empty(shape, dtype=dtype, device="cpu").normal_(mean=0.0, std=0.5) + return q, k, v + + +# ============================================================================= +# Pytest Test Cases +# ============================================================================= + +ALL_CASES = [ + (2, 2, 1024, 128, 64, 128, False), + (1, 1, 1024 * 4, 128, 64, 128, False), +] + + +@pytest.mark.parametrize("Z,H,N_CTX,HEAD_DIM,BM,BN,causal", ALL_CASES) +def test_attention_precision(Z, H, N_CTX, HEAD_DIM, BM, BN, causal): + """ + Test precision of Triton attention kernels. + + Comparison rules: + - attention_base (_attention) is compared against: + * CPU flash attention reference + * NPU torch_npu.npu_fusion_attention reference + - attention_split_cv is only compared against attention_base + (split_cv is not compared directly to references) + """ + q, k, v = generate_test_data(Z, H, N_CTX, HEAD_DIM) + sm_scale = 0.5 + + # CPU reference + ref_cpu = flash_attention_cpu(q, k, v, sm_scale, causal) + + # NPU reference (runs on NPU, result moved to CPU) + q_npu, k_npu, v_npu = q.to("npu"), k.to("npu"), v.to("npu") + ref_npu = torch_npu_fusion_attention(q_npu, k_npu, v_npu, H, sm_scale).cpu() + + # Run Triton kernels on target device + dev = "cpu" if CPU_VERIFY else "npu" + q_dev, k_dev, v_dev = q.to(dev), k.to(dev), v.to(dev) + + tri_base = attention_base(q_dev, k_dev, v_dev, sm_scale, BM, BN, causal).cpu() + tri_cv = attention_split_cv(q_dev, k_dev, v_dev, sm_scale, BM, BN, causal).cpu() + + # Compare and print results + mode_str = "CPU_VERIFY" if CPU_VERIFY else "NPU" + compare_results( + ref_cpu, ref_npu, tri_base, tri_cv, f"{mode_str}_Z{Z}_H{H}_N{N_CTX}_D{HEAD_DIM}" + ) + + # Validate assertions + assert torch.allclose( + ref_cpu, ref_npu, atol=ATOL, rtol=RTOL + ), "CPU ref vs NPU ref mismatch!" + assert torch.allclose( + ref_npu, tri_base, atol=ATOL, rtol=RTOL + ), "base vs NPU ref mismatch!" + assert torch.allclose( + ref_npu, tri_cv, atol=ATOL, rtol=RTOL + ), "cv vs NPU ref mismatch!" + assert torch.allclose( + tri_cv, tri_base, atol=ATOL, rtol=RTOL + ), "base vs cv ref mismatch!" + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/test/ascend/cpu_verify/test_vec_add.py b/test/ascend/cpu_verify/test_vec_add.py new file mode 100644 index 00000000..5fe1504d --- /dev/null +++ b/test/ascend/cpu_verify/test_vec_add.py @@ -0,0 +1,39 @@ +"""Triton vector addition kernel test on CPU verify mode.""" + +import os + +# NOTE: Must set BEFORE importing triton, as triton reads this during import +os.environ.setdefault("DLC_CPU_VERIFY", "1") +os.environ.setdefault( + "LLVM_BINARY_DIR", + "/mnt/data01/kezengxiang/work/third_party/llvm-project/build_064f02dac0c81c19350a74415b3245f42fed09dc/bin", +) + +import torch +import triton +import triton.language as tl + +BLOCK_SIZE = 1024 + + +@triton.jit +def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + tl.store(output_ptr + offsets, x + y, mask=mask) + + +def test_vec_add(): + size = 1024 + x = torch.rand(size) + y = torch.rand(size) + + output = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),) + add_kernel[grid](x, y, output, size, BLOCK_SIZE=BLOCK_SIZE) + + assert torch.allclose(x + y, output) diff --git a/tools/dicp_triton_opt/dicp_triton_opt.cpp b/tools/dicp_triton_opt/dicp_triton_opt.cpp index 2a38c56b..bd2a0c15 100644 --- a/tools/dicp_triton_opt/dicp_triton_opt.cpp +++ b/tools/dicp_triton_opt/dicp_triton_opt.cpp @@ -100,6 +100,7 @@ inline void registerDICPDialects(mlir::DialectRegistry ®istry) { dicp::linked::registerLinkedToHIVMPass(); dicp::linked::registerTritonToLinalgNPUCoversionPass(); dicp::linked::registerMemRefCopyGatherToTensorInsertPass(); + dicp::linked::registerDebugCPUVerifyPass(); dicp::LinalgExt::registerLinalgIfToSelectPass(); dicp::LinalgExt::registerLinalgGenericToSCFPass(); diff --git a/triton_dicp_triton.cc b/triton_dicp_triton.cc index 3d5c7693..5c39d4f0 100644 --- a/triton_dicp_triton.cc +++ b/triton_dicp_triton.cc @@ -70,13 +70,15 @@ void init_triton_dicp_triton_pass_linked_npu(py::module &&m) { pm.addNestedPass( dicp::LinalgExt::createScalarTo1DTensorPass()); }); - m.def("add_linalg_to_linked", - [](mlir::PassManager &pm, bool globalKernel, bool namedOps) { - pm.addPass(mlir::dicp::linked::createLinalgToLinkedPass(globalKernel, - namedOps)); - }); + m.def("add_linalg_to_linked", [](mlir::PassManager &pm, bool globalKernel, + bool namedOps, bool cpuVerify) { + pm.addPass(mlir::dicp::linked::createLinalgToLinkedPass( + globalKernel, namedOps, cpuVerify)); + }); ADD_PASS_WRAPPER_0("add_linked_to_hivm", dicp::linked::createLinkedToHIVMPass); + ADD_PASS_WRAPPER_0("add_debug_cpu_verify", + dicp::linked::createDebugCPUVerifyPass); m.def("add_vectorize_parallel_loop", [](mlir::PassManager &pm) { pm.addNestedPass( dicp::LinalgExt::createVectorizeParallelLoopPass());