aiapp / main.py
abdullahalioo's picture
Update main.py
8269cf3 verified
raw
history blame
3.01 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import asyncio
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI()
# CORS Middleware (for frontend access)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Update to specific frontend URL in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request body model
class Question(BaseModel):
question: str
# Load the model and tokenizer
model_name = "Qwen/Qwen2.5-7B-Instruct"
try:
logger.info(f"Loading model {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
async def generate_response_chunks(prompt: str):
try:
# Prepare the input prompt
messages = [
{"role": "system", "content": "You are Orion AI assistant created by Abdullah Ali, who is very intelligent, 13 years old, and lives in Lahore."},
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
# Asynchronous generator to yield tokens
async def stream_tokens():
for output in model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False,
streaming=True
):
token_id = output.sequences[0][-1]
token_text = tokenizer.decode([token_id], skip_special_tokens=True)
if token_text:
yield token_text
await asyncio.sleep(0.01) # Control streaming speed
logger.info("Streaming completed.")
# Yield tokens from stream_tokens
async for token in stream_tokens():
yield token
except Exception as e:
logger.error(f"Error during generation: {e}")
yield f"Error occurred: {e}"
@app.post("/ask")
async def ask(question: Question):
logger.info(f"Received question: {question.question}")
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
)
@app.get("/")
async def root():
return {"message": "Orion AI Chat API is running!"}