mrfakename commited on
Commit
dd656b8
·
verified ·
1 Parent(s): 728f6b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -29
app.py CHANGED
@@ -1,32 +1,56 @@
1
-
2
- from sonique import get_pretrained_model
3
- from sonique.interface.gradio import create_ui
4
- import json
5
- from huggingface_hub import login
 
 
6
  import torch
7
- import os
 
 
8
 
9
- login(token=os.getenv('HF_TOKEN'))
10
 
11
- interface = create_ui(
12
- model_config_path = str(cached_path('https://raw.githubusercontent.com/zxxwxyyy/sonique/refs/heads/main/best_model.json')),
13
- ckpt_path=str(cached_path('hf://mrfakename/SONIQUE/stable_ep=220.ckpt')),
14
- # pretrained_name=args.pretrained_name,
15
- pretransform_ckpt_path=None
16
- )
17
- interface.queue().launch()
18
-
19
-
20
-
21
-
22
- if __name__ == "__main__":
23
- import argparse
24
- parser = argparse.ArgumentParser(description='Run gradio interface')
25
- parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
26
- parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
27
- parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
28
- parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
29
- parser.add_argument('--username', type=str, help='Gradio username', required=False)
30
- parser.add_argument('--password', type=str, help='Gradio password', required=False)
31
- args = parser.parse_args()
32
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import subprocess
3
+ subprocess.run(
4
+ "pip install flash-attn --no-build-isolation",
5
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
+ shell=True,
7
+ )
8
  import torch
9
+ import gradio as gr
10
+ from PIL import Image
11
+ from transformers import AutoModelForCausalLM, AutoProcessor
12
 
13
+ model_id_or_path = "rhymes-ai/Aria"
14
 
15
+ model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
16
+
17
+ processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True)
18
+
19
+ image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
20
+
21
+ image = Image.open(requests.get(image_path, stream=True).raw)
22
+
23
+ messages = [
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {"text": None, "type": "image"},
28
+ {"text": "what is the image?", "type": "text"},
29
+ ],
30
+ }
31
+ ]
32
+
33
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
34
+ inputs = processor(text=text, images=image, return_tensors="pt")
35
+ inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
36
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
37
+ @spaces.GPU
38
+ def run():
39
+ with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
40
+ output = model.generate(
41
+ **inputs,
42
+ max_new_tokens=500,
43
+ stop_strings=["<|im_end|>"],
44
+ tokenizer=processor.tokenizer,
45
+ do_sample=True,
46
+ temperature=0.9,
47
+ )
48
+ output_ids = output[0][inputs["input_ids"].shape[1]:]
49
+ result = processor.decode(output_ids, skip_special_tokens=True)
50
+
51
+ with gr.Blocks() as demo:
52
+ btn = gr.Button("Run")
53
+ out = gr.Markdown()
54
+ btn.click(run, outputs=out)
55
+
56
+ demo.queue().launch()