tigrica007's picture
Create app.py
cab9826 verified
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification, pipeline
from PIL import Image
import torch
import pandas as pd
import os
# 1. Конфигурация моделей
MODEL_CONFIG = {
"image": {
"Пневмония": {
"processor": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
"model": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
"type": "image_classification"
},
"Опухоль мозга": {
"processor": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis",
"model": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis",
"type": "image_classification"
},
"Диабетическая ретинопатия": {
"processor": "Kontawat/vit-diabetic-retinopathy-classification",
"model": "Kontawat/vit-diabetic-retinopathy-classification",
"type": "image_classification"
}
},
"text": {
"NER (Bio_ClinicalBERT)": {
"model": "emilyalsentzer/Bio_ClinicalBERT",
"type": "ner"
},
"NER (BioBERT)": {
"model": "dmis-lab/biobert-v1.1",
"type": "ner"
}
}
}
# 2. Загрузка моделей
def load_model_and_processor(analysis_type, model_name):
if analysis_type == "Изображение":
config = MODEL_CONFIG["image"].get(model_name)
if not config:
return None, None
processor = ViTImageProcessor.from_pretrained(config["processor"])
model = ViTForImageClassification.from_pretrained(config["model"])
return processor, model
elif analysis_type == "Текст":
config = MODEL_CONFIG["text"].get(model_name)
if not config:
return None, None
nlp = pipeline(config["type"], model=config["model"], tokenizer=config["model"])
return None, nlp
return None, None
# 3. Функция для классификации изображений
def classify_image(image, model_name):
processor, model = load_model_and_processor("Изображение", model_name)
if not model or not processor:
return "Ошибка: Модель или процессор не найдены."
try:
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
return f"Результат: {predicted_class}"
except Exception as e:
return f"Ошибка при обработке изображения: {str(e)}"
# 4. Функция для обработки текста (NER)
def extract_entities(text, model_name):
_, nlp = load_model_and_processor("Текст", model_name)
if not nlp:
return "Ошибка: Модель не найдена."
try:
ner_results = nlp(text)
entities = []
current_entity = ""
current_label = ""
for result in ner_results:
word = result['word']
entity = result['entity']
if entity.startswith('B-'):
if current_entity:
entities.append((current_entity, current_label))
current_entity = word
current_label = entity[2:]
elif entity.startswith('I-') and current_label == entity[2:]:
current_entity += " " + word
else:
if current_entity:
entities.append((current_entity, current_label))
current_entity = ""
current_label = ""
if current_entity:
entities.append((current_entity, current_label))
return "\n".join([f"{entity[0]}: {entity[1]}" for entity in entities]) if entities else "Сущности не найдены."
except Exception as e:
return f"Ошибка при обработке текста: {str(e)}"
# 5. Функция для обработки CSV-файла
def process_csv(file, model_name):
try:
df = pd.read_csv(file.name)
if not all(col in df.columns for col in ['id', 'text', 'entities']):
return "Ошибка: CSV должен содержать колонки id, text, entities"
results = []
for _, row in df.iterrows():
text = row['text']
true_entities = row['entities']
predicted_entities = extract_entities(text, model_name)
results.append({
"ID": row['id'],
"Текст": text,
"Ожидаемые сущности": true_entities,
"Предсказанные сущности": predicted_entities
})
results_df = pd.DataFrame(results)
output_file = "ner_results.csv"
results_df.to_csv(output_file, index=False)
return results_df.to_string(), output_file
except Exception as e:
return f"Ошибка при обработке CSV: {str(e)}", None
# 6. Gradio интерфейс
with gr.Blocks(fill_height=True) as demo:
with gr.Sidebar():
gr.Markdown("# Медицинский анализ")
gr.Markdown("Универсальное приложение для анализа медицинских изображений и текстов. Выберите тип анализа и модель.")
with gr.Row():
with gr.Column():
analysis_type = gr.Dropdown(
choices=["Изображение", "Текст"],
label="Тип анализа",
value="Изображение"
)
model_name = gr.Dropdown(
choices=list(MODEL_CONFIG["image"].keys()),
label="Выберите модель",
value="Пневмония"
)
image_input = gr.Image(type="pil", label="Загрузите изображение (для анализа изображений)")
text_input = gr.Textbox(label="Введите текст (для анализа текста)", visible=False)
csv_input = gr.File(label="Загрузите CSV-файл (для анализа текста)", visible=False)
analyze_button = gr.Button("Анализировать")
with gr.Column():
output = gr.Textbox(label="Результат")
csv_output = gr.File(label="Результаты обработки CSV")
# Динамическое обновление моделей и видимости входов
def update_inputs(analysis_type):
model_choices = list(MODEL_CONFIG[analysis_type.lower()].keys())
image_visible = analysis_type == "Изображение"
text_visible = analysis_type == "Текст"
csv_visible = analysis_type == "Текст"
return (
gr.update(choices=model_choices, value=model_choices[0]),
gr.update(visible=image_visible),
gr.update(visible=text_visible),
gr.update(visible=csv_visible)
)
analysis_type.change(
fn=update_inputs,
inputs=analysis_type,
outputs=[model_name, image_input, text_input, csv_input]
)
# Обработка нажатия кнопки
def analyze(analysis_type, model_name, image, text, csv_file):
if analysis_type == "Изображение" and image:
return classify_image(image, model_name), None
elif analysis_type == "Текст" and text:
return extract_entities(text, model_name), None
elif analysis_type == "Текст" and csv_file:
return process_csv(csv_file, model_name)
return "Ошибка: Загрузите данные и выберите тип анализа.", None
analyze_button.click(
fn=analyze,
inputs=[analysis_type, model_name, image_input, text_input, csv_input],
outputs=[output, csv_output]
)
# 7. Запуск приложения
demo.launch(server_name="0.0.0.0", server_port=7860)