File size: 4,291 Bytes
3b6fde5
 
 
 
 
 
 
ca20e12
 
3b6fde5
49ddfe8
484861b
5d910a8
 
 
 
 
484861b
5d910a8
 
 
 
 
 
 
 
 
 
 
 
484861b
3b6fde5
 
ab3a251
 
3b6fde5
8102149
 
 
 
 
3b6fde5
8102149
 
 
7fd755a
0b1e3ba
839cd34
0b1e3ba
3b6fde5
 
 
 
 
0b1e3ba
 
 
 
 
 
 
 
 
 
839cd34
0b1e3ba
 
839cd34
0b1e3ba
 
 
3b6fde5
 
 
 
 
 
 
 
 
 
 
 
7fd755a
3b6fde5
 
 
 
e983560
0b1e3ba
 
 
 
 
 
 
 
 
3b6fde5
0b1e3ba
 
 
 
 
 
 
 
 
 
 
 
 
3b6fde5
 
 
 
 
 
 
 
 
 
e699979
 
c3eb545
 
3b6fde5
c3eb545
 
 
3b6fde5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import json
import subprocess
from tempfile import NamedTemporaryFile
import gradio as gr
from huggingface_hub import InferenceClient
import wandb
import shutil
import sys

# 1. Initialize W&B (free tier) for basic logging
# non-interactive login
#api_key = os.getenv("WANDB_API_KEY")
#if api_key:
#    wandb.login(key=api_key, relogin=True)
#    wandb.init(project="dipesh-gen-ai-2025-personal", entity="dipesh-gen-ai-2025")
#else:
    # disable wandb entirely if key missing
#    wandb.init(mode="disabled")

key = os.getenv("WANDB_API_KEY")
if key:
    wandb.login(key=key, relogin=True)

# Always run anonymously (no entity permission needed)
wandb.init(
    project="misra-smart-fixer",
    mode="online",
    anonymous="must"
)

# 2. Hugging Face Inference Client (CPU-only, free quota)
HF_TOKEN = os.getenv("HF_API_TOKEN")
#client = InferenceClient(model="declare-lab/flan-alpaca-gpt4", token=HF_TOKEN)
client = InferenceClient(model="codellama/CodeLlama-7b-hf", token=HF_TOKEN)

def ensure_tool(name: str):
    if shutil.which(name) is None:
        print(f"Error: `{name}` not found. Please install it and retry.", file=sys.stderr)
        sys.exit(1)

def run_cppcheck(source_code: str):
    # Check for the code checker tool
    ensure_tool("cppcheck")

    ext = ".c" if source_code.endswith(".c") else ".cpp"
    tf = NamedTemporaryFile(suffix=ext, delete=False)
    print("ext" + ext)

    tf.write(source_code.encode())
    tf.flush()
    tf.close()

    # Run Cppcheck with MISRA 2012 profile, JSON output
    # cmd = [
    #    "cppcheck", "--enable=all",
    #    "--std=c++17", "--language=c++",
    #    "--profile=misra-cpp-2012",
    #    "--template=json", tf.name
    # ]

    # select language/std/profile by extension
    if tf.name.endswith(".c"):
        lang_args = ["--std=c11", "--language=c", "--profile=misra-c-2012"]
        print("misra-c-2012")
    else:
        lang_args = ["--std=c++17", "--language=c++", "--profile=misra-cpp-2012"]
        print("misra-cpp-2012")

    cmd = ["cppcheck", "--enable=all", *lang_args, "--template=json", tf.name]

    res = subprocess.run(cmd, capture_output=True, text=True)
    try:
        issues = json.loads(res.stderr)
    except json.JSONDecodeError:
        issues = []
    return tf.name, issues

def build_prompt(filename: str, issues: list):
    with open(filename) as f:
        src = f.read()
    if not issues:
        return None

    summary = "\n".join([
        f"- {item['message']} at line {item['line']}"
        for item in issues
    ])

    # prompt = f"""
    # You are a C++ expert. The code below may violate MISRA C++:2012 rules.
    # Source code:
    # Issues:
    # {summary}
    # …"""

    # detect language from filename
    rule_set = "MISRA C:2012" if filename.endswith(".c") else "MISRA C++:2012"
    prompt = f"""
	You are a { 'C expert' if filename.endswith('.c') else 'C++ expert' } specializing in {rule_set} compliance.

	Here is the source file (`{filename}`):
	```
	{src}
	```

	The static analyzer reported the following violations:
	{summary}

	Produce a unified diff patch that fixes all violations. For each change, include a one‐sentence rationale referencing the violated rule number.
	Only return the diff. No extra commentary.
	"""

    return prompt.strip()

def predict_patch(prompt: str):
    response = client.text_generation(prompt, max_new_tokens=256)
    patch = response.generated_text
    wandb.log({"prompt": prompt, "patch": patch})
    return patch

def process_file(file_obj):
    #src = file_obj.read().decode()
    src = file_obj
    print("Source file - ")
    print(src)
    fname, issues = run_cppcheck(src)
    print("Source file issues - ")
    print(fname)
    print(issues)
    prompt = build_prompt(fname, issues)
    if prompt is None:
        return "No MISRA violations found.", None
    patch = predict_patch(prompt)
    return "Patch generated below:", patch

# Gradio UI
iface = gr.Interface(
    fn=process_file,
    inputs=gr.File(file_types=[".c", ".cpp", ".h", ".hpp"]),
    outputs=[gr.Text(), gr.Text()],
    title="MISRA Smart Fixer",
    description="Upload C/C++ code to auto-fix MISRA violations.",
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)