Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # svg_compare_gradio.py | |
| # ------------------------------------------------------------ | |
| import re, os, torch, cairosvg, lpips, clip, gradio as gr | |
| from io import BytesIO | |
| from pathlib import Path | |
| from PIL import Image | |
| from unsloth import FastLanguageModel | |
| from transformers import BitsAndBytesConfig, AutoTokenizer | |
| import gradio as gr | |
| import spaces | |
| # ---------- paths YOU may want to edit ---------------------- | |
| ADAPTER_DIR = "unsloth_trained_weights/checkpoint-1700" # LoRA ckpt | |
| BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| MAX_NEW = 512 | |
| DEVICE = "cuda" # if torch.cuda.is_available() else "cpu" | |
| # ---------- utils ------------------------------------------- | |
| SVG_PAT = re.compile(r"<svg[^>]*>.*?</svg>", re.S | re.I) | |
| def extract_svg(txt:str): | |
| m = list(SVG_PAT.finditer(txt)) | |
| return m[-1].group(0) if m else None # last match β | |
| def svg2pil(svg:str): | |
| try: | |
| png = cairosvg.svg2png(bytestring=svg.encode()) | |
| return Image.open(BytesIO(png)).convert("RGB") | |
| except Exception: | |
| return None | |
| # ---------- backbone loaders (CLIP + LPIPS) ----------------- | |
| _CLIP,_PREP,_LP=None,None,None | |
| def _load_backbones(): | |
| global _CLIP,_PREP,_LP | |
| if _CLIP is None: | |
| _CLIP,_PREP = clip.load("ViT-L/14", device=DEVICE); _CLIP.eval() | |
| if _LP is None: | |
| _LP = lpips.LPIPS(net="vgg").to(DEVICE).eval() | |
| def fused_sim(a:Image.Image,b:Image.Image,Ξ±=.5): | |
| _load_backbones() | |
| ta,tb = _PREP(a).unsqueeze(0).to(DEVICE), _PREP(b).unsqueeze(0).to(DEVICE) | |
| fa = _CLIP.encode_image(ta); fa/=fa.norm(dim=-1,keepdim=True) | |
| fb = _CLIP.encode_image(tb); fb/=fb.norm(dim=-1,keepdim=True) | |
| clip_sim=(([email protected]).item()+1)/2 | |
| lp_sim = 1 - _LP(ta,tb,normalize=True).item() | |
| return Ξ±*clip_sim + (1-Ξ±)*lp_sim | |
| # ---------- load models once at startup --------------------- | |
| bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True) | |
| print("Loading BASE β¦") | |
| base, tok = FastLanguageModel.from_pretrained( | |
| BASE_MODEL, max_seq_length=2048, | |
| load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto") | |
| tok.pad_token = tok.eos_token | |
| print("Loading LoRA β¦") | |
| lora, _ = FastLanguageModel.from_pretrained( | |
| ADAPTER_DIR, max_seq_length=2048, | |
| load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto") | |
| def build_prompt(desc:str): | |
| msgs=[{"role":"system","content":"You are an SVG illustrator."}, | |
| {"role":"user", | |
| "content":f"ONLY reply with a valid, complete <svg>β¦</svg> file that depicts: {desc}"}] | |
| return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| def draw(model, desc:str): | |
| prompt = build_prompt(desc) | |
| ids = tok(prompt, return_tensors="pt").to(DEVICE) | |
| out = model.generate(**ids, max_new_tokens=MAX_NEW, | |
| do_sample=True, temperature=.7, top_p=.8) | |
| txt = tok.decode(out[0], skip_special_tokens=True) | |
| svg = extract_svg(txt) | |
| img = svg2pil(svg) if svg else None | |
| return img, svg or "(no SVG found)" | |
| # ---------- gradio interface -------------------------------- | |
| def compare(desc): | |
| img_base, svg_base = draw(base, desc) | |
| img_lora, svg_lora = draw(lora, desc) | |
| # sim = (fused_sim(img_lora, img_base) if img_base and img_lora else float("nan")) | |
| caption = "Thanks for trying our model π" | |
| return img_base, img_lora, caption, svg_base, svg_lora | |
| with gr.Blocks(css="body{background:#111;color:#eee}") as demo: | |
| gr.Markdown("## ποΈ Qwen-2.5 SVG Generator β base vs GRPO-LoRA") | |
| gr.Markdown( | |
| "Type an image **description** (e.g. *a purple forest at dusk*). " | |
| "Click **Generate** to see what the base model and your fine-tuned LoRA produce." | |
| ) | |
| inp = gr.Textbox(label="Description", placeholder="a purple forest at dusk") | |
| btn = gr.Button("Generate") | |
| with gr.Row(): | |
| out_base = gr.Image(label="Base model", type="pil") | |
| out_lora = gr.Image(label="LoRA-tuned model", type="pil") | |
| sim_lbl = gr.Markdown() | |
| with gr.Accordion("βοΈ Raw SVG code", open=False): | |
| svg_base_box = gr.Textbox(label="Base SVG", lines=6) | |
| svg_lora_box = gr.Textbox(label="LoRA SVG", lines=6) | |
| btn.click(compare, inp, [out_base, out_lora, sim_lbl, svg_base_box, svg_lora_box]) | |
| demo.launch() | |