import torch import torchvision import torch.nn as nn model = torchvision.models.resnet50(pretrained=False) model.fc = nn.Linear(model.fc.in_features, 2) model.load_state_dict(torch.load("model.pth")) model.to(device) model.eval() import gradio as gr from PIL import Image # Define the function to make predictions def predict(image): image = transform(image).unsqueeze(0).to(device) model.eval() with torch.no_grad(): output = model(image) _, predicted = torch.max(output.data, 1) return dataset.classes[predicted.item()] # Define the input and output components image_input = gr.inputs.Image(type="pil", label="Upload Image") label_output = gr.outputs.Label() # Create the interface interface = gr.Interface(fn=predict, inputs=image_input, outputs=label_output) # Launch the interface interface.launch()