Ayamohamed commited on
Commit
ceb2520
·
verified ·
1 Parent(s): dda41bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -9,28 +9,21 @@ import gradio as gr
9
  model_path = hf_hub_download(repo_id="Ayamohamed/DiaClassification", filename="dia_none_classifier_full.pth")
10
 
11
  # Load model
12
- model_hg = torch.load(model_path)
13
- model_hg.eval()
14
 
15
  transform = transforms.Compose([
16
  transforms.Resize((224, 224)),
17
  transforms.ToTensor(),
18
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19
  ])
20
- def predict(image_path):
21
-
22
- try:
23
- image = Image.open(image_path).convert("RGB")
24
- image = transform(image).unsqueeze(0)
25
- with torch.no_grad():
26
- output = model_hg(image)
27
- print("Model output:", output)
28
- class_idx = torch.argmax(output, dim=1).item()
29
-
30
  return "Diagram" if class_idx == 0 else "Not Diagram"
31
- except Exception as e:
32
- print("Error during prediction:", str(e))
33
- return f"Prediction Error: {str(e)}"
34
 
35
  gr.Interface(
36
  fn=predict,
 
9
  model_path = hf_hub_download(repo_id="Ayamohamed/DiaClassification", filename="dia_none_classifier_full.pth")
10
 
11
  # Load model
12
+ model = torch.load(model_path)
13
+ model.eval()
14
 
15
  transform = transforms.Compose([
16
  transforms.Resize((224, 224)),
17
  transforms.ToTensor(),
18
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19
  ])
20
+ def predict(image):
21
+ image = transform(image).unsqueeze(0)
22
+ with torch.no_grad():
23
+ output = local_model(image)
24
+ probabilities = F.softmax(output, dim=1)
25
+ class_idx = torch.argmax(probabilities, dim=1).item()
 
 
 
 
26
  return "Diagram" if class_idx == 0 else "Not Diagram"
 
 
 
27
 
28
  gr.Interface(
29
  fn=predict,