from fastapi import FastAPI from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os import asyncio # ✅ Set a safe and writable HF cache directory os.environ["HF_HOME"] = "./hf_home" os.makedirs(os.environ["HF_HOME"], exist_ok=True) # ✅ Model and tokenizer (only loaded once) model_name = "Qwen/Qwen2.5-0.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) # ✅ Set device (use GPU if available) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # ✅ FastAPI app app = FastAPI() # ✅ CORS settings app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ✅ Request schema class Question(BaseModel): question: str # ✅ System prompt SYSTEM_PROMPT = "You are Orion, an intelligent AI assistant created by Abdullah Ali, a 13-year-old from Lahore. Respond kindly and wisely." # ✅ Streaming generator async def generate_response_chunks(prompt: str): qwen_prompt = ( f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" f"<|im_start|>user\n{prompt}<|im_end|>\n" f"<|im_start|>assistant\n" ) # Tokenize prompt inputs = tokenizer(qwen_prompt, return_tensors="pt").to(device) # Generate output outputs = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) # Decode output full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) reply = full_output.split("<|im_start|>assistant\n")[-1].strip() # Yield chunks word by word (simulating stream) for word in reply.split(): yield word + " " await asyncio.sleep(0.01) # slight delay for streaming effect # ✅ POST endpoint @app.post("/ask") async def ask(question: Question): return StreamingResponse(generate_response_chunks(question.question), media_type="text/plain")