huangrh9 commited on
Commit
7378375
·
verified ·
1 Parent(s): d59404b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -21,8 +21,7 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
21
 
22
  import gradio as gr
23
 
24
- from illume.conversation import default_conversation, conv_templates, SeparatorStyle
25
- # from conversation import default_conversation, conv_templates, SeparatorStyle
26
 
27
  # --- Global Variables and Model Loading ---
28
  model = None # Global variable to hold the loaded ILLUME model
@@ -936,10 +935,10 @@ if __name__ == "__main__":
936
  # prepare models and processors
937
  model = AutoModel.from_pretrained(
938
  args.model_name,
939
- # torch_dtype=torch.bfloat16,
940
- # attn_implementation='flash_attention_2', # OR 'sdpa' for Ascend NPUs
941
- torch_dtype=args.torch_dtype,
942
- attn_implementation='sdpa', # OR 'sdpa' for Ascend NPUs
943
  low_cpu_mem_usage=True,
944
  trust_remote_code=True).eval().cuda()
945
  processor = AutoProcessor.from_pretrained(args.model_name, trust_remote_code=True)
 
21
 
22
  import gradio as gr
23
 
24
+ from conversation import default_conversation, conv_templates, SeparatorStyle
 
25
 
26
  # --- Global Variables and Model Loading ---
27
  model = None # Global variable to hold the loaded ILLUME model
 
935
  # prepare models and processors
936
  model = AutoModel.from_pretrained(
937
  args.model_name,
938
+ torch_dtype=torch.bfloat16,
939
+ attn_implementation='flash_attention_2', # OR 'sdpa' for Ascend NPUs
940
+ # torch_dtype=args.torch_dtype,
941
+ # attn_implementation='sdpa', # OR 'sdpa' for Ascend NPUs
942
  low_cpu_mem_usage=True,
943
  trust_remote_code=True).eval().cuda()
944
  processor = AutoProcessor.from_pretrained(args.model_name, trust_remote_code=True)