abdullahalioo commited on
Commit
6a547e4
Β·
verified Β·
1 Parent(s): 29331bd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -20
main.py CHANGED
@@ -5,23 +5,25 @@ from fastapi.responses import StreamingResponse
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
  import os
8
- os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
9
 
 
 
 
10
 
11
-
12
- # Load Qwen model and tokenizer (once)
13
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
16
 
17
- # Set device
18
- device = torch.device("cpu") # Or "cuda" if using GPU
19
  model.to(device)
20
 
21
- # FastAPI app
22
  app = FastAPI()
23
 
24
- # CORS settings
25
  app.add_middleware(
26
  CORSMiddleware,
27
  allow_origins=["*"],
@@ -30,26 +32,25 @@ app.add_middleware(
30
  allow_headers=["*"],
31
  )
32
 
33
- # Request body model
34
  class Question(BaseModel):
35
  question: str
36
 
37
- # System prompt (your custom instructions)
38
  SYSTEM_PROMPT = "You are Orion, an intelligent AI assistant created by Abdullah Ali, a 13-year-old from Lahore. Respond kindly and wisely."
39
 
40
- # Chat response generator
41
  async def generate_response_chunks(prompt: str):
42
- # Build prompt using Qwen's expected format
43
  qwen_prompt = (
44
  f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
45
  f"<|im_start|>user\n{prompt}<|im_end|>\n"
46
  f"<|im_start|>assistant\n"
47
  )
48
 
49
- # Tokenize input
50
  inputs = tokenizer(qwen_prompt, return_tensors="pt").to(device)
51
 
52
- # Generate response
53
  outputs = model.generate(
54
  **inputs,
55
  max_new_tokens=256,
@@ -59,16 +60,16 @@ async def generate_response_chunks(prompt: str):
59
  pad_token_id=tokenizer.eos_token_id
60
  )
61
 
62
- # Decode and yield line by line
63
  full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
  reply = full_output.split("<|im_start|>assistant\n")[-1].strip()
65
 
66
- for chunk in reply.split():
67
- yield chunk + " "
 
 
68
 
 
69
  @app.post("/ask")
70
  async def ask(question: Question):
71
- return StreamingResponse(
72
- generate_response_chunks(question.question),
73
- media_type="text/plain"
74
- )
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
  import os
8
+ import asyncio
9
 
10
+ # βœ… Set a safe and writable HF cache directory
11
+ os.environ["HF_HOME"] = "./hf_home"
12
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
13
 
14
+ # βœ… Model and tokenizer (only loaded once)
 
15
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
18
 
19
+ # βœ… Set device (use GPU if available)
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  model.to(device)
22
 
23
+ # βœ… FastAPI app
24
  app = FastAPI()
25
 
26
+ # βœ… CORS settings
27
  app.add_middleware(
28
  CORSMiddleware,
29
  allow_origins=["*"],
 
32
  allow_headers=["*"],
33
  )
34
 
35
+ # βœ… Request schema
36
  class Question(BaseModel):
37
  question: str
38
 
39
+ # βœ… System prompt
40
  SYSTEM_PROMPT = "You are Orion, an intelligent AI assistant created by Abdullah Ali, a 13-year-old from Lahore. Respond kindly and wisely."
41
 
42
+ # βœ… Streaming generator
43
  async def generate_response_chunks(prompt: str):
 
44
  qwen_prompt = (
45
  f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
46
  f"<|im_start|>user\n{prompt}<|im_end|>\n"
47
  f"<|im_start|>assistant\n"
48
  )
49
 
50
+ # Tokenize prompt
51
  inputs = tokenizer(qwen_prompt, return_tensors="pt").to(device)
52
 
53
+ # Generate output
54
  outputs = model.generate(
55
  **inputs,
56
  max_new_tokens=256,
 
60
  pad_token_id=tokenizer.eos_token_id
61
  )
62
 
63
+ # Decode output
64
  full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
  reply = full_output.split("<|im_start|>assistant\n")[-1].strip()
66
 
67
+ # Yield chunks word by word (simulating stream)
68
+ for word in reply.split():
69
+ yield word + " "
70
+ await asyncio.sleep(0.01) # slight delay for streaming effect
71
 
72
+ # βœ… POST endpoint
73
  @app.post("/ask")
74
  async def ask(question: Question):
75
+ return StreamingResponse(generate_response_chunks(question.question), media_type="text/plain")