Spaces:
Sleeping
Sleeping
import os | |
import json | |
import subprocess | |
from tempfile import NamedTemporaryFile | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import wandb | |
import shutil | |
import sys | |
# 1. Initialize W&B (free tier) for basic logging | |
# non-interactive login | |
#api_key = os.getenv("WANDB_API_KEY") | |
#if api_key: | |
# wandb.login(key=api_key, relogin=True) | |
# wandb.init(project="dipesh-gen-ai-2025-personal", entity="dipesh-gen-ai-2025") | |
#else: | |
# disable wandb entirely if key missing | |
# wandb.init(mode="disabled") | |
key = os.getenv("WANDB_API_KEY") | |
if key: | |
wandb.login(key=key, relogin=True) | |
# Always run anonymously (no entity permission needed) | |
wandb.init( | |
project="misra-smart-fixer", | |
mode="online", | |
anonymous="must" | |
) | |
# 2. Hugging Face Inference Client (CPU-only, free quota) | |
HF_TOKEN = os.getenv("HF_API_TOKEN") | |
#client = InferenceClient(model="declare-lab/flan-alpaca-gpt4", token=HF_TOKEN) | |
client = InferenceClient(model="codellama/CodeLlama-7b-hf", token=HF_TOKEN) | |
def ensure_tool(name: str): | |
if shutil.which(name) is None: | |
print(f"Error: `{name}` not found. Please install it and retry.", file=sys.stderr) | |
sys.exit(1) | |
def run_cppcheck(source_code: str, filename: str): | |
""" | |
Runs cppcheck on the provided source code using a temporary file. | |
Args: | |
source_code (str): The content of the source file. | |
filename (str): The name of the source file, used to determine the language. | |
Returns: | |
tuple: The source code and a list of issues. | |
""" | |
# Check for the code checker tool | |
ensure_tool("cppcheck") | |
print("cppcheck tool found.") | |
issues = [] | |
# Use a 'with' statement to ensure the temporary file is properly handled and deleted. | |
ext = ".c" if filename.endswith(".c") else ".cpp" | |
with NamedTemporaryFile(suffix=ext, mode='w', delete=True, encoding='utf-8') as tf: | |
tf.write(source_code) | |
tf.flush() | |
print(f"Temporary file created and written at: {tf.name}") | |
print(f"Content size: {len(source_code)} bytes.") | |
# select language/std/profile by extension | |
if filename.endswith(".c"): | |
lang_args = ["--std=c99", "--language=c", "--addon=misra"] | |
print("misra-c-2012") | |
else: | |
lang_args = ["--std=c++17", "--language=c++", "--profile=misra-cpp-2012"] | |
print("misra-cpp-2012") | |
cmd = ["cppcheck", "--enable=all", *lang_args, tf.name] | |
print("Running command: " + " ".join(cmd)) | |
res = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8') | |
print("Command finished. Return code: " + str(res.returncode)) | |
# Print both stdout and stderr for debugging purposes | |
print("cppcheck stdout:") | |
print(res.stdout) | |
print("cppcheck stderr:") | |
print(res.stderr) | |
# Process the output to find issues. cppcheck's default output is not JSON. | |
for line in res.stderr.splitlines(): | |
if "error" in line or "warning" in line: | |
# You would need to parse this line to extract the message, line number, etc. | |
# This is a placeholder for a proper parser. | |
issues.append({"message": line, "line": "N/A"}) | |
print("Issues found (after parsing): " + str(issues)) | |
return issues | |
def build_prompt(source_code: str, issues: list): | |
""" | |
Builds the prompt for the language model. | |
""" | |
if not issues: | |
print("No issues to build prompt.") | |
return None | |
summary = "\n".join([ | |
f"- {item['message']} at line {item['line']}" | |
for item in issues | |
]) | |
print("Summary of issues: \n" + summary) | |
rule_set = "MISRA C:2012" if issues[0]['message'].endswith(".c") else "MISRA C++:2012" | |
prompt = f""" | |
You are a { 'C expert' if 'C:2012' in rule_set else 'C++ expert' } specializing in {rule_set} compliance. | |
Here is the source file: | |
``` | |
{source_code} | |
``` | |
The static analyzer reported the following violations: | |
{summary} | |
Produce a unified diff patch that fixes all violations. For each change, include a one‐sentence rationale referencing the violated rule number. | |
Only return the diff. No extra commentary. | |
""" | |
print("Generated prompt: \n" + prompt.strip()) | |
return prompt.strip() | |
def predict_patch(prompt: str): | |
""" | |
Calls the Hugging Face Inference Client to generate a patch. | |
""" | |
print("Calling Hugging Face Inference Client...") | |
response = client.text_generation(prompt, max_new_tokens=256) | |
patch = response.generated_text | |
print("Patch generated.") | |
wandb.log({"prompt": prompt, "patch": patch}) | |
return patch | |
def process_file(file_obj): | |
""" | |
This function processes the uploaded file. | |
It reads the content and passes it to the static analyzer. | |
""" | |
print("Processing file...") | |
# Read the file content as a string | |
src = file_obj.read().decode() | |
# Get the original filename from the file object | |
filename = file_obj.name | |
print("Source file - " + filename) | |
# Run cppcheck and get the issues | |
issues = run_cppcheck(src, filename) | |
print("Source file issues - ") | |
print(issues) | |
prompt = build_prompt(src, issues) | |
if prompt is None: | |
print("No MISRA violations found.") | |
return "No MISRA violations found.", None | |
patch = predict_patch(prompt) | |
print("Patch to be returned.") | |
return "Patch generated below:", patch | |
# Gradio UI | |
iface = gr.Interface( | |
fn=process_file, | |
inputs=gr.File(file_types=[".c", ".cpp", ".h", ".hpp"]), | |
outputs=[gr.Text(), gr.Text()], | |
title="MISRA Smart Fixer", | |
description="Upload C/C++ code to auto-fix MISRA violations.", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860) |