File size: 5,640 Bytes
fbfb0eb a0046fe fbfb0eb a0046fe fbfb0eb a0046fe 2a8fcc0 fbfb0eb a0046fe fbfb0eb a0046fe fbfb0eb a0046fe fbfb0eb a0046fe fbfb0eb a0046fe fbfb0eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
from transformers import RobertaTokenizer, RobertaModel
from huggingface_hub import hf_hub_download # <--- NEW IMPORT
import numpy as np
from scipy.special import softmax
import gradio as gr
import re
import os # <--- NEW IMPORT
# Define the model class with dimension reduction
class CodeClassifier(torch.nn.Module):
def __init__(self, base_model, num_labels=6):
super(CodeClassifier, self).__init__()
self.base = base_model # This will be the microsoft/codebert-base model
self.reduction = torch.nn.Linear(768, 512)
self.classifier = torch.nn.Linear(512, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
reduced = self.reduction(outputs.pooler_output)
return self.classifier(reduced)
# --- START OF MODIFIED LOADING LOGIC ---
# Hugging Face Model ID where your .pt file is located
HF_MODEL_REPO_ID = 'martynattakit/CodeSentinel-Model'
# The exact filename of your .pt file in that repository
HF_MODEL_FILENAME = 'best_model.pt' # <--- CONFIRM THIS IS THE EXACT FILENAME YOU UPLOADED
# Load the base tokenizer (from Hugging Face Hub as before)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
# Initialize the base CodeBERT model from Hugging Face (standard download)
base_codebert_model = RobertaModel.from_pretrained('microsoft/codebert-base')
# Instantiate your custom CodeClassifier model
model = CodeClassifier(base_codebert_model, num_labels=6)
# Download the .pt file from Hugging Face Hub
print(f"Attempting to download {HF_MODEL_FILENAME} from {HF_MODEL_REPO_ID}...")
try:
# hf_hub_download will download the file and return its local path (which might be in cache)
downloaded_model_path = hf_hub_download(
repo_id=HF_MODEL_REPO_ID,
filename=HF_MODEL_FILENAME,
# If your model repo is private, you might need to ensure
# that huggingface-cli login has been run in your environment.
)
print(f"Model downloaded to: {downloaded_model_path}")
# Load the state dictionary from the downloaded .pt file
state_dict = torch.load(downloaded_model_path, map_location=device)
# Handle 'module.' prefix if the model was saved with DataParallel
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v # remove 'module.' prefix
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
print(f"Successfully loaded model state into CodeClassifier.")
except Exception as e:
print(f"Error during model download or loading: {e}")
print("Please ensure:")
print(f"1. The repository '{HF_MODEL_REPO_ID}' exists and is public (or you're logged in with `huggingface-cli login`).")
print(f"2. The file '{HF_MODEL_FILENAME}' exists within that repository on Hugging Face and is spelled exactly correctly.")
# Exiting here is good for deployment environments like Hugging Face Spaces,
# as it makes the error clear early on.
exit()
# --- END OF MODIFIED LOADING LOGIC ---
print("Loaded state dict keys (after loading .pt):", model.state_dict().keys())
print("Classifier weight shape (after loading .pt):", model.classifier.weight.shape)
model.eval()
model.to(device)
# Label mapping with descriptions
label_map = {
0: ('none', 'No Vulnerability Detected'),
1: ('cwe-121', 'Stack-based Buffer Overflow'),
2: ('cwe-78', 'OS Command Injection'),
3: ('cwe-190', 'Integer Overflow or Wraparound'),
4: ('cwe-191', 'Integer Underflow'),
5: ('cwe-122', 'Heap-based Buffer Overflow')
}
def load_c_file(file):
try:
if file is None:
return ""
with open(file.name, 'r', encoding='utf-8') as f:
content = f.read()
return content
except Exception as e:
return f"Error reading file: {str(e)}"
def clean_code(code):
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)
code = ' '.join(code.split())
return code
def evaluate_code(code):
try:
if len(code) >= 1500000:
return "Code too large"
cleaned_code = clean_code(code)
inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device)
print("Input shape:", inputs['input_ids'].shape)
with torch.no_grad():
outputs = model(**inputs)
print("Raw logits:", outputs.cpu().numpy())
probs = softmax(outputs.cpu().numpy(), axis=1)
pred = np.argmax(probs, axis=1)[0]
cwe, description = label_map[pred]
return f"{cwe} {description}"
except Exception as e:
return f"Error during prediction: {str(e)}"
with gr.Blocks() as web:
with gr.Row():
with gr.Column(scale=1):
code_box = gr.Textbox(lines=20, label="** C/C++ Code", placeholder="Paste your C or C++ code here...")
with gr.Column(scale=1):
cc_file = gr.File(label="Upload C/C++ File (.c or .cpp)", file_types=[".c", ".cpp"])
check_btn = gr.Button("Check")
with gr.Row():
gr.Markdown("### Result:")
with gr.Row():
with gr.Column(scale=1):
label_box = gr.Textbox(label="Vulnerability", interactive=False)
cc_file.change(fn=load_c_file, inputs=cc_file, outputs=code_box)
check_btn.click(fn=evaluate_code, inputs=code_box, outputs=[label_box])
web.launch() |