Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ from src.data.single_video import SingleVideo
|
|
| 11 |
from src.data.utils_asr import PromptASR
|
| 12 |
from src.models.llama_inference import inference
|
| 13 |
from src.test.vidchapters import get_chapters
|
| 14 |
-
from tools.download.models import download_model
|
| 15 |
|
| 16 |
# Set up proxies
|
| 17 |
# from urllib.request import getproxies
|
|
@@ -26,7 +26,7 @@ tokenizer = None
|
|
| 26 |
current_peft_model = None
|
| 27 |
inference_model = None
|
| 28 |
|
| 29 |
-
LLAMA_CKPT_PATH = "meta-llama/Llama-3.1-8B-Instruct"
|
| 30 |
|
| 31 |
|
| 32 |
def load_base_model():
|
|
@@ -34,16 +34,28 @@ def load_base_model():
|
|
| 34 |
global base_model, tokenizer
|
| 35 |
|
| 36 |
if base_model is None:
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
tokenizer.pad_token = tokenizer.eos_token
|
| 48 |
|
| 49 |
print("Base model loaded successfully")
|
|
|
|
| 11 |
from src.data.utils_asr import PromptASR
|
| 12 |
from src.models.llama_inference import inference
|
| 13 |
from src.test.vidchapters import get_chapters
|
| 14 |
+
from tools.download.models import download_base_model, download_model
|
| 15 |
|
| 16 |
# Set up proxies
|
| 17 |
# from urllib.request import getproxies
|
|
|
|
| 26 |
current_peft_model = None
|
| 27 |
inference_model = None
|
| 28 |
|
| 29 |
+
LLAMA_CKPT_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 30 |
|
| 31 |
|
| 32 |
def load_base_model():
|
|
|
|
| 34 |
global base_model, tokenizer
|
| 35 |
|
| 36 |
if base_model is None:
|
| 37 |
+
try:
|
| 38 |
+
print(f"Loading base model: {LLAMA_CKPT_PATH}")
|
| 39 |
+
base_model = load_model_llamarecipes(
|
| 40 |
+
model_name=LLAMA_CKPT_PATH,
|
| 41 |
+
device_map="auto",
|
| 42 |
+
quantization=None,
|
| 43 |
+
use_fast_kernels=True,
|
| 44 |
+
)
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained(LLAMA_CKPT_PATH)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
# Try to get the local path using the download function
|
| 48 |
+
model_path = download_base_model(LLAMA_CKPT_PATH, local_dir=".")
|
| 49 |
+
print(f"Model path: {model_path}")
|
| 50 |
+
base_model = load_model_llamarecipes(
|
| 51 |
+
model_name=model_path,
|
| 52 |
+
device_map="auto",
|
| 53 |
+
quantization=None,
|
| 54 |
+
use_fast_kernels=True,
|
| 55 |
+
)
|
| 56 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 57 |
|
| 58 |
+
base_model.eval()
|
| 59 |
tokenizer.pad_token = tokenizer.eos_token
|
| 60 |
|
| 61 |
print("Base model loaded successfully")
|