import gradio as gr from transformers import ViTImageProcessor, ViTForImageClassification, pipeline from PIL import Image import torch import pandas as pd import os # 1. Конфигурация моделей MODEL_CONFIG = { "image": { "Пневмония": { "processor": "nickmuchi/vit-finetuned-chest-xray-pneumonia", "model": "nickmuchi/vit-finetuned-chest-xray-pneumonia", "type": "image_classification" }, "Опухоль мозга": { "processor": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis", "model": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis", "type": "image_classification" }, "Диабетическая ретинопатия": { "processor": "Kontawat/vit-diabetic-retinopathy-classification", "model": "Kontawat/vit-diabetic-retinopathy-classification", "type": "image_classification" } }, "text": { "NER (Bio_ClinicalBERT)": { "model": "emilyalsentzer/Bio_ClinicalBERT", "type": "ner" }, "NER (BioBERT)": { "model": "dmis-lab/biobert-v1.1", "type": "ner" } } } # 2. Загрузка моделей def load_model_and_processor(analysis_type, model_name): if analysis_type == "Изображение": config = MODEL_CONFIG["image"].get(model_name) if not config: return None, None processor = ViTImageProcessor.from_pretrained(config["processor"]) model = ViTForImageClassification.from_pretrained(config["model"]) return processor, model elif analysis_type == "Текст": config = MODEL_CONFIG["text"].get(model_name) if not config: return None, None nlp = pipeline(config["type"], model=config["model"], tokenizer=config["model"]) return None, nlp return None, None # 3. Функция для классификации изображений def classify_image(image, model_name): processor, model = load_model_and_processor("Изображение", model_name) if not model or not processor: return "Ошибка: Модель или процессор не найдены." try: inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() predicted_class = model.config.id2label[predicted_class_idx] return f"Результат: {predicted_class}" except Exception as e: return f"Ошибка при обработке изображения: {str(e)}" # 4. Функция для обработки текста (NER) def extract_entities(text, model_name): _, nlp = load_model_and_processor("Текст", model_name) if not nlp: return "Ошибка: Модель не найдена." try: ner_results = nlp(text) entities = [] current_entity = "" current_label = "" for result in ner_results: word = result['word'] entity = result['entity'] if entity.startswith('B-'): if current_entity: entities.append((current_entity, current_label)) current_entity = word current_label = entity[2:] elif entity.startswith('I-') and current_label == entity[2:]: current_entity += " " + word else: if current_entity: entities.append((current_entity, current_label)) current_entity = "" current_label = "" if current_entity: entities.append((current_entity, current_label)) return "\n".join([f"{entity[0]}: {entity[1]}" for entity in entities]) if entities else "Сущности не найдены." except Exception as e: return f"Ошибка при обработке текста: {str(e)}" # 5. Функция для обработки CSV-файла def process_csv(file, model_name): try: df = pd.read_csv(file.name) if not all(col in df.columns for col in ['id', 'text', 'entities']): return "Ошибка: CSV должен содержать колонки id, text, entities" results = [] for _, row in df.iterrows(): text = row['text'] true_entities = row['entities'] predicted_entities = extract_entities(text, model_name) results.append({ "ID": row['id'], "Текст": text, "Ожидаемые сущности": true_entities, "Предсказанные сущности": predicted_entities }) results_df = pd.DataFrame(results) output_file = "ner_results.csv" results_df.to_csv(output_file, index=False) return results_df.to_string(), output_file except Exception as e: return f"Ошибка при обработке CSV: {str(e)}", None # 6. Gradio интерфейс with gr.Blocks(fill_height=True) as demo: with gr.Sidebar(): gr.Markdown("# Медицинский анализ") gr.Markdown("Универсальное приложение для анализа медицинских изображений и текстов. Выберите тип анализа и модель.") with gr.Row(): with gr.Column(): analysis_type = gr.Dropdown( choices=["Изображение", "Текст"], label="Тип анализа", value="Изображение" ) model_name = gr.Dropdown( choices=list(MODEL_CONFIG["image"].keys()), label="Выберите модель", value="Пневмония" ) image_input = gr.Image(type="pil", label="Загрузите изображение (для анализа изображений)") text_input = gr.Textbox(label="Введите текст (для анализа текста)", visible=False) csv_input = gr.File(label="Загрузите CSV-файл (для анализа текста)", visible=False) analyze_button = gr.Button("Анализировать") with gr.Column(): output = gr.Textbox(label="Результат") csv_output = gr.File(label="Результаты обработки CSV") # Динамическое обновление моделей и видимости входов def update_inputs(analysis_type): model_choices = list(MODEL_CONFIG[analysis_type.lower()].keys()) image_visible = analysis_type == "Изображение" text_visible = analysis_type == "Текст" csv_visible = analysis_type == "Текст" return ( gr.update(choices=model_choices, value=model_choices[0]), gr.update(visible=image_visible), gr.update(visible=text_visible), gr.update(visible=csv_visible) ) analysis_type.change( fn=update_inputs, inputs=analysis_type, outputs=[model_name, image_input, text_input, csv_input] ) # Обработка нажатия кнопки def analyze(analysis_type, model_name, image, text, csv_file): if analysis_type == "Изображение" and image: return classify_image(image, model_name), None elif analysis_type == "Текст" and text: return extract_entities(text, model_name), None elif analysis_type == "Текст" and csv_file: return process_csv(csv_file, model_name) return "Ошибка: Загрузите данные и выберите тип анализа.", None analyze_button.click( fn=analyze, inputs=[analysis_type, model_name, image_input, text_input, csv_input], outputs=[output, csv_output] ) # 7. Запуск приложения demo.launch(server_name="0.0.0.0", server_port=7860)