UnhurriedDawn commited on
Commit
12437ad
·
1 Parent(s): c1e0817
Files changed (1) hide show
  1. app.py +41 -26
app.py CHANGED
@@ -10,7 +10,7 @@ import time
10
  import os
11
  from typing import List, Dict, Optional, Tuple, Iterator, Set
12
  import gradio as gr
13
- import spaces # 新增:导入 spaces 模块
14
 
15
  # Suppress some Hugging Face warnings
16
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -24,6 +24,7 @@ from model_cache.llada.configuration_llada import LLaDAConfig
24
  def set_seed(seed):
25
  torch.manual_seed(seed); random.seed(seed); np.random.seed(seed);
26
  if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
 
27
  def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
28
  if dtype is None: dtype = torch.bfloat16
29
  attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
@@ -38,12 +39,14 @@ def create_full_block_attention_mask(prompt_length, max_length, block_size, devi
38
  attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
39
  attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
40
  return attention_mask
 
41
  def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
42
  end_pos = start_pos + input_length; total_length = cache_length + input_length
43
  extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype)
44
  extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
45
  extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
46
  return extracted_mask
 
47
  def top_p_logits(logits, top_p=None):
48
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
49
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
@@ -54,11 +57,13 @@ def top_p_logits(logits, top_p=None):
54
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
55
  logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
56
  return logits
 
57
  def top_k_logits(logits, top_k=None):
58
  top_k = min(top_k, logits.size(-1))
59
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
60
  logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
61
  return logits
 
62
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
63
  if temperature > 0: logits = logits / temperature
64
  if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p)
@@ -178,39 +183,46 @@ class DreamLoRAInference:
178
 
179
  def __init__(self, **kwargs):
180
  print("Initializing DreamLoRAInference...")
 
 
 
 
181
  self.device = torch.device(kwargs.get("device", "cuda") if torch.cuda.is_available() else "cpu")
182
- self.__dict__.update(kwargs)
183
- if self.dtype == "bfloat16" and torch.cuda.is_bf16_supported(): self.target_dtype = torch.bfloat16
184
- elif self.dtype == "float16": self.target_dtype = torch.float16
185
- else: self.target_dtype = torch.float32
186
- self._setup_model(self.pretrained_path, self.lora_path)
187
- print("Model and tokenizer setup complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def _setup_model(self, pretrained_path, lora_path):
190
- # --- MODIFICATION START ---
191
- # The arguments `trust_remote_code=True` have been removed as they are not needed here
192
- # and were causing warnings in the log.
193
  config = LLaDAConfig.from_pretrained(pretrained_path)
194
  self.model = LLaDAModelLM.from_pretrained(
195
  pretrained_path,
196
  config=config,
197
  torch_dtype=self.target_dtype,
198
- # device_map="auto" is handled by accelerate for better memory management on Spaces
199
  device_map="auto"
200
  ).eval()
201
-
202
- # THIS IS THE CRITICAL FIX: Tie the weights before loading the adapter.
203
- # This resolves the error message from the log and allows `device_map="auto"` to work correctly.
204
- # print("Tying model weights...")
205
- # self.model.tie_weights()
206
- # print("Weights tied.")
207
 
208
- # Now, load the PEFT adapter on top of the correctly configured base model
209
  self.model = PeftModel.from_pretrained(self.model, lora_path)
210
- # --- MODIFICATION END ---
211
-
212
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
213
- if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
 
214
 
215
  def _apply_chat_template(self, prompt):
216
  chat_history = [{"role": "user", "content": prompt}]
@@ -225,13 +237,11 @@ class DreamLoRAInference:
225
  if (next_block_id := block_id + 1) in block_states:
226
  block_states[next_block_id]['is_complete'] = True
227
 
228
- # The rest of your class methods (_render_visualization_html, _render_status_html, stream_and_capture_for_gradio)
229
- # remain completely unchanged.
230
  def _render_visualization_html(self, step: int, x_t: torch.Tensor, block_states: Dict, cache_length: int, updated_block_ids: Set[int]) -> str:
231
  timestamp = int(time.time() * 1000)
232
 
233
  html_parts = []
234
- for block_id in sorted(k for k in block_states.keys() if k > 0): # Only render generated part (block_id > 0)
235
  state = block_states[block_id]
236
  container_classes = ["block-container"]
237
  if block_id in updated_block_ids: container_classes.append("block-updating")
@@ -370,7 +380,7 @@ class DreamLoRAInference:
370
 
371
  return complete_html
372
 
373
- @spaces.GPU # ← 新增:关键修复 - 添加 GPU 装饰器
374
  @torch.inference_mode()
375
  def stream_and_capture_for_gradio(
376
  self,
@@ -382,6 +392,9 @@ class DreamLoRAInference:
382
  skip_threshold: float
383
  ) -> Iterator[Tuple[str, List[Tuple[str, str]], str, str, str]]:
384
 
 
 
 
385
  start_time = time.time()
386
  captured_frames: List[Tuple[str, str]] = []
387
 
@@ -396,7 +409,7 @@ class DreamLoRAInference:
396
 
397
  # Capture initial state
398
  initial_viz_html = self._render_visualization_html(0, x_t, block_states, 0, set())
399
- initial_status_html = self._render_status_html(0, x_t, block_states, 0)
400
  captured_frames.append((initial_viz_html, initial_status_html))
401
 
402
  yield "", captured_frames, "Initializing generation process...", "Initializing visualization...", "Initializing block status..."
@@ -507,6 +520,8 @@ if __name__ == "__main__":
507
  "sampling_strategy": "default",
508
  }
