File size: 4,058 Bytes
5c9bc3a
 
a1ee699
5c9bc3a
 
 
9e11359
5c9bc3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1ee699
5c9bc3a
 
 
a1ee699
5c9bc3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e11359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1ee699
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
from PIL import Image
import gradio as gr

# Model definition and setup
class VisionLanguageModel(nn.Module):
    def __init__(self):
        super(VisionLanguageModel, self).__init__()
        self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.language_model = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(
            self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
            2  # Number of classes: benign or malignant
        )

    def forward(self, input_ids, attention_mask, pixel_values):
        vision_outputs = self.vision_model(pixel_values=pixel_values)
        vision_pooled_output = vision_outputs.pooler_output

        language_outputs = self.language_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        language_pooled_output = language_outputs.pooler_output

        combined_features = torch.cat(
            (vision_pooled_output, language_pooled_output),
            dim=1
        )

        logits = self.classifier(combined_features)
        return logits

model = VisionLanguageModel()
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
model.eval()

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

def predict(image, text_input):
    image = feature_extractor(images=image, return_tensors="pt").pixel_values
    encoding = tokenizer(
        text_input,
        add_special_tokens=True,
        max_length=256,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    with torch.no_grad():
        outputs = model(
            input_ids=encoding['input_ids'],
            attention_mask=encoding['attention_mask'],
            pixel_values=image
        )
    _, prediction = torch.max(outputs, dim=1)
    return prediction.item()  # 1 for Malignant, 0 for Benign

# Enhanced UI with color-coded prediction display
with gr.Blocks(css="""
    .benign {background-color: white; border: 1px solid lightgray; padding: 10px; border-radius: 5px;}
    .malignant {background-color: white; border: 1px solid lightgray; padding: 10px; border-radius: 5px;}
    .benign.correct {background-color: lightgreen;}
    .malignant.correct {background-color: lightgreen;}
""") as demo:
    gr.Markdown(
        """
        # 🩺 SKIN LESION CLASSIFICATION 
        Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant.
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Skin Lesion Image")
            text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")

        with gr.Column(scale=1):
            benign_output = gr.HTML("<div class='benign'>Benign</div>")
            malignant_output = gr.HTML("<div class='malignant'>Malignant</div>")
            gr.Markdown("## Example:")
            example_image = gr.Image(value="skin_cancer_detection/Unknown-4.png")  # Provide path to an example image
            example_text = gr.Textbox(value="consistent with resolving/involuting keratoacanthoma 67", interactive=False)

    def display_prediction(image, text_input):
        prediction = predict(image, text_input)
        benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "")
        malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "")
        return benign_html, malignant_html

    # Submit button and prediction outputs
    submit_btn = gr.Button("Get Prediction")
    submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output])

demo.launch()