Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
07a3268
NCU profiling wrapper generation and execution
Jan 7, 2026
3c4b124
Refactor profiling components and add kernel_perf_util
Jan 7, 2026
11f4e79
Refactor profiling components and add kernel_perf_util
Jan 7, 2026
251f419
Refactor profiling components and add kernel_perf_util
Jan 7, 2026
b789660
update directory name and add package in pyproject
Jan 7, 2026
4d35d57
Remove kernel_perf_util directory
Jan 7, 2026
d871678
move gpu spec.py to future PR and fix import
Jan 7, 2026
db0c754
Add copyright header
Jan 7, 2026
cd29759
fix ruff
Jan 7, 2026
bbfa6cd
address previous comments
Jan 13, 2026
543453a
fix ruff
Jan 13, 2026
706c9cc
Add unified benchmarking module for kernel performance measurement
Jan 8, 2026
4febdd6
Introducing benchmarking infra for kernel performance
Jan 8, 2026
d92a7b7
fix ruff
Jan 9, 2026
2994315
fix ruff
Jan 9, 2026
1378fc3
address comments
Jan 14, 2026
45fec80
Diagnose module - prompt constructor
Jan 11, 2026
b640cde
Refactors the diagnose_prompt module into a modular architecture
Jan 13, 2026
e952123
fix diff issue
Jan 13, 2026
e7ba29a
fix ruff issue
Jan 13, 2026
72ac4d1
fix
Jan 15, 2026
e2c599e
fix ruff
Jan 15, 2026
8ab907c
Merge branch 'main' into kaiming/opt_component_3
kaiming-cheng Jan 27, 2026
e350802
fix gpu_spec based on feedback and remove judger_prompt for future PR
Jan 29, 2026
8541299
Remove judger_prompts.py changes from this PR
Jan 29, 2026
313a84f
Merge branch 'main' into kaiming/opt_component_3
kaiming-cheng Jan 29, 2026
9e608ac
Update gpu_specs_database.py
kaiming-cheng Jan 29, 2026
f3220e1
address feedback
Jan 29, 2026
4443f33
ruff fix
Jan 29, 2026
b12b138
Merge branch 'main' into kaiming/opt_component_3
kaiming-cheng Jan 29, 2026
31d0d70
introduce roofline analyzer
Jan 29, 2026
3c607b5
update doc string in init and fix ncu_roofline
Jan 29, 2026
1aad0ad
introduce judger prompt
Jan 31, 2026
c0bd09c
add optimization template
Jan 31, 2026
56fba36
update prompt manager
Jan 31, 2026
c4a3641
Introduce bootleneck_analyzer
Feb 2, 2026
c2223d8
updated orchestrator
Feb 2, 2026
b3ae632
Introduce end-to-end worker integration
Feb 2, 2026
6c2ccc1
fix ruff
Feb 2, 2026
7a408b5
Merge branch 'main' into kaiming/worker_clean
kaiming-cheng Feb 17, 2026
3d8992c
fix
Feb 17, 2026
95d9bee
fix
Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 121 additions & 14 deletions triton_kernel_agent/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
from pathlib import Path
from typing import Any

from triton_kernel_agent.platform_config import get_platform
from utils.providers import get_model_provider

from .prompt_manager import PromptManager
from .worker_util import _run_test_multiprocess
from utils.providers import get_model_provider
from triton_kernel_agent.platform_config import get_platform


