diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index fdb59b2d3..da6e70abe 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -262,6 +262,206 @@ def _generate_sqlite_write_code( ] +def _prepare_valid_calls( + calls: list[dict[str, Any]], + body_bytes: bytes, + body_lines: list[str], + line_byte_starts: list[int], +) -> list[dict[str, Any]]: + """Filter collected calls and pre-compute offsets used during replacement.""" + valid_calls = [call for call in calls if not call["in_lambda"] and not call.get("in_complex", False)] + if not valid_calls: + return [] + + valid_calls.sort(key=lambda call: call["start_byte"]) + for counter, call in enumerate(valid_calls, 1): + call["_counter"] = counter + call["_call_start_char"] = len(body_bytes[: call["start_byte"]].decode("utf8")) + call["_call_end_char"] = len(body_bytes[: call["end_byte"]].decode("utf8")) + if call["parent_type"] == "expression_statement": + call["_es_start_char"] = len(body_bytes[: call["es_start_byte"]].decode("utf8")) + call["_es_end_char"] = len(body_bytes[: call["es_end_byte"]].decode("utf8")) + line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts) + call["_line_idx"] = line_idx + call["_line_char_start"] = len(body_bytes[: line_byte_starts[line_idx]].decode("utf8")) + call["_line_indent"] = " " * (len(body_lines[line_idx]) - len(body_lines[line_idx].lstrip())) + return valid_calls + + +def _build_call_statements( + call: dict[str, Any], + iter_id: int, + call_counter: int, + precise_call_timing: bool, + target_return_type: str, +) -> dict[str, str]: + """Build the per-call statements shared by expression and embedded instrumentation.""" + is_void = target_return_type == "void" + var_name = f"_cf_result{iter_id}_{call_counter}" + receiver = call.get("receiver", "this") + arg_texts: list[str] = call.get("arg_texts", []) + cast_type = _infer_array_cast_type(call["_source_line"]) + if not cast_type and target_return_type and not is_void: + cast_type = target_return_type + var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name + + if precise_call_timing: + start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();" + end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();" + else: + start_stmt = f"_cf_start{iter_id} = System.nanoTime();" + end_stmt = f"_cf_end{iter_id} = System.nanoTime();" + + statements: dict[str, str] = { + "is_void": str(is_void), + "var_name": var_name, + "var_with_cast": var_with_cast, + "start_stmt": start_stmt, + "end_stmt": end_stmt, + } + if is_void: + bare_call_stmt = f"{call['full_call']};" + # For void methods, serialize the post-call state to capture side effects. + # We always serialize the arguments (which are mutated in place). + # For instance methods, we also include the receiver to capture object state changes. + # For static methods, the receiver is a class name (not a value), so args only. + is_static_call = receiver != "this" and receiver[:1].isupper() + serialize_parts: list[str] = [] + if not is_static_call: + serialize_parts.append(receiver) + serialize_parts.extend(arg_texts) + serialize_target = f"new Object[]{{{', '.join(serialize_parts)}}}" if serialize_parts else "new Object[]{}" + if precise_call_timing: + serialize_stmt = ( + f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize({serialize_target});" + ) + else: + serialize_stmt = f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize({serialize_target});" + statements["bare_call_stmt"] = bare_call_stmt + statements["serialize_stmt"] = serialize_stmt + return statements + + capture_stmt_with_decl = f"var {var_name} = {call['full_call']};" + capture_stmt_assign = f"{var_name} = {call['full_call']};" + if precise_call_timing: + serialize_stmt = ( + f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});" + ) + else: + serialize_stmt = f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});" + statements["capture_stmt_with_decl"] = capture_stmt_with_decl + statements["capture_stmt_assign"] = capture_stmt_assign + statements["serialize_stmt"] = serialize_stmt + return statements + + +def _build_expression_statement_replacement( + *, + call: dict[str, Any], + iter_id: int, + call_counter: int, + inv_id: str, + precise_call_timing: bool, + statements: dict[str, str], + class_name: str, + func_name: str, + test_method_name: str, +) -> str: + """Build the replacement text for expression-statement calls.""" + line_indent_str = call["_line_indent"] + is_void = statements["is_void"] == "True" + if precise_call_timing: + if is_void: + var_decls = [ + f"long _cf_end{iter_id}_{call_counter} = -1;", + f"long _cf_start{iter_id}_{call_counter} = 0;", + f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;", + ] + try_block = [ + "try {", + f" {statements['start_stmt']}", + f" {statements['bare_call_stmt']}", + f" {statements['end_stmt']}", + f" {statements['serialize_stmt']}", + ] + else: + var_decls = [ + f"Object {statements['var_name']} = null;", + f"long _cf_end{iter_id}_{call_counter} = -1;", + f"long _cf_start{iter_id}_{call_counter} = 0;", + f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;", + ] + try_block = [ + "try {", + f" {statements['start_stmt']}", + f" {statements['capture_stmt_assign']}", + f" {statements['end_stmt']}", + f" {statements['serialize_stmt']}", + ] + start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");' + finally_block = _generate_sqlite_write_code( + iter_id, + call_counter, + "", + class_name, + func_name, + test_method_name, + invocation_id=inv_id, + verification_type="void_state" if is_void else "function_call", + ) + all_lines = [*var_decls, start_marker, *try_block, *finally_block] + return all_lines[0] + "\n" + "\n".join(f"{line_indent_str}{repl_line}" for repl_line in all_lines[1:]) + + if is_void: + return f"{statements['bare_call_stmt']} {statements['serialize_stmt']}" + return f"{statements['capture_stmt_with_decl']} {statements['serialize_stmt']}" + + +def _build_embedded_call_prefix_lines( + *, + call: dict[str, Any], + iter_id: int, + call_counter: int, + inv_id: str, + precise_call_timing: bool, + statements: dict[str, str], + class_name: str, + func_name: str, + test_method_name: str, +) -> list[str]: + """Build the prefix lines inserted ahead of embedded call expressions.""" + line_indent_str = call["_line_indent"] + if precise_call_timing: + prefix_lines = [ + f"{line_indent_str}Object {statements['var_name']} = null;", + f"{line_indent_str}long _cf_end{iter_id}_{call_counter} = -1;", + f"{line_indent_str}long _cf_start{iter_id}_{call_counter} = 0;", + f"{line_indent_str}byte[] _cf_serializedResult{iter_id}_{call_counter} = null;", + f'{line_indent_str}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");', + f"{line_indent_str}try {{", + f"{line_indent_str} {statements['start_stmt']}", + f"{line_indent_str} {statements['capture_stmt_assign']}", + f"{line_indent_str} {statements['end_stmt']}", + f"{line_indent_str} {statements['serialize_stmt']}", + ] + prefix_lines.extend( + _generate_sqlite_write_code( + iter_id, + call_counter, + line_indent_str, + class_name, + func_name, + test_method_name, + invocation_id=inv_id, + ) + ) + return prefix_lines + return [ + f"{line_indent_str}{statements['capture_stmt_with_decl']}", + f"{line_indent_str}{statements['serialize_stmt']}", + ] + + def wrap_target_calls_with_treesitter( body_lines: list[str], func_name: str, @@ -272,16 +472,20 @@ def wrap_target_calls_with_treesitter( body_start_line: int = 0, target_return_type: str = "", ) -> tuple[list[str], int]: - """Replace target method calls in body_lines with capture + serialize using tree-sitter. + """Wrap matching Java calls in a method body using tree-sitter-aware rewrites. + + The wrapper walks the full method body as text so it can instrument calls that span + multiple lines, then applies replacements back-to-front to keep precomputed offsets + stable. It skips lambda-contained and complex-expression calls that would become + invalid if instrumentation were inserted inline. - Operates on the full body text with character offsets (not line-by-line) to correctly - handle calls that span multiple lines. Processes calls back-to-front so earlier offsets - remain valid after later replacements. + When ``precise_call_timing`` is enabled, each call gets its own start/end markers and + ``try/finally`` block so behavior-mode runs can persist one SQLite row per invocation + without losing data from later calls in the same test method. - For behavior mode (precise_call_timing=True), each call is wrapped in its own - try-finally block with immediate SQLite write to prevent data loss from multiple calls. + Returns: + A tuple of ``(wrapped_body_lines, call_count)``. - Returns (wrapped_body_lines, call_counter). """ from codeflash.languages.java.parser import get_java_analyzer @@ -311,22 +515,9 @@ def wrap_target_calls_with_treesitter( offset += len(line.encode("utf8")) + 1 # +1 for \n from join # Filter out lambda and complex-expression calls, sort by start_byte ascending for counter assignment - valid_calls = [c for c in calls if not c["in_lambda"] and not c.get("in_complex", False)] + valid_calls = _prepare_valid_calls(calls, body_bytes, body_lines, line_byte_starts) if not valid_calls: return list(body_lines), 0 - valid_calls.sort(key=lambda c: c["start_byte"]) - - # Pre-compute character offsets and line info for each call (before any text modifications) - for i, call in enumerate(valid_calls, 1): - call["_counter"] = i - call["_call_start_char"] = len(body_bytes[: call["start_byte"]].decode("utf8")) - call["_call_end_char"] = len(body_bytes[: call["end_byte"]].decode("utf8")) - if call["parent_type"] == "expression_statement": - call["_es_start_char"] = len(body_bytes[: call["es_start_byte"]].decode("utf8")) - call["_es_end_char"] = len(body_bytes[: call["es_end_byte"]].decode("utf8")) - line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts) - call["_line_idx"] = line_idx - call["_line_char_start"] = len(body_bytes[: line_byte_starts[line_idx]].decode("utf8")) # Process calls back-to-front so earlier character offsets stay valid for call in reversed(valid_calls): @@ -334,112 +525,26 @@ def wrap_target_calls_with_treesitter( line_idx = call["_line_idx"] call_absolute_line = body_start_line + line_idx + 1 inv_id = f"L{call_absolute_line}_{call_counter}" - - orig_line = body_lines[line_idx] - line_indent_str = " " * (len(orig_line) - len(orig_line.lstrip())) - - is_void = target_return_type == "void" - var_name = f"_cf_result{iter_id}_{call_counter}" - receiver = call.get("receiver", "this") - arg_texts: list[str] = call.get("arg_texts", []) - cast_type = _infer_array_cast_type(orig_line) - if not cast_type and target_return_type and not is_void: - cast_type = target_return_type - var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name - - if is_void: - bare_call_stmt = f"{call['full_call']};" - # For void methods, serialize the post-call state to capture side effects. - # We always serialize the arguments (which are mutated in place). - # For instance methods, we also include the receiver to capture object state changes. - # For static methods, the receiver is a class name (not a value), so args only. - is_static_call = receiver != "this" and receiver[:1].isupper() - parts: list[str] = [] - if not is_static_call: - parts.append(receiver) - parts.extend(arg_texts) - if parts: - serialize_target = f"new Object[]{{{', '.join(parts)}}}" - else: - serialize_target = "new Object[]{}" - if precise_call_timing: - serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize({serialize_target});" - start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();" - end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();" - else: - serialize_stmt = ( - f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize({serialize_target});" - ) - start_stmt = f"_cf_start{iter_id} = System.nanoTime();" - end_stmt = f"_cf_end{iter_id} = System.nanoTime();" - else: - capture_stmt_with_decl = f"var {var_name} = {call['full_call']};" - capture_stmt_assign = f"{var_name} = {call['full_call']};" - if precise_call_timing: - serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});" - start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();" - end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();" - else: - serialize_stmt = ( - f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});" - ) - start_stmt = f"_cf_start{iter_id} = System.nanoTime();" - end_stmt = f"_cf_end{iter_id} = System.nanoTime();" + call["_source_line"] = body_lines[line_idx] + statements = _build_call_statements(call, iter_id, call_counter, precise_call_timing, target_return_type) + is_void = statements["is_void"] == "True" if call["parent_type"] == "expression_statement": es_start = call["_es_start_char"] es_end = call["_es_end_char"] - if precise_call_timing: - # No indent on first line — body_text[:es_start] already has leading whitespace. - # Subsequent lines get line_indent_str. - if is_void: - var_decls = [ - f"long _cf_end{iter_id}_{call_counter} = -1;", - f"long _cf_start{iter_id}_{call_counter} = 0;", - f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;", - ] - else: - var_decls = [ - f"Object {var_name} = null;", - f"long _cf_end{iter_id}_{call_counter} = -1;", - f"long _cf_start{iter_id}_{call_counter} = 0;", - f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;", - ] - start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");' - if is_void: - try_block = [ - "try {", - f" {start_stmt}", - f" {bare_call_stmt}", - f" {end_stmt}", - f" {serialize_stmt}", - ] - else: - try_block = [ - "try {", - f" {start_stmt}", - f" {capture_stmt_assign}", - f" {end_stmt}", - f" {serialize_stmt}", - ] - finally_block = _generate_sqlite_write_code( - iter_id, - call_counter, - "", - class_name, - func_name, - test_method_name, - invocation_id=inv_id, - verification_type="void_state" if is_void else "function_call", - ) - all_lines = [*var_decls, start_marker, *try_block, *finally_block] - replacement = ( - all_lines[0] + "\n" + "\n".join(f"{line_indent_str}{repl_line}" for repl_line in all_lines[1:]) - ) - elif is_void: - replacement = f"{bare_call_stmt} {serialize_stmt}" - else: - replacement = f"{capture_stmt_with_decl} {serialize_stmt}" + # No indent on first line — body_text[:es_start] already has leading whitespace. + # Subsequent lines get the original line indentation. + replacement = _build_expression_statement_replacement( + call=call, + iter_id=iter_id, + call_counter=call_counter, + inv_id=inv_id, + precise_call_timing=precise_call_timing, + statements=statements, + class_name=class_name, + func_name=func_name, + test_method_name=test_method_name, + ) body_text = body_text[:es_start] + replacement + body_text[es_end:] else: if is_void: @@ -451,35 +556,20 @@ def wrap_target_calls_with_treesitter( call_start = call["_call_start_char"] call_end = call["_call_end_char"] line_char_start = call["_line_char_start"] - - if precise_call_timing: - prefix_lines = [ - f"{line_indent_str}Object {var_name} = null;", - f"{line_indent_str}long _cf_end{iter_id}_{call_counter} = -1;", - f"{line_indent_str}long _cf_start{iter_id}_{call_counter} = 0;", - f"{line_indent_str}byte[] _cf_serializedResult{iter_id}_{call_counter} = null;", - f'{line_indent_str}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");', - f"{line_indent_str}try {{", - f"{line_indent_str} {start_stmt}", - f"{line_indent_str} {capture_stmt_assign}", - f"{line_indent_str} {end_stmt}", - f"{line_indent_str} {serialize_stmt}", - ] - finally_lines = _generate_sqlite_write_code( - iter_id, - call_counter, - line_indent_str, - class_name, - func_name, - test_method_name, - invocation_id=inv_id, - ) - prefix_lines.extend(finally_lines) - else: - prefix_lines = [f"{line_indent_str}{capture_stmt_with_decl}", f"{line_indent_str}{serialize_stmt}"] + prefix_lines = _build_embedded_call_prefix_lines( + call=call, + iter_id=iter_id, + call_counter=call_counter, + inv_id=inv_id, + precise_call_timing=precise_call_timing, + statements=statements, + class_name=class_name, + func_name=func_name, + test_method_name=test_method_name, + ) # Step 1: Replace the call with the variable (at higher offset, safe to do first) - body_text = body_text[:call_start] + var_with_cast + body_text[call_end:] + body_text = body_text[:call_start] + statements["var_with_cast"] + body_text[call_end:] # Step 2: Insert prefix lines before the line containing the call (at lower offset) prefix_text = "\n".join(prefix_lines) + "\n" body_text = body_text[:line_char_start] + prefix_text + body_text[line_char_start:] @@ -672,24 +762,29 @@ def instrument_existing_test( test_path: Path | None = None, test_class_name: str | None = None, ) -> tuple[bool, str | None]: - """Inject profiling code into an existing test file. + """Rewrite an existing Java test class for Codeflash behavior or performance runs. + + The rewritten source always renames the test class to match the instrumented filename, + which keeps Java's class-name/file-name contract intact. From there the function adds + the mode-specific instrumentation: - For Java, this: - 1. Renames the class to match the new file name (Java requires class name = file name) - 2. For behavior mode: adds timing instrumentation that writes to SQLite - 3. For performance mode: adds timing instrumentation with stdout markers + - ``behavior`` wraps target calls with per-invocation timing and SQLite persistence. + - ``performance`` adds stdout timing markers used by the Java benchmarking runner. Args: - test_string: String to the test file. - call_positions: List of code positions where the function is called. - function_to_optimize: The function being optimized. - tests_project_root: Root directory of tests. - mode: Testing mode - "behavior" or "performance". - analyzer: Optional JavaAnalyzer instance. - output_class_suffix: Optional suffix for the renamed class. + test_string: Original test source code. + function_to_optimize: The target function under evaluation. + mode: Instrumentation mode, either ``"behavior"`` or ``"performance"``. + test_path: Path to the existing test file, used to derive the original class name. + test_class_name: Optional class name override for generated or synthetic test inputs. Returns: - Tuple of (success, modified_source). + A ``(success, modified_source)`` tuple. ``success`` is ``False`` only when + instrumentation cannot be produced, and ``modified_source`` contains the rewritten + Java test when successful. + + Raises: + ValueError: If neither ``test_path`` nor ``test_class_name`` is provided. """ source = test_string @@ -1325,7 +1420,7 @@ def build_instrumented_body( def create_benchmark_test( target_function: FunctionToOptimize, test_setup_code: str, invocation_code: str, iterations: int = 1000 ) -> str: - """Create a benchmark test for a function. + """Create a standalone JUnit benchmark harness for a Java target function. Args: target_function: The function to benchmark. @@ -1399,21 +1494,22 @@ def instrument_generated_java_test( mode: str, # "behavior" or "performance" function_to_optimize: FunctionToOptimize, ) -> str: - """Instrument a generated Java test for behavior or performance testing. + """Instrument an AI-generated Java test for the requested evaluation mode. - For generated tests (AI-generated), this function: - 1. Removes assertions and captures function return values (for regression testing) - 2. Renames the class to include mode suffix - 3. Adds timing instrumentation for performance mode + Generated tests follow the same class-renaming and timing conventions as existing + tests, but they may start from assertion-stripped source prepared by the caller. + Performance mode instruments the generated class directly, while behavior mode + reuses ``instrument_existing_test`` so both paths emit the same persistence format. Args: - test_code: The generated test source code. + test_code: Generated Java test source. function_name: Name of the function being tested. - qualified_name: Fully qualified name of the function. - mode: "behavior" for behavior capture or "performance" for timing. + qualified_name: Fully qualified name of the function under test. + mode: ``"behavior"`` for SQLite-backed capture or ``"performance"`` for timing markers. + function_to_optimize: Function metadata used when reusing the existing-test instrumentation path. Returns: - Instrumented test source code. + The instrumented Java test source. """ if not test_code or not test_code.strip():