File size: 5,640 Bytes
fbfb0eb
 
a0046fe
fbfb0eb
 
 
 
a0046fe
fbfb0eb
 
 
 
 
a0046fe
2a8fcc0
 
fbfb0eb
 
 
 
 
 
a0046fe
 
 
 
 
 
 
 
fbfb0eb
 
a0046fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbfb0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0046fe
fbfb0eb
 
 
a0046fe
fbfb0eb
 
 
 
 
 
 
a0046fe
fbfb0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from transformers import RobertaTokenizer, RobertaModel
from huggingface_hub import hf_hub_download # <--- NEW IMPORT
import numpy as np
from scipy.special import softmax
import gradio as gr
import re
import os # <--- NEW IMPORT

# Define the model class with dimension reduction
class CodeClassifier(torch.nn.Module):
    def __init__(self, base_model, num_labels=6):
        super(CodeClassifier, self).__init__()
        self.base = base_model # This will be the microsoft/codebert-base model
        self.reduction = torch.nn.Linear(768, 512)
        self.classifier = torch.nn.Linear(512, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
        reduced = self.reduction(outputs.pooler_output)
        return self.classifier(reduced)

# --- START OF MODIFIED LOADING LOGIC ---

# Hugging Face Model ID where your .pt file is located
HF_MODEL_REPO_ID = 'martynattakit/CodeSentinel-Model'
# The exact filename of your .pt file in that repository
HF_MODEL_FILENAME = 'best_model.pt' # <--- CONFIRM THIS IS THE EXACT FILENAME YOU UPLOADED

# Load the base tokenizer (from Hugging Face Hub as before)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')

# Initialize the base CodeBERT model from Hugging Face (standard download)
base_codebert_model = RobertaModel.from_pretrained('microsoft/codebert-base')

# Instantiate your custom CodeClassifier model
model = CodeClassifier(base_codebert_model, num_labels=6)

# Download the .pt file from Hugging Face Hub
print(f"Attempting to download {HF_MODEL_FILENAME} from {HF_MODEL_REPO_ID}...")
try:
    # hf_hub_download will download the file and return its local path (which might be in cache)
    downloaded_model_path = hf_hub_download(
        repo_id=HF_MODEL_REPO_ID,
        filename=HF_MODEL_FILENAME,
        # If your model repo is private, you might need to ensure
        # that huggingface-cli login has been run in your environment.
    )
    print(f"Model downloaded to: {downloaded_model_path}")

    # Load the state dictionary from the downloaded .pt file
    state_dict = torch.load(downloaded_model_path, map_location=device)

    # Handle 'module.' prefix if the model was saved with DataParallel
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v # remove 'module.' prefix
        else:
            new_state_dict[k] = v
    model.load_state_dict(new_state_dict)
    print(f"Successfully loaded model state into CodeClassifier.")

except Exception as e:
    print(f"Error during model download or loading: {e}")
    print("Please ensure:")
    print(f"1. The repository '{HF_MODEL_REPO_ID}' exists and is public (or you're logged in with `huggingface-cli login`).")
    print(f"2. The file '{HF_MODEL_FILENAME}' exists within that repository on Hugging Face and is spelled exactly correctly.")
    # Exiting here is good for deployment environments like Hugging Face Spaces,
    # as it makes the error clear early on.
    exit()

# --- END OF MODIFIED LOADING LOGIC ---


print("Loaded state dict keys (after loading .pt):", model.state_dict().keys())
print("Classifier weight shape (after loading .pt):", model.classifier.weight.shape)
model.eval()
model.to(device)

# Label mapping with descriptions
label_map = {
    0: ('none', 'No Vulnerability Detected'),
    1: ('cwe-121', 'Stack-based Buffer Overflow'),
    2: ('cwe-78', 'OS Command Injection'),
    3: ('cwe-190', 'Integer Overflow or Wraparound'),
    4: ('cwe-191', 'Integer Underflow'),
    5: ('cwe-122', 'Heap-based Buffer Overflow')
}

def load_c_file(file):
    try:
        if file is None:
            return ""
        with open(file.name, 'r', encoding='utf-8') as f:
            content = f.read()
        return content
    except Exception as e:
        return f"Error reading file: {str(e)}"

def clean_code(code):
    code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
    code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)
    code = ' '.join(code.split())
    return code

def evaluate_code(code):
    try:
        if len(code) >= 1500000:
            return "Code too large"

        cleaned_code = clean_code(code)
        inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device)
        print("Input shape:", inputs['input_ids'].shape)

        with torch.no_grad():
            outputs = model(**inputs)
            print("Raw logits:", outputs.cpu().numpy())
            probs = softmax(outputs.cpu().numpy(), axis=1)
            pred = np.argmax(probs, axis=1)[0]
            cwe, description = label_map[pred]
            return f"{cwe} {description}"

    except Exception as e:
        return f"Error during prediction: {str(e)}"

with gr.Blocks() as web:
    with gr.Row():
        with gr.Column(scale=1):
            code_box = gr.Textbox(lines=20, label="** C/C++ Code", placeholder="Paste your C or C++ code here...")
        with gr.Column(scale=1):
            cc_file = gr.File(label="Upload C/C++ File (.c or .cpp)", file_types=[".c", ".cpp"])
            check_btn = gr.Button("Check")

    with gr.Row():
        gr.Markdown("### Result:")

    with gr.Row():
        with gr.Column(scale=1):
            label_box = gr.Textbox(label="Vulnerability", interactive=False)

    cc_file.change(fn=load_c_file, inputs=cc_file, outputs=code_box)
    check_btn.click(fn=evaluate_code, inputs=code_box, outputs=[label_box])

web.launch()