File size: 2,733 Bytes
0218c20
03991d8
 
 
0218c20
 
29331bd
6a547e4
29331bd
2ba12d8
6f6ae2a
 
 
 
 
 
2ba12d8
29331bd
2ba12d8
62eaea3
 
 
0218c20
eb96984
 
 
 
2ba12d8
6a547e4
0218c20
03991d8
2ba12d8
03991d8
 
2ba12d8
03991d8
 
0218c20
03991d8
 
 
 
 
 
 
 
eb96984
0218c20
eb96984
62eaea3
03991d8
62eaea3
 
eb96984
 
 
 
 
 
2ba12d8
62eaea3
 
eb96984
 
 
 
62eaea3
 
 
eb96984
62eaea3
 
0bf5acd
62eaea3
0218c20
 
62eaea3
0218c20
 
62eaea3
eb96984
 
0bf5acd
eb96984
62eaea3
 
eb96984
2ba12d8
6a547e4
62eaea3
03991d8
 
 
eb96984
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import asyncio

# Set cache directories
cache_dir = "/tmp/hf_home"
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir

os.makedirs(cache_dir, exist_ok=True)
os.chmod(cache_dir, 0o777)

# Load model and tokenizer
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)

# Set pad token if not defined
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Initialize FastAPI
app = FastAPI()

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class Question(BaseModel):
    question: str

SYSTEM_PROMPT = "You are a helpful, professional, and highly persuasive sales assistant..."

chat_history_ids = None

async def generate_response_chunks(prompt: str):
    global chat_history_ids

    # Combine system prompt and user input
    input_text = SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:"
    new_input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    
    # Create attention mask (handle case where pad_token_id might be None)
    attention_mask = torch.ones_like(new_input_ids)
    
    if chat_history_ids is not None:
        input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
        attention_mask = torch.cat([
            torch.ones_like(chat_history_ids),
            attention_mask
        ], dim=-1)
    else:
        input_ids = new_input_ids

    # Generate response
    output_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=200,
        do_sample=True,
        top_p=0.9,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id
    )

    # Update chat history
    chat_history_ids = output_ids

    # Decode only the new tokens
    response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)

    # Stream the response
    for word in response.split():
        yield word + " "
        await asyncio.sleep(0.03)

@app.post("/ask")
async def ask(question: Question):
    return StreamingResponse(
        generate_response_chunks(question.question),
        media_type="text/plain"
    )