509
  set_seed(42)
 
 
510
  inference_engine = DreamLoRAInference(**config)
511
 
512
  def animate_visualization(html_frames_list: List[Tuple[str, str]], delay: float) -> Iterator[Tuple[str, str]]:
 
10
  import os
11
  from typing import List, Dict, Optional, Tuple, Iterator, Set
12
  import gradio as gr
13
+ import spaces # 导入 spaces 模块
14
 
15
  # Suppress some Hugging Face warnings
16
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
24
  def set_seed(seed):
25
  torch.manual_seed(seed); random.seed(seed); np.random.seed(seed);
26
  if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
27
+
28
  def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
29
  if dtype is None: dtype = torch.bfloat16
30
  attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
 
39
  attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
40
  attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
41
  return attention_mask
42
+
43
  def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
44
  end_pos = start_pos + input_length; total_length = cache_length + input_length
45
  extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype)
46
  extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
47
  extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
48
  return extracted_mask
49
+
50
  def top_p_logits(logits, top_p=None):
51
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
52
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
 
57
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
58
  logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
59
  return logits
60
+
61
  def top_k_logits(logits, top_k=None):
62
  top_k = min(top_k, logits.size(-1))
63
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
64
  logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
65
  return logits
66
+
67
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
68
  if temperature > 0: logits = logits / temperature
69
  if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p)
 
183
 
184
  def __init__(self, **kwargs):
185
  print("Initializing DreamLoRAInference...")
186
+ # 只保存配置,不立即加载模型
187
+ self.config = kwargs
188
+ self.model = None
189
+ self.tokenizer = None
190
  self.device = torch.device(kwargs.get("device", "cuda") if torch.cuda.is_available() else "cpu")
191
+
192
+ if kwargs.get("dtype") == "bfloat16" and torch.cuda.is_bf16_supported():
193
+ self.target_dtype = torch.bfloat16
194
+ elif kwargs.get("dtype") == "float16":
195
+ self.target_dtype = torch.float16
196
+ else:
197
+ self.target_dtype = torch.float32
198
+
199
+ # 从配置中获取其他属性
200
+ for key, value in kwargs.items():
201
+ if not hasattr(self, key):
202
+ setattr(self, key, value)
203
+
204
+ print("DreamLoRAInference initialized (model will be loaded on first use).")
205
+
206
+ def _ensure_model_loaded(self):
207
+ """延迟加载模型,只在真正需要时加载"""
208
+ if self.model is None:
209
+ print("Loading model for the first time...")
210
+ self._setup_model(self.config["pretrained_path"], self.config["lora_path"])
211
+ print("Model and tokenizer setup complete.")
212
 
213
  def _setup_model(self, pretrained_path, lora_path):
 
 
 
214
  config = LLaDAConfig.from_pretrained(pretrained_path)
215
  self.model = LLaDAModelLM.from_pretrained(
216
  pretrained_path,
217
  config=config,
218
  torch_dtype=self.target_dtype,
 
219
  device_map="auto"
220
  ).eval()
 
 
 
 
 
 
221
 
 
222
  self.model = PeftModel.from_pretrained(self.model, lora_path)
 
 
223
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
224
+ if self.tokenizer.pad_token is None:
225
+ self.tokenizer.pad_token = self.tokenizer.eos_token
226
 
227
  def _apply_chat_template(self, prompt):
228
  chat_history = [{"role": "user", "content": prompt}]
 
237
  if (next_block_id := block_id + 1) in block_states:
238
  block_states[next_block_id]['is_complete'] = True
239
 
 
 
240
  def _render_visualization_html(self, step: int, x_t: torch.Tensor, block_states: Dict, cache_length: int, updated_block_ids: Set[int]) -> str:
241
  timestamp = int(time.time() * 1000)
242
 
243
  html_parts = []
244
+ for block_id in sorted(k for k in block_states.keys() if k > 0):
245
  state = block_states[block_id]
246
  container_classes = ["block-container"]
247
  if block_id in updated_block_ids: container_classes.append("block-updating")
 
380
 
381
  return complete_html
382
 
383
+ @spaces.GPU # 关键修复:GPU 装饰器
384
  @torch.inference_mode()
385
  def stream_and_capture_for_gradio(
386
  self,
 
392
  skip_threshold: float
393
  ) -> Iterator[Tuple[str, List[Tuple[str, str]], str, str, str]]:
394
 
395
+ # 确保模型已加载
396
+ self._ensure_model_loaded()
397
+
398
  start_time = time.time()
399
  captured_frames: List[Tuple[str, str]] = []
400
 
 
409
 
410
  # Capture initial state
411
  initial_viz_html = self._render_visualization_html(0, x_t, block_states, 0, set())
412
+ initial_status_html = self._render_status_html(0, block_states, 0)
413
  captured_frames.append((initial_viz_html, initial_status_html))
414
 
415
  yield "", captured_frames, "Initializing generation process...", "Initializing visualization...", "Initializing block status..."
 
520
  "sampling_strategy": "default",
521
  }
522
  set_seed(42)
523
+
524
+ # 创建推理引擎但不立即加载模型
525
  inference_engine = DreamLoRAInference(**config)
526
 
527
  def animate_visualization(html_frames_list: List[Tuple[str, str]], delay: float) -> Iterator[Tuple[str, str]]: