arya-ai-model commited on
Commit
d36dc81
·
1 Parent(s): ad04391

updated requirement

Browse files
Files changed (1) hide show
  1. model.py +14 -3
model.py CHANGED
@@ -10,16 +10,27 @@ HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
10
  if not HF_TOKEN:
11
  raise ValueError("Missing Hugging Face token. Set HUGGINGFACE_TOKEN as an environment variable.")
12
 
13
- # Load tokenizer and model with authentication
 
 
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
15
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN, device_map="auto")
 
 
 
 
 
 
 
 
 
16
 
17
  def generate_code(prompt: str, max_tokens: int = 256):
18
  """Generates code based on the input prompt."""
19
  if not prompt.strip():
20
  return "Error: Empty prompt provided."
21
 
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
24
  output = model.generate(**inputs, max_new_tokens=max_tokens)
25
  return tokenizer.decode(output[0], skip_special_tokens=True)
 
10
  if not HF_TOKEN:
11
  raise ValueError("Missing Hugging Face token. Set HUGGINGFACE_TOKEN as an environment variable.")
12
 
13
+ # Set device
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Load tokenizer with authentication
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
18
+
19
+ # Load model with optimizations
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ MODEL_NAME,
22
+ token=HF_TOKEN,
23
+ torch_dtype=torch.float16, # Reduce memory usage
24
+ low_cpu_mem_usage=True, # Optimize loading
25
+ device_map="auto", # Automatic device placement
26
+ offload_folder="offload" # Offload to disk if needed
27
+ ).to(device)
28
 
29
  def generate_code(prompt: str, max_tokens: int = 256):
30
  """Generates code based on the input prompt."""
31
  if not prompt.strip():
32
  return "Error: Empty prompt provided."
33
 
 
34
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
35
  output = model.generate(**inputs, max_new_tokens=max_tokens)
36
  return tokenizer.decode(output[0], skip_special_tokens=True)