diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index 39185b2..2380468 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -26,10 +26,11 @@ from pathlib import Path from typing import Any +from triton_kernel_agent.platform_config import get_platform +from utils.providers import get_model_provider + from .prompt_manager import PromptManager from .worker_util import _run_test_multiprocess -from utils.providers import get_model_provider -from triton_kernel_agent.platform_config import get_platform DISALLOWED_TORCH_PATTERNS = [ @@ -472,14 +473,16 @@ def run( # Subsequent rounds: only update kernel, test remains unchanged self._write_kernel(current_kernel) - violation = self._detect_pytorch_compute(current_kernel) + # Run verification + success, stdout, stderr, violation = self._single_verification_pass( + current_kernel + ) + if violation: - message = f"Disallowed PyTorch usage detected: {violation}" - self.logger.error(message) - self._log_round(round_num + 1, False, current_kernel, "", message) + self._log_round(round_num + 1, False, current_kernel, "", violation) error_info = { "stdout": "", - "stderr": message, + "stderr": violation, "history": list(self.history), } current_kernel = self._refine_kernel( @@ -487,13 +490,6 @@ def run( ) continue - # Run test - success, stdout, stderr = ( - self._run_test() - if os.getenv("KA_PROCESS_USE_SYS_EXECUTABLE", "1") == "1" - else _run_test_multiprocess(self.logger, self.workdir, self.test_file) - ) - # Log round self._log_round(round_num + 1, success, current_kernel, stdout, stderr) @@ -529,3 +525,114 @@ def run( "rounds": self.max_rounds, "history": list(self.history), } + + def _single_verification_pass( + self, kernel_code: str + ) -> tuple[bool, str, str, str | None]: + """ + Run a single verification pass on the kernel. + + Returns: + Tuple of (success, stdout, stderr, violation_message) + - violation_message is set if PyTorch usage detected, None otherwise + """ + violation = self._detect_pytorch_compute(kernel_code) + if violation: + message = f"Disallowed PyTorch usage detected: {violation}" + self.logger.error(message) + return False, "", message, message + + success, stdout, stderr = ( + self._run_test() + if os.getenv("KA_PROCESS_USE_SYS_EXECUTABLE", "1") == "1" + else _run_test_multiprocess(self.logger, self.workdir, self.test_file) + ) + + return success, stdout, stderr, None + + def verify_with_refinement( + self, + kernel_code: str, + test_code: str, + problem_description: str, + max_refine_attempts: int = 3, + ) -> tuple[bool, str, str]: + """ + Verify kernel correctness with refinement attempts. + + This is a simpler API for single-pass verification with refinement, + useful for optimization loops that manage their own iteration. + + Args: + kernel_code: Kernel code to verify + test_code: Test code for verification + problem_description: Problem description for refinement context + max_refine_attempts: Maximum refinement attempts if verification fails + + Returns: + Tuple of (success, final_kernel_code, error_feedback) + - success: Whether the kernel passed verification + - final_kernel_code: The verified (possibly refined) kernel + - error_feedback: Error message if failed, empty string if success + """ + current_kernel = kernel_code + + # Write files for testing + self._write_files(current_kernel, test_code) + + # Initial verification + success, stdout, stderr, violation = self._single_verification_pass( + current_kernel + ) + + if violation: + return False, current_kernel, violation + + if success: + self.logger.info("✅ Verification passed on first attempt") + return True, current_kernel, "" + + # Refinement loop + for attempt in range(1, max_refine_attempts + 1): + error_output = stderr if stderr.strip() else stdout + self.logger.info(f"Refinement attempt {attempt}/{max_refine_attempts}...") + + error_info = { + "stdout": stdout, + "stderr": stderr, + "error_type": ( + "compilation" + if "CompilationError" in error_output + or "SyntaxError" in error_output + else "runtime" + ), + } + + # Refine kernel + refined_kernel = self._refine_kernel( + current_kernel, error_info, problem_description, test_code + ) + + # Write and test refined kernel + self._write_kernel(refined_kernel) + success, stdout, stderr, violation = self._single_verification_pass( + refined_kernel + ) + + if violation: + current_kernel = refined_kernel + continue + + if success: + self.logger.info( + f"✅ Verification passed after refinement (attempt {attempt})" + ) + return True, refined_kernel, "" + + current_kernel = refined_kernel + + # All attempts exhausted + error_output = stderr if stderr.strip() else stdout + error_feedback = f"Verification failed after {max_refine_attempts} refinement attempts:\n{error_output[:2000]}" + self.logger.warning(f"❌ {error_feedback[:200]}...") + return False, current_kernel, error_feedback diff --git a/triton_kernel_agent/worker_util.py b/triton_kernel_agent/worker_util.py index 1a8bc22..ee5ef95 100644 --- a/triton_kernel_agent/worker_util.py +++ b/triton_kernel_agent/worker_util.py @@ -12,13 +12,88 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility functions for the verification worker.""" +"""Utility functions for the verification and optimization workers.""" import multiprocessing as mp import os +import re +from logging import Logger from pathlib import Path -from logging import Logger +# ------------------------ +# LLM Utilities +# ------------------------ + + +def _extract_history_usage_from_response( + response_text: str, + logger: Logger | None = None, +) -> dict[str, str | int | None] | None: + """ + Extract history usage metadata from LLM response. + + Looks for structured comments indicating how the LLM used history. + + Returns: + Dict with history_usage, based_on_attempt, evolution_rationale, or None. + """ + if not response_text: + return None + + result: dict[str, str | int | None] = {} + + # Look for history usage patterns + usage_match = re.search(r"History usage:\s*(\w+)", response_text, re.IGNORECASE) + if usage_match: + result["history_usage"] = usage_match.group(1) + + # Look for "based on attempt N" + attempt_match = re.search(r"based on attempt\s*(\d+)", response_text, re.IGNORECASE) + if attempt_match: + result["based_on_attempt"] = int(attempt_match.group(1)) + + # Look for evolution rationale + rationale_match = re.search( + r"Evolution rationale:\s*(.+?)(?:\n|$)", response_text, re.IGNORECASE + ) + if rationale_match: + result["evolution_rationale"] = rationale_match.group(1).strip() + + return result if result else None + + +# ------------------------ +# File I/O Utilities +# ------------------------ + + +def _write_kernel_file( + kernel_file: Path, kernel_code: str, logger: Logger | None = None +) -> None: + """Write kernel code to file.""" + kernel_file.write_text(kernel_code) + if logger: + logger.debug(f"Wrote kernel to {kernel_file}") + + +def _save_debug_file( + filepath: Path, + content: str, + logger: Logger | None = None, +) -> None: + """Save content to a file for debugging purposes.""" + try: + filepath.write_text(content) + if logger: + logger.debug(f"Saved debug file: {filepath}") + except Exception as e: + if logger: + logger.warning(f"Failed to save debug file {filepath}: {e}") + + +# ------------------------ +# Test Execution Utilities +# ------------------------ def _run_test_process(test_file: Path, workdir: Path, result_queue: mp.Queue) -> None: