cody82 commited on
Commit
9113cb8
·
verified ·
1 Parent(s): 7bf924e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -26
app.py CHANGED
@@ -1,15 +1,13 @@
1
- import os
2
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" # отключаем нестабильную загрузку
3
-
4
- import torch
5
- import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
7
 
8
- model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
9
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
11
  model = AutoModelForCausalLM.from_pretrained(model_id)
12
-
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
@@ -19,9 +17,12 @@ context = (
19
  "расположенный в городе Иннополис, Татарстан.\n"
20
  )
21
 
22
- def respond(message, history=None):
23
- prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {message}\nОтвет:"
24
 
 
 
 
25
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
26
 
27
  with torch.no_grad():
@@ -34,22 +35,10 @@ def respond(message, history=None):
34
  pad_token_id=tokenizer.eos_token_id
35
  )
36
 
37
- full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
38
-
39
- # Извлекаем только текст после "Ответ:"
40
- if "Ответ:" in full_output:
41
- answer = full_output.split("Ответ:")[-1].strip()
42
  else:
43
- answer = full_output[len(prompt):].strip()
44
-
45
- return answer
46
-
47
- iface = gr.ChatInterface(
48
- fn=respond,
49
- title="Бот об Университете Иннополис",
50
- chatbot=gr.Chatbot(label="Диалог"),
51
- textbox=gr.Textbox(placeholder="Задай вопрос на русском...", label="Твой вопрос")
52
- )
53
 
54
- if __name__ == "__main__":
55
- iface.launch()
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
 
 
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
 
6
+ app = FastAPI()
7
 
8
+ model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
9
  tokenizer = AutoTokenizer.from_pretrained(model_id)
10
  model = AutoModelForCausalLM.from_pretrained(model_id)
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
 
17
  "расположенный в городе Иннополис, Татарстан.\n"
18
  )
19
 
20
+ class Question(BaseModel):
21
+ message: str
22
 
23
+ @app.post("/ask")
24
+ def ask(q: Question):
25
+ prompt = f"{context}\nВопрос: {q.message}\nОтвет:"
26
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
27
 
28
  with torch.no_grad():
 
35
  pad_token_id=tokenizer.eos_token_id
36
  )
37
 
38
+ output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
+ if "Ответ:" in output:
40
+ answer = output.split("Ответ:")[-1].strip()
 
 
41
  else:
42
+ answer = output.strip()
 
 
 
 
 
 
 
 
 
43
 
44
+ return {"answer": answer}