import torch from transformers import RobertaTokenizer, RobertaModel from huggingface_hub import hf_hub_download # <--- NEW IMPORT import numpy as np from scipy.special import softmax import gradio as gr import re import os # <--- NEW IMPORT # Define the model class with dimension reduction class CodeClassifier(torch.nn.Module): def __init__(self, base_model, num_labels=6): super(CodeClassifier, self).__init__() self.base = base_model # This will be the microsoft/codebert-base model self.reduction = torch.nn.Linear(768, 512) self.classifier = torch.nn.Linear(512, num_labels) def forward(self, input_ids, attention_mask): outputs = self.base(input_ids=input_ids, attention_mask=attention_mask) reduced = self.reduction(outputs.pooler_output) return self.classifier(reduced) # --- START OF MODIFIED LOADING LOGIC --- # Hugging Face Model ID where your .pt file is located HF_MODEL_REPO_ID = 'martynattakit/CodeSentinel-Model' # The exact filename of your .pt file in that repository HF_MODEL_FILENAME = 'best_model.pt' # <--- CONFIRM THIS IS THE EXACT FILENAME YOU UPLOADED # Load the base tokenizer (from Hugging Face Hub as before) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') # Initialize the base CodeBERT model from Hugging Face (standard download) base_codebert_model = RobertaModel.from_pretrained('microsoft/codebert-base') # Instantiate your custom CodeClassifier model model = CodeClassifier(base_codebert_model, num_labels=6) # Download the .pt file from Hugging Face Hub print(f"Attempting to download {HF_MODEL_FILENAME} from {HF_MODEL_REPO_ID}...") try: # hf_hub_download will download the file and return its local path (which might be in cache) downloaded_model_path = hf_hub_download( repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME, # If your model repo is private, you might need to ensure # that huggingface-cli login has been run in your environment. ) print(f"Model downloaded to: {downloaded_model_path}") # Load the state dictionary from the downloaded .pt file state_dict = torch.load(downloaded_model_path, map_location=device) # Handle 'module.' prefix if the model was saved with DataParallel new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v # remove 'module.' prefix else: new_state_dict[k] = v model.load_state_dict(new_state_dict) print(f"Successfully loaded model state into CodeClassifier.") except Exception as e: print(f"Error during model download or loading: {e}") print("Please ensure:") print(f"1. The repository '{HF_MODEL_REPO_ID}' exists and is public (or you're logged in with `huggingface-cli login`).") print(f"2. The file '{HF_MODEL_FILENAME}' exists within that repository on Hugging Face and is spelled exactly correctly.") # Exiting here is good for deployment environments like Hugging Face Spaces, # as it makes the error clear early on. exit() # --- END OF MODIFIED LOADING LOGIC --- print("Loaded state dict keys (after loading .pt):", model.state_dict().keys()) print("Classifier weight shape (after loading .pt):", model.classifier.weight.shape) model.eval() model.to(device) # Label mapping with descriptions label_map = { 0: ('none', 'No Vulnerability Detected'), 1: ('cwe-121', 'Stack-based Buffer Overflow'), 2: ('cwe-78', 'OS Command Injection'), 3: ('cwe-190', 'Integer Overflow or Wraparound'), 4: ('cwe-191', 'Integer Underflow'), 5: ('cwe-122', 'Heap-based Buffer Overflow') } def load_c_file(file): try: if file is None: return "" with open(file.name, 'r', encoding='utf-8') as f: content = f.read() return content except Exception as e: return f"Error reading file: {str(e)}" def clean_code(code): code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) code = re.sub(r'//.*$', '', code, flags=re.MULTILINE) code = ' '.join(code.split()) return code def evaluate_code(code): try: if len(code) >= 1500000: return "Code too large" cleaned_code = clean_code(code) inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device) print("Input shape:", inputs['input_ids'].shape) with torch.no_grad(): outputs = model(**inputs) print("Raw logits:", outputs.cpu().numpy()) probs = softmax(outputs.cpu().numpy(), axis=1) pred = np.argmax(probs, axis=1)[0] cwe, description = label_map[pred] return f"{cwe} {description}" except Exception as e: return f"Error during prediction: {str(e)}" with gr.Blocks() as web: with gr.Row(): with gr.Column(scale=1): code_box = gr.Textbox(lines=20, label="** C/C++ Code", placeholder="Paste your C or C++ code here...") with gr.Column(scale=1): cc_file = gr.File(label="Upload C/C++ File (.c or .cpp)", file_types=[".c", ".cpp"]) check_btn = gr.Button("Check") with gr.Row(): gr.Markdown("### Result:") with gr.Row(): with gr.Column(scale=1): label_box = gr.Textbox(label="Vulnerability", interactive=False) cc_file.change(fn=load_c_file, inputs=cc_file, outputs=code_box) check_btn.click(fn=evaluate_code, inputs=code_box, outputs=[label_box]) web.launch()