abdullahalioo commited on
Commit
6f6ae2a
Β·
verified Β·
1 Parent(s): 463f46a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -6
main.py CHANGED
@@ -7,14 +7,24 @@ import torch
7
  import os
8
  import asyncio
9
 
10
- # βœ… Use writable temp dir for Hugging Face cache
11
- os.environ["HF_HOME"] = "/tmp/hf_home"
12
- os.makedirs(os.environ["HF_HOME"], exist_ok=True)
 
 
 
 
 
 
13
 
14
  # βœ… Load model and tokenizer
15
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
16
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
18
 
19
  # βœ… Use CUDA if available
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -64,4 +74,4 @@ async def generate_response_chunks(prompt: str):
64
  # βœ… API route
65
  @app.post("/ask")
66
  async def ask(question: Question):
67
- return StreamingResponse(generate_response_chunks(question.question), media_type="text/plain")
 
7
  import os
8
  import asyncio
9
 
10
+ # βœ… Set all cache directories to a writable location
11
+ cache_dir = "/tmp/hf_home"
12
+ os.environ["HF_HOME"] = cache_dir
13
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
14
+ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
15
+
16
+ # βœ… Create cache directory with proper permissions
17
+ os.makedirs(cache_dir, exist_ok=True)
18
+ os.chmod(cache_dir, 0o777) # Make writable by all
19
 
20
  # βœ… Load model and tokenizer
21
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
22
+ try:
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir=cache_dir)
24
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, cache_dir=cache_dir)
25
+ except Exception as e:
26
+ print(f"Error loading model: {e}")
27
+ raise
28
 
29
  # βœ… Use CUDA if available
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
74
  # βœ… API route
75
  @app.post("/ask")
76
  async def ask(question: Question):
77
+ return StreamingResponse(generate_response_chunks(question.question), media_type="text/plain"