HoiAlice commited on
Commit
f4962f2
·
1 Parent(s): fb4c764
Files changed (2) hide show
  1. app.py +2 -3
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
3
- from typing import List, Dict
4
  import torch
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -13,7 +12,7 @@ def load_model():
13
  inference_model = AutoModelForSequenceClassification.from_pretrained(model_path)
14
  return pipeline("text-classification", model=inference_model, tokenizer=inference_tokenizer, top_k=None)
15
 
16
- def top_pct(preds: List[Dict], threshold: float = 0.95) -> List[Dict]:
17
  """Возвращает топ предсказаний, пока их суммарная вероятность не превысит threshold"""
18
  if not preds:
19
  return []
@@ -25,7 +24,7 @@ def top_pct(preds: List[Dict], threshold: float = 0.95) -> List[Dict]:
25
  break
26
  return preds[:(i+1)]
27
 
28
- def format_predictions(preds: List[Dict]) -> str:
29
  """Форматирует предсказания для вывода"""
30
  if not preds:
31
  return "Нет результатов"
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
3
  import torch
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
12
  inference_model = AutoModelForSequenceClassification.from_pretrained(model_path)
13
  return pipeline("text-classification", model=inference_model, tokenizer=inference_tokenizer, top_k=None)
14
 
15
+ def top_pct(preds, threshold: float = 0.95):
16
  """Возвращает топ предсказаний, пока их суммарная вероятность не превысит threshold"""
17
  if not preds:
18
  return []
 
24
  break
25
  return preds[:(i+1)]
26
 
27
+ def format_predictions(preds) -> str:
28
  """Форматирует предсказания для вывода"""
29
  if not preds:
30
  return "Нет результатов"
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  transformers>=4.40.0
2
  torch>=2.2.0
3
  streamlit>=1.32.0
4
- accelerate>=0.29.0
5
- typing>=4.9.0
 
1
  transformers>=4.40.0
2
  torch>=2.2.0
3
  streamlit>=1.32.0
4
+ accelerate>=0.29.0