cody82 commited on
Commit
e6961b0
·
verified ·
1 Parent(s): 7e8d01b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" # отключаем нестабильную загрузку
3
 
4
  import torch
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
  model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
9
-
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
11
  model = AutoModelForCausalLM.from_pretrained(model_id)
12
 
@@ -21,7 +20,6 @@ context = (
21
 
22
  def respond(message, history=None):
23
  prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {message}\nОтвет:"
24
-
25
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
26
 
27
  with torch.no_grad():
@@ -36,7 +34,6 @@ def respond(message, history=None):
36
 
37
  full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
38
 
39
- # Извлекаем только текст после "Ответ:"
40
  if "Ответ:" in full_output:
41
  answer = full_output.split("Ответ:")[-1].strip()
42
  else:
@@ -44,12 +41,25 @@ def respond(message, history=None):
44
 
45
  return answer
46
 
47
- iface = gr.ChatInterface(
 
48
  fn=respond,
49
  title="Бот об Университете Иннополис (на русском)",
50
  chatbot=gr.Chatbot(label="Диалог"),
51
  textbox=gr.Textbox(placeholder="Задай вопрос на русском...", label="Твой вопрос")
52
  )
53
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if __name__ == "__main__":
55
- iface.launch()
 
1
  import os
2
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
3
 
4
  import torch
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
  model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_id)
10
  model = AutoModelForCausalLM.from_pretrained(model_id)
11
 
 
20
 
21
  def respond(message, history=None):
22
  prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {message}\nОтвет:"
 
23
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
24
 
25
  with torch.no_grad():
 
34
 
35
  full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
36
 
 
37
  if "Ответ:" in full_output:
38
  answer = full_output.split("Ответ:")[-1].strip()
39
  else:
 
41
 
42
  return answer
43
 
44
+ # основной Gradio чат
45
+ chat = gr.ChatInterface(
46
  fn=respond,
47
  title="Бот об Университете Иннополис (на русском)",
48
  chatbot=gr.Chatbot(label="Диалог"),
49
  textbox=gr.Textbox(placeholder="Задай вопрос на русском...", label="Твой вопрос")
50
  )
51
 
52
+ # добавим простой API endpoint
53
+ demo = gr.Blocks()
54
+
55
+ with demo:
56
+ gr.Markdown("### Иннополис Бот + API")
57
+ chat.render()
58
+
59
+ # API endpoint
60
+ @gr.api()
61
+ def ask_api(question: str):
62
+ return {"answer": respond(question)}
63
+
64
  if __name__ == "__main__":
65
+ demo.launch()