diff --git a/test/cpu_only/test_register.py b/test/cpu_only/test_register.py index eb41e1a1..c0f5bc09 100644 --- a/test/cpu_only/test_register.py +++ b/test/cpu_only/test_register.py @@ -337,6 +337,164 @@ def _iter_annotation_types(ann): if errors: pytest.fail("\n".join(errors), pytrace=False) +def _ast_has_register_decorator(node: "ast.ClassDef") -> bool: + """Check whether *node* carries ``@OPERATOR_REGISTRY.register()``.""" + import ast + for deco in node.decorator_list: + # @OPERATOR_REGISTRY.register() + if (isinstance(deco, ast.Call) + and isinstance(deco.func, ast.Attribute) + and deco.func.attr == "register" + and isinstance(deco.func.value, ast.Name) + and deco.func.value.id == "OPERATOR_REGISTRY"): + return True + # @OPERATOR_REGISTRY.register (no parentheses) + if (isinstance(deco, ast.Attribute) + and deco.attr == "register" + and isinstance(deco.value, ast.Name) + and deco.value.id == "OPERATOR_REGISTRY"): + return True + return False + + +def _ast_base_names(node: "ast.ClassDef"): + """Return the set of simple base-class names for *node*.""" + import ast + names = set() + for base in node.bases: + if isinstance(base, ast.Name): + names.add(base.id) + elif isinstance(base, ast.Attribute): + names.add(base.attr) + return names + + +def _scan_operator_classes(operators_dir): + """ + Two-pass AST scan of ``dataflow/operators/``. + + Pass 1 — collect intermediate ABC names (class names ending with ``ABC`` + that inherit from ``OperatorABC`` or another intermediate ABC). + + Pass 2 — collect every *concrete* operator class, i.e. a class that + either carries ``@OPERATOR_REGISTRY.register()`` **or** inherits from + ``OperatorABC`` / an intermediate ABC, while its own name does **not** + end with ``ABC``. + + Returns + ------- + dict {class_name: (rel_path, has_decorator, has_base)} + """ + import ast + from pathlib import Path + + operators_dir = Path(operators_dir) + project_root = operators_dir.parent.parent + + file_trees = [] + for py_file in sorted(operators_dir.rglob("*.py")): + if py_file.name == "__init__.py": + continue + try: + source = py_file.read_text(encoding="utf-8") + tree = ast.parse(source) + except (SyntaxError, UnicodeDecodeError): + continue + rel = py_file.relative_to(project_root).as_posix() + file_trees.append((rel, tree)) + + # --- pass 1: intermediate ABCs --- + operator_bases = {"OperatorABC"} + changed = True + while changed: + changed = False + for _rel, tree in file_trees: + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef): + continue + if not node.name.endswith("ABC"): + continue + if node.name in operator_bases: + continue + if _ast_base_names(node) & operator_bases: + operator_bases.add(node.name) + changed = True + + # --- pass 2: concrete operator classes --- + result = {} + for rel, tree in file_trees: + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef): + continue + if node.name.endswith("ABC"): + continue + has_deco = _ast_has_register_decorator(node) + has_base = bool(_ast_base_names(node) & operator_bases) + if has_deco or has_base: + result[node.name] = (rel, has_deco, has_base) + return result + + +@pytest.mark.cpu +def test_no_operator_missing_from_lazyload(): + """ + AST-scan ``dataflow/operators/`` for concrete operator classes (identified + by ``@OPERATOR_REGISTRY.register()`` decorator **or** inheritance from + ``OperatorABC`` / intermediate ABCs), then verify every one of them is + present in the registry after ``_get_all()``. + + Catches two failure modes: + A. Decorator present, but class not listed in ``__init__.py`` + ``TYPE_CHECKING`` block → LazyLoad never loads the file. + B. Inherits from ``OperatorABC`` but has **neither** the decorator + **nor** a LazyLoad entry → completely invisible to the framework. + """ + from pathlib import Path + import dataflow + + operators_dir = Path(dataflow.__file__).parent / "operators" + ast_classes = _scan_operator_classes(operators_dir) + + assert ast_classes, ( + "AST scan found zero concrete operator classes — check scan logic." + ) + print(f"\n[AST] Found {len(ast_classes)} concrete operator classes") + + # --- trigger full LazyLoad, snapshot registry --- + OPERATOR_REGISTRY._get_all() + registered = set(OPERATOR_REGISTRY.get_obj_map().keys()) + print(f"[Registry] {len(registered)} operators registered after _get_all()") + + # --- diff --- + missing = { + name: info for name, info in ast_classes.items() + if name not in registered + } + + if missing: + lines = [] + for name, (path, has_deco, has_base) in sorted(missing.items()): + if has_deco and not has_base: + reason = "has @register but missing from __init__.py TYPE_CHECKING" + elif has_base and not has_deco: + reason = ("inherits OperatorABC but MISSING @OPERATOR_REGISTRY.register() " + "AND __init__.py TYPE_CHECKING entry") + else: + reason = "has @register but missing from __init__.py TYPE_CHECKING" + lines.append(f" - {name} -> {path}\n reason: {reason}") + detail = "\n".join(lines) + pytest.fail( + f"\n{len(missing)} operator class(es) defined but NOT in the registry:\n\n" + f"{detail}\n\n" + f"Fix: 1) add @OPERATOR_REGISTRY.register() on the class (if missing),\n" + f" 2) add the import to the corresponding __init__.py " + f"`if TYPE_CHECKING:` block.", + pytrace=False, + ) + + print(f"[PASS] All {len(ast_classes)} concrete operator classes are in the registry.") + + if __name__ == "__main__": # 全局table,看所有注册的算子的str名称和对应的module路径 # 获得所有算子的类名2class映射