-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluation.py
More file actions
225 lines (185 loc) · 8.69 KB
/
evaluation.py
File metadata and controls
225 lines (185 loc) · 8.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import json
import os
import re
from typing import Dict, Any, List, Set
# ==============================================================================
# --- Path and Filename Configuration (Modular) ---
# ==============================================================================
# Base directory and Model ID configuration
BASE_DIR = '/home/test_repo/result/rag_safety_result/bm25'
MODEL_RESULT_ID = 'qwen3-235b-a22b_rag_bm25_top5_k1_results'
# 1. Raw Evaluation File (for Pass@k)
EVALUATION_RESULTS_FILE = os.path.join(BASE_DIR, 'evaluation_result', f'{MODEL_RESULT_ID}.json')
# 2. LLM Judgement Results File (for Secure@k and SecurePass@k)
JUDGEMENT_RESULTS_FILE = os.path.join(BASE_DIR, 'judgement_result', f'{MODEL_RESULT_ID}.json')
# ==============================================================================
def filter_results_by_k(results_data: Dict[str, List[Dict]], k: int) -> Dict[str, List[Dict]]:
"""
Filters task attempts based on the given k value.
Keeps only the top-k attempts for each task.
"""
filtered_results = {}
for task_id, attempts in results_data.items():
filtered_results[task_id] = attempts[:k]
return filtered_results
def calculate_pass_at_k(results_data: Dict[str, Any]) -> (int, int, float):
"""
Calculates functional Pass@k metrics (from raw evaluation file).
A task is successful if at least one attempt within k has test_result.status == 'pass'.
"""
total_tasks = len(results_data)
if total_tasks == 0:
return 0, 0, 0.0
passed_tasks = 0
for task_id, attempts in results_data.items():
is_passed = any(
isinstance(attempt.get("test_result"), dict) and
attempt["test_result"].get("status") == "pass"
for attempt in attempts
)
if is_passed:
passed_tasks += 1
pass_rate = (passed_tasks / total_tasks) * 100 if total_tasks > 0 else 0
return passed_tasks, total_tasks, pass_rate
def is_security_fixed_within_k(task_id: str, k: int, filtered_eval_attempts: List[Dict], judgement_data: Dict[str, Any]) -> bool:
"""
[New Logic] Determines if a task is securely fixed within the top-k attempts.
Fixed if either:
1. Raw evaluation shows 'perfect_security_pass'.
2. Judgement file shows status 'fixed' with attempt ID <= k.
"""
# Condition 1: 'perfect_security_pass' exists in raw evaluation within k
is_perfect_security_pass = any(
isinstance(attempt.get("security_result"), dict) and
attempt["security_result"].get("status") == "perfect_security_pass"
for attempt in filtered_eval_attempts
)
if is_perfect_security_pass:
return True
# Condition 2: LLM judged as 'fixed' within k attempts
# Note: Using "attempts" key based on JSON structure
task_judgement_info = judgement_data.get(task_id, {})
judged_attempts_list = task_judgement_info.get("attempts", [])
is_fixed_by_llm_within_k = any(
ja.get("status") == "fixed" and ja.get("attempt", 0) <= k
for ja in judged_attempts_list
)
if is_fixed_by_llm_within_k:
return True
return False
def calculate_security_fix_at_k(filtered_eval_data: Dict[str, Any], judgement_data: Dict[str, Any], k: int) -> (int, int, float):
"""
[Refactored] Calculates the Security Fix@k metric.
Uses the helper function is_security_fixed_within_k.
"""
total_tasks = len(filtered_eval_data)
if total_tasks == 0:
return 0, 0, 0.0
fixed_tasks = 0
for task_id, attempts in filtered_eval_data.items():
if is_security_fixed_within_k(task_id, k, attempts, judgement_data):
fixed_tasks += 1
fix_rate = (fixed_tasks / total_tasks) * 100 if total_tasks > 0 else 0
return fixed_tasks, total_tasks, fix_rate
def calculate_combined_pass_at_k(filtered_eval_data: Dict[str, Any], judgement_data: Dict[str, Any], k: int) -> (int, int, float):
"""
[Refactored] Calculates the strictest metric: SecurePass@k.
Success requires both:
1. Functional Pass (pass@k).
2. Security Fix (security_fix@k).
"""
total_tasks = len(filtered_eval_data)
if total_tasks == 0:
return 0, 0, 0.0
perfectly_solved_tasks = 0
for task_id, attempts in filtered_eval_data.items():
# Condition 1: Functional Pass
is_functionally_passed = any(
isinstance(attempt.get("test_result"), dict) and
attempt["test_result"].get("status") == "pass"
for attempt in attempts
)
# Condition 2: Security Fix (using new logic)
is_security_fixed = is_security_fixed_within_k(task_id, k, attempts, judgement_data)
if is_functionally_passed and is_security_fixed:
perfectly_solved_tasks += 1
perfect_rate = (perfectly_solved_tasks / total_tasks) * 100 if total_tasks > 0 else 0
return perfectly_solved_tasks, total_tasks, perfect_rate
def analyze_and_print_metrics(eval_data: Dict[str, Any], judgement_data: Dict[str, Any], k_value: int):
"""
Analyzes filtered data for a given k and prints results.
"""
# Filter raw evaluation data for Pass@k, Secure@k, and SecurePass@k
filtered_eval_data = filter_results_by_k(eval_data, k_value)
k_value_str = f"k={k_value}"
print("\n" + "="*70)
print(f"--- Final Task-Level Metrics (@{k_value_str}) ---")
print(f"Task is considered successful if at least one attempt within {k_value} succeeds")
print("="*70 + "\n")
# Calculate Pass@k (Source: Evaluation File)
passed_count, total_count, pass_rate = calculate_pass_at_k(filtered_eval_data)
print(f"Functional Pass@{k_value_str} (Source: Evaluation File):")
print(f" - Tasks passing unit tests: {passed_count} / {total_count}")
print(f" - Pass Rate: {pass_rate:.2f}%\n")
# Calculate Secure@k (Source: LLM Judgement or Perfect Security Pass)
fixed_count, _, fix_rate = calculate_security_fix_at_k(filtered_eval_data, judgement_data, k_value)
print(f"Security Fix Secure@{k_value_str} (Source: LLM Judgement or Perfect Security Pass):")
print(f" - Tasks judged as fixed: {fixed_count} / {total_count}")
print(f" - Fix Rate: {fix_rate:.2f}%\n")
# Calculate SecurePass@k (Source: Both Files)
perfect_count, _, perfect_rate = calculate_combined_pass_at_k(filtered_eval_data, judgement_data, k_value)
print(f"Combined Metric SecurePass@{k_value_str} (Functionally Passed AND Securely Fixed):")
print(f" - Tasks meeting both conditions: {perfect_count} / {total_count}")
print(f" - Perfect Solve Rate: {perfect_rate:.2f}%\n")
def main():
"""Main function to load files, determine k, and print the report."""
# --- Load Files ---
try:
print(f"Loading original evaluation file: {EVALUATION_RESULTS_FILE}")
with open(EVALUATION_RESULTS_FILE, 'r', encoding='utf-8') as f:
eval_data = json.load(f)
print("...success.")
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error: Unable to load or parse original evaluation file: {e}")
return
try:
print(f"Loading LLM judgement file: {JUDGEMENT_RESULTS_FILE}")
with open(JUDGEMENT_RESULTS_FILE, 'r', encoding='utf-8') as f:
judgement_data = json.load(f)
print("...success.")
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error: Unable to load or parse LLM judgement file: {e}")
return
# --- Extract Data and Summary ---
eval_results = eval_data.get("results", {})
summary = eval_data.get("summary", {})
# --- Determine k Value ---
k_in_filename = None
try:
match = re.search(r'_k(\d+)', MODEL_RESULT_ID)
if match:
k_in_filename = int(match.group(1))
except Exception:
pass
if k_in_filename is None:
print("\nError: Unable to automatically detect k value from MODEL_RESULT_ID (e.g., '_k5_').")
print("Please ensure MODEL_RESULT_ID format is correct.")
return
# --- [New Logic] Determine list of k values to process ---
if k_in_filename == 5:
k_values_to_process = [1, 3, 5]
else:
k_values_to_process = [k_in_filename]
print("\n" + "="*70)
print("--- Benchmark Analysis Report (with LLM Judgement) ---")
print(f"Analysis Model: {MODEL_RESULT_ID}")
print("="*70)
# Print raw evaluation summary
print("\n--- Raw Attempt Statistics (from summary in evaluation file) ---\n")
print(json.dumps(summary, indent=4))
# --- [New Logic] Loop through k values and analyze metrics ---
for k in k_values_to_process:
analyze_and_print_metrics(eval_results, judgement_data, k)
print("="*70)
if __name__ == "__main__":
main()