Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import asyncio | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# FastAPI app | |
app = FastAPI() | |
# CORS Middleware (for frontend access) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Update to specific frontend URL in production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Request body model | |
class Question(BaseModel): | |
question: str | |
# Load the model and tokenizer | |
model_name = "Qwen/Qwen2.5-7B-Instruct" | |
try: | |
logger.info(f"Loading model {model_name}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
logger.info("Model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
raise | |
async def generate_response_chunks(prompt: str): | |
try: | |
# Prepare the input prompt | |
messages = [ | |
{"role": "system", "content": "You are Orion AI assistant created by Abdullah Ali, who is very intelligent, 13 years old, and lives in Lahore."}, | |
{"role": "user", "content": prompt} | |
] | |
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
# Asynchronous generator to yield tokens | |
async def stream_tokens(): | |
for output in model.generate( | |
inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
return_dict_in_generate=True, | |
output_scores=False, | |
streaming=True | |
): | |
token_id = output.sequences[0][-1] | |
token_text = tokenizer.decode([token_id], skip_special_tokens=True) | |
if token_text: | |
yield token_text | |
await asyncio.sleep(0.01) # Control streaming speed | |
logger.info("Streaming completed.") | |
return stream_tokens() | |
except Exception as e: | |
logger.error(f"Error during generation: {e}") | |
yield f"Error occurred: {e}" | |
async def ask(question: Question): | |
logger.info(f"Received question: {question.question}") | |
return StreamingResponse( | |
generate_response_chunks(question.question), | |
media_type="text/plain" | |
) | |
async def root(): | |
return {"message": "Orion AI Chat API is running!"} | |