HoiAlice commited on
Commit
2afe35a
·
1 Parent(s): f81186a
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
3
+ from typing import List, Dict
4
+ import torch
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ @st.cache_resource
9
+ def load_model():
10
+ model_path = "HoiAlice/bert-paper-classifier-arxiv"
11
+ inference_tokenizer = AutoTokenizer.from_pretrained(model_path)
12
+ inference_model = AutoModelForSequenceClassification.from_pretrained(model_path)
13
+ return pipeline("text-classification", model=inference_model, tokenizer=inference_tokenizer, top_k=None)
14
+
15
+ def top_pct(preds: List[Dict], threshold: float = 0.95) -> List[Dict]:
16
+ """Возвращает топ предсказаний, пока их суммарная вероятность не превысит threshold"""
17
+ if not preds:
18
+ return []
19
+ preds = sorted(preds, key=lambda x: -x["score"])
20
+ cum_score = 0
21
+ for i, item in enumerate(preds):
22
+ cum_score += item["score"]
23
+ if cum_score >= threshold:
24
+ break
25
+ return preds[:(i+1)]
26
+
27
+ def format_predictions(preds: List[Dict]) -> str:
28
+ """Форматирует предсказания для вывода"""
29
+ if not preds:
30
+ return "Нет результатов"
31
+ return "\n".join([f"{i+1}. {item['label']} (score {item['score']:.2f})" for i, item in enumerate(preds)])
32
+
33
+ # Интерфейс Streamlit
34
+ st.set_page_config(page_title="Классификатор научных статей", page_icon="📚")
35
+ st.title("📚 Классификатор научных статей по тематикам")
36
+ st.write("Введите текст абстракта статьи, и модель определит наиболее подходящие тематики:")
37
+
38
+ # Поле для ввода текста
39
+ abstract = st.text_area(
40
+ "Текст абстракта:",
41
+ height=200,
42
+ placeholder="Введите текст научного абстракта здесь..."
43
+ )
44
+
45
+ # Слайдер для выбора порога уверенности
46
+ threshold = st.slider(
47
+ "Порог уверенности (суммарная вероятность тематик):",
48
+ min_value=0.5,
49
+ max_value=1.0,
50
+ value=0.95,
51
+ step=0.05
52
+ )
53
+
54
+ if st.button("Определить тематики"):
55
+ if not abstract.strip():
56
+ st.warning("Пожалуйста, введите текст абстракта")
57
+ else:
58
+ with st.spinner("Загружаем модель... (это может занять некоторое время при первом запуске)"):
59
+ classifier = load_model()
60
+
61
+ if classifier is not None:
62
+ with st.spinner("Анализируем текст..."):
63
+ try:
64
+ # Получаем предсказания
65
+ predictions = classifier(abstract)[0]
66
+ # Фильтруем по порогу
67
+ top_predictions = top_pct(predictions, threshold)
68
+
69
+ # Выводим результаты
70
+ st.subheader("Результаты классификации:")
71
+ st.text(format_predictions(top_predictions))
72
+
73
+ # Визуализация в виде столбчатой диаграммы
74
+ st.subheader("Визуализация:")
75
+ chart_data = {p['label']: p['score'] for p in top_predictions}
76
+ st.bar_chart(chart_data)
77
+
78
+ except Exception as e:
79
+ st.error(f"Произошла ошибка при анализе текста: {str(e)}")
80
+
81
+ # Добавляем пояснения в сайдбар
82
+ with st.sidebar:
83
+ st.markdown("""
84
+ ## О сервисе
85
+ Этот сервис использует модель PubMedBERT, обученную для классификации научных статей по тематикам.
86
+
87
+ ### Как использовать:
88
+ 1. Введите текст абстракта в поле ввода
89
+ 2. Отрегулируйте порог уверенности (по умолчанию 0.95)
90
+ 3. Нажмите кнопку "Определить тематики"
91
+
92
+ ### Техническая информация:
93
+ - Используемое устройство: {'GPU' if device == 'cuda' else 'CPU'}
94
+ - Модель: oracat/bert-paper-classifier-arxiv
95
+ """)