prouditalian commited on
Commit
5b92fa1
·
1 Parent(s): 8efe9c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from transformers import pipeline
4
 
5
  model_name = "ai-forever/mGPT"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -8,14 +8,14 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
 
10
  # Декоратор @st.cache говорит Streamlit, что модель нужно загрузить только один раз, чтобы избежать утечек памяти
11
- @st.cache_resource
12
  # загружает модель
13
- def load_model():
14
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
15
 
16
 
17
  # Загружаем предварительно обученную модель
18
- answer = load_model()
19
 
20
  # Выводим заголовок страницы
21
  st.title("Помощник студента")
@@ -29,5 +29,17 @@ button = st.button('Получить ответ')
29
 
30
  # Выводим результат по нажатию кнопки
31
  if button:
 
 
 
 
 
 
 
 
 
 
 
 
32
  st.subheader("Вот мой ответ:")
33
- st.write(answer(text)[0]["generated_text"])
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ # from transformers import pipeline
4
 
5
  model_name = "ai-forever/mGPT"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
8
 
9
 
10
  # Декоратор @st.cache говорит Streamlit, что модель нужно загрузить только один раз, чтобы избежать утечек памяти
11
+ # @st.cache_resource
12
  # загружает модель
13
+ # def load_model():
14
+ # return pipeline("text-generation", model=model, tokenizer=tokenizer)
15
 
16
 
17
  # Загружаем предварительно обученную модель
18
+ # answer = load_model()
19
 
20
  # Выводим заголовок страницы
21
  st.title("Помощник студента")
 
29
 
30
  # Выводим результат по нажатию кнопки
31
  if button:
32
+ input_ids = tokenizer.encode(text, return_tensors="pt")
33
+ out = model.generate(
34
+ input_ids,
35
+ min_length=100,
36
+ max_length=100,
37
+ eos_token_id=5,
38
+ pad_token=1,
39
+ top_k=10,
40
+ top_p=0.0,
41
+ no_repeat_ngram_size=5
42
+ )
43
+ generated_text = list(map(tokenizer.decode, out))[0]
44
  st.subheader("Вот мой ответ:")
45
+ st.write(generated_text)