yakine commited on
Commit
a686c13
·
verified ·
1 Parent(s): 8628226

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -35,13 +35,14 @@ model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
35
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
36
 
37
  # Load the Llama-3 model and tokenizer once during startup
 
38
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
39
  model_llama = AutoModelForCausalLM.from_pretrained(
40
  "meta-llama/Meta-Llama-3-8B",
41
  torch_dtype='float16',
42
  device_map='auto',
43
  token=hf_token
44
- )
45
 
46
  # Define your prompt template
47
  prompt_template = """...""" # Your existing prompt template here
 
35
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
36
 
37
  # Load the Llama-3 model and tokenizer once during startup
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
40
  model_llama = AutoModelForCausalLM.from_pretrained(
41
  "meta-llama/Meta-Llama-3-8B",
42
  torch_dtype='float16',
43
  device_map='auto',
44
  token=hf_token
45
+ ).to(device)
46
 
47
  # Define your prompt template
48
  prompt_template = """...""" # Your existing prompt template here