lucas-ventura commited on
Commit
a762f5a
·
verified ·
1 Parent(s): dd653bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -11
app.py CHANGED
@@ -11,7 +11,7 @@ from src.data.single_video import SingleVideo
11
  from src.data.utils_asr import PromptASR
12
  from src.models.llama_inference import inference
13
  from src.test.vidchapters import get_chapters
14
- from tools.download.models import download_model
15
 
16
  # Set up proxies
17
  # from urllib.request import getproxies
@@ -26,7 +26,7 @@ tokenizer = None
26
  current_peft_model = None
27
  inference_model = None
28
 
29
- LLAMA_CKPT_PATH = "meta-llama/Llama-3.1-8B-Instruct"
30
 
31
 
32
  def load_base_model():
@@ -34,16 +34,28 @@ def load_base_model():
34
  global base_model, tokenizer
35
 
36
  if base_model is None:
37
- print(f"Loading base model: {LLAMA_CKPT_PATH}")
38
- base_model = load_model_llamarecipes(
39
- model_name=LLAMA_CKPT_PATH,
40
- device_map="auto",
41
- quantization=None,
42
- use_fast_kernels=True,
43
- )
44
- base_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- tokenizer = AutoTokenizer.from_pretrained(LLAMA_CKPT_PATH)
47
  tokenizer.pad_token = tokenizer.eos_token
48
 
49
  print("Base model loaded successfully")
 
11
  from src.data.utils_asr import PromptASR
12
  from src.models.llama_inference import inference
13
  from src.test.vidchapters import get_chapters
14
+ from tools.download.models import download_base_model, download_model
15
 
16
  # Set up proxies
17
  # from urllib.request import getproxies
 
26
  current_peft_model = None
27
  inference_model = None
28
 
29
+ LLAMA_CKPT_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
30
 
31
 
32
  def load_base_model():
 
34
  global base_model, tokenizer
35
 
36
  if base_model is None:
37
+ try:
38
+ print(f"Loading base model: {LLAMA_CKPT_PATH}")
39
+ base_model = load_model_llamarecipes(
40
+ model_name=LLAMA_CKPT_PATH,
41
+ device_map="auto",
42
+ quantization=None,
43
+ use_fast_kernels=True,
44
+ )
45
+ tokenizer = AutoTokenizer.from_pretrained(LLAMA_CKPT_PATH)
46
+ except Exception as e:
47
+ # Try to get the local path using the download function
48
+ model_path = download_base_model(LLAMA_CKPT_PATH, local_dir=".")
49
+ print(f"Model path: {model_path}")
50
+ base_model = load_model_llamarecipes(
51
+ model_name=model_path,
52
+ device_map="auto",
53
+ quantization=None,
54
+ use_fast_kernels=True,
55
+ )
56
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
57
 
58
+ base_model.eval()
59
  tokenizer.pad_token = tokenizer.eos_token
60
 
61
  print("Base model loaded successfully")