| import gradio as gr |
| import torch |
| import numpy as np |
|
|
| from sentence_transformers import SentenceTransformer, util |
|
|
| |
| |
| model_name = "juanwisz/modernbert-python-code-retrieval" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| embedding_model = SentenceTransformer(model_name, device=device) |
|
|
| |
| |
| |
| |
| def retrieve_top_snippets(query, code_input): |
| |
| |
| snippets = [s.strip() for s in code_input.split("---") if s.strip()] |
|
|
| |
| if len(snippets) == 0: |
| return "No code snippets detected (make sure to separate them with ---)." |
|
|
| |
| query_emb = embedding_model.encode(query, convert_to_tensor=True) |
| snippets_emb = embedding_model.encode(snippets, convert_to_tensor=True) |
|
|
| |
| cos_scores = util.cos_sim(query_emb, snippets_emb)[0] |
|
|
| |
| |
| top_indices = torch.topk(cos_scores, k=min(3, len(snippets))).indices |
|
|
| |
| results = [] |
| for idx in top_indices: |
| score = cos_scores[idx].item() |
| snippet_text = snippets[idx] |
| results.append(f"**Score**: {score:.4f}\n```python\n{snippet_text}\n```") |
|
|
| |
| return "\n\n".join(results) |
|
|
|
|
| |
| |
| |
| css = """ |
| #container { |
| margin: 0 auto; |
| max-width: 700px; |
| } |
| """ |
|
|
| with gr.Blocks(css=css) as demo: |
| gr.Markdown("# Code Retrieval using ModernBERT\n" |
| "Enter a natural language query and paste multiple Python code snippets, " |
| "delimited by `---`. We'll return the top 3 matches.") |
|
|
| with gr.Column(elem_id="container"): |
| with gr.Row(): |
| query_input = gr.Textbox( |
| label="Natural Language Query", |
| placeholder="What does your function do? e.g., 'Parse JSON from a string'" |
| ) |
|
|
| code_snippets_input = gr.Textbox( |
| label="Paste Python functions (delimited by ---)", |
| lines=10, |
| placeholder="Example:\n---\ndef parse_json(data):\n return json.loads(data)\n---\ndef add_numbers(a, b):\n return a + b\n---" |
| ) |
|
|
| search_btn = gr.Button("Search", variant="primary") |
| results_output = gr.Markdown(label="Top 3 Matches") |
|
|
| |
| search_btn.click( |
| fn=retrieve_top_snippets, |
| inputs=[query_input, code_snippets_input], |
| outputs=results_output |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|