MartyNattakit
Update app.py to load model from Hugging Face Model Hub
2a8fcc0
raw
history blame
3.44 kB
import torch
from transformers import RobertaTokenizer, RobertaModel
import numpy as np
from scipy.special import softmax
import gradio as gr
import re
# Define the model class with dimension reduction
class CodeClassifier(torch.nn.Module):
def __init__(self, base_model, num_labels=6):
super(CodeClassifier, self).__init__()
self.base = base_model
self.reduction = torch.nn.Linear(768, 512)
self.classifier = torch.nn.Linear(512, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
reduced = self.reduction(outputs.pooler_output)
return self.classifier(reduced)
# Load model and tokenizer from Hugging Face Model Hub
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
base_model = RobertaModel.from_pretrained('martynattakit/CodeSentinel-Model') # Match your Model repo
model = CodeClassifier(base_model)
print("Loaded state dict keys:", model.state_dict().keys())
print("Classifier weight shape:", model.classifier.weight.shape)
model.eval()
model.to(device)
# Label mapping with descriptions
label_map = {
0: ('none', 'No Vulnerability Detected'),
1: ('cwe-121', 'Stack-based Buffer Overflow'),
2: ('cwe-78', 'OS Command Injection'),
3: ('cwe-190', 'Integer Overflow or Wraparound'),
4: ('cwe-191', 'Integer Underflow'),
5: ('cwe-122', 'Heap-based Buffer Overflow')
}
def load_c_file(file):
try:
if file is None:
return ""
with open(file.name, 'r', encoding='utf-8') as f:
content = f.read()
return content
except Exception as e:
return f"Error reading file: {str(e)}"
def clean_code(code):
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)
code = ' '.join(code.split())
return code
def evaluate_code(code):
try:
if len(code) >= 1500000:
return "Code too large"
cleaned_code = clean_code(code)
inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device)
print("Input shape:", inputs['input_ids'].shape)
with torch.no_grad():
outputs = model(**inputs)
print("Raw logits:", outputs.cpu().numpy())
probs = softmax(outputs.cpu().numpy(), axis=1)
pred = np.argmax(probs, axis=1)[0]
cwe, description = label_map[pred]
return f"{cwe} {description}"
except Exception as e:
return f"Error during prediction: {str(e)}"
with gr.Blocks() as web:
with gr.Row():
with gr.Column(scale=1):
code_box = gr.Textbox(lines=20, label="** C/C++ Code", placeholder="Paste your C or C++ code here...")
with gr.Column(scale=1):
cc_file = gr.File(label="Upload C/C++ File (.c or .cpp)", file_types=[".c", ".cpp"])
check_btn = gr.Button("Check")
with gr.Row():
gr.Markdown("### Result:")
with gr.Row():
with gr.Column(scale=1):
label_box = gr.Textbox(label="Vulnerability", interactive=False)
cc_file.change(fn=load_c_file, inputs=cc_file, outputs=code_box)
check_btn.click(fn=evaluate_code, inputs=code_box, outputs=[label_box])
web.launch()