Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/benchmarker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .humaneval import HumanEvalBenchmarker
from .livecodebench import LCBBenchmarker
from .math500 import Math500Benchmarker
from .mbpp import MBPPBenchmarker
from .mmlu import MMLUBenchmarker
from .mmstar import MMStarBenchmarker
from .mtbench import MTBenchBenchmarker
Expand All @@ -25,5 +26,6 @@
"FinanceQABenchmarker",
"MMLUBenchmarker",
"LCBBenchmarker",
"MBPPBenchmarker",
"SimpleQABenchmarker",
]
134 changes: 134 additions & 0 deletions benchmarks/benchmarker/mbpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
MBPP benchmark evaluation script.
"""

import re
from typing import Any, Dict, List, Optional, Tuple

from datasets import load_dataset

from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_simple_sgl_function


def extract_code_from_output(output: str) -> Optional[str]:
"""Extract Python code from model output (markdown block or `def ...:`)."""
code_block_pattern = r"```(?:python)?\s*(.*?)\s*```"
match = re.search(code_block_pattern, output, re.DOTALL)
if match:
return match.group(1).strip()
def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)"
match = re.search(def_pattern, output, re.DOTALL)
if match:
return match.group(1).strip()
return output.strip() if output.strip() else None


def check_code_passes_tests(code: str, test_code: str) -> bool:
"""Run `code` then `test_code` (which contains assertions) in a fresh namespace.

Returns True iff no exception is raised. Simplified vs. the official MBPP
evaluation framework — we just want a pass/fail signal.
"""
try:
namespace: Dict[str, Any] = {}
exec(code, namespace)
exec(test_code, namespace)
Comment on lines +36 to +37

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Using exec to run model-generated code is a significant security risk and can cause the benchmark to hang indefinitely if the model produces an infinite loop. While this pattern exists in other benchmarks in this repository, it is highly recommended to execute the code in a separate process with a strict timeout (e.g., using the multiprocessing module) to ensure the benchmarker remains robust and responsive.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same pattern is used in humaneval.py (and the comment in check_code_passes_tests already calls out that this is a "simplified" evaluation, deferring to the official frameworks for rigorous pass@k). I'd prefer to keep mbpp consistent with humaneval

return True
except Exception:
return False
Comment on lines +15 to +40

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The functions extract_code_from_output and check_code_passes_tests are identical to those in humaneval.py. To improve maintainability and adhere to DRY (Don't Repeat Yourself) principles, these utility functions should be moved to a shared location like benchmarks/benchmarker/utils.py and imported in both benchmarkers.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considered this initially had it as a shared code_eval.py module imported by both humaneval and mbpp.py Pulled back to inline the helpers in mbpp.py to keep this PR scoped to MBPP only and avoid touching humaneval.py at all



def build_mbpp_prompt(text: str, test_list: List[str]) -> str:
"""Standard MBPP prompt format used in the original paper."""
tests = "\n".join(test_list)
return (
"You are an expert Python programmer, and here is your task: "
f"{text} Your code should pass these tests:\n\n{tests}\n\n[BEGIN]\n"
)


@BENCHMARKS.register("mbpp")
class MBPPBenchmarker(Benchmarker):
"""MBPP benchmark implementation (sanitized split)."""

def __init__(self, num_samples: Optional[int] = None):
super().__init__(num_samples, None)

def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, Any]]]]:
# Sanitized split is the standard one quoted in DFlash and most
# other speculative-decoding benchmarks.
dataset = load_dataset("google-research-datasets/mbpp", "sanitized")["test"]
questions: List[Dict[str, Any]] = []
labels: List[Optional[Dict[str, Any]]] = []

for idx, q in enumerate(dataset):
if self.num_samples is not None and idx >= self.num_samples:
break

# Sanitized split uses `prompt`; full split uses `text`.
text = q.get("prompt") or q.get("text") or ""
test_list = q.get("test_list", []) or []
# Sanitized split exposes `test_imports` (List[str]); full split
# exposes `test_setup_code` (single str). Combine both into one
# block so accuracy checks can run imports the tests rely on.
test_imports = q.get("test_imports", []) or []
test_setup_code = q.get("test_setup_code", "") or ""
test_setup = "\n".join([*test_imports, test_setup_code]).strip()

prompt = build_mbpp_prompt(text, test_list)
questions.append({"question": prompt})
labels.append(
{
"test_list": test_list,
"test_setup": test_setup,
"canonical_solution": q.get("code", ""),
}
)

return questions, labels

def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
# MBPP responses sometimes wrap in [DONE]; strip that and any leading [BEGIN].
if output is None:
return None
cleaned = output.strip()
cleaned = cleaned.split("[DONE]")[0].strip()
if cleaned.startswith("[BEGIN]"):
cleaned = cleaned[len("[BEGIN]") :].strip()
return extract_code_from_output(cleaned)

def compute_accuracy(
self, predictions: List[Any], labels: List[Any]
) -> Optional[float]:
if not labels:
return None
if all(label is None for label in labels):
return None

correct = 0
valid = 0
for pred, label in zip(predictions, labels):
if label is None or not isinstance(label, dict):
continue
valid += 1
if pred is None:
continue
test_setup = label.get("test_setup", "") or ""
test_list = label.get("test_list", []) or []
test_code = test_setup + "\n" + "\n".join(test_list)
if check_code_passes_tests(str(pred), test_code):
correct += 1
return correct / valid if valid > 0 else 0.0

def create_sgl_function(self):
return create_simple_sgl_function(
function_name="get_mbpp_answer",
answer_key="answer",
max_tokens=self.get_max_new_tokens(),
stop=["[DONE]"],
)

def get_max_new_tokens(self) -> int:
return 1024
Loading