File size: 2,205 Bytes
6275495
e6961b0
6275495
 
 
52ae47c
f478cdc
af4b77b
db606bb
af4b77b
9113cb8
b10ba12
80ceb8c
6275495
9c601ea
 
9d9c29a
db606bb
 
 
 
 
af5c917
7e8d01b
6275495
db606bb
af4b77b
c263659
db606bb
 
d334b30
db606bb
5e09a54
 
24fb973
c263659
af4b77b
52ae47c
af4b77b
52ae47c
 
5e09a54
52ae47c
af4b77b
6275495
 
4436ef3
 
 
 
52ae47c
 
4436ef3
 
 
52ae47c
 
4436ef3
 
3bb8597
4436ef3
3bb8597
4436ef3
 
3bb8597
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from fastapi import FastAPI
import uvicorn

# === Модель ===
model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

context = (
    "Университет Иннополис был основан в 2012 году. "
    "Это современный вуз в России, специализирующийся на IT и робототехнике, "
    "расположенный в городе Иннополис, Татарстан.\n"
)

def respond(message, history=None):
    prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {message}\nОтвет:"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=100,
            temperature=0.8,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if "Ответ:" in output:
        answer = output.split("Ответ:")[-1].strip()
    else:
        answer = output[len(prompt):].strip()

    return answer

# === Gradio интерфейс (обязательно должен быть `demo` для Hugging Face Spaces) ===
demo = gr.ChatInterface(fn=respond, title="Иннополис Бот")

# === FastAPI (опционально, если нужен API) ===
app = FastAPI()

@app.get("/health")
def health_check():
    return {"status": "OK"}

@app.post("/ask")
async def ask(question: str):
    return {"answer": respond(question)}

# === Если запускаем локально (не в Spaces) ===
if __name__ == "__main__":
    # Для локального теста с API
    app = gr.mount_gradio_app(app, demo, path="/")
    uvicorn.run(app, host="0.0.0.0", port=8000)