Update app.py
Browse files
app.py
CHANGED
@@ -68,6 +68,7 @@ MODEL_ACTUAL_PATHS = {
|
|
68 |
def clear_outputs_action():
|
69 |
return None, None
|
70 |
|
|
|
71 |
def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
|
72 |
global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
|
73 |
|
@@ -77,42 +78,42 @@ def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_st
|
|
77 |
CURRENT_MODEL_PATH = model_path_to_load
|
78 |
|
79 |
status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
|
80 |
-
try:
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
else:
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
if TOKENIZER.eos_token_id is not None:
|
98 |
-
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
99 |
-
TOKENIZER.pad_token = TOKENIZER.eos_token
|
100 |
-
status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
|
101 |
-
else:
|
102 |
-
status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
|
103 |
-
|
104 |
-
if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
|
105 |
-
status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
except Exception as e:
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
|
117 |
def handle_model_selection_change(selected_model_name_ui):
|
118 |
if "coming soon" in selected_model_name_ui.lower():
|
|
|
68 |
def clear_outputs_action():
|
69 |
return None, None
|
70 |
|
71 |
+
@spaces.GPU
|
72 |
def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
|
73 |
global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
|
74 |
|
|
|
78 |
CURRENT_MODEL_PATH = model_path_to_load
|
79 |
|
80 |
status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
|
81 |
+
# try:
|
82 |
+
TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
|
83 |
+
status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
|
84 |
|
85 |
+
MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
|
86 |
+
status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
|
87 |
|
88 |
+
uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
|
89 |
+
|
90 |
+
if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None:
|
91 |
+
MASK_ID = TOKENIZER.mask_token_id
|
92 |
+
status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.")
|
93 |
+
else:
|
94 |
+
MASK_ID = 126336
|
95 |
+
status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
|
96 |
+
|
97 |
+
if TOKENIZER.pad_token_id is None:
|
98 |
+
if TOKENIZER.eos_token_id is not None:
|
99 |
+
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
100 |
+
TOKENIZER.pad_token = TOKENIZER.eos_token
|
101 |
+
status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
|
102 |
else:
|
103 |
+
status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
|
104 |
+
|
105 |
+
if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
|
106 |
+
status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
|
109 |
+
|
110 |
+
return " ".join(status_msg_parts)
|
111 |
+
# except Exception as e:
|
112 |
+
# MODEL = None
|
113 |
+
# TOKENIZER = None
|
114 |
+
# MASK_ID = None
|
115 |
+
# CURRENT_MODEL_PATH = None
|
116 |
+
# return f"Error loading model '{model_display_name_for_status}': {str(e)}"
|
117 |
|
118 |
def handle_model_selection_change(selected_model_name_ui):
|
119 |
if "coming soon" in selected_model_name_ui.lower():
|