| | --- |
| | base_model: |
| | - Qwen/QwQ-32B |
| | tags: |
| | - code |
| | --- |
| | |
| | # Model Summary |
| | KernelCoder is trained on a curated dataset of reasoning traces and CUDA kernel pairs. |
| |
|
| | See details in [paper](https://lkongam.github.io/ConCuR/). |
| |
|
| | # Usage |
| |
|
| | ```python |
| | from vllm import LLM, SamplingParams |
| | from transformers import AutoTokenizer |
| | import torch |
| | import re |
| | from typing import List, Tuple |
| | from string import Template |
| | PROMPT_TEMPLATE = Template(''' |
| | ''') |
| | |
| | class KernelCoder: |
| | |
| | def __init__(self, model_name="lkongam/KernelCoder", tensor_parallel_size=1, gpu_memory_utilization=0.9): |
| | |
| | self.model_name = model_name |
| | |
| | self.llm = LLM( |
| | model=model_name, |
| | tensor_parallel_size=tensor_parallel_size, |
| | gpu_memory_utilization=gpu_memory_utilization, |
| | trust_remote_code=True, |
| | dtype="auto" |
| | ) |
| | |
| | self.tokenizer = self.llm.get_tokenizer() |
| | self.device = torch.device("cuda") |
| | |
| | def generate_raw(self, prompt, temperature=1.0): |
| | messages = [ |
| | {"role": "user", "content": prompt} |
| | ] |
| | text = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | enable_thinking=True |
| | ) |
| | return text |
| | |
| | def extract_last_code_block(text): |
| | code_blocks = re.findall(r"```(?:python)?\n(.*?)```", text, re.DOTALL) |
| | if code_blocks: |
| | return code_blocks[-1].strip() |
| | match = re.search(r"</think>(.*)", text, re.S) |
| | after_think = match.group(1).strip() if match else text |
| | if not after_think: |
| | return None |
| | import_match = re.search(r"\bimport\b", after_think) |
| | if import_match: |
| | return after_think[import_match.start():].strip() |
| | return after_think.strip() |
| |
|
| | origin_code = """ |
| | """ |
| | |
| | model = KernelCoder(model_name="lkongam/KernelCoder") |
| |
|
| | prompt = PROMPT_TEMPLATE.substitute(code=origin_code) |
| | code_output = model.generate_raw(prompt) |
| | code = extract_last_code_block(code_output) |
| | print(code) |
| | ``` |
| | |
| | # Evaluation |
| |  |
| | |
| | Left: Pass@1, Right: Pass@10. |