from fastapi import FastAPI from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os import asyncio # ✅ Set all cache directories to a writable location cache_dir = "/tmp/hf_home" os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir # ✅ Create cache directory with proper permissions os.makedirs(cache_dir, exist_ok=True) os.chmod(cache_dir, 0o777) # Make writable by all # ✅ Load model and tokenizer model_name = "Qwen/Qwen2.5-0.5B-Instruct" try: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir=cache_dir) model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, cache_dir=cache_dir) except Exception as e: print(f"Error loading model: {e}") raise # ✅ Use CUDA if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # ✅ Initialize FastAPI app = FastAPI() # ✅ Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ✅ Input data model class Question(BaseModel): question: str # ✅ Instructional system prompt SYSTEM_PROMPT = "You are Orion, an intelligent AI assistant created by Abdullah Ali, a 13-year-old from Lahore. Respond kindly and wisely." # ✅ Streaming response generator async def generate_response_chunks(prompt: str): qwen_prompt = ( f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" f"<|im_start|>user\n{prompt}<|im_end|>\n" f"<|im_start|>assistant\n" ) inputs = tokenizer(qwen_prompt, return_tensors="pt").to(device) outputs = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) reply = full_output.split("<|im_start|>assistant\n")[-1].strip() for word in reply.split(): yield word + " " await asyncio.sleep(0.01) # ✅ API route @app.post("/ask") async def ask(question: Question): return StreamingResponse(generate_response_chunks(question.question), media_type="text/plain")