jaqen79 commited on
Commit
227f288
·
verified ·
1 Parent(s): befcffa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -15
app.py CHANGED
@@ -1,31 +1,26 @@
1
  import gradio as gr
2
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
  import torch
4
  from PIL import Image
5
 
6
- # Carica feature extractor e modello dal tuo Hub
7
- MODEL_ID = "jaqen79/retail_images_classification_v1"
8
- extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
9
- model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
10
 
11
  def predict(image: Image.Image):
12
- # Preprocess
13
- inputs = extractor(images=image, return_tensors="pt")
14
  with torch.no_grad():
15
  outputs = model(**inputs)
16
- probs = outputs.logits.softmax(dim=-1).tolist()[0]
17
- # Assumi che model.config.id2label esista
18
  labels = [model.config.id2label[i] for i in range(len(probs))]
19
- # Ritorna dizionario label→probabilità
20
  return {labels[i]: float(probs[i]) for i in range(len(probs))}
21
 
22
- # Interfaccia Gradio
23
  demo = gr.Interface(
24
  fn=predict,
25
- inputs=gr.components.Image(type="pil"), # Changed to gr.components.Image
26
- outputs=gr.components.Label(num_top_classes=5), # Changed to gr.components.Label
27
- title="Vision Transformer Demo",
28
- description="Carica un'immagine e il modello ritorna le classi con probabilità."
29
  )
30
 
31
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
  import torch
4
  from PIL import Image
5
 
6
+ MODEL_ID = "jaqen79/retail_images_classification_v1"
7
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
8
+ model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
 
9
 
10
  def predict(image: Image.Image):
11
+ inputs = processor(images=image, return_tensors="pt")
 
12
  with torch.no_grad():
13
  outputs = model(**inputs)
14
+ probs = outputs.logits.softmax(dim=-1).tolist()[0]
 
15
  labels = [model.config.id2label[i] for i in range(len(probs))]
 
16
  return {labels[i]: float(probs[i]) for i in range(len(probs))}
17
 
 
18
  demo = gr.Interface(
19
  fn=predict,
20
+ inputs=gr.components.Image(type="pil"),
21
+ outputs=gr.components.Label(num_top_classes=5),
22
+ title="Retail Image classification using fine-tuned ViT",
23
+ description="Upload an image and the model returns the classes with probabilities."
24
  )
25
 
26
  if __name__ == "__main__":