Chanlefe commited on
Commit
6554f18
·
verified ·
1 Parent(s): 3711151

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -21
app.py CHANGED
@@ -4,34 +4,27 @@ from transformers import AutoProcessor, AutoModelForImageClassification
4
  import gradio as gr
5
  import pytesseract
6
 
7
- def classify_meme(image: Image.Image):
8
-     # OCR: extract text from image
9
-     extracted_text = pytesseract.image_to_string(image)
10
-
11
-     # Process image with SigLIP2 model
12
-     inputs = processor(images=image, return_tensors="pt").to(model.device)
13
-     with torch.no_grad():
14
-         outputs = model(**inputs)
15
-         probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
16
-     predictions = {labels[i]: float(probs[0][i]) for i in range(len(labels))}
17
-
18
-     return {
19
-         "Predictions": predictions,
20
-         "Extracted Text": extracted_text.strip()
21
-     }
22
-
23
- # Load model and processor from Hugging Face
24
  model = AutoModelForImageClassification.from_pretrained("google/siglip2-base-patch16-naflex")
25
  processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")
26
  labels = model.config.id2label
27
 
 
28
  def classify_meme(image: Image.Image):
 
 
 
 
29
  inputs = processor(images=image, return_tensors="pt").to(model.device)
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
33
  predictions = {labels[i]: float(probs[0][i]) for i in range(len(labels))}
34
- return predictions
 
 
 
 
35
 
36
  # Gradio interface
37
  demo = gr.Interface(
@@ -45,7 +38,5 @@ demo = gr.Interface(
45
  description="Upload a meme to classify its sentiment and extract text using OCR."
46
  )
47
 
48
-
49
  if __name__ == "__main__":
50
- demo.launch(share = True)
51
-
 
4
  import gradio as gr
5
  import pytesseract
6
 
7
+ # Load model and processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  model = AutoModelForImageClassification.from_pretrained("google/siglip2-base-patch16-naflex")
9
  processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")
10
  labels = model.config.id2label
11
 
12
+ # Classify meme and extract text
13
  def classify_meme(image: Image.Image):
14
+ # OCR: extract text from image
15
+ extracted_text = pytesseract.image_to_string(image)
16
+
17
+ # Process image with SigLIP2 model
18
  inputs = processor(images=image, return_tensors="pt").to(model.device)
19
  with torch.no_grad():
20
  outputs = model(**inputs)
21
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
22
  predictions = {labels[i]: float(probs[0][i]) for i in range(len(labels))}
23
+
24
+ return {
25
+ "Predictions": predictions,
26
+ "Extracted Text": extracted_text.strip()
27
+ }
28
 
29
  # Gradio interface
30
  demo = gr.Interface(
 
38
  description="Upload a meme to classify its sentiment and extract text using OCR."
39
  )
40
 
 
41
  if __name__ == "__main__":
42
+ demo.launch()