jyo01 commited on
Commit
882f627
·
verified ·
1 Parent(s): 388937c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -106,16 +106,21 @@ def get_llm_response(prompt: str, model_name: str = "meta-llama/Llama-2-7b-chat-
106
  max_new_tokens = 1024 if is_detailed_query(prompt) else 256
107
 
108
  torch.cuda.empty_cache()
 
 
 
109
 
110
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HF_TOKEN)
111
  model = AutoModelForCausalLM.from_pretrained(
112
  model_name,
113
  device_map="auto",
 
114
  use_safetensors=False,
115
  trust_remote_code=True,
116
  torch_dtype=torch.float16,
117
  token=HF_TOKEN
118
  )
 
119
 
120
  text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
121
  outputs = text_gen(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7)
 
106
  max_new_tokens = 1024 if is_detailed_query(prompt) else 256
107
 
108
  torch.cuda.empty_cache()
109
+
110
+ if not os.path.exists("offload"):
111
+ os.makedirs("offload")
112
 
113
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HF_TOKEN)
114
  model = AutoModelForCausalLM.from_pretrained(
115
  model_name,
116
  device_map="auto",
117
+ offload_folder="offload", # Specify the folder where weights will be offloaded
118
  use_safetensors=False,
119
  trust_remote_code=True,
120
  torch_dtype=torch.float16,
121
  token=HF_TOKEN
122
  )
123
+
124
 
125
  text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
126
  outputs = text_gen(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7)