Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
import asyncio | |
# Set cache directories | |
cache_dir = "/tmp/hf_home" | |
os.environ["HF_HOME"] = cache_dir | |
os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir | |
os.makedirs(cache_dir, exist_ok=True) | |
os.chmod(cache_dir, 0o777) | |
# Load model and tokenizer | |
model_name = "microsoft/DialoGPT-small" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) | |
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) | |
# Set pad token if not defined | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Initialize FastAPI | |
app = FastAPI() | |
# Enable CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class Question(BaseModel): | |
question: str | |
SYSTEM_PROMPT = "You are a helpful, professional, and highly persuasive sales assistant..." | |
chat_history_ids = None | |
async def generate_response_chunks(prompt: str): | |
global chat_history_ids | |
# Combine system prompt and user input | |
input_text = SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:" | |
new_input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) | |
# Create attention mask (handle case where pad_token_id might be None) | |
attention_mask = torch.ones_like(new_input_ids) | |
if chat_history_ids is not None: | |
input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) | |
attention_mask = torch.cat([ | |
torch.ones_like(chat_history_ids), | |
attention_mask | |
], dim=-1) | |
else: | |
input_ids = new_input_ids | |
# Generate response | |
output_ids = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=200, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Update chat history | |
chat_history_ids = output_ids | |
# Decode only the new tokens | |
response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# Stream the response | |
for word in response.split(): | |
yield word + " " | |
await asyncio.sleep(0.03) | |
async def ask(question: Question): | |
return StreamingResponse( | |
generate_response_chunks(question.question), | |
media_type="text/plain" | |
) |