Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions codeflash/languages/java/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL


def _run_java_with_graceful_timeout(
java_command: list[str], env: dict[str, str], timeout: int, stage_name: str
) -> None:
def _run_java_with_graceful_timeout(java_command: list[str], env: dict[str, str], timeout: int, stage_name: str) -> int:
"""Run a Java command with graceful timeout handling.

Sends SIGTERM first (allowing JFR dump and shutdown hooks to run),
then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds.

Returns the process exit code, or -1 if the process was killed due to timeout.
"""
if not timeout:
subprocess.run(java_command, env=env, check=False)
return
result = subprocess.run(java_command, env=env, check=False)
if result.returncode != 0:
logger.warning("%s exited with code %d", stage_name, result.returncode)
return result.returncode

import signal

Expand All @@ -46,6 +48,11 @@ def _run_java_with_graceful_timeout(
logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name)
proc.kill()
proc.wait()
return -1

if proc.returncode != 0:
logger.warning("%s exited with code %d", stage_name, proc.returncode)
return proc.returncode


# --add-opens flags needed for Kryo serialization on Java 16+
Expand Down Expand Up @@ -85,12 +92,23 @@ def trace(
combined_env = self.build_combined_env(jfr_file, config_path)

logger.info("Running combined JFR profiling + argument capture...")
_run_java_with_graceful_timeout(java_command, combined_env, timeout, "Combined tracing")
exit_code = _run_java_with_graceful_timeout(java_command, combined_env, timeout, "Combined tracing")

if not trace_db_path.exists():
msg = (
f"Combined tracing failed with exit code {exit_code} — trace database was not created at "
f"{trace_db_path}. Cannot proceed without trace data."
)
raise RuntimeError(msg)

if exit_code != 0:
logger.warning(
"Combined tracing exited with code %d but trace database was created — proceeding with partial data",
exit_code,
)

if not jfr_file.exists():
logger.warning("JFR file was not created at %s", jfr_file)
if not trace_db_path.exists():
logger.error("Trace database was not created at %s", trace_db_path)

return trace_db_path, jfr_file

Expand Down
151 changes: 151 additions & 0 deletions tests/test_languages/test_java/test_tracer_exit_codes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch

if TYPE_CHECKING:
from pathlib import Path

import pytest

from codeflash.languages.java.tracer import JavaTracer, _run_java_with_graceful_timeout


class TestRunJavaWithGracefulTimeout:
def test_returns_zero_on_success(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 0
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 0

def test_returns_nonzero_on_failure(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 1
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 1

def test_returns_exit_code_137_oom_kill(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 137
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 137

def test_timeout_path_returns_zero_on_success(self) -> None:
mock_proc = MagicMock()
mock_proc.returncode = 0
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == 0

def test_timeout_path_returns_nonzero_on_failure(self) -> None:
mock_proc = MagicMock()
mock_proc.returncode = 1
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == 1

def test_timeout_returns_negative_one(self) -> None:
import subprocess

mock_proc = MagicMock()
# First wait() times out, SIGTERM wait succeeds
mock_proc.wait.side_effect = [
subprocess.TimeoutExpired(cmd="java", timeout=60),
None, # SIGTERM wait succeeds
]
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == -1

def test_timeout_sends_sigterm_then_sigkill(self) -> None:
import signal
import subprocess

mock_proc = MagicMock()
# First wait() times out, SIGTERM wait also times out
mock_proc.wait.side_effect = [
subprocess.TimeoutExpired(cmd="java", timeout=60),
subprocess.TimeoutExpired(cmd="java", timeout=5),
None,
]
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")

assert rc == -1
mock_proc.send_signal.assert_called_once_with(signal.SIGTERM)
mock_proc.kill.assert_called_once()


class TestJavaTracerExitCodeHandling:
def test_success_with_trace_db_created(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()

def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
trace_db_path.write_bytes(b"fake-db")
return 0

with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_combined_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
):
trace_db, _jfr_file = tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)
assert trace_db == trace_db_path

def test_failure_without_trace_db_raises(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()

def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
return 1

with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_combined_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
pytest.raises(RuntimeError, match="Combined tracing failed with exit code 1"),
):
tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)

def test_nonzero_exit_with_trace_db_continues(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()

def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
trace_db_path.write_bytes(b"fake-db")
return 1

with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_combined_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
):
trace_db, _jfr_file = tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)
assert trace_db == trace_db_path

def test_timeout_without_trace_db_raises(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()

def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
return -1

with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_combined_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
pytest.raises(RuntimeError, match="Combined tracing failed with exit code -1"),
):
tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)
Loading