Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions test/cpu_only/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,164 @@ def _iter_annotation_types(ann):
if errors:
pytest.fail("\n".join(errors), pytrace=False)

def _ast_has_register_decorator(node: "ast.ClassDef") -> bool:
"""Check whether *node* carries ``@OPERATOR_REGISTRY.register()``."""
import ast
for deco in node.decorator_list:
# @OPERATOR_REGISTRY.register()
if (isinstance(deco, ast.Call)
and isinstance(deco.func, ast.Attribute)
and deco.func.attr == "register"
and isinstance(deco.func.value, ast.Name)
and deco.func.value.id == "OPERATOR_REGISTRY"):
return True
# @OPERATOR_REGISTRY.register (no parentheses)
if (isinstance(deco, ast.Attribute)
and deco.attr == "register"
and isinstance(deco.value, ast.Name)
and deco.value.id == "OPERATOR_REGISTRY"):
return True
return False


def _ast_base_names(node: "ast.ClassDef"):
"""Return the set of simple base-class names for *node*."""
import ast
names = set()
for base in node.bases:
if isinstance(base, ast.Name):
names.add(base.id)
elif isinstance(base, ast.Attribute):
names.add(base.attr)
return names


def _scan_operator_classes(operators_dir):
"""
Two-pass AST scan of ``dataflow/operators/``.

Pass 1 — collect intermediate ABC names (class names ending with ``ABC``
that inherit from ``OperatorABC`` or another intermediate ABC).

Pass 2 — collect every *concrete* operator class, i.e. a class that
either carries ``@OPERATOR_REGISTRY.register()`` **or** inherits from
``OperatorABC`` / an intermediate ABC, while its own name does **not**
end with ``ABC``.

Returns
-------
dict {class_name: (rel_path, has_decorator, has_base)}
"""
import ast
from pathlib import Path

operators_dir = Path(operators_dir)
project_root = operators_dir.parent.parent

file_trees = []
for py_file in sorted(operators_dir.rglob("*.py")):
if py_file.name == "__init__.py":
continue
try:
source = py_file.read_text(encoding="utf-8")
tree = ast.parse(source)
except (SyntaxError, UnicodeDecodeError):
continue
rel = py_file.relative_to(project_root).as_posix()
file_trees.append((rel, tree))

# --- pass 1: intermediate ABCs ---
operator_bases = {"OperatorABC"}
changed = True
while changed:
changed = False
for _rel, tree in file_trees:
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
if not node.name.endswith("ABC"):
continue
if node.name in operator_bases:
continue
if _ast_base_names(node) & operator_bases:
operator_bases.add(node.name)
changed = True

# --- pass 2: concrete operator classes ---
result = {}
for rel, tree in file_trees:
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
if node.name.endswith("ABC"):
continue
has_deco = _ast_has_register_decorator(node)
has_base = bool(_ast_base_names(node) & operator_bases)
if has_deco or has_base:
result[node.name] = (rel, has_deco, has_base)
return result


@pytest.mark.cpu
def test_no_operator_missing_from_lazyload():
"""
AST-scan ``dataflow/operators/`` for concrete operator classes (identified
by ``@OPERATOR_REGISTRY.register()`` decorator **or** inheritance from
``OperatorABC`` / intermediate ABCs), then verify every one of them is
present in the registry after ``_get_all()``.

Catches two failure modes:
A. Decorator present, but class not listed in ``__init__.py``
``TYPE_CHECKING`` block → LazyLoad never loads the file.
B. Inherits from ``OperatorABC`` but has **neither** the decorator
**nor** a LazyLoad entry → completely invisible to the framework.
"""
from pathlib import Path
import dataflow

operators_dir = Path(dataflow.__file__).parent / "operators"
ast_classes = _scan_operator_classes(operators_dir)

assert ast_classes, (
"AST scan found zero concrete operator classes — check scan logic."
)
print(f"\n[AST] Found {len(ast_classes)} concrete operator classes")

# --- trigger full LazyLoad, snapshot registry ---
OPERATOR_REGISTRY._get_all()
registered = set(OPERATOR_REGISTRY.get_obj_map().keys())
print(f"[Registry] {len(registered)} operators registered after _get_all()")

# --- diff ---
missing = {
name: info for name, info in ast_classes.items()
if name not in registered
}

if missing:
lines = []
for name, (path, has_deco, has_base) in sorted(missing.items()):
if has_deco and not has_base:
reason = "has @register but missing from __init__.py TYPE_CHECKING"
elif has_base and not has_deco:
reason = ("inherits OperatorABC but MISSING @OPERATOR_REGISTRY.register() "
"AND __init__.py TYPE_CHECKING entry")
else:
reason = "has @register but missing from __init__.py TYPE_CHECKING"
lines.append(f" - {name} -> {path}\n reason: {reason}")
detail = "\n".join(lines)
pytest.fail(
f"\n{len(missing)} operator class(es) defined but NOT in the registry:\n\n"
f"{detail}\n\n"
f"Fix: 1) add @OPERATOR_REGISTRY.register() on the class (if missing),\n"
f" 2) add the import to the corresponding __init__.py "
f"`if TYPE_CHECKING:` block.",
pytrace=False,
)

print(f"[PASS] All {len(ast_classes)} concrete operator classes are in the registry.")


if __name__ == "__main__":
# 全局table,看所有注册的算子的str名称和对应的module路径
# 获得所有算子的类名2class映射
Expand Down
Loading