JohanDL commited on
Commit
eaa86c1
·
1 Parent(s): 3c14524

Adding initial eval code

Browse files
Files changed (1) hide show
  1. app.py +3 -15
app.py CHANGED
@@ -53,6 +53,7 @@ def fused_sim(a:Image.Image,b:Image.Image,α=.5):
53
  bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
54
 
55
  # ---------- load models once at startup ---------------------
 
56
  @spaces.GPU
57
  def load_models():
58
  from unsloth import FastLanguageModel
@@ -85,21 +86,7 @@ def build_prompt(desc:str):
85
  @spaces.GPU
86
  @torch.no_grad()
87
  def draw(model, desc:str):
88
- # ensure_models()
89
- from unsloth import FastLanguageModel
90
- global base, tok, lora
91
- if base is None:
92
- print("Loading BASE …")
93
- base, tok = FastLanguageModel.from_pretrained(
94
- BASE_MODEL, max_seq_length=2048,
95
- load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
96
- tok.pad_token = tok.eos_token
97
-
98
- print("Loading LoRA …")
99
- lora, _ = FastLanguageModel.from_pretrained(
100
- ADAPTER_DIR, max_seq_length=2048,
101
- load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
102
- print("✔ models loaded")
103
  prompt = build_prompt(desc)
104
  ids = tok(prompt, return_tensors="pt").to(DEVICE)
105
  out = model.generate(**ids, max_new_tokens=MAX_NEW,
@@ -111,6 +98,7 @@ def draw(model, desc:str):
111
 
112
  # ---------- gradio interface --------------------------------
113
  def compare(desc):
 
114
  img_base, svg_base = draw(base, desc)
115
  img_lora, svg_lora = draw(lora, desc)
116
  # sim = (fused_sim(img_lora, img_base) if img_base and img_lora else float("nan"))
 
53
  bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
54
 
55
  # ---------- load models once at startup ---------------------
56
+ base = None
57
  @spaces.GPU
58
  def load_models():
59
  from unsloth import FastLanguageModel
 
86
  @spaces.GPU
87
  @torch.no_grad()
88
  def draw(model, desc:str):
89
+ ensure_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  prompt = build_prompt(desc)
91
  ids = tok(prompt, return_tensors="pt").to(DEVICE)
92
  out = model.generate(**ids, max_new_tokens=MAX_NEW,
 
98
 
99
  # ---------- gradio interface --------------------------------
100
  def compare(desc):
101
+ ensure_models()
102
  img_base, svg_base = draw(base, desc)
103
  img_lora, svg_lora = draw(lora, desc)
104
  # sim = (fused_sim(img_lora, img_base) if img_base and img_lora else float("nan"))