File size: 8,372 Bytes
cab9826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification, pipeline
from PIL import Image
import torch
import pandas as pd
import os

# 1. Конфигурация моделей
MODEL_CONFIG = {
    "image": {
        "Пневмония": {
            "processor": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
            "model": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
            "type": "image_classification"
        },
        "Опухоль мозга": {
            "processor": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis",
            "model": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis",
            "type": "image_classification"
        },
        "Диабетическая ретинопатия": {
            "processor": "Kontawat/vit-diabetic-retinopathy-classification",
            "model": "Kontawat/vit-diabetic-retinopathy-classification",
            "type": "image_classification"
        }
    },
    "text": {
        "NER (Bio_ClinicalBERT)": {
            "model": "emilyalsentzer/Bio_ClinicalBERT",
            "type": "ner"
        },
        "NER (BioBERT)": {
            "model": "dmis-lab/biobert-v1.1",
            "type": "ner"
        }
    }
}

# 2. Загрузка моделей
def load_model_and_processor(analysis_type, model_name):
    if analysis_type == "Изображение":
        config = MODEL_CONFIG["image"].get(model_name)
        if not config:
            return None, None
        processor = ViTImageProcessor.from_pretrained(config["processor"])
        model = ViTForImageClassification.from_pretrained(config["model"])
        return processor, model
    elif analysis_type == "Текст":
        config = MODEL_CONFIG["text"].get(model_name)
        if not config:
            return None, None
        nlp = pipeline(config["type"], model=config["model"], tokenizer=config["model"])
        return None, nlp
    return None, None

# 3. Функция для классификации изображений
def classify_image(image, model_name):
    processor, model = load_model_and_processor("Изображение", model_name)
    if not model or not processor:
        return "Ошибка: Модель или процессор не найдены."
    
    try:
        inputs = processor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            predicted_class_idx = logits.argmax(-1).item()
            predicted_class = model.config.id2label[predicted_class_idx]
        return f"Результат: {predicted_class}"
    except Exception as e:
        return f"Ошибка при обработке изображения: {str(e)}"

# 4. Функция для обработки текста (NER)
def extract_entities(text, model_name):
    _, nlp = load_model_and_processor("Текст", model_name)
    if not nlp:
        return "Ошибка: Модель не найдена."
    
    try:
        ner_results = nlp(text)
        entities = []
        current_entity = ""
        current_label = ""
        
        for result in ner_results:
            word = result['word']
            entity = result['entity']
            if entity.startswith('B-'):
                if current_entity:
                    entities.append((current_entity, current_label))
                current_entity = word
                current_label = entity[2:]
            elif entity.startswith('I-') and current_label == entity[2:]:
                current_entity += " " + word
            else:
                if current_entity:
                    entities.append((current_entity, current_label))
                current_entity = ""
                current_label = ""
        
        if current_entity:
            entities.append((current_entity, current_label))
        
        return "\n".join([f"{entity[0]}: {entity[1]}" for entity in entities]) if entities else "Сущности не найдены."
    except Exception as e:
        return f"Ошибка при обработке текста: {str(e)}"

# 5. Функция для обработки CSV-файла
def process_csv(file, model_name):
    try:
        df = pd.read_csv(file.name)
        if not all(col in df.columns for col in ['id', 'text', 'entities']):
            return "Ошибка: CSV должен содержать колонки id, text, entities"
        
        results = []
        for _, row in df.iterrows():
            text = row['text']
            true_entities = row['entities']
            predicted_entities = extract_entities(text, model_name)
            results.append({
                "ID": row['id'],
                "Текст": text,
                "Ожидаемые сущности": true_entities,
                "Предсказанные сущности": predicted_entities
            })
        
        results_df = pd.DataFrame(results)
        output_file = "ner_results.csv"
        results_df.to_csv(output_file, index=False)
        return results_df.to_string(), output_file
    except Exception as e:
        return f"Ошибка при обработке CSV: {str(e)}", None

# 6. Gradio интерфейс
with gr.Blocks(fill_height=True) as demo:
    with gr.Sidebar():
        gr.Markdown("# Медицинский анализ")
        gr.Markdown("Универсальное приложение для анализа медицинских изображений и текстов. Выберите тип анализа и модель.")
    
    with gr.Row():
        with gr.Column():
            analysis_type = gr.Dropdown(
                choices=["Изображение", "Текст"],
                label="Тип анализа",
                value="Изображение"
            )
            model_name = gr.Dropdown(
                choices=list(MODEL_CONFIG["image"].keys()),
                label="Выберите модель",
                value="Пневмония"
            )
            image_input = gr.Image(type="pil", label="Загрузите изображение (для анализа изображений)")
            text_input = gr.Textbox(label="Введите текст (для анализа текста)", visible=False)
            csv_input = gr.File(label="Загрузите CSV-файл (для анализа текста)", visible=False)
            analyze_button = gr.Button("Анализировать")
        
        with gr.Column():
            output = gr.Textbox(label="Результат")
            csv_output = gr.File(label="Результаты обработки CSV")

    # Динамическое обновление моделей и видимости входов
    def update_inputs(analysis_type):
        model_choices = list(MODEL_CONFIG[analysis_type.lower()].keys())
        image_visible = analysis_type == "Изображение"
        text_visible = analysis_type == "Текст"
        csv_visible = analysis_type == "Текст"
        return (
            gr.update(choices=model_choices, value=model_choices[0]),
            gr.update(visible=image_visible),
            gr.update(visible=text_visible),
            gr.update(visible=csv_visible)
        )

    analysis_type.change(
        fn=update_inputs,
        inputs=analysis_type,
        outputs=[model_name, image_input, text_input, csv_input]
    )

    # Обработка нажатия кнопки
    def analyze(analysis_type, model_name, image, text, csv_file):
        if analysis_type == "Изображение" and image:
            return classify_image(image, model_name), None
        elif analysis_type == "Текст" and text:
            return extract_entities(text, model_name), None
        elif analysis_type == "Текст" and csv_file:
            return process_csv(csv_file, model_name)
        return "Ошибка: Загрузите данные и выберите тип анализа.", None

    analyze_button.click(
        fn=analyze,
        inputs=[analysis_type, model_name, image_input, text_input, csv_input],
        outputs=[output, csv_output]
    )

# 7. Запуск приложения
demo.launch(server_name="0.0.0.0", server_port=7860)