Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import ViTImageProcessor, ViTForImageClassification, pipeline | |
from PIL import Image | |
import torch | |
import pandas as pd | |
import os | |
# 1. Конфигурация моделей | |
MODEL_CONFIG = { | |
"image": { | |
"Пневмония": { | |
"processor": "nickmuchi/vit-finetuned-chest-xray-pneumonia", | |
"model": "nickmuchi/vit-finetuned-chest-xray-pneumonia", | |
"type": "image_classification" | |
}, | |
"Опухоль мозга": { | |
"processor": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis", | |
"model": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis", | |
"type": "image_classification" | |
}, | |
"Диабетическая ретинопатия": { | |
"processor": "Kontawat/vit-diabetic-retinopathy-classification", | |
"model": "Kontawat/vit-diabetic-retinopathy-classification", | |
"type": "image_classification" | |
} | |
}, | |
"text": { | |
"NER (Bio_ClinicalBERT)": { | |
"model": "emilyalsentzer/Bio_ClinicalBERT", | |
"type": "ner" | |
}, | |
"NER (BioBERT)": { | |
"model": "dmis-lab/biobert-v1.1", | |
"type": "ner" | |
} | |
} | |
} | |
# 2. Загрузка моделей | |
def load_model_and_processor(analysis_type, model_name): | |
if analysis_type == "Изображение": | |
config = MODEL_CONFIG["image"].get(model_name) | |
if not config: | |
return None, None | |
processor = ViTImageProcessor.from_pretrained(config["processor"]) | |
model = ViTForImageClassification.from_pretrained(config["model"]) | |
return processor, model | |
elif analysis_type == "Текст": | |
config = MODEL_CONFIG["text"].get(model_name) | |
if not config: | |
return None, None | |
nlp = pipeline(config["type"], model=config["model"], tokenizer=config["model"]) | |
return None, nlp | |
return None, None | |
# 3. Функция для классификации изображений | |
def classify_image(image, model_name): | |
processor, model = load_model_and_processor("Изображение", model_name) | |
if not model or not processor: | |
return "Ошибка: Модель или процессор не найдены." | |
try: | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = model.config.id2label[predicted_class_idx] | |
return f"Результат: {predicted_class}" | |
except Exception as e: | |
return f"Ошибка при обработке изображения: {str(e)}" | |
# 4. Функция для обработки текста (NER) | |
def extract_entities(text, model_name): | |
_, nlp = load_model_and_processor("Текст", model_name) | |
if not nlp: | |
return "Ошибка: Модель не найдена." | |
try: | |
ner_results = nlp(text) | |
entities = [] | |
current_entity = "" | |
current_label = "" | |
for result in ner_results: | |
word = result['word'] | |
entity = result['entity'] | |
if entity.startswith('B-'): | |
if current_entity: | |
entities.append((current_entity, current_label)) | |
current_entity = word | |
current_label = entity[2:] | |
elif entity.startswith('I-') and current_label == entity[2:]: | |
current_entity += " " + word | |
else: | |
if current_entity: | |
entities.append((current_entity, current_label)) | |
current_entity = "" | |
current_label = "" | |
if current_entity: | |
entities.append((current_entity, current_label)) | |
return "\n".join([f"{entity[0]}: {entity[1]}" for entity in entities]) if entities else "Сущности не найдены." | |
except Exception as e: | |
return f"Ошибка при обработке текста: {str(e)}" | |
# 5. Функция для обработки CSV-файла | |
def process_csv(file, model_name): | |
try: | |
df = pd.read_csv(file.name) | |
if not all(col in df.columns for col in ['id', 'text', 'entities']): | |
return "Ошибка: CSV должен содержать колонки id, text, entities" | |
results = [] | |
for _, row in df.iterrows(): | |
text = row['text'] | |
true_entities = row['entities'] | |
predicted_entities = extract_entities(text, model_name) | |
results.append({ | |
"ID": row['id'], | |
"Текст": text, | |
"Ожидаемые сущности": true_entities, | |
"Предсказанные сущности": predicted_entities | |
}) | |
results_df = pd.DataFrame(results) | |
output_file = "ner_results.csv" | |
results_df.to_csv(output_file, index=False) | |
return results_df.to_string(), output_file | |
except Exception as e: | |
return f"Ошибка при обработке CSV: {str(e)}", None | |
# 6. Gradio интерфейс | |
with gr.Blocks(fill_height=True) as demo: | |
with gr.Sidebar(): | |
gr.Markdown("# Медицинский анализ") | |
gr.Markdown("Универсальное приложение для анализа медицинских изображений и текстов. Выберите тип анализа и модель.") | |
with gr.Row(): | |
with gr.Column(): | |
analysis_type = gr.Dropdown( | |
choices=["Изображение", "Текст"], | |
label="Тип анализа", | |
value="Изображение" | |
) | |
model_name = gr.Dropdown( | |
choices=list(MODEL_CONFIG["image"].keys()), | |
label="Выберите модель", | |
value="Пневмония" | |
) | |
image_input = gr.Image(type="pil", label="Загрузите изображение (для анализа изображений)") | |
text_input = gr.Textbox(label="Введите текст (для анализа текста)", visible=False) | |
csv_input = gr.File(label="Загрузите CSV-файл (для анализа текста)", visible=False) | |
analyze_button = gr.Button("Анализировать") | |
with gr.Column(): | |
output = gr.Textbox(label="Результат") | |
csv_output = gr.File(label="Результаты обработки CSV") | |
# Динамическое обновление моделей и видимости входов | |
def update_inputs(analysis_type): | |
model_choices = list(MODEL_CONFIG[analysis_type.lower()].keys()) | |
image_visible = analysis_type == "Изображение" | |
text_visible = analysis_type == "Текст" | |
csv_visible = analysis_type == "Текст" | |
return ( | |
gr.update(choices=model_choices, value=model_choices[0]), | |
gr.update(visible=image_visible), | |
gr.update(visible=text_visible), | |
gr.update(visible=csv_visible) | |
) | |
analysis_type.change( | |
fn=update_inputs, | |
inputs=analysis_type, | |
outputs=[model_name, image_input, text_input, csv_input] | |
) | |
# Обработка нажатия кнопки | |
def analyze(analysis_type, model_name, image, text, csv_file): | |
if analysis_type == "Изображение" and image: | |
return classify_image(image, model_name), None | |
elif analysis_type == "Текст" and text: | |
return extract_entities(text, model_name), None | |
elif analysis_type == "Текст" and csv_file: | |
return process_csv(csv_file, model_name) | |
return "Ошибка: Загрузите данные и выберите тип анализа.", None | |
analyze_button.click( | |
fn=analyze, | |
inputs=[analysis_type, model_name, image_input, text_input, csv_input], | |
outputs=[output, csv_output] | |
) | |
# 7. Запуск приложения | |
demo.launch(server_name="0.0.0.0", server_port=7860) |