ankitkupadhyay's picture
Update app.py
9e11359 verified
raw
history blame
4.06 kB
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
from PIL import Image
import gradio as gr
# Model definition and setup
class VisionLanguageModel(nn.Module):
def __init__(self):
super(VisionLanguageModel, self).__init__()
self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
self.language_model = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Linear(
self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
2 # Number of classes: benign or malignant
)
def forward(self, input_ids, attention_mask, pixel_values):
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_pooled_output = vision_outputs.pooler_output
language_outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask
)
language_pooled_output = language_outputs.pooler_output
combined_features = torch.cat(
(vision_pooled_output, language_pooled_output),
dim=1
)
logits = self.classifier(combined_features)
return logits
model = VisionLanguageModel()
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
model.eval()
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
def predict(image, text_input):
image = feature_extractor(images=image, return_tensors="pt").pixel_values
encoding = tokenizer(
text_input,
add_special_tokens=True,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
with torch.no_grad():
outputs = model(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
pixel_values=image
)
_, prediction = torch.max(outputs, dim=1)
return prediction.item() # 1 for Malignant, 0 for Benign
# Enhanced UI with color-coded prediction display
with gr.Blocks(css="""
.benign {background-color: white; border: 1px solid lightgray; padding: 10px; border-radius: 5px;}
.malignant {background-color: white; border: 1px solid lightgray; padding: 10px; border-radius: 5px;}
.benign.correct {background-color: lightgreen;}
.malignant.correct {background-color: lightgreen;}
""") as demo:
gr.Markdown(
"""
# 🩺 SKIN LESION CLASSIFICATION
Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Skin Lesion Image")
text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
with gr.Column(scale=1):
benign_output = gr.HTML("<div class='benign'>Benign</div>")
malignant_output = gr.HTML("<div class='malignant'>Malignant</div>")
gr.Markdown("## Example:")
example_image = gr.Image(value="skin_cancer_detection/Unknown-4.png") # Provide path to an example image
example_text = gr.Textbox(value="consistent with resolving/involuting keratoacanthoma 67", interactive=False)
def display_prediction(image, text_input):
prediction = predict(image, text_input)
benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "")
malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "")
return benign_html, malignant_html
# Submit button and prediction outputs
submit_btn = gr.Button("Get Prediction")
submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output])
demo.launch()