MostoHF commited on
Commit
3392637
·
verified ·
1 Parent(s): 4b51a56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -9
app.py CHANGED
@@ -1,15 +1,95 @@
1
  import streamlit as st
 
 
 
 
 
 
2
 
3
- st.set_page_config(page_title="Simple Input App", layout="centered")
4
 
5
- st.title("📄 Text Input Demo")
6
 
7
- # Input fields
8
- title = st.text_input("Title", value="enter title...")
9
- summary = st.text_input("Summary", value="enter summary...")
10
 
11
- # Submit button
12
- if st.button("Submit"):
13
- st.success(f"✅ Title: {title}")
14
- st.info(f"💬 Summary: {summary}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from torch import nn
4
+ import csv
5
+ from transformers import AutoModel, AutoTokenizer
6
+ from huggingface_hub import hf_hub_download
7
+ from model import ClassificationModel
8
 
 
9
 
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ MAX_LENGTH = 512
 
 
13
 
14
+ @st.cache_resource
15
+ def get_model():
16
+ base_model = AutoModel.from_pretrained("distilbert-base-cased")
17
+ class_model = ClassificationModel(base_model)
18
+
19
+ weights_path = hf_hub_download(
20
+ repo_id="MostoHF/TunedDistillBertCased",
21
+ filename="pytorch_model.bin"
22
+ )
23
+
24
+ state_dict = torch.load(weights_path, map_location=device)
25
+ class_model.load_state_dict(state_dict)
26
+ class_model.to(device)
27
+ class_model.eval()
28
+
29
+ return class_model
30
+
31
+ @st.cache_resource
32
+ def get_tokenizer():
33
+ return AutoTokenizer.from_pretrained("distilbert-base-cased")
34
+
35
+ @st.cache_resource
36
+ def get_ind_to_cat():
37
+ ind_to_category_copy = {}
38
+ with open('ind_to_category.csv', mode='r', newline='') as f:
39
+ reader = csv.reader(f)
40
+ next(reader) # skip header
41
+ for key, value in reader:
42
+ ind_to_category_copy[int(key)] = value # ключи — int
43
+ return ind_to_category_copy
44
+
45
+ class_model = get_model()
46
+ tokenizer = get_tokenizer()
47
+ ind_to_category = get_ind_to_cat()
48
+
49
+ def inference(title, abstract, threshold=0.95):
50
+ cur_elem = title + '@' + abstract
51
+
52
+ encoding = tokenizer(cur_elem, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
53
+ input_ids = encoding["input_ids"].to(device)
54
+ attention_mask = encoding["attention_mask"].to(device)
55
+
56
+ with torch.no_grad():
57
+ res_probs = torch.exp(class_model(input_ids, attention_mask)) # shape: (1, 8)
58
 
59
+ probs = res_probs.squeeze(0) # (8,)
60
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
61
+
62
+ total = 0.0
63
+ selected_indices = []
64
+ selected_probs = []
65
+
66
+ for prob, idx in zip(sorted_probs, sorted_indices):
67
+ total += prob.item()
68
+ selected_indices.append(idx.item())
69
+ selected_probs.append(prob.item())
70
+ if total >= threshold:
71
+ break
72
+
73
+ ans_themes = [ind_to_category[idx] for idx in selected_indices]
74
+ return ans_themes, selected_probs
75
+
76
+
77
+ # ------------------- Streamlit UI -------------------
78
+
79
+ st.set_page_config(page_title="Article Theme Classifier", layout="centered")
80
+ st.title("📄 Article Theme Classifier")
81
+
82
+ title = st.text_input("Title", value="Введите title...")
83
+ abstract = st.text_input("Abstract", value="Введите abstract...")
84
+ threshold = st.slider("Выберите cumulative probability threshold", 0.0, 1.0, step=0.01, value=0.95)
85
+
86
+ if st.button("Submit"):
87
+ if title or abstract:
88
+ st.success(f"✅ Title: {title}")
89
+ st.info(f"📑 Abstract: {abstract}")
90
+ themes, probs = inference(title, abstract, threshold)
91
+ st.subheader("Predicted Themes:")
92
+ for i in range(len(themes)):
93
+ st.write(f"**{themes[i]}** — {probs[i]:.4f}")
94
+ else:
95
+ st.warning("❌ Please fill in at least one of the fields.")