tyfeld commited on
Commit
9746992
·
verified ·
1 Parent(s): 563513d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -46,14 +46,14 @@ def get_num_transfer_tokens(mask_index, steps):
46
  num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
47
  return num_transfer_tokens
48
 
49
- MODEL = None
50
- TOKENIZER = None
51
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
52
- MASK_ID = None
53
- uni_prompting = None
 
 
 
54
  VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
55
 
56
- DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-Base" # Default
57
  CURRENT_MODEL_PATH = None
58
 
59
  MODEL_CHOICES = [
 
46
  num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
47
  return num_transfer_tokens
48
 
 
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-Base" # Default
51
+ MASK_ID = 126336
52
+ MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
+ TOKENIZER = DEFAULT_MODEL_PATH.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
+ uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
  VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
 
 
57
  CURRENT_MODEL_PATH = None
58
 
59
  MODEL_CHOICES = [