Spaces:
Sleeping
Sleeping
File size: 5,847 Bytes
3b6fde5 ca20e12 3b6fde5 49ddfe8 484861b 5d910a8 484861b 5d910a8 484861b 3b6fde5 ab3a251 3b6fde5 8102149 7984efd 8102149 7984efd 3b6fde5 7984efd 3b6fde5 7fd755a 3b6fde5 7984efd e983560 7984efd 3b6fde5 7984efd 0b1e3ba 7984efd 0b1e3ba 7984efd 3b6fde5 7984efd 3b6fde5 7984efd 3b6fde5 7984efd c3eb545 7984efd 3b6fde5 7984efd 3b6fde5 7984efd 3b6fde5 7984efd 3b6fde5 7984efd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import json
import subprocess
from tempfile import NamedTemporaryFile
import gradio as gr
from huggingface_hub import InferenceClient
import wandb
import shutil
import sys
# 1. Initialize W&B (free tier) for basic logging
# non-interactive login
#api_key = os.getenv("WANDB_API_KEY")
#if api_key:
# wandb.login(key=api_key, relogin=True)
# wandb.init(project="dipesh-gen-ai-2025-personal", entity="dipesh-gen-ai-2025")
#else:
# disable wandb entirely if key missing
# wandb.init(mode="disabled")
key = os.getenv("WANDB_API_KEY")
if key:
wandb.login(key=key, relogin=True)
# Always run anonymously (no entity permission needed)
wandb.init(
project="misra-smart-fixer",
mode="online",
anonymous="must"
)
# 2. Hugging Face Inference Client (CPU-only, free quota)
HF_TOKEN = os.getenv("HF_API_TOKEN")
#client = InferenceClient(model="declare-lab/flan-alpaca-gpt4", token=HF_TOKEN)
client = InferenceClient(model="codellama/CodeLlama-7b-hf", token=HF_TOKEN)
def ensure_tool(name: str):
if shutil.which(name) is None:
print(f"Error: `{name}` not found. Please install it and retry.", file=sys.stderr)
sys.exit(1)
def run_cppcheck(source_code: str, filename: str):
"""
Runs cppcheck on the provided source code using a temporary file.
Args:
source_code (str): The content of the source file.
filename (str): The name of the source file, used to determine the language.
Returns:
tuple: The source code and a list of issues.
"""
# Check for the code checker tool
ensure_tool("cppcheck")
print("cppcheck tool found.")
issues = []
# Use a 'with' statement to ensure the temporary file is properly handled and deleted.
ext = ".c" if filename.endswith(".c") else ".cpp"
with NamedTemporaryFile(suffix=ext, mode='w', delete=True, encoding='utf-8') as tf:
tf.write(source_code)
tf.flush()
print(f"Temporary file created and written at: {tf.name}")
print(f"Content size: {len(source_code)} bytes.")
# select language/std/profile by extension
if filename.endswith(".c"):
lang_args = ["--std=c99", "--language=c", "--addon=misra"]
print("misra-c-2012")
else:
lang_args = ["--std=c++17", "--language=c++", "--profile=misra-cpp-2012"]
print("misra-cpp-2012")
cmd = ["cppcheck", "--enable=all", *lang_args, tf.name]
print("Running command: " + " ".join(cmd))
res = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8')
print("Command finished. Return code: " + str(res.returncode))
# Print both stdout and stderr for debugging purposes
print("cppcheck stdout:")
print(res.stdout)
print("cppcheck stderr:")
print(res.stderr)
# Process the output to find issues. cppcheck's default output is not JSON.
for line in res.stderr.splitlines():
if "error" in line or "warning" in line:
# You would need to parse this line to extract the message, line number, etc.
# This is a placeholder for a proper parser.
issues.append({"message": line, "line": "N/A"})
print("Issues found (after parsing): " + str(issues))
return issues
def build_prompt(source_code: str, issues: list):
"""
Builds the prompt for the language model.
"""
if not issues:
print("No issues to build prompt.")
return None
summary = "\n".join([
f"- {item['message']} at line {item['line']}"
for item in issues
])
print("Summary of issues: \n" + summary)
rule_set = "MISRA C:2012" if issues[0]['message'].endswith(".c") else "MISRA C++:2012"
prompt = f"""
You are a { 'C expert' if 'C:2012' in rule_set else 'C++ expert' } specializing in {rule_set} compliance.
Here is the source file:
```
{source_code}
```
The static analyzer reported the following violations:
{summary}
Produce a unified diff patch that fixes all violations. For each change, include a one‐sentence rationale referencing the violated rule number.
Only return the diff. No extra commentary.
"""
print("Generated prompt: \n" + prompt.strip())
return prompt.strip()
def predict_patch(prompt: str):
"""
Calls the Hugging Face Inference Client to generate a patch.
"""
print("Calling Hugging Face Inference Client...")
response = client.text_generation(prompt, max_new_tokens=256)
patch = response.generated_text
print("Patch generated.")
wandb.log({"prompt": prompt, "patch": patch})
return patch
def process_file(file_obj):
"""
This function processes the uploaded file.
It reads the content and passes it to the static analyzer.
"""
print("Processing file...")
# Read the file content as a string
src = file_obj.read().decode()
# Get the original filename from the file object
filename = file_obj.name
print("Source file - " + filename)
# Run cppcheck and get the issues
issues = run_cppcheck(src, filename)
print("Source file issues - ")
print(issues)
prompt = build_prompt(src, issues)
if prompt is None:
print("No MISRA violations found.")
return "No MISRA violations found.", None
patch = predict_patch(prompt)
print("Patch to be returned.")
return "Patch generated below:", patch
# Gradio UI
iface = gr.Interface(
fn=process_file,
inputs=gr.File(file_types=[".c", ".cpp", ".h", ".hpp"]),
outputs=[gr.Text(), gr.Text()],
title="MISRA Smart Fixer",
description="Upload C/C++ code to auto-fix MISRA violations.",
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860) |