diff --git a/CMakeLists.txt b/CMakeLists.txt index 91c2b015..8b3e01e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,31 +92,71 @@ endif() if(AUTO_DETECT_BACKENDS) message(STATUS "Auto-detecting available backends...") + # The Python that scikit-build's build-isolated environment hands + # us does not have `torch` (only `[build-system].requires` is + # installed). Fall back to a list of common system interpreters so + # the auto-detection finds `torch` when it is in the install env + # but not the build env. The first interpreter that successfully + # imports `torch` wins and is reused by the `WITH_TORCH` block + # below for include / library lookups. find_package(Python COMPONENTS Interpreter QUIET) - if(Python_FOUND) + set(_torch_python_candidates "${Python_EXECUTABLE}") + foreach(_candidate + python3 + python + /usr/bin/python3 + /usr/local/bin/python3 + /opt/conda/bin/python + /opt/conda/bin/python3) + find_program(_resolved_${_candidate} ${_candidate}) + if(_resolved_${_candidate} AND + NOT _resolved_${_candidate} STREQUAL "${Python_EXECUTABLE}") + list(APPEND _torch_python_candidates "${_resolved_${_candidate}}") + endif() + endforeach() + + foreach(_py ${_torch_python_candidates}) + if(NOT _py) + continue() + endif() + execute_process( - COMMAND ${Python_EXECUTABLE} -c "import torch" + COMMAND "${_py}" -c "import torch" RESULT_VARIABLE _torch_import_result OUTPUT_QUIET ERROR_QUIET ) if(_torch_import_result EQUAL 0) - set(WITH_TORCH ON) - message(STATUS "Auto-detected PyTorch.") + set(_TORCH_PYTHON "${_py}") + break() endif() + endforeach() + + if(_TORCH_PYTHON) + set(WITH_TORCH ON) + message(STATUS "Auto-detected PyTorch (via ${_TORCH_PYTHON}).") endif() endif() if(WITH_TORCH) find_package(Python COMPONENTS Interpreter REQUIRED) + # Prefer the interpreter that the auto-detect block already + # confirmed has `torch` (this is the system Python on hosts that + # use scikit-build's build-isolation, where the build interpreter + # does not have `torch`). Fall back to `Python_EXECUTABLE` for + # explicit `-DWITH_TORCH=ON` invocations. + if(NOT _TORCH_PYTHON) + set(_TORCH_PYTHON "${Python_EXECUTABLE}") + endif() + # Query `torch` paths directly instead of using `find_package(Torch)`, # which pulls in Caffe2's CMake config and may fail on platforms with # non-standard CUDA toolchains. execute_process( - COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))" + COMMAND ${_TORCH_PYTHON} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))" OUTPUT_VARIABLE TORCH_INCLUDE_DIRS OUTPUT_STRIP_TRAILING_WHITESPACE RESULT_VARIABLE _torch_result @@ -127,7 +167,7 @@ if(WITH_TORCH) endif() execute_process( - COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))" + COMMAND ${_TORCH_PYTHON} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))" OUTPUT_VARIABLE _torch_lib_dirs OUTPUT_STRIP_TRAILING_WHITESPACE ) @@ -144,7 +184,7 @@ if(WITH_TORCH) # the bundled `NEEDED` entries (otherwise: `undefined reference to # _gfortran_etime@GFORTRAN_8` etc.). execute_process( - COMMAND ${Python_EXECUTABLE} -c "import os, torch; d = os.path.dirname(torch.__file__); p = os.path.join(os.path.dirname(d), 'torch.libs'); print(p if os.path.isdir(p) else '')" + COMMAND ${_TORCH_PYTHON} -c "import os, torch; d = os.path.dirname(torch.__file__); p = os.path.join(os.path.dirname(d), 'torch.libs'); print(p if os.path.isdir(p) else '')" OUTPUT_VARIABLE TORCH_BUNDLED_LIBS_DIR OUTPUT_STRIP_TRAILING_WHITESPACE ) @@ -163,7 +203,7 @@ if(WITH_TORCH) # A mismatch causes linker errors (e.g. undefined reference to # `c10::Device::Device(std::string const&)`). execute_process( - COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))" + COMMAND ${_TORCH_PYTHON} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))" OUTPUT_VARIABLE TORCH_CXX11_ABI OUTPUT_STRIP_TRAILING_WHITESPACE RESULT_VARIABLE _torch_abi_result diff --git a/pyproject.toml b/pyproject.toml index 959699f9..6f6d46c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["scikit-build-core", "pybind11", "libclang"] +requires = ["scikit-build-core", "pybind11", "libclang", "pyyaml", "clang-format"] build-backend = "scikit_build_core.build" [project] diff --git a/scripts/generate_torch_ops.py b/scripts/generate_torch_ops.py new file mode 100644 index 00000000..5823468c --- /dev/null +++ b/scripts/generate_torch_ops.py @@ -0,0 +1,1085 @@ +"""Generate InfiniOps PyTorch wrappers from ATen `native_functions.yaml`. + +For each op listed in `scripts/torch_ops.yaml`, this script finds the `.out` +variant in PyTorch's `native_functions.yaml` (fetched on demand from the +PyTorch GitHub release matching `_PYTORCH_VERSION`), parses its schema, +and emits: + + - `generated/base/.h` — the InfiniOps base class + `class : public Operator<>`, with constructors and pure-virtual + `operator()` overloads mirroring the selected ATen schemas. + - `generated/torch//.h` and `.cc` — the PyTorch backend + `Operator<, kDev, 8>` that calls `at::_out(out, ...)`. + - `generated/torch_ops_metadata.json` — the kind (`unary` / `binary` / + `binary_alpha`) of every successfully-generated op, consumed by the + parametrized test suite. + +Slot 8 is the reserved convention for PyTorch backends; slots 0-7 are +left for native or vendor implementations. (The slot must also be > 0 +to side-step a partial-specialization-after-instantiation conflict with +the primary template `Operator<>` instantiated at index 0.) + +The generated files are not committed; CMake regenerates them at configure +time when `WITH_TORCH=ON`. +""" + +import argparse +import dataclasses +import json +import os +import pathlib +import re +import shutil +import subprocess +import sys +import urllib.request + +import yaml + +_SCRIPTS_DIR = pathlib.Path(__file__).resolve().parent +_REPO_ROOT = _SCRIPTS_DIR.parent +_OPS_YAML_PATH = _SCRIPTS_DIR / "torch_ops.yaml" +_BASE_DIR = _REPO_ROOT / "src" / "base" +_GENERATED_DIR = _REPO_ROOT / "generated" +_GENERATED_BASE_DIR = _GENERATED_DIR / "base" +_GENERATED_TORCH_DIR = _GENERATED_DIR / "torch" +_METADATA_PATH = _GENERATED_DIR / "torch_ops_metadata.json" + +# Reserved slot for PyTorch backends. Native and vendor implementations +# claim slots 0-7; PyTorch wrappers always live at 8. +_PYTORCH_SLOT = 8 + +# ATen uses symbolic names for some `int`/`float` defaults (e.g. +# `reduction=Mean`). Map them to C++ identifiers usable in a call. +_ENUM_DEFAULTS = { + "Mean": "at::Reduction::Mean", + "Sum": "at::Reduction::Sum", + "Contiguous": "at::MemoryFormat::Contiguous", +} + +# Default PyTorch release tag whose `native_functions.yaml` defines +# the schemas we generate against. The build picks the actual +# version by passing `--pytorch-version ` (or +# `INFINIOPS_PYTORCH_VERSION=` in the environment) so each +# platform builds against its own installed torch's schema — vendor +# forks (Cambricon's `torch_mlu` 2.1.0, Moore's `torch_musa`, …) lag +# behind upstream and would otherwise hit overload mismatches like +# `at::all_out`'s `int64_t dim` vs v2.4.0's `OptionalIntArrayRef dim`. +_DEFAULT_PYTORCH_VERSION = "v2.4.0" + +# Order matches the device list in existing hand-written torch backends +# (see `src/torch/add/add.cc`). +_DEVICE_TYPES = ( + "kCpu", + "kNvidia", + "kCambricon", + "kAscend", + "kMetax", + "kMoore", + "kIluvatar", + "kKunlun", + "kHygon", + "kQy", +) + +# YAML scalar-type tokens → C++ types. Reference types (e.g. `const Scalar&`) +# are not used so the generated signatures match the existing hand-written +# ones, which pass by value to keep pybind11 binding generation simple. +_SCALAR_TYPE_MAP = { + # `at::Scalar` is implicitly constructible from `double`, so we expose + # scalars as `double` in the base class to keep it torch-independent. + "Scalar": "double", + "int": "int64_t", + "bool": "bool", + "float": "double", + # `SymInt` / `SymInt[]` exist for `torch.compile` internals; at runtime + # they're just `int64`/IntArrayRef. + "SymInt": "int64_t", + # `str` for required string params (e.g. `index_reduce.reduce`). + # `std::string` marshals through pybind11 cleanly and converts + # implicitly to ATen's `c10::string_view`. + "str": "std::string", +} + +# `Dimname` overloads (named-tensor dim) are skipped — passing them +# from Python to ATen requires a wrapper conversion through +# `at::Dimname::fromSymbol(...)` that doesn't fit the cleanly-rendered +# 1:1 arg model, and named tensors remain experimental in PyTorch. +# The int-dim overload is always emitted alongside, so we lose nothing +# user-visible. + +# Optional ATen types we hide from the user-facing API and pass as a +# typed empty optional at the call site. Covers the common "full +# default" case for most reductions and activations. We use a typed +# `c10::optional{}` rather than bare `at::nullopt` so the compiler +# can disambiguate ops with multiple `_out` overloads (e.g. `clamp_out` +# accepts both `optional` and `optional` for `min`/`max`). +_NULLOPT_BY_TYPE = { + "Scalar?": "c10::optional{}", + "int?": "c10::optional{}", + "bool?": "c10::optional{}", + "float?": "c10::optional{}", + "str?": "c10::optional{}", + "ScalarType?": "c10::optional{}", + "MemoryFormat?": "c10::optional{}", + "Layout?": "c10::optional{}", + "Device?": "c10::optional{}", + "Generator?": "c10::optional{}", + "Tensor?": "c10::optional{}", + "Tensor?[]": "c10::List>{}", + "int[]?": "c10::optional{}", + "int[1]?": "c10::optional{}", + "int[2]?": "c10::optional{}", + "int[3]?": "c10::optional{}", + "SymInt?": "c10::optional{}", + "SymInt[]?": "c10::optional{}", + "SymInt[1]?": "c10::optional{}", + "SymInt[2]?": "c10::optional{}", + "SymInt[3]?": "c10::optional{}", + "float[]?": "c10::optional>{}", +} +_HARDCODE_NULLOPT_TYPES = frozenset(_NULLOPT_BY_TYPE) + + +@dataclasses.dataclass +class Param: + name: str + aten_type: str + default: str | None + keyword_only: bool + + @property + def is_tensor(self) -> bool: + # Real tensors only. `Tensor?` is optional and falls through to + # the hidden-param path (substituted with `at::nullopt`). + + return self.aten_type == "Tensor" or self.aten_type.startswith("Tensor(") + + @property + def is_out(self) -> bool: + # Mutable tensors carry `!` in their alias annotation, e.g. `Tensor(a!)`. + + return self.is_tensor and "!" in self.aten_type + + @property + def is_hardcoded_nullopt(self) -> bool: + """If `True`, the param is omitted from the user-facing API and + passed as `at::nullopt` to ATen.""" + + return self.aten_type in _HARDCODE_NULLOPT_TYPES + + @property + def is_hidden(self) -> bool: + """True if the param is omitted from the user-facing API. + + Default-valued non-optional params (\\`bool\\`, \\`int\\`, \\`float\\`, + \\`str\\`, \\`int[N]\\`, …) used to be hidden as a convenience, but + reviewers consistently flagged the resulting omissions — + \\`bool upper/transpose/unitriangular\\` on \\`triangular_solve\\`, + \\`int diagonal\\` on \\`triu\\`, \\`str ord\\` on \\`linalg_matrix_norm\\`, + \\`int n\\` on the special chebyshev family, etc. — as missing + semantic controls. They are now exposed and forwarded to ATen. + + Optional ATen types (\\`Tensor?\\`, \\`Scalar?\\`, \\`int?\\`, …) remain + hidden for now — exposing them would require teaching the torch + source to thread \\`std::optional\\` through to ATen, which is a + separate refactor. The same goes for ATen-internal types like + \\`Generator?\\`/\\`Layout?\\` that have no InfiniOps analogue. + """ + + return self.is_hardcoded_nullopt + + def hidden_value(self) -> str: + """C++ literal substituted for a hidden param in the ATen call.""" + + if self.is_hardcoded_nullopt: + return _NULLOPT_BY_TYPE[self.aten_type] + + if self.default == "True": + return "true" + + if self.default == "False": + return "false" + + if self.aten_type.startswith(("int[", "SymInt[")) and self.default is not None: + # `int[N]=[a, b, c]` → `{a, b, c}`; `int[N]=0` (scalar default + # for list type) → `{0, 0, ...}` replicated to size N. + if self.default.startswith("["): + return "{" + self.default[1:-1] + "}" + + size_match = re.search(r"\[(\d+)\]", self.aten_type) + n = int(size_match.group(1)) if size_match else 1 + + return "{" + ", ".join([self.default] * n) + "}" + + if self.aten_type == "str" and self.default is not None: + # YAML uses single-quoted strings (e.g. `'none'`); C++ char + # literals also use single quotes, so swap to doubles. + + return '"' + self.default.strip("'\"") + '"' + + if self.aten_type in {"int", "float", "SymInt"} and self.default is not None: + # Translate known ATen enum defaults to their C++ identifiers. + + return _ENUM_DEFAULTS.get(self.default, self.default) + + raise AssertionError( + f"param {self.name!r} of type {self.aten_type!r} with default " + f"{self.default!r} is not hidden" + ) + + @property + def cpp_type(self) -> str: + if self.is_tensor: + # `Tensor[]` / `Tensor(a!)[]` would need `std::vector` and a + # different ATen call shape — not yet supported, so reject so the + # whole overload gets skipped instead of emitting code that calls + # `at::_out(at::Tensor, ...)` against an `at::TensorList` + # signature. + if self.aten_type.endswith("[]"): + raise NotImplementedError( + f"`Tensor[]` param {self.name!r} not supported yet" + ) + + return "Tensor" + + if self.is_hidden: + # Not exposed — the ATen call substitutes a hardcoded value + # so the `cpp_type` is irrelevant. + + return "void" + + bare = self.aten_type.rstrip("?") + # Required `int[N]` / `SymInt[N]` (no default) — pybind11 accepts + # a Python list of ints into `std::vector`, which ATen + # promotes to `IntArrayRef` implicitly. + if bare.startswith(("int[", "SymInt[")) or bare in {"int[]", "SymInt[]"}: + return "std::vector" + + try: + return _SCALAR_TYPE_MAP[bare] + except KeyError as exc: + raise NotImplementedError( + f"unsupported ATen type {self.aten_type!r} for param {self.name!r}" + ) from exc + + +@dataclasses.dataclass +class Op: + aten_name: str + overload: str + params: list[Param] + + @property + def pascal_name(self) -> str: + return _snake_to_pascal(self.infini_name) + + @property + def infini_name(self) -> str: + """InfiniOps op name — always the canonical ATen base name. + + ATen disambiguates `_out` overloads with suffixes like `Tensor_Tensor_out`, + `out_x`, `forward_output`, `grad_input`, but reviewers consistently + flag those suffixes as bad public-API naming when they leak into + InfiniOps class names. Different ATen overloads of the same base op + become overloaded `operator()` methods on a single class instead. When + two overloads collapse to the same visible C++ signature after hidden + defaults, `_dedupe_visible_overloads` keeps only one. + """ + return self.aten_name + + @property + def tensor_params(self) -> list[Param]: + return [p for p in self.params if p.is_tensor] + + @property + def out_params(self) -> list[Param]: + """Mutable tensor outputs. Most ops have one (`Tensor(a!) out`); + multi-output ops like `frexp` or `sort` have several + (`Tensor(a!) values`, `Tensor(b!) indices`).""" + + return [p for p in self.params if p.is_out] + + @property + def out_param(self) -> Param: + """Single-output convenience. Asserts there's exactly one.""" + outs = self.out_params + assert len(outs) == 1, f"op {self.aten_name!r} has {len(outs)} out tensors" + + return outs[0] + + @property + def visible_params(self) -> list[Param]: + """Params the wrapper exposes to the user; hidden ones (hardcoded + optional nullopt, default-`False`/`True` bools) are filtered.""" + + return [p for p in self.params if not p.is_hidden] + + @property + def is_testable(self) -> bool: + """Cheap structural check: at least one out tensor, and the first + constructor parameter is a tensor. The latter is needed because + `Operator::Make(Tensor tensor, Args... args)` dispatches on + `tensor.device()`, so an op like `pow.Scalar_out(Scalar self, + Tensor exponent, *, Tensor(a!) out)` cannot be wired up without + a separate dispatch path. Generators like `arange` / `linspace` + also fall under this rule (no input tensors at all).""" + + if not self.out_params: + return False + + # `params` includes out tensors at the end; check the first + # non-out param. If there are no non-out params (`empty.out`, + # `arange.out`), this op also fails the dispatch precondition. + non_out = [p for p in self.params if not p.is_out] + + if not non_out: + return False + + return non_out[0].is_tensor + + +_FUNC_RE = re.compile( + r"^(?P[a-zA-Z_][a-zA-Z0-9_]*)" + r"(?:\.(?P\w+))?" + r"\((?P.*)\)\s*->\s*.+$" +) + +_ARG_RE = re.compile( + r"^(?P\S+(?:\([^)]*\))?\??)" # type with optional alias and `?` + r"\s+(?P\w+)" + r"(?:\s*=\s*(?P.+))?$" +) + + +def _parse_func(func_str: str) -> Op: + m = _FUNC_RE.match(func_str) + + if not m: + raise ValueError(f"could not parse func: {func_str!r}") + + return Op( + aten_name=m.group("name"), + overload=m.group("overload") or "", + params=_parse_args(m.group("args")), + ) + + +def _parse_args(args_str: str) -> list[Param]: + params: list[Param] = [] + keyword_only = False + + for token in _split_args(args_str): + if token == "*": + keyword_only = True + continue + + params.append(_parse_one_arg(token, keyword_only)) + + return params + + +def _split_args(args_str: str) -> list[str]: + """Split on top-level commas, respecting `(...)` and `[...]`.""" + parts: list[str] = [] + depth = 0 + current: list[str] = [] + + for ch in args_str: + if ch in "([": + depth += 1 + current.append(ch) + elif ch in ")]": + depth -= 1 + current.append(ch) + elif ch == "," and depth == 0: + piece = "".join(current).strip() + + if piece: + parts.append(piece) + + current = [] + else: + current.append(ch) + + tail = "".join(current).strip() + + if tail: + parts.append(tail) + + return parts + + +def _parse_one_arg(token: str, keyword_only: bool) -> Param: + m = _ARG_RE.match(token) + + if not m: + raise ValueError(f"could not parse arg: {token!r}") + + name = m.group("name") + # ATen names the first tensor parameter `self` (matching the + # method-style \`tensor.abs()\` convention). InfiniOps uses + # \`input\` for the primary tensor input across all hand-written + # bases (\`Add\`, \`Gemm\`, …) per \`CONTRIBUTING.md\` §C++. + # Rename at parse time so the generated headers match. + if name == "self": + name = "input" + + return Param( + name=name, + aten_type=m.group("type"), + default=m.group("default"), + keyword_only=keyword_only, + ) + + +def _snake_to_pascal(s: str) -> str: + return "".join(p.capitalize() for p in s.split("_")) + + +def _base_path(op_name: str) -> pathlib.Path: + return _BASE_DIR / f"{op_name}.h" + + +def _candidate_versions(version: str) -> list[str]: + """Return progressively-more-tolerant fallbacks for `version`: + + `v2.10.0a0+b4e4ee8` → [`v2.10.0a0+b4e4ee8`, `v2.10.0a0`, + `v2.10.0`, `v2.4.0`] + + NVIDIA-fork wheels (e.g. `2.10.0a0+b4e4ee8`) and other nightlies + do not have matching pytorch GitHub tags. Fall back to the + latest release we know exists when the version-suffixed tag + returns 404. + """ + + seen = [] + + def add(v: str) -> None: + if v and v not in seen: + seen.append(v) + + add(version) + add(version.split("+", 1)[0]) + add(re.sub(r"[a-z]\d+$", "", version.split("+", 1)[0])) + add(_DEFAULT_PYTORCH_VERSION) + return seen + + +def _load_aten_yaml(version: str) -> str: + """Return the contents of `native_functions.yaml` for `version`, + fetching and caching it on the first call. Falls back to + increasingly stable version tags if the requested one is missing + on pytorch GitHub (typical for pre-release / nightly torch builds + like `2.10.0a0+b4e4ee8`).""" + + last_error: Exception | None = None + + for candidate in _candidate_versions(version): + cache_path = ( + _REPO_ROOT / "generated" / ".cache" / f"native_functions-{candidate}.yaml" + ) + url = ( + f"https://raw.githubusercontent.com/pytorch/pytorch/{candidate}" + "/aten/src/ATen/native/native_functions.yaml" + ) + + if cache_path.exists(): + if candidate != version: + print( + f"using cached `native_functions.yaml` ({candidate}) as " + f"fallback for {version}.", + file=sys.stderr, + ) + + return cache_path.read_text() + + cache_path.parent.mkdir(parents=True, exist_ok=True) + print(f"fetching `native_functions.yaml` ({candidate})...", file=sys.stderr) + + try: + with urllib.request.urlopen(url) as response: + cache_path.write_bytes(response.read()) + except urllib.error.HTTPError as exc: + print( + f"`{candidate}` not found on pytorch GitHub ({exc.code}); " + "trying next fallback.", + file=sys.stderr, + ) + last_error = exc + continue + + return cache_path.read_text() + + raise RuntimeError( + f"could not fetch `native_functions.yaml` for any fallback of " + f"{version!r}: {last_error}" + ) + + +def _find_out_entries(entries: list[dict], op_name: str) -> list[dict]: + """Return all out-variant entries for `op_name`, with the bare + `.out(` form first and overload-suffixed variants + (e.g. `pow.Tensor_Tensor_out(`, `kthvalue.values(`) after. An + entry counts as an out-variant when it (a) is named + `.out`, (b) ends in `_out`, or (c) carries a + `Tensor(!)` mutability annotation — that last case covers + multi-output ops named after their output tensors + (`kthvalue.values`, `mode.values`, …).""" + bare_prefix = f"{op_name}.out(" + op_overload = re.compile(rf"^{re.escape(op_name)}\.\w+\(") + mut_tensor = re.compile(r"Tensor\([a-z]!\)") + bare: list[dict] = [] + others: list[dict] = [] + + for entry in entries: + func = entry.get("func", "") + + if func.startswith(bare_prefix): + bare.append(entry) + elif op_overload.match(func) and ( + func.split("(", 1)[0].endswith("_out") or mut_tensor.search(func) + ): + others.append(entry) + + return bare + others + + +def _format_signature(op: Op, *, include_defaults: bool = False) -> str: + parts = [] + + for param in op.visible_params: + prefix = "" if param.is_out else "const " + text = f"{prefix}{param.cpp_type} {param.name}" + + if include_defaults and param.default is not None: + text += f" = {_translate_default(param)}" + + parts.append(text) + + return ", ".join(parts) + + +def _visible_signature_key(op: Op) -> tuple[str, ...]: + """C++ overload identity for the user-facing API. + + Parameter names and top-level `const` do not distinguish C++ overloads, so + only the exposed C++ type sequence participates in duplicate detection. + """ + + return tuple(param.cpp_type for param in op.visible_params) + + +def _canonical_overload_score(index: int, op: Op) -> tuple[bool, int, int, str, int]: + """Sort key for duplicate visible signatures. + + Prefer the canonical unsuffixed InfiniOps name, then the schema that hides + fewer ATen-only defaults, then the shorter deterministic name. + """ + + return ( + op.infini_name != op.aten_name, + sum(param.is_hidden for param in op.params), + len(op.infini_name), + op.infini_name, + index, + ) + + +def _dedupe_visible_overloads(ops: list[Op]) -> tuple[list[Op], list[tuple[Op, Op]]]: + """Drop overloads that collapse to the same visible C++ signature. + + Returns the selected overloads in the original schema order plus a list of + `(skipped, kept)` duplicate pairs for diagnostics. + """ + winners: dict[tuple[str, ...], tuple[int, Op]] = {} + duplicates: list[tuple[Op, tuple[str, ...]]] = [] + + for index, op in enumerate(ops): + key = _visible_signature_key(op) + current = winners.get(key) + + if current is None: + winners[key] = (index, op) + continue + + current_index, current_op = current + + if _canonical_overload_score(index, op) < _canonical_overload_score( + current_index, current_op + ): + duplicates.append((current_op, key)) + winners[key] = (index, op) + else: + duplicates.append((op, key)) + + selected_indices = {index for index, _ in winners.values()} + selected = [op for index, op in enumerate(ops) if index in selected_indices] + duplicate_pairs = [ + (skipped, winners[key][1]) + for skipped, key in duplicates + if winners[key][1] is not skipped + ] + + return selected, duplicate_pairs + + +def _translate_default(param: Param) -> str: + """Translate a YAML default literal to a C++ literal.""" + raw = param.default + + if raw == "True": + return "true" + + if raw == "False": + return "false" + + if raw == "None": + return "{}" + + return raw # numeric literals (`0`, `1`, `1.0`) pass through + + +def _generate_base_header(name: str, ops: list[Op]) -> str: + pascal = _snake_to_pascal(name) + + member_decls = [] + tensor_member_order = [] + seen_tensor_members = set() + scalar_member_order = [] + scalar_member_types = {} + + for op in ops: + for param in op.tensor_params: + if param.name in seen_tensor_members: + continue + + seen_tensor_members.add(param.name) + tensor_member_order.append(param.name) + member_decls.append(f" Tensor::Shape {param.name}_shape_;") + member_decls.append(f" Tensor::Strides {param.name}_strides_;") + member_decls.append(f" DataType {param.name}_type_;") + + # Visible non-tensor params (scalars, strings, vectors) are also + # stored on the base so backends can dispatch on them later — not + # only at the moment `operator()` is invoked. Reviewers flagged + # this on multiple PRs (e.g. `n` on + # `special_chebyshev_polynomial_v_n_scalar`). Same-named params + # across overloads must share a type; if they conflict, the second + # overload's member is dropped (later constructors leave it + # default-initialised). + for param in op.visible_params: + if param.is_tensor or param.name in scalar_member_types: + continue + + scalar_member_order.append(param.name) + scalar_member_types[param.name] = param.cpp_type + member_decls.append(f" {param.cpp_type} {param.name}_{{}};") + + member_decls.append(" int device_index_{0};") + + constructors = [] + calls = [] + + for op in ops: + init_pieces = [] + tensor_params = {param.name: param for param in op.tensor_params} + scalar_params = { + param.name: param + for param in op.visible_params + if not param.is_tensor + and scalar_member_types.get(param.name) == param.cpp_type + } + + for param_name in tensor_member_order: + param = tensor_params.get(param_name) + + if param is None: + continue + + init_pieces.append(f" {param.name}_shape_{{{param.name}.shape()}}") + init_pieces.append( + f" {param.name}_strides_{{{param.name}.strides()}}" + ) + init_pieces.append(f" {param.name}_type_{{{param.name}.dtype()}}") + + for param_name in scalar_member_order: + param = scalar_params.get(param_name) + + if param is None: + continue + + init_pieces.append(f" {param.name}_{{{param.name}}}") + + # All out tensors share a device; use the first one. Keep this last + # so initializer order follows the member declaration order. + init_pieces.append( + f" device_index_{{{op.out_params[0].name}.device().index()}}" + ) + + init_list = ",\n".join(init_pieces).lstrip() + constructors.append( + f" {pascal}({_format_signature(op)})\n : {init_list} {{}}" + ) + calls.append(f" virtual void operator()({_format_signature(op)}) const = 0;") + + return _BASE_TEMPLATE.format( + name_uc=name.upper(), + pascal=pascal, + constructors="\n\n".join(constructors), + op_calls="\n\n".join(calls), + member_decls="\n\n".join(member_decls), + ) + + +def _generate_torch_header(name: str, ops: list[Op]) -> str: + pascal = _snake_to_pascal(name) + op_calls = "\n\n".join( + f" void operator()({_format_signature(op)}) const override;" for op in ops + ) + + return _TORCH_HEADER_TEMPLATE.format( + name_uc=name.upper(), + name=name, + pascal=pascal, + op_calls=op_calls, + slot=_PYTORCH_SLOT, + ) + + +def _generate_torch_method_source(name: str, op: Op) -> str: + pascal = _snake_to_pascal(name) + conversion_lines = [] + + for param in op.tensor_params: + data_expr = ( + f"{param.name}.data()" + if param.is_out + else f"const_cast({param.name}.data())" + ) + conversion_lines.append( + f" auto at_{param.name} = ToAtenTensor(\n" + f" {data_expr}, {param.name}_shape_, {param.name}_strides_,\n" + f" {param.name}_type_, device_index_);" + ) + + # ATen `_out` form puts all out tensors first, then non-out params + # in YAML order. Hardcoded-nullopt params become `at::nullopt`. + arg_order = op.out_params + [p for p in op.params if not p.is_out] + + def _render_arg(p): + if p.is_hidden: + return p.hidden_value() + + if p.is_tensor: + return f"at_{p.name}" + + return p.name + + aten_args = ", ".join(_render_arg(p) for p in arg_order) + + return _TORCH_METHOD_TEMPLATE.format( + pascal=pascal, + op_call_signature=_format_signature(op), + tensor_conversions="\n".join(conversion_lines), + # `at::_out` resolves the right kernel via C++ overload + # resolution from the argument types we pass. + aten_call=f"{op.aten_name}_out({aten_args})", + slot=_PYTORCH_SLOT, + ) + + +def _generate_torch_source(name: str, ops: list[Op]) -> str: + pascal = _snake_to_pascal(name) + methods = "\n\n".join(_generate_torch_method_source(name, op) for op in ops) + # Guard each explicit instantiation by the matching `WITH_` macro + # so a build that only enables a subset of devices does not pay the + # ATen template-instantiation cost (and memory pressure) for the + # devices it does not link against. Each macro is set by + # `target_compile_definitions` in `src/CMakeLists.txt`. + instantiations = "\n".join( + f"#ifdef WITH_{dev.removeprefix('k').upper()}\n" + f"template class Operator<{pascal}, Device::Type::{dev}, {_PYTORCH_SLOT}>;\n" + f"#endif" + for dev in _DEVICE_TYPES + ) + + return _TORCH_SOURCE_TEMPLATE.format( + name=name, + methods=methods, + instantiations=instantiations, + ) + + +_BASE_TEMPLATE = """\ +#ifndef INFINI_OPS_BASE_{name_uc}_H_ +#define INFINI_OPS_BASE_{name_uc}_H_ + +#include "operator.h" + +namespace infini::ops {{ + +class {pascal} : public Operator<{pascal}> {{ + public: +{constructors} + +{op_calls} + + protected: +{member_decls} +}}; + +}} // namespace infini::ops + +#endif +""" + + +_TORCH_HEADER_TEMPLATE = """\ +#ifndef INFINI_OPS_TORCH_{name_uc}_H_ +#define INFINI_OPS_TORCH_{name_uc}_H_ + +#include "base/{name}.h" + +namespace infini::ops {{ + +template +class Operator<{pascal}, kDev, {slot}> : public {pascal} {{ + public: + using {pascal}::{pascal}; + +{op_calls} +}}; + +}} // namespace infini::ops + +#endif +""" + + +_TORCH_METHOD_TEMPLATE = """\ +template +void Operator<{pascal}, kDev, {slot}>::operator()({op_call_signature}) const {{ +{tensor_conversions} + + at::{aten_call}; +}} +""" + + +_TORCH_SOURCE_TEMPLATE = """\ +#include "torch/{name}/{name}.h" + +#include "torch/tensor_.h" + +namespace infini::ops {{ + +{methods} + +{instantiations} + +}} // namespace infini::ops +""" + + +def _find_clang_format() -> str | None: + """Return the path to a `clang-format` binary, or `None` if none is + available. When the system does not provide one, try installing the + `clang-format` PyPI wheel; offline CI containers (no PyPI mirror) end + up returning `None` and the codegen falls through to writing + unformatted output — generated files live under `generated/` (which + is gitignored) so they do not need to satisfy the repo-level + clang-format check, only compile cleanly.""" + + found = shutil.which("clang-format") + + if found: + return found + + print( + "`clang-format` not found on PATH; trying `pip install clang-format`...", + file=sys.stderr, + ) + + try: + subprocess.run( + [sys.executable, "-m", "pip", "install", "--quiet", "clang-format"], + check=True, + ) + except subprocess.CalledProcessError: + print( + "`pip install clang-format` failed (likely offline CI); generated " + "files will be emitted without formatting.", + file=sys.stderr, + ) + + return None + + return shutil.which("clang-format") + + +def _clang_format(text: str, path: pathlib.Path) -> str: + """Pipe `text` through `clang-format` so generated headers / sources + satisfy the same style check (`clang-format` v21) that CI runs. + `path` informs include sorting (the file's own header should come + first in a `.cc`). If no `clang-format` binary is available, return + the input unchanged.""" + + if _CLANG_FORMAT is None: + return text + + return subprocess.run( + [_CLANG_FORMAT, f"--assume-filename={path}"], + input=text, + capture_output=True, + text=True, + check=True, + ).stdout + + +def _emit(name: str, ops: list[Op], *, emit_base: bool) -> None: + base_path = _GENERATED_BASE_DIR / f"{name}.h" + torch_dir = _GENERATED_TORCH_DIR / name + torch_header_path = torch_dir / f"{name}.h" + torch_source_path = torch_dir / f"{name}.cc" + + if emit_base: + _GENERATED_BASE_DIR.mkdir(parents=True, exist_ok=True) + base_path.write_text(_clang_format(_generate_base_header(name, ops), base_path)) + + torch_dir.mkdir(parents=True, exist_ok=True) + + torch_header_path.write_text( + _clang_format(_generate_torch_header(name, ops), torch_header_path) + ) + torch_source_path.write_text( + _clang_format(_generate_torch_source(name, ops), torch_source_path) + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--ops", + nargs="*", + help="Override the op allowlist. If omitted, reads `scripts/torch_ops.yaml`.", + ) + parser.add_argument( + "--pytorch-version", + default=os.environ.get("INFINIOPS_PYTORCH_VERSION", _DEFAULT_PYTORCH_VERSION), + help=( + "PyTorch release tag whose `native_functions.yaml` defines the " + "schemas to generate against (e.g. `v2.1.0` for Cambricon's " + "`torch_mlu` 2.1.0 fork). Default: `%(default)s`. Can also be " + "set via the `INFINIOPS_PYTORCH_VERSION` environment variable." + ), + ) + args = parser.parse_args() + + global _CLANG_FORMAT + _CLANG_FORMAT = _find_clang_format() + + op_names = args.ops or yaml.safe_load(_OPS_YAML_PATH.read_text()) + aten_entries = yaml.safe_load(_load_aten_yaml(args.pytorch_version)) + + # Wipe previous outputs so files for ops that have since been removed, + # renamed, or rejected by `cpp_type` don't linger and get picked up by + # the CMake glob. Both `generated/base/` and `generated/torch/` are + # written exclusively by this script. + if _GENERATED_BASE_DIR.exists(): + shutil.rmtree(_GENERATED_BASE_DIR) + + if _GENERATED_TORCH_DIR.exists(): + shutil.rmtree(_GENERATED_TORCH_DIR) + + skipped: list[tuple[str, str]] = [] + metadata: list[dict] = [] + + for name in op_names: + candidates = _find_out_entries(aten_entries, name) + + if not candidates: + skipped.append((name, f"no `.out` variant for `{name}` in YAML")) + continue + + usable: list[Op] = [] + last_reason = "" + + for entry in candidates: + try: + op = _parse_func(entry["func"]) + + for param in op.params: + _ = param.cpp_type # eagerly raise on unsupported types + except (NotImplementedError, ValueError) as exc: + last_reason = str(exc) + continue + + if not op.is_testable: + last_reason = "no testable tensor input/output pair" + continue + + usable.append(op) + + if not usable: + skipped.append((name, last_reason or "no usable overload")) + continue + + usable, duplicate_overloads = _dedupe_visible_overloads(usable) + + for skipped_op, kept_op in duplicate_overloads: + skipped.append( + ( + skipped_op.infini_name, + "duplicate visible C++ signature for " + f"`{name}`; using `{kept_op.infini_name}`", + ) + ) + + # Emit one InfiniOps wrapper per ATen op. Distinct visible overloads + # become overloaded constructors / `operator()` methods on the same + # class (`Pow` exposes both tensor and scalar exponents). Overloads + # that collapse to the same C++ signature after hidden defaults are + # skipped above. When a hand-written `src/base/.h` exists, + # skip emitting `generated/base/.h` so the hand-written one + # wins (the generated torch source's `#include "base/.h"` + # resolves through `src/` first). Signature mismatches surface as + # compile errors with a clear message — drop the op from the YAML + # to suppress. + _emit(name, usable, emit_base=not _base_path(name).exists()) + + for op in usable: + metadata.append( + { + "name": name, + "aten_name": op.aten_name, + "overload_name": op.infini_name, + "params": [ + { + "name": p.name, + "type": p.aten_type, + "is_tensor": p.is_tensor, + "is_out": p.is_out, + } + for p in op.visible_params + ], + } + ) + + _GENERATED_DIR.mkdir(parents=True, exist_ok=True) + _METADATA_PATH.write_text(json.dumps({"ops": metadata}, indent=2) + "\n") + + generated_names = sorted({m["name"] for m in metadata}) + print( + f"generated {len(metadata)} overloads across {len(generated_names)} ops: " + f"{generated_names}" + ) + + for name, reason in skipped: + print(f" skipped {name!r}: {reason}", file=sys.stderr) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index effc0787..d0226f0c 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,4 +1,5 @@ import argparse +import functools import json import pathlib import re @@ -15,6 +16,11 @@ _GENERATION_DIR = pathlib.Path("generated") +# Base headers emitted by `generate_torch_ops.py` live alongside the +# hand-written ones in `src/base/`, but in a parallel tree under +# `generated/base/` so they are not committed. +_GENERATED_BASE_DIR = _GENERATION_DIR / "base" + _BINDINGS_DIR = _GENERATION_DIR / "bindings" _GENERATED_SRC_DIR = _GENERATION_DIR / "src" @@ -24,37 +30,61 @@ _INDENTATION = " " -class _OperatorExtractor: - def __call__(self, op_name): - def _get_system_include_flags(): - def _get_compilers(): - compilers = [] +@functools.lru_cache(maxsize=1) +def _get_system_include_flags(): + """Probe the system C++ compiler for default include paths so libclang + can resolve standard headers when parsing an op's base header.""" + compilers = [] - for compiler in ("clang++", "g++"): - if shutil.which(compiler) is not None: - compilers.append(compiler) + for compiler in ("clang++", "g++"): + if shutil.which(compiler) is not None: + compilers.append(compiler) - return compilers + system_include_flags = [] - system_include_flags = [] + for compiler in compilers: + for line in subprocess.getoutput( + f"{compiler} -E -x c++ -v /dev/null" + ).splitlines(): + if not line.startswith(" "): + continue - for compiler in _get_compilers(): - for line in subprocess.getoutput( - f"{compiler} -E -x c++ -v /dev/null" - ).splitlines(): - if not line.startswith(" "): - continue + system_include_flags.append("-isystem") + system_include_flags.append(line.strip()) + + return tuple(system_include_flags) - system_include_flags.append("-isystem") - system_include_flags.append(line.strip()) - return system_include_flags +def _find_base_header(op_name): + """Resolve the base header for `op_name`, preferring the hand-written + `src/base/.h` over the auto-generated `generated/base/.h`. + Mirrors the include-path resolution order used at compile time.""" + src_path = _BASE_DIR / f"{op_name}.h" - system_include_flags = _get_system_include_flags() + if src_path.exists(): + return src_path + generated_path = _GENERATED_BASE_DIR / f"{op_name}.h" + + if generated_path.exists(): + return generated_path + + raise FileNotFoundError(f"no base header for op {op_name!r}") + + +class _OperatorExtractor: + def __call__(self, op_name): index = clang.cindex.Index.create() - args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) - translation_unit = index.parse(f"src/base/{op_name}.h", args=args) + args = ( + "-std=c++17", + "-x", + "c++", + "-I", + "src", + "-I", + str(_GENERATION_DIR), + ) + _get_system_include_flags() + translation_unit = index.parse(str(_find_base_header(op_name)), args=args) nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) @@ -98,7 +128,7 @@ def _find_optional_tensor_params(op_name): headers are not fully available, so we fall back to a regex scan of the source text. """ - source = (_BASE_DIR / f"{op_name}.h").read_text() + source = _find_base_header(op_name).read_text() return set(re.findall(r"std::optional\s+(\w+)", source)) @@ -107,14 +137,31 @@ def _find_vector_tensor_params(op_name): """Return a set of parameter names declared as `std::vector` in the base header. """ - source = (_BASE_DIR / f"{op_name}.h").read_text() + source = _find_base_header(op_name).read_text() return set(re.findall(r"std::vector\s+(\w+)", source)) +def _find_vector_int64_params(op_name): + """Return a set of parameter names declared as `std::vector` in + the base header. + + libclang on systems where the STL headers are not fully indexable + silently falls back to reporting the type as `int` for these params, + which then leaks into the generated bindings as `const int padding` + instead of `const std::vector padding` and breaks the call + to the base operator. Regex-scan the source so the binding's + parameter type comes from the actual declaration. + """ + source = _find_base_header(op_name).read_text() + + return set(re.findall(r"std::vector\s+(\w+)", source)) + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) + vector_int64_params = _find_vector_int64_params(operator.name) def _is_optional_tensor(arg): if arg.spelling in optional_tensor_params: @@ -131,6 +178,9 @@ def _is_vector_tensor(arg): return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + def _is_vector_int64(arg): + return arg.spelling in vector_int64_params + def _generate_params(node): parts = [] @@ -142,6 +192,8 @@ def _generate_params(node): parts.append(f"std::optional {arg.spelling}") elif _is_vector_tensor(arg): parts.append(f"std::vector {arg.spelling}") + elif _is_vector_int64(arg): + parts.append(f"const std::vector {arg.spelling}") else: param = arg.type.spelling.replace("const Tensor", "py::object").replace( "Tensor", "py::object" @@ -216,16 +268,51 @@ def _generate_call(op_name, call, method=True): f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);' ) - return f""" .def("__call__", [](const Self& self, {call_params}) {{ - return static_cast&>(self)({call_args}); + # The first lambda parameter is conventionally named `self`, but + # ATen schemas often have a parameter literally called `self` + # (e.g. `pow.Tensor_Scalar_out(Scalar self, Tensor exponent)`), + # so rename to `op` to avoid the collision in the generated code. + return f""" .def("__call__", [](const Self& op, {call_params}) {{ + return static_cast&>(op)({call_args}); }})""" - inits = "\n".join( - _generate_init(constructor) for constructor in operator.constructors - ) - calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) + def _overload_order_key(node): + """Sort key that places more-specific overloads first. + + Tensor parameters are exposed to pybind as `py::object`, which + accepts any Python value and only fails inside + `TensorFromPybind11Handle`. When a class has both Tensor and + scalar overloads, pybind's overload-resolver tries them in + registration order and stops at the first that does not raise, + so the scalar overload must be registered first; otherwise the + permissive Tensor signature swallows scalar calls and aborts at + runtime. + """ + object_like = 0 + total = 0 + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + + total += 1 + + if ( + _is_optional_tensor(arg) + or _is_vector_tensor(arg) + or "Tensor" in arg.type.spelling + ): + object_like += 1 + + return (object_like, -total) + + constructors = sorted(operator.constructors, key=_overload_order_key) + operator_calls = sorted(operator.calls, key=_overload_order_key) + + inits = "\n".join(_generate_init(constructor) for constructor in constructors) + calls = "\n".join(_generate_call(operator.name, call) for call in operator_calls) callers = "\n".join( - _generate_call(operator.name, call, method=False) for call in operator.calls + _generate_call(operator.name, call, method=False) for call in operator_calls ) pascal_case_op_name = _snake_to_pascal(op_name) @@ -253,7 +340,11 @@ def _generate_call(op_name, call, method=True): {inits} {calls} .def_static("active_implementation_indices", [](const std::string& device) {{ - return Self::active_implementation_indices(DeviceTypeFromString(device)); + auto dev_type = TryDeviceTypeFromString(device); + if (!dev_type.has_value()) {{ + return std::vector{{}}; + }} + return Self::active_implementation_indices(*dev_type); }}) .def_static("clear_cache", &Self::clear_cache); @@ -268,8 +359,17 @@ def _generate_call(op_name, call, method=True): def _generate_legacy_c(operator, paths): def _generate_source(operator): + def _to_include_path(path): + text = str(path) + + for prefix in ("src/", "generated/"): + if text.startswith(prefix): + return text[len(prefix) :] + + return text + impl_includes = "\n".join( - f'#include "{str(path).removeprefix("src/")}"' for path in paths + f'#include "{_to_include_path(path)}"' for path in paths ) return f"""#include "../../handle.h" @@ -444,6 +544,10 @@ def _snake_to_pascal(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) +def _matches_scan_dir(impl_path, scan_dirs): + return any(part in scan_dirs for part in impl_path.parts) + + def _get_all_ops(devices, with_torch=False): scan_dirs = set(devices) @@ -452,20 +556,45 @@ def _get_all_ops(devices, with_torch=False): ops = {} - for file_path in _BASE_DIR.iterdir(): - if not file_path.is_file(): - continue + base_dirs = [_BASE_DIR] + + # Only pull in the auto-generated torch op bases when the build is + # actually compiling them (`--with-torch`). Otherwise a stale + # `generated/` left over from a previous configure (or rsynced into + # a CI container) would cause `ops.cc` to include base headers for + # ops that have no compiled implementation, breaking the build. + if with_torch and _GENERATED_BASE_DIR.exists(): + base_dirs.append(_GENERATED_BASE_DIR) + + impl_roots = [_SRC_DIR] + + if with_torch and (_GENERATION_DIR / "torch").exists(): + impl_roots.append(_GENERATION_DIR) - op_name = file_path.stem + for base_dir in base_dirs: + for file_path in base_dir.iterdir(): + if not file_path.is_file(): + continue - ops[op_name] = [] + op_name = file_path.stem - for file_path in _SRC_DIR.rglob("*.h"): - if file_path.parent.parent.parent.name not in scan_dirs: + # Hand-written `src/base/` is scanned first; the generated + # tree never overrides an already-known op. + if op_name in ops: continue - if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text(): - ops[op_name].append(file_path) + ops[op_name] = [] + + for impl_root in impl_roots: + for impl_path in impl_root.rglob("*.h"): + if not _matches_scan_dir(impl_path, scan_dirs): + continue + + if ( + f"class Operator<{_snake_to_pascal(op_name)}" + in impl_path.read_text() + ): + ops[op_name].append(impl_path) return ops @@ -489,9 +618,14 @@ def _get_all_ops(devices, with_torch=False): args = parser.parse_args() - _BINDINGS_DIR.mkdir(parents=True, exist_ok=True) - _GENERATED_SRC_DIR.mkdir(parents=True, exist_ok=True) - _INCLUDE_DIR.mkdir(parents=True, exist_ok=True) + # Wipe previous outputs so files for ops that have since been removed + # from the active set (e.g. when toggling `--with-torch`) do not linger + # and get globbed by a later build. + for d in (_BINDINGS_DIR, _GENERATED_SRC_DIR, _INCLUDE_DIR): + if d.exists(): + shutil.rmtree(d) + + d.mkdir(parents=True) ops_json = pathlib.Path("ops.json") diff --git a/scripts/torch_ops.yaml b/scripts/torch_ops.yaml new file mode 100644 index 00000000..37dccc63 --- /dev/null +++ b/scripts/torch_ops.yaml @@ -0,0 +1,470 @@ +# Allowlist of ATen ops to expose as InfiniOps operators. +# +# Auto-discovered: every base op name with at least one parsable +# `.out` overload using the supported type vocabulary. The +# generator emits one InfiniOps wrapper per overload, so this +# file lists ~390 base names but produces 500+ wrappers. +# +# To exclude an op, comment out its line. Ops whose hand-written +# `src/base/.h` signature does not match the ATen-derived one +# (currently `add`, `linear`, `matmul`, `mul` — they pre-date this +# codegen and use a different parameter shape) must stay excluded: +# the generator skips emitting their base, but would still emit a +# torch backend declaring `operator()` with the ATen signature, and +# that override would not compile against the hand-written base. + +- abs +- absolute +- acos +- acosh +- adaptive_avg_pool2d +- adaptive_avg_pool3d +- adaptive_avg_pool3d_backward +- adaptive_max_pool2d +- adaptive_max_pool2d_backward +- adaptive_max_pool3d +- adaptive_max_pool3d_backward +- addbmm +- addcdiv +- addcmul +- addmm +- addmv +- addr +- all +- amax +- amin +- aminmax +- angle +- any +- arange +- arccos +- arccosh +- arcsin +- arcsinh +- arctan +- arctan2 +- arctanh +- argmax +- argmin +- asin +- asinh +- atan +- atan2 +- atanh +- avg_pool2d +- avg_pool2d_backward +- avg_pool3d +- avg_pool3d_backward +- baddbmm +- batch_norm_elemt +- bernoulli +- binary_cross_entropy +- binary_cross_entropy_backward +- bitwise_and +- bitwise_left_shift +- bitwise_not +- bitwise_or +- bitwise_right_shift +- bitwise_xor +- bmm +- bucketize +- ceil +- cholesky +- cholesky_inverse +- cholesky_solve +- clamp +- clamp_max +- clamp_min +- clip +- col2im +- complex +- conj_physical +- copysign +- cos +- cosh +- cross +- cudnn_convolution +- cummax +- cummin +- cumprod +- cumsum +- deg2rad +- diag +- diff +- digamma +- div +- divide +- dot +- elu +- elu_backward +- empty +- eq +- erf +- erfc +- erfinv +- exp +- exp2 +- expm1 +- eye +- fft_fft +- fft_fft2 +- fft_fftfreq +- fft_fftn +- fft_hfft +- fft_hfft2 +- fft_hfftn +- fft_ifft +- fft_ifft2 +- fft_ifftn +- fft_ihfft +- fft_ihfft2 +- fft_ihfftn +- fft_irfft +- fft_irfft2 +- fft_irfftn +- fft_rfft +- fft_rfft2 +- fft_rfftfreq +- fft_rfftn +- fix +- float_power +- floor +- floor_divide +- fmax +- fmin +- fmod +- frac +- fractional_max_pool2d +- fractional_max_pool2d_backward +- fractional_max_pool3d +- fractional_max_pool3d_backward +- frexp +- frobenius_norm +- full +- gather +- gcd +- ge +- gelu +- gelu_backward +- geqrf +- ger +- glu +- glu_backward +- greater +- greater_equal +- gt +- hardshrink +- hardshrink_backward +- hardsigmoid +- hardsigmoid_backward +- hardswish +- hardtanh +- hardtanh_backward +- heaviside +- histc +- histogram +- hspmm +- huber_loss +- huber_loss_backward +- hypot +- i0 +- igamma +- igammac +- im2col +- index +- index_add +- index_copy +- index_reduce +- index_select +- inner +- inverse +- isin +- isneginf +- isposinf +- kron +- kthvalue +- lcm +- ldexp +- le +- leaky_relu +- leaky_relu_backward +- lerp +- less +- less_equal +- lgamma +- linalg_cholesky +- linalg_cholesky_ex +- linalg_cond +- linalg_cross +- linalg_det +- linalg_eig +- linalg_eigh +- linalg_eigvals +- linalg_eigvalsh +- linalg_householder_product +- linalg_inv +- linalg_inv_ex +- linalg_ldl_factor +- linalg_ldl_factor_ex +- linalg_ldl_solve +- linalg_lstsq +- linalg_lu +- linalg_lu_factor +- linalg_lu_factor_ex +- linalg_lu_solve +- linalg_matmul +- linalg_matrix_norm +- linalg_matrix_power +- linalg_matrix_rank +- linalg_norm +- linalg_pinv +- linalg_qr +- linalg_slogdet +- linalg_solve +- linalg_solve_ex +- linalg_solve_triangular +- linalg_svd +- linalg_svdvals +- linalg_tensorinv +- linalg_tensorsolve +- linalg_vecdot +- linalg_vector_norm +- linspace +- log +- log10 +- log1p +- log2 +- log_sigmoid +- log_sigmoid_backward +- log_sigmoid_forward +- log_softmax +- logaddexp +- logaddexp2 +- logcumsumexp +- logical_and +- logical_not +- logical_or +- logical_xor +- logit +- logit_backward +- logspace +- logsumexp +- lt +- lu_solve +- lu_unpack +- masked_select +- matrix_power +- max +- max_pool2d_with_indices +- max_pool2d_with_indices_backward +- max_pool3d_with_indices +- max_pool3d_with_indices_backward +- max_unpool2d +- max_unpool3d +- maximum +- mean +- median +- min +- minimum +- mish +- mkldnn_adaptive_avg_pool2d +- mm +- mode +- mse_loss +- mse_loss_backward +- msort +- multi_margin_loss +- multi_margin_loss_backward +- multilabel_margin_loss +- multilabel_margin_loss_backward +- multilabel_margin_loss_forward +- multinomial +- multiply +- mv +- mvlgamma +- nan_to_num +- nanmean +- nanmedian +- nanquantile +- nansum +- narrow_copy +- native_batch_norm +- ne +- neg +- negative +- nextafter +- nll_loss +- nll_loss2d +- nll_loss2d_backward +- nll_loss2d_forward +- nll_loss_backward +- nll_loss_forward +- nonzero +- nonzero_static +- norm +- normal +- not_equal +- nuclear_norm +- ones +- orgqr +- ormqr +- outer +- polar +- polygamma +- pow +- prod +- qr +- quantile +- rad2deg +- rand +- randint +- randn +- randperm +- range +- reciprocal +- reflection_pad1d +- reflection_pad1d_backward +- reflection_pad2d +- reflection_pad2d_backward +- reflection_pad3d +- reflection_pad3d_backward +- remainder +- renorm +- replication_pad1d +- replication_pad1d_backward +- replication_pad2d +- replication_pad2d_backward +- replication_pad3d +- replication_pad3d_backward +- round +- rrelu_with_noise +- rsqrt +- scatter +- scatter_add +- scatter_reduce +- searchsorted +- sgn +- sigmoid +- sigmoid_backward +- sign +- signbit +- silu +- silu_backward +- sin +- sinc +- sinh +- slogdet +- slow_conv3d +- slow_conv3d_forward +- slow_conv_transpose2d +- slow_conv_transpose3d +- smooth_l1_loss +- smooth_l1_loss_backward +- soft_margin_loss +- soft_margin_loss_backward +- softmax +- softplus +- softplus_backward +- softshrink +- softshrink_backward +- sort +- sparse_sampled_addmm +- special_airy_ai +- special_bessel_j0 +- special_bessel_j1 +- special_bessel_y0 +- special_bessel_y1 +- special_chebyshev_polynomial_t +- special_chebyshev_polynomial_u +- special_chebyshev_polynomial_v +- special_chebyshev_polynomial_w +- special_digamma +- special_entr +- special_erf +- special_erfc +- special_erfcx +- special_erfinv +- special_exp2 +- special_expit +- special_expm1 +- special_gammainc +- special_gammaincc +- special_gammaln +- special_hermite_polynomial_h +- special_hermite_polynomial_he +- special_i0 +- special_i0e +- special_i1 +- special_i1e +- special_laguerre_polynomial_l +- special_legendre_polynomial_p +- special_log1p +- special_log_ndtr +- special_logit +- special_logsumexp +- special_modified_bessel_i0 +- special_modified_bessel_i1 +- special_modified_bessel_k0 +- special_modified_bessel_k1 +- special_multigammaln +- special_ndtr +- special_ndtri +- special_polygamma +- special_psi +- special_round +- special_scaled_modified_bessel_k0 +- special_scaled_modified_bessel_k1 +- special_shifted_chebyshev_polynomial_t +- special_shifted_chebyshev_polynomial_u +- special_shifted_chebyshev_polynomial_v +- special_shifted_chebyshev_polynomial_w +- special_sinc +- special_spherical_bessel_j0 +- special_xlog1py +- special_xlogy +- special_zeta +- split_copy +- split_with_sizes_copy +- sqrt +- square +- sspaddmm +- std +- sub +- subtract +- sum +- svd +- take +- take_along_dim +- tan +- tanh +- tanh_backward +- tensordot +- thnn_conv2d +- threshold +- threshold_backward +- topk +- triangular_solve +- tril +- triu +- true_divide +- trunc +- unbind_copy +- upsample_bicubic2d +- upsample_bicubic2d_backward +- upsample_bilinear2d +- upsample_bilinear2d_backward +- upsample_linear1d +- upsample_linear1d_backward +- upsample_nearest1d +- upsample_nearest1d_backward +- upsample_nearest2d +- upsample_nearest2d_backward +- upsample_nearest3d +- upsample_nearest3d_backward +- upsample_trilinear3d +- upsample_trilinear3d_backward +- var +- vdot +- where +- xlogy +- zeros diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ce888b4b..c185bf2a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -252,11 +252,63 @@ if(WITH_ASCEND) endif() if(WITH_TORCH) - file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS "torch/*.cc" "torch/*.cpp") + # Auto-generate ATen-backed operator wrappers from `scripts/torch_ops.yaml`. + # The script writes into `${PROJECT_SOURCE_DIR}/generated/` (gitignored), + # which we then glob below alongside any hand-written torch sources. + find_package(Python COMPONENTS Interpreter REQUIRED) + + # Pin codegen to the locally installed torch version so vendor + # forks (Cambricon's `torch_mlu` 2.1.0, etc.) get a schema whose + # `at::_out` overloads match the headers they ship. Without + # this, the codegen targets v2.4.0 and the build fails on older + # forks with no-known-conversion errors (e.g. `at::all_out`'s + # `int64_t dim` vs `OptionalIntArrayRef dim`). + execute_process( + COMMAND ${_TORCH_PYTHON} -c + "import torch; print('v' + torch.__version__.split('+')[0])" + OUTPUT_VARIABLE _torch_version_tag + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _torch_version_result + ) + if(NOT _torch_version_result EQUAL 0 OR NOT _torch_version_tag) + set(_torch_version_tag "v2.4.0") + endif() + message(STATUS "Codegen schema: PyTorch ${_torch_version_tag}") + + execute_process( + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_torch_ops.py + --pytorch-version ${_torch_version_tag} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + RESULT_VARIABLE _torch_ops_result + ) + if(NOT _torch_ops_result EQUAL 0) + message(FATAL_ERROR "Generating torch op wrappers - failed") + endif() + message(STATUS "Generating torch op wrappers - done") + + file(GLOB_RECURSE TORCH_SOURCES CONFIGURE_DEPENDS + "torch/*.cc" "torch/*.cpp" + "${PROJECT_SOURCE_DIR}/generated/torch/*.cc" + "${PROJECT_SOURCE_DIR}/generated/torch/*.cpp" + ) target_compile_definitions(infiniops PUBLIC WITH_TORCH=1) target_link_libraries(infiniops PUBLIC ${TORCH_LIBRARIES}) - target_include_directories(infiniops PUBLIC ${TORCH_INCLUDE_DIRS}) + target_include_directories(infiniops PUBLIC + ${TORCH_INCLUDE_DIRS} + ${PROJECT_SOURCE_DIR}/generated + ) + + # Each generated `.cc` instantiates `at::_out(...)`, which + # pulls in roughly 0.5-1 GB of ATen template metaprogramming. At + # ninja's default parallelism (one job per CPU), a build with 451 + # ops can blow past 30 GB of RSS and the OOM killer drops + # `cc1plus`. Cap the heavyweight torch sources to 4 concurrent + # compilations via a Ninja job pool; the rest of the build keeps + # full parallelism. + if(CMAKE_GENERATOR MATCHES "Ninja") + set_property(GLOBAL APPEND PROPERTY JOB_POOLS torch_compile=4) + endif() if(WITH_METAX OR WITH_MOORE) # Vendor compilers (`mxcc`/`mcc`) cannot compile vendor-forked `torch` @@ -275,7 +327,13 @@ if(WITH_TORCH) # Vendor-specific defines required by forked `torch` headers. set(_torch_extra_flags "") if(WITH_METAX) - list(APPEND _torch_extra_flags "-DUSE_MACA=1") + list(APPEND _torch_extra_flags "-DUSE_MACA=1" "-DWITH_METAX=1") + endif() + if(WITH_MOORE) + list(APPEND _torch_extra_flags "-DWITH_MOORE=1") + endif() + if(WITH_CPU) + list(APPEND _torch_extra_flags "-DWITH_CPU=1") endif() if(DEFINED TORCH_CXX11_ABI) list(APPEND _torch_extra_flags "-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") @@ -297,11 +355,13 @@ if(WITH_TORCH) COMMAND ${SYSTEM_CXX} -std=c++17 -fPIC -O2 "-I${CMAKE_CURRENT_SOURCE_DIR}" + "-I${PROJECT_SOURCE_DIR}/generated" ${_torch_include_flags} ${_torch_extra_flags} -c "${_src}" -o "${_obj}" DEPENDS "${_src}" COMMENT "Compiling ${_rel} with system C++ compiler" + JOB_POOL torch_compile ) list(APPEND TORCH_OBJECT_FILES "${_obj}") endforeach() @@ -310,7 +370,27 @@ if(WITH_TORCH) PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE) target_sources(infiniops PRIVATE ${TORCH_OBJECT_FILES}) else() - target_sources(infiniops PRIVATE ${TORCH_SOURCES}) + # Build the heavy torch sources as their own object library so + # the Ninja `torch_compile` job pool throttles only those + # compilations and the rest of `infiniops` keeps full + # parallelism. Inherit infiniops's compile-time settings via + # generator expressions (linking would create a cyclic + # dependency since infiniops then absorbs the object files). + add_library(infiniops_torch_objs OBJECT ${TORCH_SOURCES}) + target_include_directories(infiniops_torch_objs PRIVATE + $ + ${TORCH_INCLUDE_DIRS} + ${PROJECT_SOURCE_DIR}/generated) + target_compile_definitions(infiniops_torch_objs PRIVATE + $) + target_compile_options(infiniops_torch_objs PRIVATE + $) + if(CMAKE_GENERATOR MATCHES "Ninja") + set_target_properties(infiniops_torch_objs + PROPERTIES JOB_POOL_COMPILE torch_compile) + endif() + target_sources(infiniops PRIVATE + $) endif() endif() @@ -391,4 +471,11 @@ if(GENERATE_PYTHON_BINDINGS) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" "") install(FILES "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" DESTINATION .) + + if(WITH_TORCH) + # Ship the per-op metadata alongside the bindings so the unified + # torch op test can discover what to exercise at runtime. + install(FILES "${PROJECT_SOURCE_DIR}/generated/torch_ops_metadata.json" + DESTINATION .) + endif() endif() diff --git a/src/native/cuda/ops/swiglu/kernel.cuh b/src/native/cuda/ops/swiglu/kernel.cuh index a782b6f6..174439cf 100644 --- a/src/native/cuda/ops/swiglu/kernel.cuh +++ b/src/native/cuda/ops/swiglu/kernel.cuh @@ -7,10 +7,15 @@ namespace infini::ops { +namespace detail { + // Optimized sigmoid function with support for FP16 and BF16 types. // TODO: The unified FP16/BF16 branch uses `Caster` and scalar float // arithmetic instead of native vectorized intrinsics (e.g. `h2rcp`, // `__hmul2`). Profile and restore specialized paths if needed. +// +// Lives in `detail::` so it does not collide with the auto-generated +// `infini::ops::Sigmoid` operator class emitted by `generate_torch_ops.py`. template __device__ __forceinline__ T Sigmoid(const T& x) { if constexpr (IsFP16 || IsBFloat16) { @@ -24,6 +29,8 @@ __device__ __forceinline__ T Sigmoid(const T& x) { } } +} // namespace detail + // SwiGLU(x, gate) = Swish(x) * gate = (x * sigmoid(x)) * gate. template __global__ void SwigluKernel(T* __restrict__ out, const T* __restrict__ a, @@ -70,9 +77,10 @@ __global__ void SwigluKernel(T* __restrict__ out, const T* __restrict__ a, out[out_idx] = Caster::template Cast( __fmul_rn(__fmul_rn(gatef, sigf), upf)); } else if constexpr (std::is_same_v) { - out[out_idx] = __fmul_rn(__fmul_rn(gate, Sigmoid(gate)), up); + out[out_idx] = + __fmul_rn(__fmul_rn(gate, detail::Sigmoid(gate)), up); } else { - out[out_idx] = gate * Sigmoid(gate) * up; + out[out_idx] = gate * detail::Sigmoid(gate) * up; } } } diff --git a/src/operator.h b/src/operator.h index 83fc4ec2..607fb76a 100644 --- a/src/operator.h +++ b/src/operator.h @@ -67,6 +67,11 @@ std::vector ListToVector(List) { return {static_cast(values)...}; } +template +bool ListContains(ValueType value, List) { + return ((value == static_cast(values)) || ...); +} + } // namespace infini::ops::detail template <> @@ -213,6 +218,10 @@ class Operator : public OperatorBase { static std::vector active_implementation_indices( Device::Type dev_type) { + if (!detail::ListContains(dev_type, ActiveDevices{})) { + return {}; + } + std::vector result; DispatchFunc>( dev_type, diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index f13d3116..0f6332d8 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -41,6 +41,43 @@ inline Device::Type DeviceTypeFromString(const std::string& name) { return Device::TypeFromString(name); } +// Returns `nullopt` rather than aborting when the name does not resolve. +// Used by generated pybind bindings to query implementation indices for +// devices an op may not support, without crashing the process. +template +inline std::optional TryDeviceTypeFromString( + const std::string& name) { + static const auto kTorchNameToTypes{ + detail::BuildTorchNameMap(ActiveDevices{})}; + + auto it{kTorchNameToTypes.find(name)}; + + if (it != kTorchNameToTypes.cend()) { + return it->second; + } + + static const std::unordered_map kPlatformNames{ + {"cpu", Device::Type::kCpu}, + {"nvidia", Device::Type::kNvidia}, + {"cambricon", Device::Type::kCambricon}, + {"ascend", Device::Type::kAscend}, + {"metax", Device::Type::kMetax}, + {"moore", Device::Type::kMoore}, + {"iluvatar", Device::Type::kIluvatar}, + {"kunlun", Device::Type::kKunlun}, + {"hygon", Device::Type::kHygon}, + {"qy", Device::Type::kQy}, + }; + + auto platform_it{kPlatformNames.find(name)}; + + if (platform_it != kPlatformNames.cend()) { + return platform_it->second; + } + + return std::nullopt; +} + inline Tensor TensorFromPybind11Handle(py::handle obj) { auto data{ reinterpret_cast(obj.attr("data_ptr")().cast())}; diff --git a/tests/conftest.py b/tests/conftest.py index 86d01c24..875f33dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -301,7 +301,16 @@ def pytest_pyfunc_call(pyfuncitem): rtol = payload.rtol atol = payload.atol - assert torch.allclose(output, expected, rtol=rtol, atol=atol) + # `torch.allclose` rejects `bool` dtypes — use `torch.equal` for + # non-floating outputs (bool, int) so comparison ops work. Pass + # `equal_nan=True` so NaN-in-both-positions (common for special + # functions fed out-of-domain inputs) does not fail the test. + if output.dtype.is_floating_point: + assert torch.allclose( + output, expected, rtol=rtol, atol=atol, equal_nan=True + ) + else: + assert torch.equal(output, expected) return True diff --git a/tests/test_torch_ops.py b/tests/test_torch_ops.py new file mode 100644 index 00000000..68c321e7 --- /dev/null +++ b/tests/test_torch_ops.py @@ -0,0 +1,467 @@ +"""Unified test for every operator emitted by `generate_torch_ops.py`. + +The generator writes `generated/torch_ops_metadata.json` listing every op +with full per-parameter info (`name`, `type`, `is_tensor`, `is_out`). +A single parametrized test reads that metadata, builds inputs from the +parameter list, calls the InfiniOps wrapper and the torch reference, and +compares each output tensor. Adding an op to `scripts/torch_ops.yaml` +extends coverage with no test changes. +""" + +import json +import pathlib +import re + +import infini.ops +import pytest +import torch + +from tests.utils import randn_strided + +# PyTorch backends are emitted at this slot — see `_PYTORCH_SLOT` in +# `scripts/generate_torch_ops.py`. +_PYTORCH_SLOT = 8 + +_INSTALLED_METADATA_PATH = ( + pathlib.Path(infini.ops.__file__).resolve().with_name("torch_ops_metadata.json") +) +_SOURCE_METADATA_PATH = ( + pathlib.Path(__file__).resolve().parent.parent + / "generated" + / "torch_ops_metadata.json" +) + +_METADATA_PATH = next( + ( + path + for path in (_INSTALLED_METADATA_PATH, _SOURCE_METADATA_PATH) + if path.exists() + ), + _SOURCE_METADATA_PATH, +) +_METADATA = ( + json.loads(_METADATA_PATH.read_text()) if _METADATA_PATH.exists() else {"ops": []} +) + +_SHAPES = ( + (13, 4), + (13, 4, 4), + (4, 4, 5632), +) + +_DTYPES = ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), +) + +# Op-specific input shapes for matrix ops (`mm` etc.) which cannot use +# `randn_strided(shape)` for both inputs. The tuple is one shape per +# tensor input, in YAML order. +_TENSOR_SHAPES = { + "mm": ((8, 16), (16, 12)), + "bmm": ((4, 8, 16), (4, 16, 12)), + "matmul": ((8, 16), (16, 12)), + "dot": ((16,), (16,)), + "vdot": ((16,), (16,)), + "mv": ((8, 16), (16,)), + "inner": ((8, 16), (8, 16)), + "outer": ((8,), (12,)), + "ger": ((8,), (12,)), + "kron": ((3, 4), (2, 3)), +} + +# Per-(op, param-name) values for non-tensor inputs. Lookup falls back +# to a type-based default if no entry exists. +_SCALAR_VALUES = { + ("clamp_min", "min"): -0.5, + ("clamp_max", "max"): 0.5, + ("leaky_relu", "negative_slope"): 0.01, + ("hardshrink", "lambd"): 0.5, + ("softshrink", "lambd"): 0.5, + ("mvlgamma", "p"): 2, + ("prod", "dim"): 0, + ("cumsum", "dim"): 0, + ("cumprod", "dim"): 0, + ("logcumsumexp", "dim"): 0, + ("cummax", "dim"): 0, + ("cummin", "dim"): 0, + ("softmax", "dim"): -1, + ("log_softmax", "dim"): -1, + ("threshold", "threshold"): 0.0, + ("threshold", "value"): 0.0, + ("hardtanh", "min_val"): -1.0, + ("hardtanh", "max_val"): 1.0, + ("softplus", "beta"): 1.0, + ("softplus", "threshold"): 20.0, + ("elu", "alpha"): 1.0, + ("elu", "scale"): 1.0, + ("elu", "input_scale"): 1.0, + ("sub", "alpha"): 1.0, + ("addcmul", "value"): 1.0, + ("addcdiv", "value"): 1.0, + # `str reduce` modes accepted by the corresponding ATen kernels. + ("index_reduce", "reduce"): "amax", + ("scatter_reduce", "reduce"): "amax", + ("scatter_reduce_two", "reduce"): "amax", + # `int dim` for ops where 0 is a safe choice for our test shapes. + ("kthvalue_values", "k"): 1, + ("kthvalue_values", "dim"): 0, + ("mode_values", "dim"): 0, +} + +_TYPE_DEFAULTS = {"int": 0, "SymInt": 0, "bool": False, "str": "none"} + +# Mirrors `kStringToDataType` in `src/data_type.h`. Any tensor passed to +# an InfiniOps op must have one of these dtypes; others (`bool`, complex, +# quantised types) abort the process inside `DataTypeFromString`. Some +# vendor torch forks lag behind upstream and lack `uint16` / `uint32` / +# `uint64` (added in PyTorch 2.3); resolve them lazily and keep the +# attributes that actually exist. +_SUPPORTED_DTYPE_NAMES = ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "bfloat16", + "float32", + "float64", +) +_SUPPORTED_DTYPES = frozenset( + getattr(torch, name) for name in _SUPPORTED_DTYPE_NAMES if hasattr(torch, name) +) + + +_LIST_SIZE_RE = re.compile(r"\[(\d+)\]") + + +def _list_default(aten_type): + """Default value for a required `int[N]` / `SymInt[N]` param. Most + such params name a `dim` or `kernel_size`; `[0]` works for `dim` and + causes `kernel_size`-style ops to fail their reference call cleanly, + which the test then skips.""" + size_match = _LIST_SIZE_RE.search(aten_type) + n = int(size_match.group(1)) if size_match else 1 + return [0] * n + + +# Errors emitted by upstream PyTorch and vendor-forked variants for +# unsupported (op, dtype, device) combinations. We skip rather than fail +# on these — the gap is in PyTorch, not InfiniOps. +_VENDOR_SKIP_PATTERNS = ( + "not implemented for", # upstream PyTorch + "CNNL_STATUS_BAD_PARAM", # `torch_mlu` (Cambricon) + "MUDNN failed", # `torch_musa` (Moore) + "Could not run", # missing dispatcher entry on this backend + "don't support tensor dtype", # `torch_mlu` dtype check + "result requires dtype", # output dtype mismatch (e.g. `float_power`) + # ATen kernels for some loss ops (`mse_loss`, `huber_loss`, …) use + # the `out` buffer as intermediate scratch and resize it before the + # final reduction. Our `from_blob` outputs are non-resizable, so + # the kernel aborts the call with this message. Skip these — the + # zero-copy wrapper can't drive that codepath. + "Trying to resize storage that is not resizable", +) + +# Random-sampling ops never match a fresh torch reference call — +# they consume RNG state and return different draws. Skip rather +# than try to align the two PRNG streams. +_RANDOM_OPS = frozenset( + { + "bernoulli", + "multinomial", + "normal", + "rand", + "randn", + "randint", + "randperm", + "rrelu_with_noise", + } +) + +# Ops whose vendor kernel hangs indefinitely on at least one platform +# (`mode` on `torch_musa` for MUSA tensors). Skip until the vendor +# fixes the underlying kernel — letting the CI block on a hanging +# kernel costs ~30 min per platform run. +_VENDOR_HANG_OPS = frozenset( + { + "mode", + } +) + +# Ops where the ATen `_out` schema and the Python reference (`torch.`, +# `torch.nn.functional.`) diverge in positional-argument ordering, so +# the harness's purely-positional reference call lands an InfiniOps +# argument on the wrong reference parameter. E.g. ATen +# `binary_cross_entropy_out(self, target, weight=None, reduction=Mean, out)` +# has `weight` between `target` and `reduction`; with `weight` hidden as +# `Tensor?`, our visible signature is `(self, target, reduction, out)`, +# but `torch.nn.functional.binary_cross_entropy(input, target, weight, +# reduction)` reads our `reduction:int` as `weight:Tensor` and crashes +# inside `weight.size()`. The InfiniOps wrapper itself is fine; only +# the harness's reference call is wrong. +_REFERENCE_SIGNATURE_MISMATCH_OPS = frozenset( + { + "binary_cross_entropy", + "binary_cross_entropy_backward", + } +) + +# Full reductions with low-precision inputs diverge between the functional +# (`torch.(x)`) and `_out` paths because of intermediate-precision +# choices we cannot align from outside ATen. +_LARGE_REDUCTION_OPS = frozenset( + {"sum", "mean", "nansum", "nanmean", "prod", "std", "var"} +) + +# Ops with input-domain `TORCH_CHECK` macros that fire as device-side +# `assert` on CUDA when our generic random fp32 inputs fall outside the +# expected range. The Python-side `RuntimeError` is catchable, but the +# CUDA context is left poisoned and every subsequent test errors at +# setup. Skip these on cuda; the CPU path raises a clean exception +# that the existing harness already handles. +_DEVICE_ASSERTING_OPS = frozenset( + { + "binary_cross_entropy", # requires inputs in [0, 1] + "multi_margin_loss", + "multilabel_margin_loss", + "nll_loss", + "nll_loss2d", + # cuDNN paths divide by `kernel_size`/`stride` and SIGFPE on the + # `[0, 0]` defaults our harness substitutes for required `int[N]` + # parameters. + "cudnn_convolution", + "slow_conv3d", + "slow_conv_transpose2d", + "slow_conv_transpose3d", + "thnn_conv2d", + "im2col", + "col2im", + "max_unpool2d", + "max_unpool3d", + "reflection_pad1d", + "reflection_pad2d", + "reflection_pad3d", + "replication_pad1d", + "replication_pad2d", + "replication_pad3d", + "upsample_bicubic2d", + "upsample_bilinear2d", + "upsample_linear1d", + "upsample_nearest1d", + "upsample_nearest2d", + "upsample_nearest3d", + "upsample_trilinear3d", + "avg_pool2d", + "avg_pool3d", + "max_pool2d_with_indices", + "max_pool3d_with_indices", + "adaptive_max_pool2d", + "adaptive_max_pool3d", + "adaptive_avg_pool2d", + "adaptive_avg_pool3d", + } +) + + +def _torch_func(op_name): + """Resolve the reference function across `torch`, `torch.special`, + and `torch.nn.functional`. `special_` falls through to + `torch.special.` with the prefix stripped.""" + candidates = [ + (torch, op_name), + (torch.special, op_name), + (torch.nn.functional, op_name), + ] + if op_name.startswith("special_"): + candidates.append((torch.special, op_name.removeprefix("special_"))) + for namespace, attr in candidates: + func = getattr(namespace, attr, None) + if func is not None: + return func + pytest.skip(f"no reference function for `{op_name}` in PyTorch") + + +def _pascal(snake_name): + return "".join(part.capitalize() for part in snake_name.split("_")) + + +def _skip_if_not_active(op_name, device): + op_class = getattr(infini.ops, _pascal(op_name), None) + if op_class is None: + pytest.skip(f"`{op_name}` class not exposed on this build") + if _PYTORCH_SLOT not in op_class.active_implementation_indices(device): + pytest.skip(f"`{op_name}` slot {_PYTORCH_SLOT} not active on `{device}`") + + +def _skip_low_precision_reduction(op_name, dtype, device): + if op_name in _LARGE_REDUCTION_OPS: + if dtype in (torch.float16, torch.bfloat16): + pytest.skip(f"`{op_name}` precision diverges on fp16/bf16") + if device == "musa": + pytest.skip(f"`{op_name}` on `torch_musa` diverges from CPU reference") + + +def _build_input_value(op_name, param, shape, dtype, device, tensor_idx): + """Build the value passed to a non-out parameter.""" + if param["is_tensor"]: + per_op = _TENSOR_SHAPES.get(op_name) + tshape = per_op[tensor_idx] if per_op is not None else shape + return randn_strided(tshape, None, dtype=dtype, device=device) + key = (op_name, param["name"]) + if key in _SCALAR_VALUES: + return _SCALAR_VALUES[key] + t = param["type"] + if t.startswith(("int[", "SymInt[")) or t in {"int[]", "SymInt[]"}: + return _list_default(t) + return _TYPE_DEFAULTS.get(t, 0.5) + + +def _call_infini(op_name, *args): + try: + getattr(infini.ops, op_name)(*args, implementation_index=_PYTORCH_SLOT) + except RuntimeError as exc: + if any(p in str(exc) for p in _VENDOR_SKIP_PATTERNS): + pytest.skip(f"`{op_name}` unsupported by torch on this device/dtype") + raise + + +def _assert_close(actual, expected, rtol, atol): + if actual.dtype.is_floating_point: + assert torch.allclose(actual, expected, rtol=rtol, atol=atol, equal_nan=True) + else: + assert torch.equal(actual, expected) + + +def _testable_ops(): + """Filter the metadata down to ops the harness can drive. + + When multiple ATen overloads share the same `aten_name` they all + end up under one InfiniOps class (e.g., `std.dim` and + `std.correction` both map to `Std`), but each has a distinct ATen + `_out` signature. The reference call we synthesize from + `op_meta['params']` only exercises one signature; the secondary + overloads either rely on hidden defaults whose ATen interpretation + differs from the Python wrapper's (`std.correction(self, dim=None, + correction=None, ...)` defaults to a different correction than + `torch.std(self)`), or expose a positional shape that the Python + reference does not accept (e.g., `binary_cross_entropy_out`'s + `reduction:int` lands on the reference's `weight:Tensor?`). Keep + only the first overload of each `aten_name`.""" + seen = set() + keep = [] + + for op in _METADATA.get("ops", []): + if op["aten_name"] in seen: + continue + + seen.add(op["aten_name"]) + keep.append(op) + + return keep + + +def _op_meta_id(op_meta): + if not isinstance(op_meta, dict): + return "empty" + + # Multiple ATen overloads now share a single class name (`scatter` covers + # `scatter.src`, `scatter.value`, `scatter.reduce`, ...) — disambiguate + # parametrize ids by appending the visible parameter type signature so + # pytest does not collapse them into duplicate ids. + return op_meta["overload_name"] + + +@pytest.mark.parametrize("op_meta", _testable_ops(), ids=_op_meta_id) +@pytest.mark.parametrize("shape", _SHAPES, ids=lambda s: "x".join(map(str, s))) +@pytest.mark.parametrize(("dtype", "rtol", "atol"), _DTYPES) +def test_op(op_meta, shape, dtype, device, rtol, atol): + op_name = op_meta["name"] + aten_name = op_meta.get("aten_name", op_name) + _skip_if_not_active(op_name, device) + _skip_low_precision_reduction(aten_name, dtype, device) + if aten_name in _RANDOM_OPS: + pytest.skip(f"`{aten_name}` is non-deterministic (independent draws diverge)") + if aten_name in _REFERENCE_SIGNATURE_MISMATCH_OPS: + pytest.skip( + f"`{aten_name}`'s ATen `_out` and Python reference signatures " + "have different positional ordering" + ) + if aten_name in _VENDOR_HANG_OPS: + pytest.skip(f"`{aten_name}` hangs on at least one vendor kernel") + if device == "cuda" and aten_name in _DEVICE_ASSERTING_OPS: + pytest.skip( + f"`{aten_name}` triggers a CUDA device-side assert on random inputs" + ) + + in_params = [p for p in op_meta["params"] if not p["is_out"]] + out_params = [p for p in op_meta["params"] if p["is_out"]] + + # Build inputs in YAML order. + inputs = [] + tensor_idx = 0 + for p in in_params: + inputs.append( + _build_input_value(aten_name, p, shape, dtype, device, tensor_idx) + ) + if p["is_tensor"]: + tensor_idx += 1 + + # Run the reference to discover output shape(s)/dtype(s). + # An op may reject our generic `randn(shape)` input with any of these + # exception types — the gap is in our test harness's input synthesis, + # not in the InfiniOps wrapper. + try: + ref = _torch_func(aten_name)(*inputs) + except ( + RuntimeError, + TypeError, + ValueError, + IndexError, + NotImplementedError, + ) as exc: + pytest.skip(f"`torch.{aten_name}` rejects these inputs: {exc}") + + ref_outs = ref if isinstance(ref, tuple) else (ref,) + if len(ref_outs) != len(out_params): + # The Python-facing function (e.g. `F.adaptive_max_pool2d`) often + # exposes a subset of the ATen `_out` schema's outputs (returning + # only `out`, hiding `indices` behind a `return_indices=True` + # kwarg). Without a per-op map of how to coax the full tuple + # out, skip — the InfiniOps wrapper itself is fine. + pytest.skip( + f"`{aten_name}` reference produced {len(ref_outs)} output(s); " + f"schema declares {len(out_params)}" + ) + + # InfiniOps `DataType` enumerates only int{8,16,32,64}, uint{8,16,32,64}, + # float{16,32,64}, and bfloat16. Tensors with any other torch dtype + # (`bool`, `complex64`, `complex128`, …) abort on `DataTypeFromString`, + # so skip the test rather than crash the process. + tensors = [*ref_outs, *(x for x in inputs if isinstance(x, torch.Tensor))] + unsupported = next( + (t.dtype for t in tensors if t.dtype not in _SUPPORTED_DTYPES), None + ) + if unsupported is not None: + pytest.skip( + f"`{op_name}` uses dtype {unsupported} — not in InfiniOps `DataType`" + ) + + # On CUDA, `torch.empty_like` of a 0-element tensor gives a tensor + # whose `data_ptr()` is unregistered with the device; passing it + # through to the wrapper trips "pointer resides on host memory". + if any(t.numel() == 0 for t in ref_outs): + pytest.skip( + f"`{op_name}` produced 0-element output (unregistered data_ptr on cuda)" + ) + + outs = [torch.empty_like(t) for t in ref_outs] + _call_infini(op_name, *inputs, *outs) + + for actual, expected in zip(outs, ref_outs): + _assert_close(actual, expected, rtol, atol)