-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_rag_codegen.py
More file actions
391 lines (322 loc) · 16.2 KB
/
run_rag_codegen.py
File metadata and controls
391 lines (322 loc) · 16.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
```python
import json
import requests
import re
import os
import logging
import time
# ==============================================================================
# --- User Configuration (RAG Version) ---
# ==============================================================================
# 1. Your Model API Key
MODEL_API_KEY = "KEY"
# 2. The model name you want to use
MODEL_NAME = "gpt-4.1" # <-- Switch here!
# 3. Number of requests per task (for pass@k metric)
NUM_REQUESTS_PER_TASK = 1 # k
# 4. Retry delay in seconds after a failed API call
RETRY_DELAY_SECONDS = 5
# 5. Maximum total attempts for a single task
MAX_ATTEMPTS_PER_TASK = 20
# 6. Path to the source JSON file containing 'task_instance'
SOURCE_JSON_FILE = 'task_instance.json'
# 7. Select your retriever type: 'bm25', 'dataflow', or 'dense'
RETRIEVER_TYPE = 'dataflow' # <-- Switch here!
# 8. [NEW] Set the snippet count limit for bm25/dense (not applicable to 'dataflow')
SNIPPET_COUNT_LIMIT = 5 # <-- Set the number of top-k snippets you want to use
# 9. Specify the corresponding file paths for different retrievers
BM25_JSON_FILE = './retriever_result/bm25_results.json'
ORACLE_RETRIEVER_JSON_FILE = './retriever_result/retriever_dataflow_results.json'
DENSE_RETRIEVER_JSON_FILE = './retriever_result/rlcoder_dense_retriever_results.json'
# 10. Set the base output directory, and dynamically generate the final output directory based on RETRIEVER_TYPE
BASE_OUTPUT_DIRECTORY = './result/rag_result'
OUTPUT_DIRECTORY = os.path.join(BASE_OUTPUT_DIRECTORY, RETRIEVER_TYPE) # <-- Dynamically create subdirectory path
# Build different filename suffixes based on the retriever type
if RETRIEVER_TYPE in ('bm25', 'dense'):
# Add an identifier for the top-k limit for bm25 and dense
filename_suffix = f"rag_{RETRIEVER_TYPE}_top{SNIPPET_COUNT_LIMIT}_k{NUM_REQUESTS_PER_TASK}"
else: # 'dataflow'
filename_suffix = f"rag_{RETRIEVER_TYPE}_k{NUM_REQUESTS_PER_TASK}"
# 11. The output filename is now dynamically generated based on the model name and retriever type
OUTPUT_JSONL_FILE = os.path.join(OUTPUT_DIRECTORY, f"{MODEL_NAME}_{filename_suffix}_results.jsonl")
LOG_FILE = os.path.join(OUTPUT_DIRECTORY, f"{MODEL_NAME}_{filename_suffix}_run.log")
# Model API Endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
# ==============================================================================
# --- RAG Prompt Construction Function ---
# (This section is unchanged)
# ==============================================================================
def build_rag_prompt(task_instance, retrieved_snippets):
"""
[One-Shot + RAG Version] Builds a prompt that includes an example, enforces a format,
and injects relevant code snippets for the current task.
"""
# --- Here is the well-designed One-Shot example (same as the original) ---
example_input = """/**
* Checks if a given string is null, empty, or consists only of white-space characters.
*
* @param str the String to check, may be null
* @return {@code true} if the String is null, empty, or whitespace-only
*/
public static boolean isBlank(String str)"""
example_output = """
/**
* Checks if a given string is null, empty, or consists only of white-space characters.
*
* @param str the String to check, may be null
* @return {@code true} if the String is null, empty, or whitespace-only
*/
public static boolean isBlank(String str) {
if (str == null || str.isEmpty()) {
return true;
}
for (int i = 0; i < str.length(); i++) {
if (!Character.isWhitespace(str.charAt(i))) {
return false;
}
}
return true;
}
"""
# --- [NEW] Format the relevant code snippets provided for the task ---
formatted_snippets = ""
if not retrieved_snippets:
# If no snippets are found, we can explicitly inform the model
formatted_snippets = "No relevant code snippets were provided."
else:
# Format and number each code snippet to make it clearer
for i, snippet in enumerate(retrieved_snippets, 1):
formatted_snippets += f"// --- Relative Code Snippet {i} ---\n```java\n{snippet.strip()}\n```\n\n"
# --- This is the final Prompt template ---
# A new RELEVANT CODE block is inserted between EXAMPLE and YOUR TASK
prompt = f"""
You are an expert Java programmer acting as a code generation engine. Your task is to implement the body of a single Java function based on the provided specification.
### INSTRUCTIONS:
1. Your output MUST strictly follow the format and structure of the example below.
2. Generate ONLY ONE complete function block.
3. **DO NOT** define any helper methods, private functions, inner classes, or a `main` method.
4. Your response must be ONLY the Java code, wrapped in ```java. Do not add any explanation.
---
### EXAMPLE
#### Function to Implement (Example):
```java
{example_input}
```
#### Expected Output (Example):
{example_output}
---
### RELEVANT CODE FOR YOUR TASK
Here are some relevant code snippets that might help you with your task. Use them to understand the context and patterns.
{formatted_snippets.strip()}
---
### YOUR TASK
#### Function to Implement (Your Task):
```java
{task_instance}```
Now, generate the output for YOUR TASK.
"""
return prompt
# ==============================================================================
# --- Helper Functions ---
# (This section is unchanged)
# ==============================================================================
def setup_logging(log_file):
"""Configures logging."""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers():
logger.handlers.clear()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
def load_json_file(filepath):
"""Generic function to load and parse a JSON file."""
logging.info(f"Loading file: {filepath}...")
try:
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
logging.error(f"File not found '{filepath}'. The script will terminate.")
return None
except json.JSONDecodeError:
logging.error(f"JSON file '{filepath}' has an invalid format. The script will terminate.", exc_info=True)
return None
def load_and_process_dataflow_retriever_file(filepath):
"""
Loads and processes the retriever_dataflow_results.json file.
The value in this file is a list of lists (List[List[str]]). This function extracts the first list of code snippets for each task_id.
Returns a dictionary in the format {task_id: [snippet_1, snippet_2, ...]}.
"""
logging.info(f"Loading and processing Oracle Retriever file: {filepath}...")
dataflow_data = load_json_file(filepath)
if not dataflow_data:
return None
processed_data = {}
for task_id, snippet_groups in dataflow_data.items():
if snippet_groups and isinstance(snippet_groups, list) and len(snippet_groups) > 0:
processed_data[task_id] = snippet_groups
else:
logging.warning(f"ID: {task_id} - No valid code snippet groups found in the Oracle file, will use an empty list.")
processed_data[task_id] = []
logging.info(f"Successfully loaded and processed Oracle data for {len(processed_data)} tasks from '{filepath}'.")
return processed_data
def load_source_tasks_map(filepath):
"""Loads the source data file and creates a mapping from id to task_instance."""
source_data = load_json_file(filepath)
if not source_data:
return None
task_instance_map = {}
for repo_block in source_data:
for analysis_result in repo_block.get('analysis_results', []):
task_id = analysis_result.get('id')
task_instance = analysis_result.get('primary_analysis', {}).get('task_instance')
if task_id and task_instance:
task_instance_map[task_id] = task_instance
logging.info(f"Successfully loaded source data for {len(task_instance_map)} tasks from '{filepath}'.")
return task_instance_map
def get_processed_ids(filepath):
"""Reads the output file to get all processed IDs."""
processed_ids = set()
if not os.path.exists(filepath):
return processed_ids
logging.info(f"Checking for processed entries in: {filepath}...")
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
try:
processed_ids.add(json.loads(line)['id'])
except (json.JSONDecodeError, KeyError):
continue
logging.info(f"Found {len(processed_ids)} processed entries.")
return processed_ids
def call_model_api(prompt, model_name):
"""Sends a request to the Model API and returns the response text."""
headers = {"Authorization": f"Bearer {MODEL_API_KEY}", "Content-Type": "application/json"}
payload = {"model": model_name, "messages": [{"role": "user", "content": prompt}], "temperature": 0.7, "max_tokens": 4096, "stream": False}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=120)
response.raise_for_status()
return response.json()['choices']['message']['content']
except requests.exceptions.RequestException as e:
logging.error(f"API request failed: {e}", exc_info=True)
return None
except (KeyError, IndexError) as e:
logging.error(f"API response format is incorrect. Response content: {response.text}", exc_info=True)
return None
def extract_code_from_response(response_text):
"""
Uses multiple strategies to extract code blocks from the model's response.
Strategy Priority:
1. Look for content wrapped in ``` (Markdown).
2. Assume the entire response is code.
"""
if not response_text:
return None
# Look for code wrapped in ``` (Markdown)
match = re.search(r"```(?:java\n)?(.*?)```", response_text, re.DOTALL)
if match:
return match.group(1).strip()
# Strategy 2: If none of the above match, assume the entire response is code (removing possible wrappers)
# This is a fallback strategy that can handle cases where the model directly returns pure code
cleaned_response = response_text.strip()
# Avoid returning empty or invalid responses
if cleaned_response.startswith('{') or cleaned_response.startswith('public') or cleaned_response.startswith('String'):
return cleaned_response
# If it's still not identifiable after all attempts, return None
return None
def append_to_jsonl(filepath, data_dict):
"""Appends a dictionary to a .jsonl file."""
try:
with open(filepath, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
except IOError as e:
logging.error(f"Failed to write to file '{filepath}'.", exc_info=True)
# ==============================================================================
# --- RAG Script Main Execution Function ---
# ==============================================================================
def main():
"""Main execution function"""
os.makedirs(OUTPUT_DIRECTORY, exist_ok=True)
setup_logging(LOG_FILE)
if "KEY" in MODEL_API_KEY or "YOUR_MODEL_API_KEY" in MODEL_API_KEY:
logging.error("Please configure your MODEL_API_KEY at the top of the script.")
return
task_instance_map = load_source_tasks_map(SOURCE_JSON_FILE)
if not task_instance_map:
return
retrieval_data = None
source_retriever_file = ""
if RETRIEVER_TYPE in ('bm25', 'dense'):
source_retriever_file = BM25_JSON_FILE if RETRIEVER_TYPE == 'bm25' else DENSE_RETRIEVER_JSON_FILE
retrieval_data = load_json_file(source_retriever_file)
elif RETRIEVER_TYPE == 'dataflow':
source_retriever_file = ORACLE_RETRIEVER_JSON_FILE
retrieval_data = load_and_process_dataflow_retriever_file(source_retriever_file)
else:
logging.error(f"Invalid RETRIEVER_TYPE: '{RETRIEVER_TYPE}'. Please choose 'bm25', 'dataflow', or 'dense' in the configuration.")
return
if not retrieval_data:
logging.warning(f"Failed to load retrieval data from '{source_retriever_file}' or the file is empty. Will proceed with all tasks without using any relevant code.")
retrieval_data = {}
tasks_to_process = list(task_instance_map.items())
total_tasks = len(tasks_to_process)
logging.info(f"Will process all {total_tasks} tasks from '{SOURCE_JSON_FILE}'.")
logging.info(f"Using '{source_retriever_file}' (type: {RETRIEVER_TYPE}) to provide relevant code snippets.")
if RETRIEVER_TYPE in ('bm25', 'dense'):
logging.info(f"A maximum of {SNIPPET_COUNT_LIMIT} code snippets will be used per task.")
logging.info(f"Will ensure {NUM_REQUESTS_PER_TASK} results are generated for each task.")
logging.info(f"Results will be saved to: {OUTPUT_JSONL_FILE}")
processed_ids = get_processed_ids(OUTPUT_JSONL_FILE)
for i, (task_id, task_instance) in enumerate(tasks_to_process, 1):
logging.info(f"--- Processing task {i}/{total_tasks} (ID: {task_id}) ---")
if task_id in processed_ids:
logging.info(f"ID: {task_id} - [SKIPPING] This task has already been processed.")
continue
snippets = retrieval_data.get(task_id, [])
# [Core Modification] Apply the count limit if the type is bm25 or dense
if RETRIEVER_TYPE in ('bm25', 'dense'):
original_snippet_count = len(snippets)
if original_snippet_count > 0:
snippets = snippets[:SNIPPET_COUNT_LIMIT]
logging.info(f"ID: {task_id} - Applying snippet limit: using {len(snippets)}/{original_snippet_count} snippets.")
if not snippets:
logging.info(f"ID: {task_id} - No relevant code snippets found in '{source_retriever_file}' or none left after limit, will proceed with generation.")
prompt = build_rag_prompt(task_instance, snippets)
generated_codes_list = []
total_attempts = 0
while len(generated_codes_list) < NUM_REQUESTS_PER_TASK:
total_attempts += 1
if total_attempts > MAX_ATTEMPTS_PER_TASK:
logging.critical(f"ID: {task_id} - [ABORTING] The total number of attempts ({total_attempts}) for the task has been exceeded.")
break
current_progress = len(generated_codes_list) + 1
logging.info(f"ID: {task_id} - Getting result {current_progress}/{NUM_REQUESTS_PER_TASK} (Total attempts: {total_attempts})...")
model_response = call_model_api(prompt, MODEL_NAME)
if model_response:
generated_code = extract_code_from_response(model_response)
if generated_code:
generated_codes_list.append(generated_code)
logging.info(f"ID: {task_id} - Successfully obtained result {current_progress}/{NUM_REQUESTS_PER_TASK}.")
else:
logging.warning(f"ID: {task_id} - Could not find a code block in the response. Retrying in {RETRY_DELAY_SECONDS} seconds...")
time.sleep(RETRY_DELAY_SECONDS)
else:
logging.error(f"ID: {task_id} - API call failed. Retrying in {RETRY_DELAY_SECONDS} seconds...")
time.sleep(RETRY_DELAY_SECONDS)
if len(generated_codes_list) == NUM_REQUESTS_PER_TASK:
output_record = {
"id": task_id,
"model": MODEL_NAME,
"generated_codes": generated_codes_list
}
append_to_jsonl(OUTPUT_JSONL_FILE, output_record)
logging.info(f"ID: {task_id} - [SUCCESS] Task complete, {len(generated_codes_list)} codes have been saved.")
else:
logging.error(f"ID: {task_id} - [FAILURE] Task terminated, ultimately only obtained {len(generated_codes_list)}/{NUM_REQUESTS_PER_TASK} results.")
logging.info(f"All RAG tasks (type: {RETRIEVER_TYPE}) have been processed!")
if __name__ == "__main__":
main()
```