blumenstiel commited on
Commit
ef153ea
·
1 Parent(s): 9aec2d7

Init models only once

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -1,8 +1,8 @@
1
- import base64
2
- import os.path
3
- from io import BytesIO
4
- from pathlib import Path
5
 
 
 
 
 
6
  import spaces
7
  import glob
8
  import numpy as np
@@ -10,19 +10,22 @@ import gradio as gr
10
  import rasterio as rio
11
  import matplotlib.pyplot as plt
12
  import matplotlib as mpl
 
 
13
  from PIL import Image
14
  from matplotlib import rcParams
15
  from msclip.inference import run_inference_classification
 
16
 
17
  rcParams["font.size"] = 9
18
  rcParams["axes.titlesize"] = 9
19
  IMG_PX = 300
20
 
21
- import sys
22
- import csv
23
-
24
  csv.field_size_limit(sys.maxsize)
25
 
 
 
 
26
  EXAMPLES = {
27
  "EuroSAT": {
28
  "images": glob.glob("examples/eurosat/*.tif"),
@@ -163,7 +166,14 @@ def classify(images, class_text):
163
  class_names = [c.strip() for c in class_text.split(",") if c.strip()]
164
  cards = []
165
 
166
- df = run_inference_classification(image_path=images, class_names=class_names, verbose=False)
 
 
 
 
 
 
 
167
  for img_path, (id, row) in zip(images, df.iterrows()):
168
  scores = row[2:].astype(float) # drop filename column
169
  top = scores.sort_values(ascending=False)[:3]
 
 
 
 
 
1
 
2
+ import base64
3
+ import os
4
+ import sys
5
+ import csv
6
  import spaces
7
  import glob
8
  import numpy as np
 
10
  import rasterio as rio
11
  import matplotlib.pyplot as plt
12
  import matplotlib as mpl
13
+ from io import BytesIO
14
+ from pathlib import Path
15
  from PIL import Image
16
  from matplotlib import rcParams
17
  from msclip.inference import run_inference_classification
18
+ from msclip.inference.utils import build_model
19
 
20
  rcParams["font.size"] = 9
21
  rcParams["axes.titlesize"] = 9
22
  IMG_PX = 300
23
 
 
 
 
24
  csv.field_size_limit(sys.maxsize)
25
 
26
+ # Init Llama3-MS-CLIP from Hugging Face
27
+ model, preprocess, tokenizer = build_model()
28
+
29
  EXAMPLES = {
30
  "EuroSAT": {
31
  "images": glob.glob("examples/eurosat/*.tif"),
 
166
  class_names = [c.strip() for c in class_text.split(",") if c.strip()]
167
  cards = []
168
 
169
+ df = run_inference_classification(
170
+ model=model,
171
+ preprocess=preprocess,
172
+ tokenizer=tokenizer,
173
+ image_path=images,
174
+ class_names=class_names,
175
+ verbose=False
176
+ )
177
  for img_path, (id, row) in zip(images, df.iterrows()):
178
  scores = row[2:].astype(float) # drop filename column
179
  top = scores.sort_values(ascending=False)[:3]