DISALLOWED_TORCH_PATTERNS = [
Expand Down Expand Up @@ -472,28 +473,23 @@ def run(
# Subsequent rounds: only update kernel, test remains unchanged
self._write_kernel(current_kernel)

violation = self._detect_pytorch_compute(current_kernel)
# Run verification
success, stdout, stderr, violation = self._single_verification_pass(
current_kernel
)

if violation:
message = f"Disallowed PyTorch usage detected: {violation}"
self.logger.error(message)
self._log_round(round_num + 1, False, current_kernel, "", message)
self._log_round(round_num + 1, False, current_kernel, "", violation)
error_info = {
"stdout": "",
"stderr": message,
"stderr": violation,
"history": list(self.history),
}
current_kernel = self._refine_kernel(
current_kernel, error_info, problem_description, test_code
)
continue

# Run test
success, stdout, stderr = (
self._run_test()
if os.getenv("KA_PROCESS_USE_SYS_EXECUTABLE", "1") == "1"
else _run_test_multiprocess(self.logger, self.workdir, self.test_file)
)

# Log round
self._log_round(round_num + 1, success, current_kernel, stdout, stderr)

Expand Down Expand Up @@ -529,3 +525,114 @@ def run(
"rounds": self.max_rounds,
"history": list(self.history),
}

def _single_verification_pass(
self, kernel_code: str
) -> tuple[bool, str, str, str | None]:
"""
Run a single verification pass on the kernel.

Returns:
Tuple of (success, stdout, stderr, violation_message)
- violation_message is set if PyTorch usage detected, None otherwise
"""
violation = self._detect_pytorch_compute(kernel_code)
if violation:
message = f"Disallowed PyTorch usage detected: {violation}"
self.logger.error(message)
return False, "", message, message

success, stdout, stderr = (
self._run_test()
if os.getenv("KA_PROCESS_USE_SYS_EXECUTABLE", "1") == "1"
else _run_test_multiprocess(self.logger, self.workdir, self.test_file)
)

return success, stdout, stderr, None

def verify_with_refinement(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this used?

self,
kernel_code: str,
test_code: str,
problem_description: str,
max_refine_attempts: int = 3,
) -> tuple[bool, str, str]:
"""
Verify kernel correctness with refinement attempts.

This is a simpler API for single-pass verification with refinement,
useful for optimization loops that manage their own iteration.

Args:
kernel_code: Kernel code to verify
test_code: Test code for verification
problem_description: Problem description for refinement context
max_refine_attempts: Maximum refinement attempts if verification fails

Returns:
Tuple of (success, final_kernel_code, error_feedback)
- success: Whether the kernel passed verification
- final_kernel_code: The verified (possibly refined) kernel
- error_feedback: Error message if failed, empty string if success
"""
current_kernel = kernel_code

# Write files for testing
self._write_files(current_kernel, test_code)

# Initial verification
success, stdout, stderr, violation = self._single_verification_pass(
current_kernel
)

if violation:
return False, current_kernel, violation

if success:
self.logger.info("✅ Verification passed on first attempt")
return True, current_kernel, ""

# Refinement loop
for attempt in range(1, max_refine_attempts + 1):
error_output = stderr if stderr.strip() else stdout
self.logger.info(f"Refinement attempt {attempt}/{max_refine_attempts}...")

error_info = {
"stdout": stdout,
"stderr": stderr,
"error_type": (
"compilation"
if "CompilationError" in error_output
or "SyntaxError" in error_output
else "runtime"
),
}

# Refine kernel
refined_kernel = self._refine_kernel(
current_kernel, error_info, problem_description, test_code
)

# Write and test refined kernel
self._write_kernel(refined_kernel)
success, stdout, stderr, violation = self._single_verification_pass(
refined_kernel
)

if violation:
current_kernel = refined_kernel
continue

if success:
self.logger.info(
f"✅ Verification passed after refinement (attempt {attempt})"
)
return True, refined_kernel, ""

current_kernel = refined_kernel

# All attempts exhausted
error_output = stderr if stderr.strip() else stdout
error_feedback = f"Verification failed after {max_refine_attempts} refinement attempts:\n{error_output[:2000]}"
self.logger.warning(f"❌ {error_feedback[:200]}...")
return False, current_kernel, error_feedback
79 changes: 77 additions & 2 deletions triton_kernel_agent/worker_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,88 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for the verification worker."""
"""Utility functions for the verification and optimization workers."""

import multiprocessing as mp
import os
import re
from logging import Logger
from pathlib import Path

from logging import Logger
# ------------------------
# LLM Utilities
# ------------------------


def _extract_history_usage_from_response(
response_text: str,
logger: Logger | None = None,
) -> dict[str, str | int | None] | None:
"""
Extract history usage metadata from LLM response.

Looks for structured comments indicating how the LLM used history.

Returns:
Dict with history_usage, based_on_attempt, evolution_rationale, or None.
"""
if not response_text:
return None

result: dict[str, str | int | None] = {}

# Look for history usage patterns
usage_match = re.search(r"History usage:\s*(\w+)", response_text, re.IGNORECASE)
if usage_match:
result["history_usage"] = usage_match.group(1)

# Look for "based on attempt N"
attempt_match = re.search(r"based on attempt\s*(\d+)", response_text, re.IGNORECASE)
if attempt_match:
result["based_on_attempt"] = int(attempt_match.group(1))

# Look for evolution rationale
rationale_match = re.search(
r"Evolution rationale:\s*(.+?)(?:\n|$)", response_text, re.IGNORECASE
)
if rationale_match:
result["evolution_rationale"] = rationale_match.group(1).strip()

return result if result else None


# ------------------------
# File I/O Utilities
# ------------------------


def _write_kernel_file(
kernel_file: Path, kernel_code: str, logger: Logger | None = None
) -> None:
"""Write kernel code to file."""
kernel_file.write_text(kernel_code)
if logger:
logger.debug(f"Wrote kernel to {kernel_file}")


def _save_debug_file(
filepath: Path,
content: str,
logger: Logger | None = None,
) -> None:
"""Save content to a file for debugging purposes."""
try:
filepath.write_text(content)
if logger:
logger.debug(f"Saved debug file: {filepath}")
except Exception as e:
if logger:
logger.warning(f"Failed to save debug file {filepath}: {e}")


# ------------------------
# Test Execution Utilities
# ------------------------


def _run_test_process(test_file: Path, workdir: Path, result_queue: mp.Queue) -> None:
Expand Down