import gradio as gr from transformers import pipeline import torch from PIL import Image import numpy as np WOUND_TYPES = { "stage1": "Estágio 1 - Lesão Superficial", "stage2": "Estágio 2 - Lesão Parcial", "stage3": "Estágio 3 - Lesão Profunda", "stage4": "Estágio 4 - Lesão Grave", "unstageable": "Não Classificável", "healthy": "Pele Saudável" } def load_model(): # Usando um modelo público do Hugging Face para classificação de imagens classifier = pipeline( "image-classification", model="google/vit-base-patch16-224", device=0 if torch.cuda.is_available() else -1 ) return classifier def preprocess_image(image): if isinstance(image, np.ndarray): image = Image.fromarray(image) image = image.convert('RGB') return image def classify_wound(image): if image is None: return None classifier = load_model() processed_image = preprocess_image(image) # Classificação da imagem results = classifier(processed_image) # Formatando resultados formatted_results = [] for result in results: label = result['label'].replace('_', ' ').title() score = result['score'] formatted_results.append((label, score)) return formatted_results # Interface Gradio with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏥 Classificador de Imagens Médicas Sistema de classificação de imagens usando Vision Transformer (ViT). """) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload da Imagem", type="pil" ) submit_btn = gr.Button("Analisar Imagem", variant="primary") with gr.Column(): output = gr.Label( label="Classificação", num_top_classes=3 ) with gr.Row(): with gr.Accordion("Informações", open=False): gr.Markdown(""" ### Recomendações para Melhores Resultados: 1. Use imagens bem iluminadas 2. Capture a imagem em um ângulo perpendicular 3. Mantenha um fundo neutro e limpo 4. Evite sombras ou reflexos excessivos ### Observações: - Este é um modelo de classificação geral - Os resultados são aproximações e não substituem avaliação médica - Consulte sempre um profissional de saúde para diagnóstico """) # Configurando eventos submit_btn.click( fn=classify_wound, inputs=input_image, outputs=output ) if __name__ == "__main__": demo.launch()