Spaces:
Runtime error
Runtime error
File size: 2,733 Bytes
0218c20 03991d8 0218c20 29331bd 6a547e4 29331bd 2ba12d8 6f6ae2a 2ba12d8 29331bd 2ba12d8 62eaea3 0218c20 eb96984 2ba12d8 6a547e4 0218c20 03991d8 2ba12d8 03991d8 2ba12d8 03991d8 0218c20 03991d8 eb96984 0218c20 eb96984 62eaea3 03991d8 62eaea3 eb96984 2ba12d8 62eaea3 eb96984 62eaea3 eb96984 62eaea3 0bf5acd 62eaea3 0218c20 62eaea3 0218c20 62eaea3 eb96984 0bf5acd eb96984 62eaea3 eb96984 2ba12d8 6a547e4 62eaea3 03991d8 eb96984 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import asyncio
# Set cache directories
cache_dir = "/tmp/hf_home"
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.makedirs(cache_dir, exist_ok=True)
os.chmod(cache_dir, 0o777)
# Load model and tokenizer
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
# Set pad token if not defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Initialize FastAPI
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Question(BaseModel):
question: str
SYSTEM_PROMPT = "You are a helpful, professional, and highly persuasive sales assistant..."
chat_history_ids = None
async def generate_response_chunks(prompt: str):
global chat_history_ids
# Combine system prompt and user input
input_text = SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:"
new_input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
# Create attention mask (handle case where pad_token_id might be None)
attention_mask = torch.ones_like(new_input_ids)
if chat_history_ids is not None:
input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
attention_mask = torch.cat([
torch.ones_like(chat_history_ids),
attention_mask
], dim=-1)
else:
input_ids = new_input_ids
# Generate response
output_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=200,
do_sample=True,
top_p=0.9,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
# Update chat history
chat_history_ids = output_ids
# Decode only the new tokens
response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Stream the response
for word in response.split():
yield word + " "
await asyncio.sleep(0.03)
@app.post("/ask")
async def ask(question: Question):
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
) |