diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 00000000..ad86e814 --- /dev/null +++ b/tools/README.md @@ -0,0 +1,123 @@ +# PaddleAPITest Tools + +本目录收纳 PaddleAPITest 的配置整理、日志分析、错误统计和专项辅助脚本。以下命令默认在 `PaddleAPITest/` 根目录执行。 + +## 配置集合工具 + +- `extract_api_set.py`:从配置 `.txt` 文件或目录中提取 API 名集合,默认输出 `tester/api_config/output/api_extracted.txt`。 + ```bash + python tools/extract_api_set.py -i tester/api_config/api_config_tmp.txt -o tester/api_config/output + ``` + +- `merge_config_set.py`:合并、去重、排序 API 配置集合,可按数量分片输出;原地修改时默认创建 `.backup`。 + ```bash + python tools/merge_config_set.py -i tester/api_config/configs/ -o tester/api_config/output + python tools/merge_config_set.py -i config.txt -I + ``` + +- `diff_config_set.py`:对比两个配置集合,打印 left 中移除的配置和 right 中新增的配置。 + ```bash + python tools/diff_config_set.py --left old.txt --right new.txt + ``` + +- `retrieve_config_set.py`:按关键词或完整 API 名从配置文件/目录召回配置,默认输出 `tester/api_config/api_config_retrieved.txt`。 + ```bash + python tools/retrieve_config_set.py -i tester/api_config/5_accuracy -k matmul linear + python tools/retrieve_config_set.py -i tester/api_config/5_accuracy -k paddle.matmul -e + ``` + +- `remove_config_set.py`:按一个或多个配置清单从文件或目录中删除配置;默认仅在文件被修改时,将带时间戳的 `.backup` 和 `.removed_configs` 写入同目录下的 `.rm_config_backups/`。 + ```bash + python tools/remove_config_set.py -i configs/ -r remove_configs.txt + python tools/remove_config_set.py -i configs/ -r remove_a.txt remove_b.txt remove_dir/ + python tools/remove_config_set.py -i configs/ -r remove_configs.txt -n + python tools/remove_config_set.py -i configs/ -m merged_removed.txt + python tools/remove_config_set.py -i configs/ -R + python tools/remove_config_set.py -c configs/.rm_config_backups + ``` + `-R` 回退和 `-c` 精简备份目录会修改已有文件,非 `-n` 模式需要输入 `yes` 二次确认。 + +- `extract_cases_from_csv.py`:从 CSV 中按 API 名过滤 case,输出 `filtered_result_.csv` 和 `error_config_.txt`。 + ```bash + python tools/extract_cases_from_csv.py paddle.matmul --config-path TotalStableFull.csv + python tools/extract_cases_from_csv.py paddle.matmul --only-diff -o output/ + ``` + +- `extract_or_remove_api_cases.py`:从指定配置目录中按 API 提取 case 到临时文件,可选从源文件删除;删除时默认备份。 + ```bash + python tools/extract_or_remove_api_cases.py --config paddle.numel --dst mytmp.txt + python tools/extract_or_remove_api_cases.py --config paddle.numel --remove + ``` + +- `remove_lines_by_keyword.py`:按关键词文件删除匹配文件中的行,适合批量清理包含特定关键词的配置;默认备份。 + ```bash + python tools/remove_lines_by_keyword.py -p 'tester/api_config/monitor_config/accuracy/GPU/monitoring_configs_*.txt' -k kw.txt + ``` + +- `normalize_origin_api_config.py`:历史配置归一化/整理脚本,当前包含较多历史硬编码逻辑;使用前应先确认输入输出路径。 + +- `shrink_large_configs.py`:针对 crash、OOM、timeout、numpy_error 中的大 Tensor 配置缩小 shape,生成缩小后的回归配置。 + ```bash + python tools/shrink_large_configs.py --error-logs test_log --source-configs source.txt --output shrunk.txt --factor 4 + ``` + +## 测试日志工具 + +- `seek_skip_configs.py`:从 `checkpoint.txt` 中扣除已有终态日志,生成 `api_config_skip.txt`,并默认同步更新 checkpoint 与备份。 + ```bash + python tools/seek_skip_configs.py -p tester/api_config/test_log + python tools/seek_skip_configs.py -p tester/api_config/test_log --no-update-checkpoint + ``` + +- `remove_retest_configs.py`:从 checkpoint 中移除指定结果类型对应配置,并删除对应分类文件;用于重测 timeout、OOM、skip 等集合。 + ```bash + python tools/remove_retest_configs.py -p tester/api_config/test_log -r timeout oom skip + ``` + +- `log_digest.py`:解析原始测试日志,按 pass、accuracy_error、paddle_error、torch_error、cuda_error 等类型拆分日志和配置。 + ```bash + python tools/log_digest.py --file run.log --id 1 --ckpt_id 1 + ``` + +- `find_API_config.py`:从日志中提取指定 API 的错误日志和错误配置。 + ```bash + python tools/find_API_config.py --input-log log.log --api paddle.matmul + ``` + +## 错误统计工具 + +- `error_stat/error_stat.py`:整理 `log_inorder.log` 与 `api_config_*.txt`,输出 pass、error、invalid 分类统计目录 `error_stat_result/`。 + ```bash + python tools/error_stat/error_stat.py -i tester/api_config/test_log -o tester/api_config/test_log + python tools/error_stat/error_stat.py -i tester/api_config/test_log --split-errors + ``` + +- `error_stat/csv_stat_tol.py`:整理 `tol*.csv` 精度数据,生成 `tol_full.csv`、`tol_stat.csv`、`tol_stat_api.csv`。 + ```bash + python tools/error_stat/csv_stat_tol.py -i tester/api_config/test_log + ``` + +- `error_stat/csv_stat_stable.py`:整理 `stable*.csv` 稳定性精度数据,生成 `stable_full.csv`、`stable_stat.csv`、`stable_stat_api.csv`。 + ```bash + python tools/error_stat/csv_stat_stable.py -i tester/api_config/test_log --chunk-size 2000000 --max-workers 10 + ``` + +- `error_stat/error_summary.py`:错误统计辅助脚本,用于进一步汇总错误信息;使用前建议查看脚本内参数说明。 + +## API Trace 与专项工具 + +- `api_tracer/api_tracer.py`:通过 monkey-patch 捕获框架 API 调用并序列化为 PaddleAPITest 配置;详见 `tools/api_tracer/README.md`。 + +- `api_tracer/api_alias_tool.py`、`api_tracer/api_map_tool.py`、`api_tracer/api_merge_tool.py`:API trace 配套的 alias、mapping、merge 辅助工具。 + +- `prof/`:单 API 性能 demo 与 profiling 脚本目录,输出结果通常是本地分析产物,不建议提交。 + +- `accuracy_demo.py`、`performance_demo.py`:手工调试 accuracy / performance 路径的示例脚本。 + +- `prof_api_gsb.py`:性能 API 分组/标记辅助脚本,当前主要作为静态数据脚本使用。 + +- `test_signature.py`、`test_tool.py`:开发期辅助测试脚本,使用前先确认当前内容是否符合目标场景。 + +## 注意事项 + +- 会原地修改配置或 checkpoint 的工具默认创建 `.backup`;如确认不需要备份,可使用对应脚本的 `--no-backup`。 diff --git a/tools/diff_config_set.py b/tools/diff_config_set.py new file mode 100644 index 00000000..75f22b23 --- /dev/null +++ b/tools/diff_config_set.py @@ -0,0 +1,63 @@ +# 对比 API 配置集合小工具 +# @author: cangtianhuang +# @date: 2026-06-11 + +from __future__ import annotations + +import argparse +from pathlib import Path + +DEFAULT_LEFT_PATH = Path("tester/api_config/test_log_cinn_filtered/pass_config.txt") +DEFAULT_RIGHT_PATH = Path("tester/api_config/test_log_cinn/api_config_pass.txt") + + +def load_config_set(config_file): + path = Path(config_file) + content = path.read_text(encoding="utf-8") + return {line.strip() for line in content.splitlines() if line.strip()} + + +def diff_config_sets(left_path=DEFAULT_LEFT_PATH, right_path=DEFAULT_RIGHT_PATH): + left_configs = load_config_set(left_path) + right_configs = load_config_set(right_path) + + removed_configs = left_configs - right_configs + added_configs = right_configs - left_configs + + print(f"left configs: {len(left_configs)}") + print(f"right configs: {len(right_configs)}") + print(f"removed configs: {len(removed_configs)}") + for config in sorted(removed_configs): + print(config) + + print(f"added configs: {len(added_configs)}") + for config in sorted(added_configs): + print(config) + + +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="对比两个 API 配置集合") + parser.add_argument( + "--left", + "-l", + type=Path, + default=DEFAULT_LEFT_PATH, + help="基准配置文件路径", + ) + parser.add_argument( + "--right", + "-r", + type=Path, + default=DEFAULT_RIGHT_PATH, + help="对比配置文件路径", + ) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + diff_config_sets(args.left, args.right) + + +if __name__ == "__main__": + main() diff --git a/tools/error_stat/csv_stat_stable.py b/tools/error_stat/csv_stat_stable.py index a019a3d6..b0648333 100644 --- a/tools/error_stat/csv_stat_stable.py +++ b/tools/error_stat/csv_stat_stable.py @@ -1,9 +1,11 @@ # 整理 stable*.csv 精度统计数据,产出:stable_full.csv、stable_stat.csv、stable_stat_api.csv # @author: cangtianhuang -# @date: 2025-07-30 +# @date: 2026-06-11 + from __future__ import annotations -import glob +import argparse +import re from collections import defaultdict from concurrent.futures import ProcessPoolExecutor from pathlib import Path @@ -11,22 +13,25 @@ import numpy as np import pandas as pd -TEST_LOG_PATH = Path("tester/api_config/test_log") -OUTPUT_PATH = TEST_LOG_PATH -OUTPUT_PATH.mkdir(parents=True, exist_ok=True) +DEFAULT_TEST_LOG_PATH = Path("tester/api_config/test_log") +GENERATED_FILES = {"stable_stat.csv", "stable_stat_api.csv", "stable_full.csv"} +NUMERIC_COLUMNS = ["max_abs_diff", "max_rel_diff"] +CUSTOM_OP_API = "paddle._C_ops._run_custom_op" +CUSTOM_OP_PATTERN = re.compile(rf"^{re.escape(CUSTOM_OP_API)}\(\s*(['\"])(.*?)\1") + + +def get_api_key(api, config=None): + if not isinstance(api, str): + return api -# 查找所有stable*.csv文件 -file_pattern = TEST_LOG_PATH / "stable*.csv" -file_list = glob.glob(str(file_pattern)) -file_list = [ - f - for f in file_list - if Path(f).name not in ["stable_stat.csv", "stable_stat_api.csv", "stable_full.csv"] -] -file_list.sort() -if not file_list: - print(f"No files found matching pattern {file_pattern}") - exit(0) + match = CUSTOM_OP_PATTERN.match(api) + if not match and api == CUSTOM_OP_API and isinstance(config, str): + match = CUSTOM_OP_PATTERN.match(config) + if not match: + return api + + op_name = match.group(2) + return f'{CUSTOM_OP_API}("{op_name}")' def list_defaultdict_factory(): @@ -41,8 +46,17 @@ def nested_int_defaultdict_factory(): return defaultdict(int_defaultdict_factory) -# 读取并处理每个文件 +def collect_csv_files(input_path): + file_pattern = input_path / "stable*.csv" + file_list = sorted(file_pattern.parent.glob(file_pattern.name)) + file_list = [file_path for file_path in file_list if file_path.name not in GENERATED_FILES] + if not file_list: + print(f"No files found matching pattern {file_pattern}") + return file_list + + def process_chunk(chunk): + chunk["API"] = [get_api_key(api, config) for api, config in zip(chunk["API"], chunk["config"])] stats = defaultdict(list_defaultdict_factory) api_stats = defaultdict(nested_int_defaultdict_factory) for _, row in chunk.iterrows(): @@ -61,8 +75,20 @@ def process_chunk(chunk): return stats, api_stats, chunk -# 并行处理 CSV 文件 -def parallel_process_csv(file_path, chunk_size=2000000): +def merge_stats(target_stats, source_stats): + for key in source_stats: + target_stats[key]["abs_diffs"].extend(source_stats[key]["abs_diffs"]) + target_stats[key]["rel_diffs"].extend(source_stats[key]["rel_diffs"]) + + +def merge_api_stats(target_api_stats, source_api_stats): + for api in source_api_stats: + for dtype in source_api_stats[api]: + for comp in source_api_stats[api][dtype]: + target_api_stats[api][dtype][comp] += source_api_stats[api][dtype][comp] + + +def parallel_process_csv(file_path, chunk_size=2000000, max_workers=10): stats = defaultdict(list_defaultdict_factory) api_stats = defaultdict(nested_int_defaultdict_factory) chunks = [] @@ -75,167 +101,212 @@ def parallel_process_csv(file_path, chunk_size=2000000): on_bad_lines="warn", dtype={"max_abs_diff": float, "max_rel_diff": float}, ) - except Exception as e: - print(f"Error reading file {file_path} for merging: {e}") + except Exception as err: + print(f"Error reading file {file_path} for merging: {err}") return stats, api_stats, config_count, chunks - with ProcessPoolExecutor(max_workers=10) as executor: + + with ProcessPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(process_chunk, chunk) for chunk in chunks_iterator] for future in futures: chunk_stats, chunk_api_stats, chunk = future.result() chunks.append(chunk) config_count += len(chunk) - for key in chunk_stats: - stats[key]["abs_diffs"].extend(chunk_stats[key]["abs_diffs"]) - stats[key]["rel_diffs"].extend(chunk_stats[key]["rel_diffs"]) - for api in chunk_api_stats: - for dtype in chunk_api_stats[api]: - for comp in chunk_api_stats[api][dtype]: - api_stats[api][dtype][comp] += chunk_api_stats[api][dtype][comp] + merge_stats(stats, chunk_stats) + merge_api_stats(api_stats, chunk_api_stats) print(f"Read {config_count} configs in {file_path}") return stats, api_stats, config_count, chunks -stats = defaultdict(list_defaultdict_factory) -api_stats = defaultdict(nested_int_defaultdict_factory) -config_count = 0 -dfs = [] -for file_path in file_list: - # 并行处理统计数据 - file_stats, file_api_stats, file_config_count, file_chunks = parallel_process_csv(file_path) - dfs.extend(file_chunks) - config_count += file_config_count - for key in file_stats: - stats[key]["abs_diffs"].extend(file_stats[key]["abs_diffs"]) - stats[key]["rel_diffs"].extend(file_stats[key]["rel_diffs"]) - for api in file_api_stats: - for dtype in file_api_stats[api]: - for comp in file_api_stats[api][dtype]: - api_stats[api][dtype][comp] += file_api_stats[api][dtype][comp] - -print(f"\nTotal read {len(stats)} (api, dtype, comp)s, {config_count} configs.") -if not stats: - print("No data to process.") - exit(0) - -# 合并所有DataFrame并保存 -merged_df = pd.concat(dfs, ignore_index=True) -numeric_cols = ["max_abs_diff", "max_rel_diff"] -merged_df = merged_df.groupby(["API", "dtype", "config", "comp"], as_index=False)[ - numeric_cols -].mean() -# merged_df = merged_df.drop_duplicates(subset=["config", "comp"], keep="last") -merged_df = merged_df.sort_values(by=["API", "dtype", "config", "comp"], ignore_index=True) -for col in numeric_cols: - merged_df[col] = merged_df[col].apply(lambda x: f"{float(x):.6e}") -output_file = OUTPUT_PATH / "stable_full.csv" -merged_df.to_csv(output_file, index=False, na_rep="nan") - -# 准备结果数据 -stats_data = [] -for api, dtype, comp in sorted(stats.keys()): - abs_diffs = np.array(stats[(api, dtype, comp)]["abs_diffs"], dtype=np.float64) - rel_diffs = np.array(stats[(api, dtype, comp)]["rel_diffs"], dtype=np.float64) - - count = len(abs_diffs) - - if not np.any(np.isnan(abs_diffs)): - abs_quantile = np.quantile(abs_diffs, 0.99) - filtered_abs = abs_diffs[abs_diffs <= abs_quantile] - abs_diffs = filtered_abs if len(filtered_abs) > 0 else abs_diffs - - if not np.any(np.isnan(rel_diffs)): - rel_quantile = np.quantile(rel_diffs, 0.99) - filtered_rel = rel_diffs[rel_diffs <= rel_quantile] - rel_diffs = filtered_rel if len(filtered_rel) > 0 else rel_diffs - - stats_data.append( - { - "API": api, - "dtype": dtype, - "comp": comp, - "abs_min": f"{np.min(abs_diffs):.6e}", - "abs_max": f"{np.max(abs_diffs):.6e}", - "abs_mean": f"{np.mean(abs_diffs):.6e}", - "rel_min": f"{np.min(rel_diffs):.6e}", - "rel_max": f"{np.max(rel_diffs):.6e}", - "rel_mean": f"{np.mean(rel_diffs):.6e}", - "count": count, - } - ) +def load_stable_data(file_list, chunk_size=2000000, max_workers=10): + stats = defaultdict(list_defaultdict_factory) + api_stats = defaultdict(nested_int_defaultdict_factory) + config_count = 0 + dfs = [] -# 转换为DataFrame并保存 -if stats_data: - df = pd.DataFrame(stats_data) - output_file = OUTPUT_PATH / "stable_stat.csv" - df.to_csv(output_file, index=False, na_rep="nan") - print(f"\nStatistics saved to {output_file}") - print("Sample of the results:") - print(df.head()) -else: - print("No data to process.") - -# 准备统计数据 -api_stats_data = [] -for api in sorted(api_stats.keys()): - api_dtype = api_stats[api] - dtypes = "/".join(sorted(api_dtype.keys())) - total = sum(api_dtype[dtype][comp] for dtype in api_dtype for comp in api_dtype[dtype]) - all_comps = "/".join(sorted({comp for dtype in api_dtype for comp in api_dtype[dtype]})) - - # 统计所有 comp 的模式 - api_stats_data.append( - { - "API": api, - "dtype": "dtypes:" + dtypes, - "comp": f"comps:{all_comps}", - "count": total, - "percentage": 100.0, - } - ) + for file_path in file_list: + file_stats, file_api_stats, file_config_count, file_chunks = parallel_process_csv( + file_path, + chunk_size=chunk_size, + max_workers=max_workers, + ) + dfs.extend(file_chunks) + config_count += file_config_count + merge_stats(stats, file_stats) + merge_api_stats(api_stats, file_api_stats) - # 按 comp 统计 - comp_counts = defaultdict(int) - for dtype in api_dtype: - for comp in api_dtype[dtype]: - comp_counts[comp] += api_dtype[dtype][comp] + print(f"\nTotal read {len(stats)} (api, dtype, comp)s, {config_count} configs.") + return dfs, stats, api_stats - for comp in sorted(comp_counts.keys()): - comp_total = comp_counts[comp] - comp_dtypes = "/".join(sorted(dtype for dtype in api_dtype if comp in api_dtype[dtype])) - api_stats_data.append( + +def write_full_csv(dfs, output_path): + merged_df = pd.concat(dfs, ignore_index=True) + merged_df = merged_df.groupby(["API", "dtype", "config", "comp"], as_index=False)[ + NUMERIC_COLUMNS + ].mean() + merged_df = merged_df.sort_values(by=["API", "dtype", "config", "comp"], ignore_index=True) + for col in NUMERIC_COLUMNS: + merged_df[col] = merged_df[col].apply(lambda x: f"{float(x):.6e}") + + output_file = output_path / "stable_full.csv" + merged_df.to_csv(output_file, index=False, na_rep="nan") + + +def write_stat_csv(stats, output_path): + stats_data = [] + for api, dtype, comp in sorted(stats.keys()): + abs_diffs = np.array(stats[(api, dtype, comp)]["abs_diffs"], dtype=np.float64) + rel_diffs = np.array(stats[(api, dtype, comp)]["rel_diffs"], dtype=np.float64) + + count = len(abs_diffs) + + if not np.any(np.isnan(abs_diffs)): + abs_quantile = np.quantile(abs_diffs, 0.99) + filtered_abs = abs_diffs[abs_diffs <= abs_quantile] + abs_diffs = filtered_abs if len(filtered_abs) > 0 else abs_diffs + + if not np.any(np.isnan(rel_diffs)): + rel_quantile = np.quantile(rel_diffs, 0.99) + filtered_rel = rel_diffs[rel_diffs <= rel_quantile] + rel_diffs = filtered_rel if len(filtered_rel) > 0 else rel_diffs + + stats_data.append( { "API": api, - "dtype": "dtypes:" + comp_dtypes, + "dtype": dtype, "comp": comp, - "count": comp_total, - "percentage": round(comp_total / total * 100, 2), + "abs_min": f"{np.min(abs_diffs):.6e}", + "abs_max": f"{np.max(abs_diffs):.6e}", + "abs_mean": f"{np.mean(abs_diffs):.6e}", + "rel_min": f"{np.min(rel_diffs):.6e}", + "rel_max": f"{np.max(rel_diffs):.6e}", + "rel_mean": f"{np.mean(rel_diffs):.6e}", + "count": count, } ) - # 按 dtype 和 comp 统计 - for dtype in sorted(api_dtype.keys()): - for comp in sorted(api_dtype[dtype].keys()): - count = api_dtype[dtype][comp] - if count > 0: - api_stats_data.append( - { - "API": api, - "dtype": dtype, - "comp": comp, - "count": count, - "percentage": round(count / total * 100, 2), - } - ) - -# 转换为DataFrame并保存 -if api_stats_data: - df = pd.DataFrame(api_stats_data) - output_file = OUTPUT_PATH / "stable_stat_api.csv" - df.to_csv(output_file, index=False, na_rep="nan") - print(f"\nAPI statistics saved to {output_file}") - print("Sample of API statistics:") - print(df.head()) -else: - print("No API statistics to process.") + if stats_data: + df = pd.DataFrame(stats_data) + output_file = output_path / "stable_stat.csv" + df.to_csv(output_file, index=False, na_rep="nan") + print(f"\nStatistics saved to {output_file}") + print("Sample of the results:") + print(df.head()) + else: + print("No data to process.") + + +def write_api_stat_csv(api_stats, output_path): + api_stats_data = [] + for api in sorted(api_stats.keys()): + api_dtype = api_stats[api] + dtypes = "/".join(sorted(api_dtype.keys())) + total = sum(api_dtype[dtype][comp] for dtype in api_dtype for comp in api_dtype[dtype]) + all_comps = "/".join(sorted({comp for dtype in api_dtype for comp in api_dtype[dtype]})) + + api_stats_data.append( + { + "API": api, + "dtype": "dtypes:" + dtypes, + "comp": f"comps:{all_comps}", + "count": total, + "percentage": 100.0, + } + ) + + comp_counts = defaultdict(int) + for dtype in api_dtype: + for comp in api_dtype[dtype]: + comp_counts[comp] += api_dtype[dtype][comp] + + for comp in sorted(comp_counts.keys()): + comp_total = comp_counts[comp] + comp_dtypes = "/".join(sorted(dtype for dtype in api_dtype if comp in api_dtype[dtype])) + api_stats_data.append( + { + "API": api, + "dtype": "dtypes:" + comp_dtypes, + "comp": comp, + "count": comp_total, + "percentage": round(comp_total / total * 100, 2), + } + ) + + for dtype in sorted(api_dtype.keys()): + for comp in sorted(api_dtype[dtype].keys()): + count = api_dtype[dtype][comp] + if count > 0: + api_stats_data.append( + { + "API": api, + "dtype": dtype, + "comp": comp, + "count": count, + "percentage": round(count / total * 100, 2), + } + ) + + if api_stats_data: + df = pd.DataFrame(api_stats_data) + output_file = output_path / "stable_stat_api.csv" + df.to_csv(output_file, index=False, na_rep="nan") + print(f"\nAPI statistics saved to {output_file}") + print("Sample of API statistics:") + print(df.head()) + else: + print("No API statistics to process.") + + +def run_stable_stat( + input_path=DEFAULT_TEST_LOG_PATH, output_path=None, chunk_size=2000000, max_workers=10 +): + input_path = Path(input_path) + output_path = Path(output_path) if output_path is not None else input_path + output_path.mkdir(parents=True, exist_ok=True) + + file_list = collect_csv_files(input_path) + if not file_list: + return + + dfs, stats, api_stats = load_stable_data( + file_list, + chunk_size=chunk_size, + max_workers=max_workers, + ) + if not stats: + print("No data to process.") + return + + write_full_csv(dfs, output_path) + write_stat_csv(stats, output_path) + write_api_stat_csv(api_stats, output_path) + + +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="整理 stable*.csv 精度统计数据") + parser.add_argument( + "--input", + "-i", + type=str, + default=str(DEFAULT_TEST_LOG_PATH), + help="输入路径,包含 stable*.csv 文件", + ) + parser.add_argument("--output", "-o", type=str, default=None, help="输出路径(默认同输入路径)") + parser.add_argument("--chunk-size", type=int, default=2000000, help="CSV 分块读取行数") + parser.add_argument("--max-workers", type=int, default=10, help="并行处理进程数") + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + run_stable_stat( + args.input, + args.output, + chunk_size=args.chunk_size, + max_workers=args.max_workers, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/error_stat/csv_stat_tol.py b/tools/error_stat/csv_stat_tol.py index e2009bc5..5a0ad54b 100644 --- a/tools/error_stat/csv_stat_tol.py +++ b/tools/error_stat/csv_stat_tol.py @@ -1,178 +1,240 @@ # 整理 tol_*.csv 精度统计数据,产出:tol_full.csv、tol_stat.csv、tol_stat_api.csv # @author: cangtianhuang -# @date: 2025-06-21 +# @date: 2026-06-11 + from __future__ import annotations -import glob +import argparse +import re from collections import defaultdict from pathlib import Path import pandas as pd -TEST_LOG_PATH = Path("tester/api_config/test_log") -OUTPUT_PATH = TEST_LOG_PATH -OUTPUT_PATH.mkdir(parents=True, exist_ok=True) - -# 查找所有tol*.csv文件 -file_pattern = TEST_LOG_PATH / "tol*.csv" -file_list = glob.glob(str(file_pattern)) -file_list.sort() -if not file_list: - print(f"No files found matching pattern {file_pattern}") - exit(0) - -# 读取并处理每个文件 -dfs = [] -stats = defaultdict(lambda: defaultdict(list)) -api_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) -config_count = 0 -for file_path in file_list: - file_name = file_path.split("/")[-1] - if file_name in ["tol_stat.csv", "tol_stat_api.csv", "tol_full.csv"]: - continue - try: - df = pd.read_csv(file_path, on_bad_lines="warn") - dfs.append(df) - print(f"Read {len(df)} configs in {file_path}") - config_count += len(df) - for _, row in df.iterrows(): - api = row["API"] - dtype = row["dtype"] - mode = row["mode"] - max_abs_diff = row["max_abs_diff"] - max_rel_diff = row["max_rel_diff"] - stats[(api, dtype, mode)]["abs_diffs"].append(max_abs_diff) - stats[(api, dtype, mode)]["rel_diffs"].append(max_rel_diff) - - api_stats[api][dtype][mode] += 1 - except Exception as e: - print(f"Error processing file {file_path}: {e}") -print(f"\nTotal read {len(stats)} (api, dtype, mode)s, {config_count} configs.") -if not stats: - exit(0) - -# 合并所有DataFrame并保存 -merged_df = pd.concat(dfs, ignore_index=True) -merged_df = merged_df.drop_duplicates(subset=["config", "mode"], keep="last") -merged_df = merged_df.sort_values(by=["API", "dtype", "config", "mode"], ignore_index=True) -numeric_cols = ["max_abs_diff", "max_rel_diff"] -for col in numeric_cols: - merged_df[col] = merged_df[col].apply(lambda x: f"{float(x):.6e}") -output_file = OUTPUT_PATH / "tol_full.csv" -merged_df.to_csv(output_file, index=False, na_rep="nan") - -# 准备结果数据 -stats_data = [] -for api, dtype, mode in sorted(stats.keys()): - values = stats[(api, dtype, mode)] - abs_diffs = values["abs_diffs"] - rel_diffs = values["rel_diffs"] - - abs_min = min(abs_diffs) - abs_max = max(abs_diffs) - abs_mean = sum(abs_diffs) / len(abs_diffs) - rel_min = min(rel_diffs) - rel_max = max(rel_diffs) - rel_mean = sum(rel_diffs) / len(rel_diffs) - count = len(abs_diffs) - - stats_data.append( - { - "API": api, - "dtype": dtype, - "mode": mode, - "abs_min": f"{abs_min:.6e}", - "abs_max": f"{abs_max:.6e}", - "abs_mean": f"{abs_mean:.6e}", - "rel_min": f"{rel_min:.6e}", - "rel_max": f"{rel_max:.6e}", - "rel_mean": f"{rel_mean:.6e}", - "count": count, - } +DEFAULT_TEST_LOG_PATH = Path("tester/api_config/test_log") +GENERATED_FILES = {"tol_stat.csv", "tol_stat_api.csv", "tol_full.csv"} +NUMERIC_COLUMNS = ["max_abs_diff", "max_rel_diff"] +CUSTOM_OP_API = "paddle._C_ops._run_custom_op" +CUSTOM_OP_PATTERN = re.compile(rf"^{re.escape(CUSTOM_OP_API)}\(\s*(['\"])(.*?)\1") + + +def get_api_key(api, config=None): + if not isinstance(api, str): + return api + + match = CUSTOM_OP_PATTERN.match(api) + if not match and api == CUSTOM_OP_API and isinstance(config, str): + match = CUSTOM_OP_PATTERN.match(config) + if not match: + return api + + op_name = match.group(2) + return f'{CUSTOM_OP_API}("{op_name}")' + + +def collect_csv_files(input_path): + file_pattern = input_path / "tol*.csv" + file_list = sorted(file_pattern.parent.glob(file_pattern.name)) + file_list = [file_path for file_path in file_list if file_path.name not in GENERATED_FILES] + if not file_list: + print(f"No files found matching pattern {file_pattern}") + return file_list + + +def load_tol_data(file_list): + dfs = [] + stats = defaultdict(lambda: defaultdict(list)) + api_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) + config_count = 0 + + for file_path in file_list: + try: + df = pd.read_csv(file_path, on_bad_lines="warn") + df["API"] = [get_api_key(api, config) for api, config in zip(df["API"], df["config"])] + dfs.append(df) + print(f"Read {len(df)} configs in {file_path}") + config_count += len(df) + + for _, row in df.iterrows(): + api = row["API"] + dtype = row["dtype"] + mode = row["mode"] + max_abs_diff = row["max_abs_diff"] + max_rel_diff = row["max_rel_diff"] + stats[(api, dtype, mode)]["abs_diffs"].append(max_abs_diff) + stats[(api, dtype, mode)]["rel_diffs"].append(max_rel_diff) + api_stats[api][dtype][mode] += 1 + except Exception as err: + print(f"Error processing file {file_path}: {err}") + + print(f"\nTotal read {len(stats)} (api, dtype, mode)s, {config_count} configs.") + return dfs, stats, api_stats + + +def write_full_csv(dfs, output_path): + merged_df = pd.concat(dfs, ignore_index=True) + merged_df = merged_df.drop_duplicates(subset=["config", "mode"], keep="last") + merged_df = merged_df.sort_values(by=["API", "dtype", "config", "mode"], ignore_index=True) + for col in NUMERIC_COLUMNS: + merged_df[col] = merged_df[col].apply(lambda x: f"{float(x):.6e}") + + output_file = output_path / "tol_full.csv" + merged_df.to_csv(output_file, index=False, na_rep="nan") + + +def write_stat_csv(stats, output_path): + stats_data = [] + for api, dtype, mode in sorted(stats.keys()): + values = stats[(api, dtype, mode)] + abs_diffs = values["abs_diffs"] + rel_diffs = values["rel_diffs"] + + abs_min = min(abs_diffs) + abs_max = max(abs_diffs) + abs_mean = sum(abs_diffs) / len(abs_diffs) + rel_min = min(rel_diffs) + rel_max = max(rel_diffs) + rel_mean = sum(rel_diffs) / len(rel_diffs) + count = len(abs_diffs) + + stats_data.append( + { + "API": api, + "dtype": dtype, + "mode": mode, + "abs_min": f"{abs_min:.6e}", + "abs_max": f"{abs_max:.6e}", + "abs_mean": f"{abs_mean:.6e}", + "rel_min": f"{rel_min:.6e}", + "rel_max": f"{rel_max:.6e}", + "rel_mean": f"{rel_mean:.6e}", + "count": count, + } + ) + + if stats_data: + df = pd.DataFrame(stats_data) + output_file = output_path / "tol_stat.csv" + df.to_csv(output_file, index=False, na_rep="nan") + print(f"\nStatistics saved to {output_file}") + print("Sample of the results:") + print(df.head()) + else: + print("No data to process.") + + +def write_api_stat_csv(api_stats, output_path): + api_stats_data = [] + for api in sorted(api_stats.keys()): + api_dtype = api_stats[api] + dtypes = "/".join(sorted(api_dtype.keys())) + total = sum( + api_dtype[dtype]["forward"] + api_dtype[dtype]["backward"] for dtype in api_dtype + ) + + api_stats_data.append( + { + "API": api, + "dtype": "dtypes:" + dtypes, + "mode": "modes:forward/backward", + "count": total, + "percentage": 100.0, + } + ) + + forward_dtypes = [] + forward_total = 0 + backward_dtypes = [] + backward_total = 0 + for dtype, modes in api_dtype.items(): + if "forward" in modes: + forward_dtypes.append(dtype) + forward_total += modes["forward"] + if "backward" in modes: + backward_dtypes.append(dtype) + backward_total += modes["backward"] + forward_dtypes = "/".join(sorted(forward_dtypes)) + backward_dtypes = "/".join(sorted(backward_dtypes)) + + api_stats_data.append( + { + "API": api, + "dtype": "dtypes:" + forward_dtypes, + "mode": "forward", + "count": forward_total, + "percentage": round(forward_total / total * 100, 2), + } + ) + api_stats_data.append( + { + "API": api, + "dtype": "dtypes:" + backward_dtypes, + "mode": "backward", + "count": backward_total, + "percentage": round(backward_total / total * 100, 2), + } + ) + + for dtype in sorted(api_dtype.keys()): + for mode in ["forward", "backward"]: + count = api_dtype[dtype][mode] + if count > 0: + api_stats_data.append( + { + "API": api, + "dtype": dtype, + "mode": mode, + "count": count, + "percentage": round(count / total * 100, 2), + } + ) + + if api_stats_data: + df = pd.DataFrame(api_stats_data) + output_file = output_path / "tol_stat_api.csv" + df.to_csv(output_file, index=False, na_rep="nan") + print(f"\nAPI statistics saved to {output_file}") + print("Sample of API statistics:") + print(df.head()) + else: + print("No API statistics to process.") + + +def run_tol_stat(input_path=DEFAULT_TEST_LOG_PATH, output_path=None): + input_path = Path(input_path) + output_path = Path(output_path) if output_path is not None else input_path + output_path.mkdir(parents=True, exist_ok=True) + + file_list = collect_csv_files(input_path) + if not file_list: + return + + dfs, stats, api_stats = load_tol_data(file_list) + if not stats: + return + + write_full_csv(dfs, output_path) + write_stat_csv(stats, output_path) + write_api_stat_csv(api_stats, output_path) + + +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="整理 tol_*.csv 精度统计数据") + parser.add_argument( + "--input", + "-i", + type=str, + default=str(DEFAULT_TEST_LOG_PATH), + help="输入路径,包含 tol*.csv 文件", ) + parser.add_argument("--output", "-o", type=str, default=None, help="输出路径(默认同输入路径)") + return parser.parse_args(argv) -# 转换为DataFrame并保存 -if stats_data: - df = pd.DataFrame(stats_data) - output_file = OUTPUT_PATH / "tol_stat.csv" - df.to_csv(output_file, index=False, na_rep="nan") - print(f"\nStatistics saved to {output_file}") - print("Sample of the results:") - print(df.head()) -else: - print("No data to process.") - -# 准备统计数据 -api_stats_data = [] -for api in sorted(api_stats.keys()): - api_dtype = api_stats[api] - dtypes = "/".join(sorted(api_dtype.keys())) - total = sum(api_dtype[dtype]["forward"] + api_dtype[dtype]["backward"] for dtype in api_dtype) - - api_stats_data.append( - { - "API": api, - "dtype": "dtypes:" + dtypes, - "mode": "modes:forward/backward", - "count": total, - "percentage": 100.0, - } - ) - forward_dtypes = [] - forward_total = 0 - backward_dtypes = [] - backward_total = 0 - for dtype, modes in api_dtype.items(): - if "forward" in modes: - forward_dtypes.append(dtype) - forward_total += modes["forward"] - if "backward" in modes: - backward_dtypes.append(dtype) - backward_total += modes["backward"] - forward_dtypes = "/".join(sorted(forward_dtypes)) - backward_dtypes = "/".join(sorted(backward_dtypes)) - - api_stats_data.append( - { - "API": api, - "dtype": "dtypes:" + forward_dtypes, - "mode": "forward", - "count": forward_total, - "percentage": round(forward_total / total * 100, 2), - } - ) - api_stats_data.append( - { - "API": api, - "dtype": "dtypes:" + backward_dtypes, - "mode": "backward", - "count": backward_total, - "percentage": round(backward_total / total * 100, 2), - } - ) +def main(argv=None): + args = parse_args(argv) + run_tol_stat(args.input, args.output) + - for dtype in sorted(api_dtype.keys()): - for mode in ["forward", "backward"]: - count = api_dtype[dtype][mode] - if count > 0: - api_stats_data.append( - { - "API": api, - "dtype": dtype, - "mode": mode, - "count": count, - "percentage": round(count / total * 100, 2), - } - ) - -# 转换为DataFrame并保存 -if api_stats_data: - df = pd.DataFrame(api_stats_data) - output_file = OUTPUT_PATH / "tol_stat_api.csv" - df.to_csv(output_file, index=False, na_rep="nan") - print(f"\nAPI statistics saved to {output_file}") - print("Sample of API statistics:") - print(df.head()) -else: - print("No API statistics to process.") +if __name__ == "__main__": + main() diff --git a/tools/error_stat/error_stat.py b/tools/error_stat/error_stat.py index e0c14771..512bb602 100644 --- a/tools/error_stat/error_stat.py +++ b/tools/error_stat/error_stat.py @@ -1,6 +1,7 @@ # test_log 一键整理小工具 # @author: cangtianhuang -# @date: 2025-11-11 +# @date: 2026-06-11 + # 整理效果:pass + error + invalid (可按类型拆分) from __future__ import annotations @@ -12,7 +13,7 @@ SKIP_ERROR_INFO = [ "(Cannot allocate memory)", "(InvalidArgument)", - "(NotFound)", + # "(NotFound)", "(ResourceExhausted)", "(Unimplemented)", "CUDA out of memory", @@ -24,8 +25,13 @@ "[paddle_to_torch]", "[torch error]", "output type diff error", + "Too large tensor to get cached numpy", + "There is no grad op for inputs:", ] +DEFAULT_TEST_LOG_PATH = Path("tester/api_config/test_log_big_tensor") +RESULT_DIR_NAME = "error_stat_result" + LOG_PREFIXES = { "checkpoint": "checkpoint", "pass": "api_config_pass", @@ -153,7 +159,7 @@ def write_logs_and_meta(output_path, logs_dict, prefix): def error_state(input_path, output_path, split_errors=False): # 写入目标目录下的独立子文件夹,避免与原始日志文件混在同级目录 - output_path = Path(output_path) / "error_stat_result" + output_path = Path(output_path) / RESULT_DIR_NAME if output_path.exists(): shutil.rmtree(output_path) print(f"Cleared existing directory: {output_path}", flush=True) @@ -216,13 +222,13 @@ def error_state(input_path, output_path, split_errors=False): write_logs_and_meta(output_path, invalid_union, "invalid") -def main(): +def parse_args(argv=None): parser = argparse.ArgumentParser(description="test_log 整理工具(可按类型拆分)") parser.add_argument( "--input", "-i", type=str, - default="tester/api_config/test_log_big_tensor", + default=str(DEFAULT_TEST_LOG_PATH), help="输入路径", ) parser.add_argument("--output", "-o", type=str, default=None, help="输出路径(默认同输入路径)") @@ -232,10 +238,13 @@ def main(): action="store_true", help="是否将错误和无效按类型拆分输出", ) - args = parser.parse_args() - if args.output is None: - args.output = args.input - error_state(args.input, args.output, split_errors=args.split_errors) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + output_path = args.output if args.output is not None else args.input + error_state(args.input, output_path, split_errors=args.split_errors) if __name__ == "__main__": diff --git a/tools/get_api_set.py b/tools/extract_api_set.py similarity index 73% rename from tools/get_api_set.py rename to tools/extract_api_set.py index a1df02f3..d07d6c83 100644 --- a/tools/get_api_set.py +++ b/tools/extract_api_set.py @@ -1,11 +1,16 @@ -# 获取 api 集合小工具 +# 提取 API 名集合小工具 # @author: cangtianhuang -# @date: 2025-09-26 +# @date: 2026-06-11 + from __future__ import annotations import argparse from pathlib import Path +DEFAULT_INPUT_PATHS = ["tester/api_config/api_config_tmp.txt"] +DEFAULT_OUTPUT_DIR = Path("tester/api_config/output") +OUTPUT_FILE_NAME = "api_extracted.txt" + def collect_input_files(input_paths): files = [] @@ -22,7 +27,7 @@ def collect_input_files(input_paths): return files -def extract_apis(input_paths, output_dir): +def extract_apis(input_paths, output_dir=DEFAULT_OUTPUT_DIR): input_files = collect_input_files(input_paths) if not input_files: print("No valid input files found") @@ -58,39 +63,44 @@ def extract_apis(input_paths, output_dir): print(f"Total processed: {total_processed}, Unique APIs: {len(api_names)}") - sorted_apis = sorted(api_names) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) + sorted_apis = sorted(api_names) - output_file = output_path / "api_extracted.txt" + output_file = output_path / OUTPUT_FILE_NAME output_file.write_text("\n".join(sorted_apis) + "\n", encoding="utf-8") print(f"Wrote {len(sorted_apis)} API names to {output_file}") -def main(): - default_input = ["tester/api_config/api_config_tmp.txt"] - default_output = "tester/api_config/output" - +def parse_args(argv=None): parser = argparse.ArgumentParser( - description="API提取工具", + description="API 提取工具", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 使用示例: - python %(prog)s -i config.txt # 处理单个配置文件 - python %(prog)s -i configs/ # 处理目录下所有.txt文件 - python %(prog)s -i . -o output/ # 当前目录 + python %(prog)s -i config.txt # 处理单个配置文件 + python %(prog)s -i configs/ # 处理目录下所有 .txt 文件 + python %(prog)s -i . -o output/ # 当前目录 """, ) parser.add_argument( "--input", "-i", nargs="+", - default=default_input, + default=DEFAULT_INPUT_PATHS, help="输入路径列表(支持文件或目录)", ) - parser.add_argument("--output-dir", "-o", default=default_output, help="输出目录路径") + parser.add_argument( + "--output-dir", + "-o", + default=str(DEFAULT_OUTPUT_DIR), + help="输出目录路径", + ) + return parser.parse_args(argv) + - args = parser.parse_args() +def main(argv=None): + args = parse_args(argv) extract_apis(args.input, args.output_dir) diff --git a/tools/extract_cases_from_csv.py b/tools/extract_cases_from_csv.py new file mode 100644 index 00000000..bb80fa44 --- /dev/null +++ b/tools/extract_cases_from_csv.py @@ -0,0 +1,100 @@ +# 从 CSV 中按 API 提取 case 小工具 +# @author: cangtianhuang +# @date: 2026-06-11 + +from __future__ import annotations + +import argparse +import csv +from pathlib import Path + +DEFAULT_CONFIG_PATH = Path("TotalStableFull.csv") +DEFAULT_OUTPUT_DIR = Path(".") +DIFF_THRESHOLD = 1e-16 + + +def _output_paths(api_name, output_dir): + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + return ( + output_path / f"filtered_result_{api_name}.csv", + output_path / f"error_config_{api_name}.txt", + ) + + +def extract_cases_for_api( + api_name, only_diff=False, config_path=DEFAULT_CONFIG_PATH, output_dir=DEFAULT_OUTPUT_DIR +): + filtered_file, error_config_file = _output_paths(api_name, output_dir) + + with ( + Path(config_path).open(newline="") as infile, + filtered_file.open("w", newline="") as outfile, + ): + reader = csv.reader(infile) + writer = csv.writer(outfile) + + header = next(reader) + writer.writerow(header) + + for row in reader: + first_col = row[0] + if api_name not in first_col: + continue + + last_col = float(row[-1]) if row[-1].strip() else 0 + second_last_col = float(row[-2]) if row[-2].strip() else 0 + if only_diff and last_col < DIFF_THRESHOLD and second_last_col < DIFF_THRESHOLD: + continue + + writer.writerow(row) + + configs = set() + with filtered_file.open(newline="") as infile: + reader = csv.reader(infile) + next(reader) + for row in reader: + configs.add(row[2].replace('""', '"')) + + error_config_file.write_text("\n".join(configs), encoding="utf-8") + + +def run_extract_cases( + api_names, only_diff=False, config_path=DEFAULT_CONFIG_PATH, output_dir=DEFAULT_OUTPUT_DIR +): + for api_name in api_names: + extract_cases_for_api(api_name, only_diff, config_path, output_dir) + + +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="从 CSV 中按 API 提取 case") + parser.add_argument("api_names", nargs="+", help="需要提取的 API 名称列表") + parser.add_argument( + "--only-diff", + action="store_true", + help="仅保留最后两列误差不同时为 0 的记录", + ) + parser.add_argument( + "--config-path", + "-c", + type=Path, + default=DEFAULT_CONFIG_PATH, + help="输入 CSV 文件路径", + ) + parser.add_argument( + "--output-dir", + "-o", + type=Path, + default=DEFAULT_OUTPUT_DIR, + help="输出目录路径", + ) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + run_extract_cases(args.api_names, args.only_diff, args.config_path, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/tools/extract_or_remove_api_cases.py b/tools/extract_or_remove_api_cases.py new file mode 100644 index 00000000..73828294 --- /dev/null +++ b/tools/extract_or_remove_api_cases.py @@ -0,0 +1,175 @@ +# 按 API 提取或删除配置 case 小工具 +# @author: cangtianhuang +# @date: 2026-06-11 + +from __future__ import annotations + +import argparse +from pathlib import Path + +DEFAULT_CONFIG_DIR = Path("tester/api_config/5_accuracy") +DEFAULT_FILE_KEYWORD = "accuracy" +DEFAULT_OUTPUT_FILE = Path("mytmp.txt") + + +def collect_target_files(config_dir=DEFAULT_CONFIG_DIR, file_keyword=DEFAULT_FILE_KEYWORD): + path = Path(config_dir) + if not path.exists(): + print(f"{path} not exists") + return [] + return sorted(file_path for file_path in path.iterdir() if file_keyword in file_path.name) + + +def count_unique_apis(config_file): + api_names = set() + with Path(config_file).open(encoding="utf-8") as f: + for line in f: + api_name = line.split("(", 1)[0] + if api_name: + api_names.add(api_name) + return len(api_names) + + +def check_config_clean( + config_prefix, config_dir=DEFAULT_CONFIG_DIR, file_keyword=DEFAULT_FILE_KEYWORD +): + target_files = collect_target_files(config_dir, file_keyword) + total_count = 0 + existing_files = set() + existing_counts = {} + + for target_file in target_files: + count = 0 + with target_file.open(encoding="utf-8") as f: + for line in f: + if config_prefix in line: + total_count += 1 + count += 1 + existing_files.add(target_file.name) + existing_counts[target_file.name] = count + + api_name = config_prefix[:-1] + if total_count: + print(f"{api_name} is still exist, number of times : {total_count}") + print(f"{api_name} is still exist in these files : ") + for file_name in sorted(existing_files): + print(file_name, existing_counts[file_name]) + return False + + print("clean") + return True + + +def append_config_cases( + config_prefix, output_file, config_dir=DEFAULT_CONFIG_DIR, file_keyword=DEFAULT_FILE_KEYWORD +): + target_files = collect_target_files(config_dir, file_keyword) + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + for target_file in target_files: + lines = target_file.read_text(encoding="utf-8").splitlines(keepends=True) + matched_lines = [line for line in lines if config_prefix in line] + + with output_path.open("a", encoding="utf-8") as f: + f.writelines(matched_lines) + + if matched_lines: + print( + "add", + len(matched_lines), + "lines in", + target_file.name, + "to ", + output_path, + "successfully", + ) + + +def backup_file(input_file): + backup_path = input_file.with_suffix(input_file.suffix + ".backup") + backup_path.write_text(input_file.read_text(encoding="utf-8"), encoding="utf-8") + print(f"Created backup: {backup_path}") + + +def remove_config_cases( + config_prefix, config_dir=DEFAULT_CONFIG_DIR, file_keyword=DEFAULT_FILE_KEYWORD, backup=True +): + target_files = collect_target_files(config_dir, file_keyword) + for target_file in target_files: + lines = target_file.read_text(encoding="utf-8").splitlines(keepends=True) + count = sum(config_prefix in line for line in lines) + remaining_lines = [line for line in lines if config_prefix not in line] + + if count and backup: + backup_file(target_file) + + target_file.write_text("".join(remaining_lines), encoding="utf-8") + if count: + print("remove", count, "lines in", target_file.name, "successfully") + + +def process_api_cases( + config, + output_file=DEFAULT_OUTPUT_FILE, + remove=False, + config_dir=DEFAULT_CONFIG_DIR, + file_keyword=DEFAULT_FILE_KEYWORD, + backup=True, +): + print(f"开始处理配置:{config},目标文件:{output_file}") + config_prefix = f"{config}(" + + if not check_config_clean(config_prefix, config_dir, file_keyword): + append_config_cases(config_prefix, output_file, config_dir, file_keyword) + + if remove: + print(f"执行删除配置:{config}") + remove_config_cases(config_prefix, config_dir, file_keyword, backup) + + +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="按 API 提取或删除配置 case") + parser.add_argument("--config", type=str, required=True, help="配置字符串,例如 paddle.numel") + parser.add_argument( + "--dst", + type=Path, + default=DEFAULT_OUTPUT_FILE, + help="临时文件名,默认是 mytmp.txt", + ) + parser.add_argument( + "--config-dir", + type=Path, + default=DEFAULT_CONFIG_DIR, + help="配置目录路径", + ) + parser.add_argument( + "--file-keyword", + type=str, + default=DEFAULT_FILE_KEYWORD, + help="待扫描文件名关键词", + ) + parser.add_argument("--remove", action="store_true", help="是否执行删除配置") + parser.add_argument( + "--no-backup", + action="store_false", + dest="backup", + help="删除时不创建备份", + ) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + process_api_cases( + args.config, + args.dst, + args.remove, + args.config_dir, + args.file_keyword, + args.backup, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/get_cases_from_csv.py b/tools/get_cases_from_csv.py deleted file mode 100644 index 58a9f4f6..00000000 --- a/tools/get_cases_from_csv.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import csv -from pathlib import Path - -import typer - -app = typer.Typer() - - -def _get_cases(api_name: str, only_diff: bool, original_csv: str): - with ( - open(original_csv) as infile, - open(f"filtered_result_{api_name}.csv", "w", newline="") as outfile, - ): - reader = csv.reader(infile) - writer = csv.writer(outfile) - - # 写入头部行 - header = next(reader) - writer.writerow(header) - - # 处理数据行 - for row in reader: - first_col = row[0] - - if api_name not in first_col: - continue - - last_col = float(row[-1]) if row[-1].strip() else 0 - second_last_col = float(row[-2]) if row[-2].strip() else 0 - - if only_diff and last_col < 1e-16 and second_last_col < 1e-16: - continue - - writer.writerow(row) - - with ( - open(f"filtered_result_{api_name}.csv") as infile, - open(f"error_config_{api_name}.txt", "w") as outfile, - ): - reader = csv.reader(infile) - - header = next(reader) - - outs = [] - for row in reader: - first_col = row[0] - last_col = float(row[-1]) if row[-1].strip() else 0 - second_last_col = float(row[-2]) if row[-2].strip() else 0 - outs.append(row[2].replace('""', '"')) - outfile.write("\n".join(set(outs))) - - -@app.command() -def get_cases( - api_names: list[str], - only_diff: bool = False, - config_path: Path = Path("TotalStableFull.csv"), -): - for api_name in api_names: - _get_cases(api_name, only_diff, config_path.as_posix()) - - -if __name__ == "__main__": - app() diff --git a/tools/get_diff_config_set.py b/tools/get_diff_config_set.py deleted file mode 100644 index 75d39baa..00000000 --- a/tools/get_diff_config_set.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -PATH1 = Path("tester/api_config/test_log_cinn_filtered/pass_config.txt") -PATH2 = Path("tester/api_config/test_log_cinn/api_config_pass.txt") - -content1 = PATH1.read_text(encoding="utf-8") -config1 = {line.strip() for line in content1.splitlines() if line.strip()} -content2 = PATH2.read_text(encoding="utf-8") -config2 = {line.strip() for line in content2.splitlines() if line.strip()} - -if len(config1) > len(config2): - print(f"len(config1) > len(config2), {len(config1) - len(config2)} lines removed") - for config in sorted(config1 - config2): - print(config) -else: - print(f"len(config1) < len(config2), {len(config2) - len(config1)} lines added") - for config in sorted(config2 - config1): - print(config) diff --git a/tools/get_config_set.py b/tools/merge_config_set.py similarity index 61% rename from tools/get_config_set.py rename to tools/merge_config_set.py index 93f319a5..82484e47 100644 --- a/tools/get_config_set.py +++ b/tools/merge_config_set.py @@ -1,11 +1,16 @@ -# 获取 api 配置集合小工具 +# 合并 API 配置集合小工具 # @author: cangtianhuang -# @date: 2025-10-29 +# @date: 2026-06-11 + from __future__ import annotations import argparse from pathlib import Path +DEFAULT_INPUT_PATHS = ["tester/api_config/api_config_tmp.txt"] +DEFAULT_OUTPUT_DIR = Path("tester/api_config/output") +DEFAULT_MAX_CONFIGS_PER_FILE = 500000 + def collect_input_files(input_paths): files = [] @@ -19,7 +24,19 @@ def collect_input_files(input_paths): return files -def process_api_configs(input_paths, output_dir, max_configs_per_file=500000, inplace=False): +def _backup_file(input_file): + backup_file = input_file.with_suffix(input_file.suffix + ".backup") + backup_file.write_text(input_file.read_text(encoding="utf-8"), encoding="utf-8") + print(f"Created backup: {backup_file}") + + +def process_api_configs( + input_paths, + output_dir=DEFAULT_OUTPUT_DIR, + max_configs_per_file=DEFAULT_MAX_CONFIGS_PER_FILE, + inplace=False, + backup=True, +): input_files = collect_input_files(input_paths) if not input_files: print("No valid input files found") @@ -53,6 +70,8 @@ def process_api_configs(input_paths, output_dir, max_configs_per_file=500000, in merged_content = "\n".join(sorted_configs) + "\n" for input_file in input_files: try: + if backup: + _backup_file(input_file) input_file.write_text(merged_content, encoding="utf-8") print(f"Inplace wrote {len(sorted_configs)} configs to {input_file}") except Exception as err: @@ -75,31 +94,38 @@ def process_api_configs(input_paths, output_dir, max_configs_per_file=500000, in print(f"Wrote {len(chunk)} configs to {output_file}") -def main(): - default_input = ["tester/api_config/api_config_tmp.txt"] - default_output = "tester/api_config/output" - +def parse_args(argv=None): parser = argparse.ArgumentParser( - description="API配置集合整理工具", + description="API 配置集合整理工具", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 使用示例: - python %(prog)s -i file.txt # 处理单个文件 - python %(prog)s -i dir/ # 处理目录下所有.txt文件 - python %(prog)s -i . -o output/ --max-configs 100000 # 当前目录,限制10万条/文件 - python %(prog)s -i file.txt -I # 原地去重排序,覆盖原文件 - python %(prog)s -i dir/ -I # 原地处理目录下所有.txt文件 + python %(prog)s -i file.txt # 处理单个文件 + python %(prog)s -i dir/ # 处理目录下所有 .txt 文件 + python %(prog)s -i . -o output/ --max-configs 100000 # 限制 10 万条/文件 + python %(prog)s -i file.txt -I # 原地去重排序,覆盖原文件 + python %(prog)s -i dir/ -I --no-backup # 原地处理且不创建备份 """, ) parser.add_argument( "--input", "-i", nargs="+", - default=default_input, + default=DEFAULT_INPUT_PATHS, help="输入路径列表(支持文件或目录)", ) - parser.add_argument("--output-dir", "-o", default=default_output, help="输出目录路径") - parser.add_argument("--max-configs", type=int, default=500000, help="单个输出文件最大配置数量") + parser.add_argument( + "--output-dir", + "-o", + default=str(DEFAULT_OUTPUT_DIR), + help="输出目录路径", + ) + parser.add_argument( + "--max-configs", + type=int, + default=DEFAULT_MAX_CONFIGS_PER_FILE, + help="单个输出文件最大配置数量", + ) parser.add_argument( "--inplace", "-I", @@ -107,9 +133,24 @@ def main(): default=False, help="原地修改:将合并去重排序后的结果写回所有输入文件(忽略 --output-dir)", ) + parser.add_argument( + "--no-backup", + action="store_false", + dest="backup", + help="原地修改时不创建备份", + ) + return parser.parse_args(argv) + - args = parser.parse_args() - process_api_configs(args.input, args.output_dir, args.max_configs, args.inplace) +def main(argv=None): + args = parse_args(argv) + process_api_configs( + args.input, + args.output_dir, + args.max_configs, + args.inplace, + args.backup, + ) if __name__ == "__main__": diff --git a/tools/move_config.py b/tools/move_config.py deleted file mode 100644 index ed9bbe3f..00000000 --- a/tools/move_config.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -import argparse -import os - -path = "/root/PaddleAPITest/tester/api_config/5_accuracy/" -keyword = "accuracy" - - -def getnum(dst): - s = set() - with open(path + dst) as f: - data = f.readlines() - for j in data: - tmp = j.split("(")[0] - if tmp not in s: - s.add(tmp) - return len(s) - - -def isclean(config): - list = os.listdir(path) - merge_list = [] - for i in list: - if i.find(keyword) != -1: - merge_list.append(i) - - cnt = 0 - exist_list = set() - exist_nums = {} - for i in merge_list: - cnt2 = 0 - with open(path + i, encoding="utf8") as f: - data = f.readlines() - for j in data: - if config in j: - cnt += 1 - cnt2 += 1 - exist_list.add(i) - exist_nums[i] = cnt2 - - if cnt: - print(config[: len(config) - 1] + " is still exist, number of times : ", cnt) - print(config[: len(config) - 1] + " is still exist in these files : ") - for i in exist_list: - print(i, exist_nums[i]) - return 0 - else: - print("clean") - return 1 - - -def add(config, dst): - list = os.listdir(path) - merge_list = [] - for i in list: - if i.find(keyword) != -1: - merge_list.append(i) - - for i in merge_list: - with open(path + i) as f: - lines = f.readlines() - - matched_lines = [line for line in lines if config in line] - - with open(path + dst, "a+") as f: - f.writelines(matched_lines) - - if len(matched_lines): - print("add", len(matched_lines), "lines in", i, "to ", dst, "successfully") - - -def remove(config): - list = os.listdir(path) - merge_list = [] - for i in list: - if i.find(keyword) != -1: - merge_list.append(i) - - for i in merge_list: - with open(path + i) as f: - lines = f.readlines() - - count = sum(config in line for line in lines) - remaining_lines = [line for line in lines if config not in line] - - with open(path + i, "w") as f: - f.writelines(remaining_lines) - - if count: - print("remove", count, "lines in", i, "successfully") - - -def main(): - parser = argparse.ArgumentParser(description="Process config and temporary file.") - parser.add_argument("--config", type=str, required=True, help="配置字符串,例如 paddle.numel") - parser.add_argument("--dst", type=str, default="mytmp.txt", help="临时文件名,默认是 mytmp.txt") - parser.add_argument("--remove", action="store_true", help="是否执行删除配置") - args = parser.parse_args() - - config = args.config - dst = args.dst - print(f"开始处理配置:{config},目标文件:{dst}") - config += "(" - - if not isclean(config): - add(config, dst) # 向指定临时文件以a+方式添加写入,不会改变原有配置 - - if args.remove: - print(f"执行删除配置:{config[: len(config) - 1]}") - remove(config) # 仅当设置了--remove时才执行删除 - - -if __name__ == "__main__": - main() diff --git a/tools/remove_case_by_api.py b/tools/remove_case_by_api.py deleted file mode 100644 index 7e8b8ad6..00000000 --- a/tools/remove_case_by_api.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import annotations - -import glob -import re - - -def delete_lines_with_keywords(file_pattern, keyword_set, case_sensitive=True): - """删除匹配模式文件中包含关键字的行 - :param file_pattern: 文件匹配模式(如"A*") - :param keyword_set: 关键字集合 - :param case_sensitive: 是否区分大小写 - """ - # 获取匹配文件列表 - target_files = glob.glob(file_pattern) - if not target_files: - print(f"警告:未找到匹配 {file_pattern} 的文件") - return - - # 预先编译正则表达式 - flags = 0 if case_sensitive else re.IGNORECASE - patterns = [re.compile(re.escape(kw), flags) for kw in keyword_set] - total_removed = 0 - - for file_path in target_files: - try: - # 读取文件内容 - with open(file_path) as f: - lines = f.readlines() - - # 过滤包含关键字的行 - original_count = len(lines) - new_lines = [ - line for line in lines if not any(pattern.search(line) for pattern in patterns) - ] - removed_count = original_count - len(new_lines) - total_removed += removed_count - - # 写回文件 - with open(file_path, "w") as f: - f.writelines(new_lines) - print( - f"处理 {file_path}: 原始行数 {original_count}, 删除 {removed_count} 行, 保留 {len(new_lines)} 行" - ) - - except Exception as e: - print(f"处理文件 {file_path} 时出错: {e!s}") - - print(f"\n处理完成!共处理 {len(target_files)} 个文件, 总计删除 {total_removed} 行") - - -def load_keywords(keyword_file): - """从文件加载关键字集合""" - try: - with open(keyword_file) as f: - return {line.strip() for line in f if line.strip()} - except FileNotFoundError: - print(f"错误:关键字文件 {keyword_file} 不存在") - exit(1) - - -if __name__ == "__main__": - # 配置参数 - FILE_PATTERN = "tester/api_config/monitor_config/accuracy/GPU/monitoring_configs_*.txt" # 匹配所有以A开头的文件 - KEYWORD_FILE = "kw.txt" # 关键字文件名 - CASE_SENSITIVE = True # 区分大小写(设为False关闭) - - # 执行删除操作 - keywords = load_keywords(KEYWORD_FILE) - if keywords: - print( - f"加载 {len(keywords)} 个关键字: {', '.join(sorted(keywords)[:5])}" - + ("..." if len(keywords) > 5 else "") - ) - delete_lines_with_keywords(FILE_PATTERN, keywords, CASE_SENSITIVE) - else: - print("警告:关键字集为空,未执行任何操作") diff --git a/tools/remove_config_set.py b/tools/remove_config_set.py new file mode 100644 index 00000000..93f92840 --- /dev/null +++ b/tools/remove_config_set.py @@ -0,0 +1,550 @@ +# 移除指定配置集合小工具 +# @author: cangtianhuang +# @date: 2026-06-11 + +from __future__ import annotations + +import argparse +import re +from datetime import datetime +from pathlib import Path + +BACKUP_DIR_NAME = ".rm_config_backups" +LEGACY_BACKUP_DIR_NAME = "remove_config_set_backups" +BACKUP_DIR_NAMES = (BACKUP_DIR_NAME, LEGACY_BACKUP_DIR_NAME) +BACKUP_SUFFIX = ".backup" +REMOVED_CONFIGS_SUFFIX = ".removed_configs" +CONFIRM_TEXT = "yes" +TIMESTAMP_RE = re.compile(r"^\d{8}_\d{6}_\d{6}(?:\.\d+)?$") + + +def print_section(title): + print(f"\n== {title} ==") + + +def print_file_item(index, total, path): + print(f"\n[{index}/{total}] {path}") + + +def confirm_action(message, dry_run=False): + if dry_run: + return True + + answer = input(f"{message}\nType '{CONFIRM_TEXT}' to continue: ").strip() + if answer != CONFIRM_TEXT: + print("Canceled") + return False + return True + + +def collect_input_files(input_paths): + files = [] + seen_files = set() + for input_path in input_paths: + path = Path(input_path) + if path.is_file(): + text_files = [path] + elif path.is_dir(): + text_files = [ + text_file + for text_file in sorted(path.rglob("*.txt")) + if not any( + backup_dir_name in text_file.parts for backup_dir_name in BACKUP_DIR_NAMES + ) + ] + else: + print(f"Ignored invalid path: {path}") + continue + + for text_file in text_files: + resolved_file = text_file.resolve() + if resolved_file not in seen_files: + files.append(text_file) + seen_files.add(resolved_file) + return files + + +def get_timestamped_path(input_file, timestamp, suffix): + backup_dir = input_file.parent / BACKUP_DIR_NAME + backup_file = backup_dir / f"{input_file.name}.{timestamp}.{suffix}" + if not backup_file.exists(): + return backup_file + + index = 1 + while True: + backup_file = backup_dir / f"{input_file.name}.{timestamp}.{index}.{suffix}" + if not backup_file.exists(): + return backup_file + index += 1 + + +def load_configs_to_remove(remove_config_paths): + configs_to_remove = set() + remove_config_files = collect_input_files(remove_config_paths) + if not remove_config_files: + print("No valid remove config files found") + return configs_to_remove + + print_section("Load Remove Configs") + for index, path in enumerate(remove_config_files, start=1): + try: + content = path.read_text(encoding="utf-8") + lines = [line.strip() for line in content.splitlines() if line.strip()] + old_count = len(configs_to_remove) + configs_to_remove.update(lines) + print(f"[{index}/{len(remove_config_files)}] {path}") + print(f" lines: {len(lines)}, new unique: {len(configs_to_remove) - old_count}") + except Exception as err: + print(f"Error reading remove config file {path}: {err}") + raise + + print(f"Total unique configs to remove: {len(configs_to_remove)}") + return configs_to_remove + + +def remove_configs_from_files(input_paths, remove_config_paths, backup=True, dry_run=False): + input_files = collect_input_files(input_paths) + if not input_files: + print("No valid input files found") + return [] + + configs_to_remove = load_configs_to_remove(remove_config_paths) + if not configs_to_remove: + print("No configs to remove found") + return [] + + print_section("Process Input Files") + if dry_run: + print("Mode: dry-run (no files will be modified)") + print(f"Files: {len(input_files)}, configs to remove: {len(configs_to_remove)}") + + total_removed = 0 + files_modified = 0 + all_removed_lines = [] + + for index, input_file in enumerate(input_files, start=1): + print_file_item(index, len(input_files), input_file) + try: + content = input_file.read_text(encoding="utf-8") + original_lines = content.splitlines(keepends=True) + + filtered_lines = [] + removed_lines = [] + + for line in original_lines: + stripped_line = line.strip() + if stripped_line and stripped_line in configs_to_remove: + removed_lines.append(line) + else: + filtered_lines.append(line) + + removed_count = len(removed_lines) + + if removed_count > 0: + files_modified += 1 + total_removed += removed_count + all_removed_lines.extend(removed_lines) + + if backup: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + backup_file = get_timestamped_path(input_file, timestamp, "backup") + removed_backup_file = get_timestamped_path( + input_file, timestamp, "removed_configs" + ) + if dry_run: + print(f" backup: {backup_file} (dry-run)") + print(f" removed: {removed_backup_file} (dry-run)") + else: + backup_file.parent.mkdir(parents=True, exist_ok=True) + backup_file.write_text(content, encoding="utf-8") + removed_backup_file.write_text("".join(removed_lines), encoding="utf-8") + print(f" backup: {backup_file}") + print(f" removed: {removed_backup_file}") + + if dry_run: + print(f" status: would modify") + print(f" remove: {removed_count} configs") + print(f" remain: {len(filtered_lines)} lines") + else: + input_file.write_text("".join(filtered_lines), encoding="utf-8") + print(f" status: modified") + print(f" removed: {removed_count} configs") + print(f" remain: {len(filtered_lines)} lines") + else: + print(" status: unchanged") + print(" removed: 0 configs") + + except Exception as err: + print(f" status: error") + print(f" error: {err}") + continue + + print_section("Summary") + print(f"Files processed: {len(input_files)}") + print(f"Files {'would be modified' if dry_run else 'modified'}: {files_modified}") + print(f"Configs {'would be removed' if dry_run else 'removed'}: {total_removed}") + return all_removed_lines + + +def parse_backup_artifact(artifact_path): + name = artifact_path.name + if name.endswith(BACKUP_SUFFIX): + suffix = BACKUP_SUFFIX + elif name.endswith(REMOVED_CONFIGS_SUFFIX): + suffix = REMOVED_CONFIGS_SUFFIX + else: + return None + + base_name = name[: -len(suffix)] + dot_index = base_name.rfind(".") + if dot_index == -1: + return base_name, suffix, None + + possible_timestamp = base_name[dot_index + 1 :] + if not TIMESTAMP_RE.match(possible_timestamp): + return base_name, suffix, None + return base_name[:dot_index], suffix, possible_timestamp + + +def collect_compact_groups(backup_dir): + groups = {} + for artifact_path in sorted(backup_dir.iterdir()): + if not artifact_path.is_file(): + continue + parsed = parse_backup_artifact(artifact_path) + if parsed is None: + continue + original_name, suffix, timestamp = parsed + group = groups.setdefault(original_name, {"backups": [], "removed": []}) + artifact = {"path": artifact_path, "timestamp": timestamp} + if suffix == BACKUP_SUFFIX: + group["backups"].append(artifact) + else: + group["removed"].append(artifact) + return groups + + +def compact_backup_dir(backup_dir_path, dry_run=False): + backup_dir = Path(backup_dir_path) + if not backup_dir.is_dir(): + print(f"Invalid backup directory: {backup_dir}") + return + + groups = collect_compact_groups(backup_dir) + print_section("Compact Backup Directory") + if dry_run: + print("Mode: dry-run (no files will be modified)") + print(f"Backup dir: {backup_dir}") + print(f"Modified files: {len(groups)}") + if not groups: + return + if not confirm_action("This will rewrite backup artifacts in the backup directory.", dry_run): + return + + groups_compacted = 0 + backups_deleted = 0 + removed_deleted = 0 + for index, (original_name, group) in enumerate(sorted(groups.items()), start=1): + backups = sorted( + group["backups"], + key=lambda artifact: ( + artifact["timestamp"] is None, + artifact["timestamp"] or "", + artifact["path"].name, + ), + ) + removed_artifacts = sorted( + group["removed"], + key=lambda artifact: ( + artifact["timestamp"] is None, + artifact["timestamp"] or "", + artifact["path"].name, + ), + ) + compact_backup = backup_dir / f"{original_name}{BACKUP_SUFFIX}" + compact_removed = backup_dir / f"{original_name}{REMOVED_CONFIGS_SUFFIX}" + + print_file_item(index, len(groups), original_name) + if not backups: + print(" status: skipped") + print(" reason: no backup file") + continue + + first_backup = backups[0]["path"] + merged_removed = [] + seen_removed = set() + for artifact in removed_artifacts: + removed_file = artifact["path"] + try: + lines = removed_file.read_text(encoding="utf-8").splitlines() + except Exception as err: + print(f" warning: failed to read {removed_file}: {err}") + continue + add_unique_configs(merged_removed, seen_removed, lines) + + print(f" keep backup: {first_backup.name} -> {compact_backup.name}") + print(f" merge removed: {len(removed_artifacts)} files -> {compact_removed.name}") + print(f" unique removed configs: {len(merged_removed)}") + + if dry_run: + print(" status: would compact") + groups_compacted += 1 + backups_deleted += sum( + artifact["path"].name != compact_backup.name for artifact in backups[1:] + ) + removed_deleted += sum( + artifact["path"].name != compact_removed.name for artifact in removed_artifacts + ) + continue + + backup_content = first_backup.read_text(encoding="utf-8") + removed_content = "\n".join(merged_removed) + if removed_content: + removed_content += "\n" + + backup_paths = [artifact["path"] for artifact in backups] + removed_paths = [artifact["path"] for artifact in removed_artifacts] + for artifact_path in backup_paths + removed_paths: + if artifact_path.name not in {compact_backup.name, compact_removed.name}: + artifact_path.unlink() + + compact_backup.write_text(backup_content, encoding="utf-8") + compact_removed.write_text(removed_content, encoding="utf-8") + groups_compacted += 1 + backups_deleted += sum(path.name != compact_backup.name for path in backup_paths[1:]) + removed_deleted += sum(path.name != compact_removed.name for path in removed_paths) + print(" status: compacted") + + print_section("Summary") + print(f"Groups {'would be compacted' if dry_run else 'compacted'}: {groups_compacted}") + print(f"Extra backups {'would be removed' if dry_run else 'removed'}: {backups_deleted}") + print(f"Removed-config files {'would be merged' if dry_run else 'merged'}: {removed_deleted}") + + +def collect_backup_files(input_file): + backup_files = [] + for backup_dir_name in BACKUP_DIR_NAMES: + backup_dir = input_file.parent / backup_dir_name + if backup_dir.is_dir(): + backup_files.extend(backup_dir.glob(f"{input_file.name}.*.backup")) + return sorted(backup_files, key=lambda backup_file: backup_file.name) + + +def revert_files(input_paths, dry_run=False): + input_files = collect_input_files(input_paths) + if not input_files: + print("No valid input files found") + return + + print_section("Revert Input Files") + if dry_run: + print("Mode: dry-run (no files will be modified)") + print(f"Files: {len(input_files)}") + if not confirm_action("This will restore each input file from its latest backup.", dry_run): + return + + files_reverted = 0 + files_would_revert = 0 + for index, input_file in enumerate(input_files, start=1): + print_file_item(index, len(input_files), input_file) + backup_files = collect_backup_files(input_file) + if not backup_files: + print(" status: no backup found") + continue + + latest_backup = backup_files[-1] + if dry_run: + files_would_revert += 1 + print(" status: would revert") + print(f" source: {latest_backup}") + else: + input_file.write_text(latest_backup.read_text(encoding="utf-8"), encoding="utf-8") + files_reverted += 1 + print(" status: reverted") + print(f" source: {latest_backup}") + + print_section("Summary") + print(f"Files processed: {len(input_files)}") + print( + f"Files {'would be reverted' if dry_run else 'reverted'}: " + f"{files_would_revert if dry_run else files_reverted}" + ) + + +def collect_backup_dirs(input_paths): + backup_dirs = [] + seen_dirs = set() + for input_path in input_paths: + path = Path(input_path) + if path.is_file(): + candidate_dirs = [path.parent / BACKUP_DIR_NAME] + elif path.is_dir(): + candidate_dirs = [] + for backup_dir_name in BACKUP_DIR_NAMES: + candidate_dirs.extend( + backup_dir + for backup_dir in sorted(path.rglob(backup_dir_name)) + if backup_dir.is_dir() + ) + candidate_dirs.append(path / backup_dir_name) + else: + continue + + for backup_dir in candidate_dirs: + if not backup_dir.is_dir(): + continue + resolved_dir = backup_dir.resolve() + if resolved_dir not in seen_dirs: + backup_dirs.append(backup_dir) + seen_dirs.add(resolved_dir) + return backup_dirs + + +def collect_removed_config_files(input_paths): + removed_config_files = [] + seen_files = set() + for backup_dir in collect_backup_dirs(input_paths): + for removed_config_file in sorted(backup_dir.glob("*.removed_configs")): + resolved_file = removed_config_file.resolve() + if resolved_file not in seen_files: + removed_config_files.append(removed_config_file) + seen_files.add(resolved_file) + return removed_config_files + + +def add_unique_configs(configs, seen_configs, lines): + added_count = 0 + for line in lines: + stripped_line = line.strip() + if stripped_line and stripped_line not in seen_configs: + configs.append(stripped_line) + seen_configs.add(stripped_line) + added_count += 1 + return added_count + + +def merge_removed_configs(input_paths, output_path, extra_removed_lines=None, dry_run=False): + removed_config_files = collect_removed_config_files(input_paths) + merged_configs = [] + seen_configs = set() + + print_section("Merge Removed Configs") + if dry_run: + print("Mode: dry-run (no files will be modified)") + + for index, removed_config_file in enumerate(removed_config_files, start=1): + try: + lines = removed_config_file.read_text(encoding="utf-8").splitlines() + added_count = add_unique_configs(merged_configs, seen_configs, lines) + print(f"[{index}/{len(removed_config_files)}] {removed_config_file}") + print(f" lines: {len(lines)}, new unique: {added_count}") + except Exception as err: + print(f"[{index}/{len(removed_config_files)}] {removed_config_file}") + print(" status: error") + print(f" error: {err}") + continue + + if extra_removed_lines: + added_count = add_unique_configs(merged_configs, seen_configs, extra_removed_lines) + print("[current run]") + print(f" lines: {len(extra_removed_lines)}, new unique: {added_count}") + + output_file = Path(output_path) + output_content = "\n".join(merged_configs) + if output_content: + output_content += "\n" + + print_section("Summary") + if dry_run: + print(f"Removed-config backups: {len(removed_config_files)}") + print(f"Unique configs to merge: {len(merged_configs)}") + print(f"Output: {output_file} (dry-run)") + return + + output_file.parent.mkdir(parents=True, exist_ok=True) + output_file.write_text(output_content, encoding="utf-8") + print(f"Removed-config backups: {len(removed_config_files)}") + print(f"Unique configs merged: {len(merged_configs)}") + print(f"Output: {output_file}") + + +def parse_args(argv=None): + parser = argparse.ArgumentParser( + description="移除指定配置工具", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +使用示例: + python %(prog)s -i input.txt -r remove_configs.txt # 从单文件删除配置 + python %(prog)s -i file1.txt file2.txt -r remove_configs.txt # 从多文件删除配置 + python %(prog)s -i ./configs/ -r remove_configs.txt # 从文件夹删除配置 + python %(prog)s -i input.txt -r remove1.txt remove2.txt remove_dir/ # 合并多个删除配置集 + python %(prog)s -i configs/ -r remove_configs.txt -n # 预览删除与备份 + python %(prog)s -i configs/ -m merged.txt # 合并历史剔除配置 + python %(prog)s -i configs/ -R # 按最新备份回退 + python %(prog)s -c configs/.rm_config_backups # 精简备份目录 + python %(prog)s -i input.txt -r remove_configs.txt --no-backup # 不创建备份 +注意: 所有删除操作会原地修改文件。默认仅在文件被修改时,将带时间戳的 .backup 原文件备份和 .removed_configs 剔除配置备份写入同目录下的 .rm_config_backups/。 + """, + ) + parser.add_argument("-i", "--input", nargs="+", help="待处理的文件或目录") + parser.add_argument("-r", "--remove", nargs="+", help="包含要删除配置的文件或目录") + parser.add_argument( + "-m", + "--merge", + metavar="OUTPUT", + help="合并 input 对应备份目录下多次产生的 .removed_configs 到指定文件", + ) + parser.add_argument("-n", "--dry-run", action="store_true", help="仅预览将要执行的修改和输出") + parser.add_argument( + "-R", "--revert", action="store_true", help="按最新 .backup 回退 input 文件" + ) + parser.add_argument("-c", "--compact-backups", metavar="DIR", help="精简指定备份目录") + parser.add_argument( + "--no-backup", + action="store_false", + dest="backup", + help="不创建备份文件", + ) + args = parser.parse_args(argv) + if args.compact_backups and (args.input or args.remove or args.merge or args.revert): + parser.error("-c/--compact-backups cannot be used with -i, -r, -m, or -R") + if args.revert and (args.remove or args.merge): + parser.error("-R/--revert cannot be used with -r/--remove or -m/--merge") + if not args.compact_backups and not args.input: + parser.error("-i/--input is required unless using -c/--compact-backups") + if not args.remove and not args.merge and not args.revert and not args.compact_backups: + parser.error( + "at least one of -r/--remove, -m/--merge, -R/--revert, " + "or -c/--compact-backups is required" + ) + return args + + +def main(argv=None): + args = parse_args(argv) + if args.compact_backups: + compact_backup_dir(args.compact_backups, dry_run=args.dry_run) + return + if args.revert: + revert_files(args.input, dry_run=args.dry_run) + return + + removed_lines = [] + if args.remove: + removed_lines = remove_configs_from_files( + args.input, + args.remove, + backup=args.backup, + dry_run=args.dry_run, + ) + if args.merge: + merge_removed_configs( + args.input, + args.merge, + extra_removed_lines=removed_lines, + dry_run=args.dry_run, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/remove_configs.py b/tools/remove_configs.py deleted file mode 100644 index 340e7fe9..00000000 --- a/tools/remove_configs.py +++ /dev/null @@ -1,126 +0,0 @@ -# 移除指定配置小工具 -# @author: cangtianhuang -# @date: 2025-09-26 -from __future__ import annotations - -import argparse -from pathlib import Path - - -def collect_input_files(input_paths): - files = [] - for input_path in input_paths: - path = Path(input_path) - if path.is_file(): - files.append(path) - elif path.is_dir(): - text_files = list(path.rglob("*.txt")) - files.extend(text_files) - return files - - -def load_configs_to_remove(remove_config_file): - configs_to_remove = set() - - path = Path(remove_config_file) - try: - content = path.read_text(encoding="utf-8") - lines = [line.strip() for line in content.splitlines() if line.strip()] - configs_to_remove.update(lines) - print(f"Loaded {len(configs_to_remove)} configs to remove from {path}") - except Exception as err: - print(f"Error reading remove config file {path}: {err}") - raise - - return configs_to_remove - - -def remove_configs_from_files(input_paths, remove_config_file, backup=False): - input_files = collect_input_files(input_paths) - if not input_files: - print("No valid input files found") - return - - configs_to_remove = load_configs_to_remove(remove_config_file) - if not configs_to_remove: - print("No configs to remove found") - return - - print(f"Processing {len(input_files)} files...") - print(f"Will remove {len(configs_to_remove)} unique configs") - - total_removed = 0 - files_modified = 0 - - for input_file in input_files: - try: - content = input_file.read_text(encoding="utf-8") - original_lines = content.splitlines() - - filtered_lines = [] - removed_count = 0 - - for line in original_lines: - stripped_line = line.strip() - if stripped_line and stripped_line in configs_to_remove: - removed_count += 1 - else: - filtered_lines.append(line) - - if removed_count > 0: - files_modified += 1 - total_removed += removed_count - - if backup: - backup_file = input_file.with_suffix(input_file.suffix + ".backup") - backup_file.write_text(content, encoding="utf-8") - print(f"Created backup: {backup_file}") - - new_content = "\n".join(filtered_lines) - if new_content and not new_content.endswith("\n"): - new_content += "\n" - - input_file.write_text(new_content, encoding="utf-8") - - print( - f"Modified {input_file}: removed {removed_count} configs, " - f"remaining {len(filtered_lines)} lines" - ) - else: - print(f"No configs to remove in {input_file}") - - except Exception as err: - print(f"Error processing {input_file}: {err}") - continue - - print("\nSummary:") - print(f"Files processed: {len(input_files)}") - print(f"Files modified: {files_modified}") - print(f"Total configs removed: {total_removed}") - - -def main(): - parser = argparse.ArgumentParser( - description="移除指定配置工具", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -使用示例: - python config_remover.py -i input.txt -r remove_configs.txt # 从单文件删除配置 - python config_remover.py -i file1.txt file2.txt -r remove_configs.txt # 从多文件删除配置 - python config_remover.py -i ./configs/ -r remove_configs.txt # 从文件夹删除配置 - python config_remover.py -i input.txt -r remove_configs.txt --backup # 有备份地处理 -注意: 所有操作会原地修改文件。使用 --backup 选项可创建备份文件。 - """, - ) - - parser.add_argument("-i", "--input", nargs="+", required=True, help="待处理的文件或目录") - parser.add_argument("-r", "--remove", required=True, help="包含要删除配置的文件") - parser.add_argument("-b", "--backup", action="store_true", default=False, help="创建备份文件") - - args = parser.parse_args() - - remove_configs_from_files(args.input, args.remove, args.backup) - - -if __name__ == "__main__": - main() diff --git a/tools/remove_lines_by_keyword.py b/tools/remove_lines_by_keyword.py new file mode 100644 index 00000000..f8e5c4d6 --- /dev/null +++ b/tools/remove_lines_by_keyword.py @@ -0,0 +1,117 @@ +# 按关键词删除配置行小工具 +# @author: cangtianhuang +# @date: 2026-06-11 + +from __future__ import annotations + +import argparse +import glob +import re +from pathlib import Path + +DEFAULT_FILE_PATTERN = "tester/api_config/monitor_config/accuracy/GPU/monitoring_configs_*.txt" +DEFAULT_KEYWORD_FILE = Path("kw.txt") + + +def load_keywords(keyword_file): + path = Path(keyword_file) + try: + return { + line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip() + } + except FileNotFoundError: + print(f"错误:关键字文件 {path} 不存在") + raise + + +def backup_file(file_path): + path = Path(file_path) + backup_path = path.with_suffix(path.suffix + ".backup") + backup_path.write_text(path.read_text(encoding="utf-8"), encoding="utf-8") + print(f"创建备份: {backup_path}") + + +def delete_lines_with_keywords(file_pattern, keyword_set, case_sensitive=True, backup=True): + target_files = sorted(glob.glob(file_pattern)) + if not target_files: + print(f"警告:未找到匹配 {file_pattern} 的文件") + return + + flags = 0 if case_sensitive else re.IGNORECASE + patterns = [re.compile(re.escape(keyword), flags) for keyword in keyword_set] + total_removed = 0 + + for file_path in target_files: + try: + path = Path(file_path) + lines = path.read_text(encoding="utf-8").splitlines(keepends=True) + original_count = len(lines) + new_lines = [ + line for line in lines if not any(pattern.search(line) for pattern in patterns) + ] + removed_count = original_count - len(new_lines) + total_removed += removed_count + + if removed_count > 0 and backup: + backup_file(path) + + path.write_text("".join(new_lines), encoding="utf-8") + print( + f"处理 {file_path}: 原始行数 {original_count}, " + f"删除 {removed_count} 行, 保留 {len(new_lines)} 行" + ) + except Exception as err: + print(f"处理文件 {file_path} 时出错: {err!s}") + + print(f"\n处理完成!共处理 {len(target_files)} 个文件, 总计删除 {total_removed} 行") + + +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="按关键词删除配置行工具") + parser.add_argument( + "--file-pattern", + "-p", + default=DEFAULT_FILE_PATTERN, + help="待处理文件 glob 匹配模式", + ) + parser.add_argument( + "--keyword-file", + "-k", + default=str(DEFAULT_KEYWORD_FILE), + help="关键词文件路径,每行一个关键词", + ) + parser.add_argument( + "--ignore-case", + action="store_true", + help="关键词匹配时忽略大小写", + ) + parser.add_argument( + "--no-backup", + action="store_false", + dest="backup", + help="不创建备份", + ) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + keywords = load_keywords(args.keyword_file) + if not keywords: + print("警告:关键字集为空,未执行任何操作") + return + + print( + f"加载 {len(keywords)} 个关键字: {', '.join(sorted(keywords)[:5])}" + + ("..." if len(keywords) > 5 else "") + ) + delete_lines_with_keywords( + args.file_pattern, + keywords, + case_sensitive=not args.ignore_case, + backup=args.backup, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/retest_remover.py b/tools/remove_retest_configs.py similarity index 67% rename from tools/retest_remover.py rename to tools/remove_retest_configs.py index 59e013f2..f2d8da1c 100644 --- a/tools/retest_remover.py +++ b/tools/remove_retest_configs.py @@ -1,12 +1,13 @@ # 重测配置移除小工具 # @author: cangtianhuang -# @date: 2025-11-11 +# @date: 2026-06-11 + from __future__ import annotations import argparse -import os from pathlib import Path +DEFAULT_LOG_PATH = Path("tester/api_config/test_log") LOG_PREFIXES = { "checkpoint": "checkpoint", "pass": "api_config_pass", @@ -23,42 +24,61 @@ "cuda_error": "api_config_cuda_error", "skip": "api_config_skip", } +DEFAULT_REMOVE_TYPES = ["timeout", "oom", "skip"] + + +def read_config_set(config_file): + with Path(config_file).open("r", encoding="utf-8") as f: + return {line.strip() for line in f if line.strip()} + +def write_config_set(config_file, configs): + with Path(config_file).open("w", encoding="utf-8") as f: + f.writelines(f"{line}\n" for line in sorted(configs)) -def remove_configs(log_path, to_remove): + +def backup_file(config_file): + path = Path(config_file) + backup_path = path.with_suffix(path.suffix + ".backup") + backup_path.write_text(path.read_text(encoding="utf-8"), encoding="utf-8") + print(f"Created backup: {backup_path}", flush=True) + + +def remove_configs(log_path=DEFAULT_LOG_PATH, to_remove=None, backup=True): log_path = Path(log_path) + if to_remove is None: + to_remove = DEFAULT_REMOVE_TYPES if not log_path.exists(): print(f"{log_path} not exists", flush=True) return - checkpoint_configs = set() checkpoint_file = log_path / "checkpoint.txt" if not checkpoint_file.exists(): print("No checkpoint file found", flush=True) return try: - with checkpoint_file.open("r") as f: - checkpoint_configs = {line.strip() for line in f if line.strip()} + checkpoint_configs = read_config_set(checkpoint_file) except Exception as err: print(f"Error reading {checkpoint_file}: {err}", flush=True) return print(f"Read {len(checkpoint_configs)} api configs from checkpoint", flush=True) retest_configs = set() + valid_remove_types = [] for log_type in to_remove: if log_type not in LOG_PREFIXES: print(f"Invalid log type: {log_type}", flush=True) continue + valid_remove_types.append(log_type) prefix = LOG_PREFIXES[log_type] log_file = log_path / f"{prefix}.txt" if not log_file.exists(): continue try: - with log_file.open("r") as f: - lines = {line.strip() for line in f if line.strip()} - retest_configs.update(lines) - print(f"Read {len(lines)} api configs from {log_file}", flush=True) + lines = read_config_set(log_file) + retest_configs.update(lines) + print(f"Read {len(lines)} api configs from {log_file}", flush=True) except Exception as err: print(f"Error reading {log_file}: {err}", flush=True) return @@ -72,31 +92,30 @@ def remove_configs(log_path, to_remove): ) print(f"checkpoint remaining: {len(checkpoint_configs)}", flush=True) try: - with checkpoint_file.open("w") as f: - f.writelines(f"{line}\n" for line in sorted(checkpoint_configs)) + if backup: + backup_file(checkpoint_file) + write_config_set(checkpoint_file, checkpoint_configs) except Exception as err: print(f"Error writing {checkpoint_file}: {err}", flush=True) return else: print("No retest configs found", flush=True) - for log_type in to_remove: - if log_type not in LOG_PREFIXES: - continue + for log_type in valid_remove_types: prefix = LOG_PREFIXES[log_type] log_file = log_path / f"{prefix}.txt" if not log_file.exists(): continue try: - os.remove(log_file) + if backup: + backup_file(log_file) + log_file.unlink() except Exception as err: print(f"Error removing {log_file}: {err}", flush=True) return -def main(): - default_log_path = "tester/api_config/test_log" - +def parse_args(argv=None): parser = argparse.ArgumentParser( description="重测配置移除小工具", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -122,18 +141,29 @@ def main(): parser.add_argument( "--path", "-p", - type=str, - default=default_log_path, + type=Path, + default=DEFAULT_LOG_PATH, help="测试日志目录路径", ) parser.add_argument( "--remove", "-r", nargs="+", + default=DEFAULT_REMOVE_TYPES, help="指定需要移除的配置", ) - args = parser.parse_args() - remove_configs(args.path, args.remove) + parser.add_argument( + "--no-backup", + action="store_false", + dest="backup", + help="不创建备份", + ) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + remove_configs(args.path, args.remove, args.backup) if __name__ == "__main__": diff --git a/tools/retrieve_configs.py b/tools/retrieve_config_set.py similarity index 65% rename from tools/retrieve_configs.py rename to tools/retrieve_config_set.py index 0c974b52..ff49d30b 100644 --- a/tools/retrieve_configs.py +++ b/tools/retrieve_config_set.py @@ -1,12 +1,16 @@ -# 召回配置小工具 +# 召回配置集合小工具 # @author: cangtianhuang -# @date: 2025-09-26 +# @date: 2026-06-11 + from __future__ import annotations import argparse import re from pathlib import Path +DEFAULT_INPUT_PATHS = ["tester/api_config/5_accuracy"] +DEFAULT_OUTPUT_FILE = Path("tester/api_config/api_config_retrieved.txt") + def collect_input_files(input_paths): files = [] @@ -20,17 +24,19 @@ def collect_input_files(input_paths): return files -def search_files(input_paths, keywords, output_file, exact_match=False): +def build_pattern(keywords, exact_match=False): + if exact_match: + return re.compile("|".join(rf"\b{re.escape(kw)}\b[^(\n]*\(" for kw in keywords)) + return re.compile("|".join(rf"^[^(\n]*{re.escape(kw)}[^(\n]*\(" for kw in keywords)) + + +def search_files(input_paths, keywords, output_file=DEFAULT_OUTPUT_FILE, exact_match=False): input_files = collect_input_files(input_paths) if not input_files: print("No valid input files found") return - if exact_match: - pattern = re.compile("|".join(rf"\b{re.escape(kw)}\b[^(\n]*\(" for kw in keywords)) - else: - pattern = re.compile("|".join(rf"^[^(\n]*{re.escape(kw)}[^(\n]*\(" for kw in keywords)) - + pattern = build_pattern(keywords, exact_match) configs = set() prefixes = set() count = 0 @@ -47,23 +53,21 @@ def search_files(input_paths, keywords, output_file, exact_match=False): paren_pos = line.find("(", match.start()) if paren_pos != -1: prefixes.add(line[:paren_pos].strip()) - except (UnicodeDecodeError, PermissionError) as e: - print(f"Error reading {input_file}: {e}") + except (UnicodeDecodeError, PermissionError) as err: + print(f"Error reading {input_file}: {err}") continue print(f"Retrieved {count} configs") print(f"Get {len(configs)} unique configs") print(f"APIs: {sorted(prefixes)}") - Path(output_file).write_text("\n".join(sorted(configs)) + "\n", encoding="utf-8") - print(f"Saved to {output_file}") - + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text("\n".join(sorted(configs)) + "\n", encoding="utf-8") + print(f"Saved to {output_path}") -def main(): - default_input = ["tester/api_config/5_accuracy"] - default_keywords = [] - default_output = "tester/api_config/api_config_retrieved.txt" +def parse_args(argv=None): parser = argparse.ArgumentParser( description="配置文件召回工具", formatter_class=argparse.RawTextHelpFormatter, @@ -77,7 +81,7 @@ def main(): "--input", "-i", nargs="+", - default=default_input, + default=DEFAULT_INPUT_PATHS, help="输入路径列表(支持文件或目录)", ) parser.add_argument( @@ -85,7 +89,6 @@ def main(): "-k", nargs="+", required=True, - default=default_keywords, help="关键词列表", ) parser.add_argument( @@ -94,10 +97,17 @@ def main(): action="store_true", help="启用精确匹配(匹配完整单词)", ) - parser.add_argument("--output", "-o", default=default_output, help="输出文件路径") + parser.add_argument( + "--output", + "-o", + default=str(DEFAULT_OUTPUT_FILE), + help="输出文件路径", + ) + return parser.parse_args(argv) - args = parser.parse_args() +def main(argv=None): + args = parse_args(argv) search_files(args.input, args.keywords, args.output, args.exact) diff --git a/tools/seek_skip_configs.py b/tools/seek_skip_configs.py new file mode 100644 index 00000000..db62d99a --- /dev/null +++ b/tools/seek_skip_configs.py @@ -0,0 +1,160 @@ +# 筛选 skip 配置小工具 +# @author: cangtianhuang +# @date: 2026-06-11 + +from __future__ import annotations + +import argparse +from pathlib import Path + +DEFAULT_TEST_LOG_PATH = Path("tester/api_config/test_log") +DEFAULT_OUTPUT_FILE_NAME = "api_config_skip.txt" +LOG_PREFIXES = { + "checkpoint": "checkpoint", + "pass": "api_config_pass", + "crash": "api_config_crash", + "oom": "api_config_oom", + "timeout": "api_config_timeout", + "paddle_error": "api_config_paddle_error", + "accuracy_error": "api_config_accuracy_error", + "accuracy_diff": "api_config_accuracy_diff", + "torch_error": "api_config_torch_error", + "paddle_to_torch_failed": "api_config_paddle_to_torch_failed", + "match_error": "api_config_match_error", + "numpy_error": "api_config_numpy_error", + "cuda_error": "api_config_cuda_error", +} + + +def read_config_set(config_file): + with Path(config_file).open("r", encoding="utf-8") as f: + return {line.strip() for line in f if line.strip()} + + +def write_config_set(config_file, configs): + with Path(config_file).open("w", encoding="utf-8") as f: + f.writelines(f"{line}\n" for line in sorted(configs)) + + +def backup_file(config_file): + path = Path(config_file) + backup_path = path.with_suffix(path.suffix + ".backup") + backup_path.write_text(path.read_text(encoding="utf-8"), encoding="utf-8") + print(f"Created backup: {backup_path}", flush=True) + + +def seek_skip_configs( + test_log_path=DEFAULT_TEST_LOG_PATH, output_file=None, update_checkpoint=True, backup=True +): + test_log_path = Path(test_log_path) + output_path = ( + Path(output_file) if output_file is not None else test_log_path / DEFAULT_OUTPUT_FILE_NAME + ) + + log_counts = {} + checkpoint_file = test_log_path / "checkpoint.txt" + if not checkpoint_file.exists(): + print("No checkpoint file found", flush=True) + return + + try: + checkpoint_configs = read_config_set(checkpoint_file) + log_counts["checkpoint"] = len(checkpoint_configs) + except Exception as err: + print(f"Error reading {checkpoint_file}: {err}", flush=True) + return + print(f"Read {len(checkpoint_configs)} api configs from checkpoint", flush=True) + + api_configs = checkpoint_configs.copy() + for log_type, prefix in LOG_PREFIXES.items(): + if log_type == "checkpoint": + continue + log_file = test_log_path / f"{prefix}.txt" + if not log_file.exists(): + continue + try: + lines = read_config_set(log_file) + api_configs -= lines + log_counts[log_type] = len(lines) + except Exception as err: + print(f"Error reading {log_file}: {err}", flush=True) + return + + if api_configs: + log_counts["skip"] = len(api_configs) + else: + print("No skip configs found", flush=True) + + for log_type, count in log_counts.items(): + print(f"{log_type}: {count}", flush=True) + + if not api_configs: + return + + try: + if output_path.exists() and backup: + backup_file(output_path) + write_config_set(output_path, api_configs) + except Exception as err: + print(f"Error writing to {output_path}: {err}", flush=True) + return + print(f"Write {len(api_configs)} skip api configs to {output_path}", flush=True) + + if not update_checkpoint: + return + + checkpoint_count = len(checkpoint_configs) + checkpoint_configs -= api_configs + print(f"checkpoint removed: {checkpoint_count - len(checkpoint_configs)}", flush=True) + print(f"checkpoint remaining: {len(checkpoint_configs)}", flush=True) + try: + if backup: + backup_file(checkpoint_file) + write_config_set(checkpoint_file, checkpoint_configs) + except Exception as err: + print(f"Error writing {checkpoint_file}: {err}", flush=True) + return + print( + f"Write {len(checkpoint_configs)} checkpoint api configs to {checkpoint_file}", + flush=True, + ) + + +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="筛选 skip 配置小工具") + parser.add_argument( + "--path", + "-p", + type=Path, + default=DEFAULT_TEST_LOG_PATH, + help="测试日志目录路径", + ) + parser.add_argument( + "--output", + "-o", + type=Path, + default=None, + help="skip 配置输出文件路径(默认写入日志目录 api_config_skip.txt)", + ) + parser.add_argument( + "--no-update-checkpoint", + action="store_false", + dest="update_checkpoint", + help="只输出 skip 配置,不修改 checkpoint.txt", + ) + parser.add_argument( + "--no-backup", + action="store_false", + dest="backup", + help="不创建备份", + ) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + seek_skip_configs(args.path, args.output, args.update_checkpoint, args.backup) + + +if __name__ == "__main__": + main() diff --git a/tools/shrink_large_configs.py b/tools/shrink_large_configs.py new file mode 100644 index 00000000..302d5f09 --- /dev/null +++ b/tools/shrink_large_configs.py @@ -0,0 +1,413 @@ +# 缩小大 Tensor 配置小工具 +# @author: cangtianhuang +# @date: 2026-06-11 + +""" +用法: + python shrink_large_configs.py \ + --error-logs [ ...] \ + --source-configs [ ...] \ + --output \ + --factor # 将元素数量缩小到原来的 1/N(如 4、8、16) + [--threshold ] # 只缩小元素数达到此阈值的 Tensor(默认 1048576,即 1M) + [--error-types crash oom timeout numpy_error] # 默认全部四种 + +说明: + 脚本从 error-logs 目录中读取 api_config_crash.txt / api_config_oom.txt / + api_config_timeout.txt / api_config_numpy_error.txt,收集出错的配置行。 + 再从 source-configs 中找到对应行,对其中 Tensor 的 shape 进行等比缩小: + - 每个维度除以 factor^(1/ndim),保持各维度比例尽量一致 + - 若某维度为 0 或 1,则不参与缩放 + - list[...] / tuple(...) 中与 Tensor shape 同步缩放的整数也同步处理 + - strides=[...] 按相同倍率缩放(跳过 0 和 1) + - -1(动态维度)保留不变 + +示例: + python shrink_large_configs.py \ + --error-logs workspace/0601_dsV4/test_log_1_dsv4_1M \ + workspace/0601_dsV4/test_log_5_v2_1M_fix_tofix \ + --source-configs workspace/0601_dsV4/dsv4_1M_tofix.txt \ + workspace/0601_dsV4/v2_1M_fix_tofix.txt \ + --output workspace/0601_dsV4/shrunk_4x.txt \ + --factor 4 +""" + +from __future__ import annotations + +import argparse +import math +import re +import sys +from pathlib import Path + +# --------------------------------------------------------------------------- +# Regex helpers +# --------------------------------------------------------------------------- + +# Match a full Tensor(...) token including optional strides/is_contiguous fields. +# Captures: +# group 1: paddle.Size([...]) content — the comma-sep integers +# group 2: dtype string (without quotes) +# group 3: extra fields like ,is_contiguous=False,strides=[1,2048] (may be empty) +_TENSOR_RE = re.compile(r'Tensor\(paddle\.Size\(\[([^\]]*)\]\),"([^"]+)"((?:,[^)]*)?)\)') + +# Match strides=[...] inside the extra-fields group +_STRIDES_RE = re.compile(r"(strides=\[)([^\]]*?)(\])") + +# Match list[...] or tuple(...) — shape-like integer sequences used as args +# e.g. list[1,1048576,262144,] or tuple(1,1048576,262144,) +_LIST_ARG_RE = re.compile(r"(list\[|tuple\()([^\]\)]*)([\]\)])") + +# Match a bare integer (possibly negative for -1 sentinel) +_INT_RE = re.compile(r"-?\d+") + + +# --------------------------------------------------------------------------- +# Shape / value scaling helpers +# --------------------------------------------------------------------------- + + +def _numel(dims: list[int]) -> int: + n = 1 + for d in dims: + n *= d + return n + + +def _scale_dims(dims: list[int], factor: float, threshold: int) -> list[int]: + """ + Scale a shape so its element count is reduced by approximately `factor`. + + Strategy: + 1. Compute which dims are "large" (> 1 and > threshold ^ (1/ndim)). + 2. Distribute the reduction evenly across large dims. + 3. -1 (dynamic) is kept as-is; 0 and 1 are kept as-is. + """ + if not dims: + return dims + numel = _numel(dims) + if numel < threshold: + return dims # already small enough + + new_dims = list(dims) + # Identify scalable indices (not 0, not 1, not -1) + scalable = [i for i, d in enumerate(dims) if d > 1 and d != -1] + if not scalable: + return new_dims + + # We want product(new[i] for i in scalable) ≈ product(old[i]) / factor + # Distribute reduction: each dim gets divided by factor^(1/len(scalable)) + per_dim_ratio = factor ** (1.0 / len(scalable)) + for i in scalable: + new_val = max(1, round(dims[i] / per_dim_ratio)) + new_dims[i] = new_val + + return new_dims + + +def _scale_single_value(v: int, scale: float) -> int: + """Scale a single integer value, keeping 0, 1, -1 unchanged.""" + if v in (0, 1, -1): + return v + return max(1, round(v / scale)) + + +def _dims_from_str(s: str) -> list[int]: + """Parse comma-separated integer string into list[int], ignoring empty.""" + result = [] + for tok in s.split(","): + tok = tok.strip() + if tok == "": + continue + try: + result.append(int(tok)) + except ValueError: + pass # skip non-integer tokens (shouldn't happen in shape) + return result + + +def _dims_to_str(dims: list[int]) -> str: + return ", ".join(str(d) for d in dims) + + +# --------------------------------------------------------------------------- +# Per-line transform +# --------------------------------------------------------------------------- + + +def _compute_scale_for_line(line: str, factor: float, threshold: int) -> float | None: + """ + Determine the actual scale factor to apply to this config line. + We look at the *largest* tensor numel; if it exceeds threshold, scale by factor. + Returns None if no tensor exceeds threshold (line needs no change). + """ + max_numel = 0 + for m in _TENSOR_RE.finditer(line): + dims = _dims_from_str(m.group(1)) + n = _numel(dims) + if n > max_numel: + max_numel = n + if max_numel < threshold: + return None + return factor + + +def _replace_strides(extra: str, old_dims: list[int], new_dims: list[int]) -> str: + """ + Scale strides= values inside the extra-fields string. + We compute a per-element scale based on old vs new tensor numel. + """ + # Build a mapping: old dim value → new dim value + # (for the simple proportional case used in actual strides) + # actual scale = old_numel / new_numel + old_numel = _numel([d for d in old_dims if d > 0]) + new_numel = _numel([d for d in new_dims if d > 0]) + if old_numel == 0 or new_numel == 0: + return extra + scale = old_numel / new_numel # > 1 (we are shrinking) + + def replace_strides_match(m: re.Match) -> str: + prefix, content, suffix = m.group(1), m.group(2), m.group(3) + vals = _dims_from_str(content) + new_vals = [_scale_single_value(v, scale) for v in vals] + return prefix + ", ".join(str(v) for v in new_vals) + suffix + + return _STRIDES_RE.sub(replace_strides_match, extra) + + +def _transform_line(line: str, factor: float, threshold: int) -> str: + """ + Return a new config line with all large Tensor shapes scaled down by `factor`. + Also scales list[...]/tuple(...) args and strides proportionally. + """ + line = line.rstrip("\n") + + actual_scale = _compute_scale_for_line(line, factor, threshold) + if actual_scale is None: + return line # nothing to do + + # --- Step 1: collect old→new dim mappings per Tensor occurrence ---------- + # We process Tensor(...) tokens and record (old_dims, new_dims) + tensor_replacements: list[tuple[str, str]] = [] # (old_token, new_token) + + for m in _TENSOR_RE.finditer(line): + old_token = m.group(0) + old_dims = _dims_from_str(m.group(1)) + dtype = m.group(2) + extra = m.group(3) # e.g. ",is_contiguous=False,strides=[1,2048]" + + new_dims = _scale_dims(old_dims, actual_scale, threshold) + new_shape_str = ", ".join(str(d) for d in new_dims) + + # Scale strides in extra field + new_extra = _replace_strides(extra, old_dims, new_dims) + + new_token = f'Tensor(paddle.Size([{new_shape_str}]),"{dtype}"{new_extra})' + if old_token != new_token: + tensor_replacements.append((old_token, new_token)) + + # Apply Tensor replacements (use plain string replace to preserve order) + new_line = line + for old_tok, new_tok in tensor_replacements: + new_line = new_line.replace(old_tok, new_tok, 1) + + # --- Step 2: scale list[...] / tuple(...) args that contain large ints --- + # We only scale values that look like shape dimensions (positive integers > 1) + # and are larger than a conservative threshold. + # Use a heuristic: scale any integer > sqrt(threshold) in list/tuple args. + list_threshold = max(4, int(math.sqrt(threshold))) + + def replace_list_arg(m: re.Match) -> str: + prefix, content, suffix = m.group(1), m.group(2), m.group(3) + tokens = content.split(",") + new_tokens = [] + for tok in tokens: + stripped = tok.strip() + if stripped == "" or stripped == "-1": + new_tokens.append(tok) + continue + try: + v = int(stripped) + if v > list_threshold: + new_v = _scale_single_value(v, actual_scale) + new_tokens.append(tok.replace(stripped, str(new_v), 1)) + else: + new_tokens.append(tok) + except ValueError: + new_tokens.append(tok) + return prefix + ",".join(new_tokens) + suffix + + new_line = _LIST_ARG_RE.sub(replace_list_arg, new_line) + + return new_line + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + +ERROR_FILE_NAMES = { + "crash": "api_config_crash.txt", + "oom": "api_config_oom.txt", + "timeout": "api_config_timeout.txt", + "numpy_error": "api_config_numpy_error.txt", +} + + +def collect_error_lines(log_dirs: list[Path], error_types: list[str]) -> set[str]: + """Collect all lines from the specified error files in each log directory.""" + error_lines: set[str] = set() + for log_dir in log_dirs: + for etype in error_types: + fname = ERROR_FILE_NAMES[etype] + fpath = log_dir / fname + if not fpath.exists(): + print(f"[skip] {fpath} not found", file=sys.stderr) + continue + count = 0 + with open(fpath) as f: + for raw_line in f: + line = raw_line.strip() + if line: + error_lines.add(line) + count += 1 + print(f"[info] {fpath}: {count} lines", file=sys.stderr) + print(f"[info] total unique error lines: {len(error_lines)}", file=sys.stderr) + return error_lines + + +def load_source_lines(source_configs: list[Path]) -> list[str]: + """Load all lines from source config files (preserving order, dedup by line content).""" + seen: set[str] = set() + lines: list[str] = [] + for cfg in source_configs: + if not cfg.exists(): + print(f"[warn] source config not found: {cfg}", file=sys.stderr) + continue + with open(cfg) as f: + for raw_line in f: + line = raw_line.strip() + if line and line not in seen: + seen.add(line) + lines.append(line) + print(f"[info] loaded {cfg}: {len(lines)} unique lines so far", file=sys.stderr) + return lines + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Shrink large Tensor shapes in API configs that caused crash/oom/timeout/numpy_error." + ) + parser.add_argument( + "--error-logs", + nargs="+", + required=True, + metavar="LOG_DIR", + help="One or more test_log_* directories containing api_config_*.txt error files.", + ) + parser.add_argument( + "--source-configs", + nargs="+", + required=True, + metavar="CONFIG_FILE", + help="Source config .txt files (e.g. dsv4_1M_tofix.txt) to read original configs from.", + ) + parser.add_argument( + "--output", + required=True, + metavar="OUTPUT_FILE", + help="Output .txt file with shrunken configs.", + ) + parser.add_argument( + "--factor", + type=float, + default=8.0, + help="Reduce element count by this factor (default: 8). " + "Use 4, 8, 16, etc. for 4×/8×/16× shrinkage.", + ) + parser.add_argument( + "--threshold", + type=int, + default=1048576, + help="Only shrink Tensors whose element count reaches this value (default: 1048576 = 1M).", + ) + parser.add_argument( + "--error-types", + nargs="+", + default=["crash", "oom", "timeout", "numpy_error"], + choices=["crash", "oom", "timeout", "numpy_error"], + help="Which error types to include (default: all four).", + ) + parser.add_argument( + "--keep-unchanged", + action="store_true", + help="Also write lines from source that matched error set but had no large tensors.", + ) + args = parser.parse_args() + + log_dirs = [Path(p) for p in args.error_logs] + source_cfgs = [Path(p) for p in args.source_configs] + output_path = Path(args.output) + + # 1. Collect error lines + error_lines = collect_error_lines(log_dirs, args.error_types) + if not error_lines: + print("[warn] No error lines found — output will be empty.", file=sys.stderr) + + # 2. Load source configs + source_lines = load_source_lines(source_cfgs) + if not source_lines: + print("[error] No source config lines loaded.", file=sys.stderr) + sys.exit(1) + + # 3. Filter source lines to those in error set + matched = [l for l in source_lines if l in error_lines] + print( + f"[info] source lines matching error set: {len(matched)} / {len(source_lines)}", + file=sys.stderr, + ) + + not_in_source = error_lines - set(source_lines) + if not_in_source: + print( + f"[warn] {len(not_in_source)} error lines not found in source configs " + f"(from other run / already shrunk?)", + file=sys.stderr, + ) + + # 4. Transform matched lines + output_lines: list[str] = [] + changed = 0 + unchanged = 0 + for line in matched: + new_line = _transform_line(line, args.factor, args.threshold) + if new_line != line: + output_lines.append(new_line) + changed += 1 + else: + unchanged += 1 + if args.keep_unchanged: + output_lines.append(line) + + print(f"[info] transformed (shape changed): {changed}", file=sys.stderr) + print(f"[info] no large tensors (skipped): {unchanged}", file=sys.stderr) + + # 5. Deduplicate while preserving order + seen_out: set[str] = set() + deduped: list[str] = [] + for l in output_lines: + if l not in seen_out: + seen_out.add(l) + deduped.append(l) + print(f"[info] output lines after dedup: {len(deduped)}", file=sys.stderr) + + # 6. Write output + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + for l in deduped: + f.write(l + "\n") + print(f"[done] written to {output_path}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/tools/skip_seeker.py b/tools/skip_seeker.py deleted file mode 100644 index 99fb0814..00000000 --- a/tools/skip_seeker.py +++ /dev/null @@ -1,90 +0,0 @@ -# 筛选 skip 配置小工具 -# @author: cangtianhuang -# @date: 2025-11-11 -from __future__ import annotations - -from pathlib import Path - -TEST_LOG_PATH = Path("tester/api_config/test_log") -OUTPUT_PATH = TEST_LOG_PATH / "api_config_skip.txt" - -LOG_PREFIXES = { - "checkpoint": "checkpoint", - "pass": "api_config_pass", - "crash": "api_config_crash", - "oom": "api_config_oom", - "timeout": "api_config_timeout", - "paddle_error": "api_config_paddle_error", - "accuracy_error": "api_config_accuracy_error", - "accuracy_diff": "api_config_accuracy_diff", - "torch_error": "api_config_torch_error", - "paddle_to_torch_failed": "api_config_paddle_to_torch_failed", - "match_error": "api_config_match_error", - "numpy_error": "api_config_numpy_error", - "cuda_error": "api_config_cuda_error", -} - -log_counts = {} -checkpoint_configs = set() -api_configs = set() -checkpoint_file = TEST_LOG_PATH / "checkpoint.txt" -if not checkpoint_file.exists(): - print("No checkpoint file found", flush=True) - exit(0) -try: - with checkpoint_file.open("r") as f: - checkpoint_configs = {line.strip() for line in f if line.strip()} - log_counts["checkpoint"] = len(checkpoint_configs) -except Exception as err: - print(f"Error reading {checkpoint_file}: {err}", flush=True) - exit(0) -print(f"Read {len(checkpoint_configs)} api configs from checkpoint", flush=True) - -api_configs = checkpoint_configs.copy() -for log_type, prefix in LOG_PREFIXES.items(): - if log_type == "checkpoint": - continue - log_file = TEST_LOG_PATH / f"{prefix}.txt" - if not log_file.exists(): - continue - try: - with log_file.open("r") as f: - lines = {line.strip() for line in f if line.strip()} - api_configs -= lines - log_counts[log_type] = len(lines) - except Exception as err: - print(f"Error reading {log_file}: {err}", flush=True) - exit(0) - -if api_configs: - log_counts["skip"] = len(api_configs) -else: - print("No skip configs found", flush=True) - -for log_type, count in log_counts.items(): - print(f"{log_type}: {count}", flush=True) - -if api_configs: - skip_file = OUTPUT_PATH - try: - with skip_file.open("w") as f: - f.writelines(f"{line}\n" for line in sorted(api_configs)) - except Exception as err: - print(f"Error writing to {skip_file}: {err}", flush=True) - exit(0) - print(f"Write {len(api_configs)} skip api configs to {skip_file}", flush=True) - - checkpoint_count = len(checkpoint_configs) - checkpoint_configs -= api_configs - print(f"checkpoint removed: {checkpoint_count - len(checkpoint_configs)}", flush=True) - print(f"checkpoint remaining: {len(checkpoint_configs)}", flush=True) - try: - with checkpoint_file.open("w") as f: - f.writelines(f"{line}\n" for line in sorted(checkpoint_configs)) - except Exception as err: - print(f"Error writing {checkpoint_file}: {err}", flush=True) - exit(0) - print( - f"Write {len(checkpoint_configs)} checkpoint api configs to {checkpoint_file}", - flush=True, - )