saritha commited on
Commit
9159a27
·
verified ·
1 Parent(s): 4249d8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -46
app.py CHANGED
@@ -8,10 +8,11 @@ import os
8
  import contextlib
9
  from transformers import ViTForImageClassification, pipeline
10
 
11
- # Suppress warnings
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
13
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
14
 
 
15
  @contextlib.contextmanager
16
  def suppress_stdout():
17
  with open(os.devnull, 'w') as devnull:
@@ -22,12 +23,16 @@ def suppress_stdout():
22
  finally:
23
  sys.stdout = old_stdout
24
 
25
- # Load the saved model and suppress the warnings
26
  with suppress_stdout():
27
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
28
  model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
29
  model.eval()
30
 
 
 
 
 
31
  # Define the same transformation used during training
32
  transform = transforms.Compose([
33
  transforms.Resize((224, 224)),
@@ -35,12 +40,9 @@ transform = transforms.Compose([
35
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36
  ])
37
 
38
- # Load the class names
39
  class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
40
 
41
- # Load AI response generator
42
- ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
43
-
44
  # Knowledge base for sugarcane diseases
45
  knowledge_base = {
46
  'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
@@ -51,10 +53,31 @@ knowledge_base = {
51
  'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
52
  }
53
 
54
- # Update the predict_disease function to handle non-sugarcane images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def predict_disease(image):
 
 
 
 
 
 
 
56
  # Apply transformations to the image
57
- img_tensor = transform(image).unsqueeze(0) # Add batch dimension
58
 
59
  # Make prediction
60
  with torch.no_grad():
@@ -62,65 +85,37 @@ def predict_disease(image):
62
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
63
  max_prob, predicted_class = torch.max(probabilities, 1)
64
 
65
- # Confidence threshold for non-sugarcane detection
66
- confidence_threshold = 0.8 # Adjust based on experimentation
67
-
68
- # Check if the confidence is below the threshold
69
- if max_prob.item() < confidence_threshold:
70
- return """
71
- <div style='font-size: 18px; color: #FF5722; font-weight: bold;'>
72
- The uploaded image does not belong to the sugarcane dataset.
73
- </div>
74
- """
75
-
76
  # Get the predicted label
77
  predicted_label = class_names[predicted_class.item()]
78
 
79
  # Retrieve response from knowledge base
80
- if predicted_label in knowledge_base:
81
- detailed_response = knowledge_base[predicted_label]
82
- else:
83
- # Fallback to AI-generated response
84
- prompt = f"The detected sugarcane disease is '{predicted_label}'. Provide detailed advice for managing this condition."
85
- detailed_response = ai_pipeline(prompt, max_length=100, num_return_sequences=1, truncation=True)[0]['generated_text']
86
-
87
- # Create a styled HTML output
88
  output_message = f"""
89
  <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
90
  Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
91
  </div>
92
  """
93
-
94
- if predicted_label != "Healthy":
95
- output_message += f"""
96
- <p style='font-size: 16px; color: #757575;'>
97
- {detailed_response}
98
- </p>
99
- """
100
- else:
101
- output_message += f"""
102
- <p style='font-size: 16px; color: #757575;'>
103
- {detailed_response}
104
- </p>
105
- """
106
-
107
  return output_message
108
 
109
  # Create Gradio interface
110
  inputs = gr.Image(type="pil")
111
  outputs = gr.HTML()
112
 
113
- EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]
114
-
115
  demo_app = gr.Interface(
116
  fn=predict_disease,
117
  inputs=inputs,
118
  outputs=outputs,
119
  title="Sugarcane Disease Detection",
120
- examples=EXAMPLES,
121
- live=True,
122
- theme="huggingface"
123
  )
124
 
125
  demo_app.launch(debug=True)
126
 
 
 
8
  import contextlib
9
  from transformers import ViTForImageClassification, pipeline
10
 
11
+ # Suppress warnings related to the model weights initialization
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
13
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
14
 
15
+ # Suppress output for copying files and verbose model initialization messages
16
  @contextlib.contextmanager
17
  def suppress_stdout():
18
  with open(os.devnull, 'w') as devnull:
 
23
  finally:
24
  sys.stdout = old_stdout
25
 
26
+ # Load the sugarcane disease model
27
  with suppress_stdout():
28
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
29
  model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
30
  model.eval()
31
 
32
+ # Load a general-purpose classifier (e.g., MobileNetV2)
33
+ general_model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
34
+ general_model.eval()
35
+
36
  # Define the same transformation used during training
37
  transform = transforms.Compose([
38
  transforms.Resize((224, 224)),
 
40
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
41
  ])
42
 
43
+ # Load the class names (disease types)
44
  class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
45
 
 
 
 
46
  # Knowledge base for sugarcane diseases
47
  knowledge_base = {
48
  'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
 
53
  'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
54
  }
55
 
56
+ # Function to check if the image is plant-related
57
+ def is_plant_image(image):
58
+ general_transform = transforms.Compose([
59
+ transforms.Resize((224, 224)),
60
+ transforms.ToTensor(),
61
+ ])
62
+ img_tensor = general_transform(image).unsqueeze(0)
63
+ with torch.no_grad():
64
+ outputs = general_model(img_tensor)
65
+ _, predicted_class = torch.max(outputs, 1)
66
+ # Check if the predicted class corresponds to plant-like images
67
+ plant_related_classes = range(20, 25) # Replace with specific classes for plants
68
+ return predicted_class.item() in plant_related_classes
69
+
70
+ # Predict disease or detect non-sugarcane images
71
  def predict_disease(image):
72
+ if not is_plant_image(image):
73
+ return """
74
+ <div style='font-size: 18px; color: #FF5722; font-weight: bold;'>
75
+ The uploaded image is not related to sugarcane. Please upload a sugarcane image.
76
+ </div>
77
+ """
78
+
79
  # Apply transformations to the image
80
+ img_tensor = transform(image).unsqueeze(0)
81
 
82
  # Make prediction
83
  with torch.no_grad():
 
85
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
86
  max_prob, predicted_class = torch.max(probabilities, 1)
87
 
 
 
 
 
 
 
 
 
 
 
 
88
  # Get the predicted label
89
  predicted_label = class_names[predicted_class.item()]
90
 
91
  # Retrieve response from knowledge base
92
+ detailed_response = knowledge_base.get(predicted_label, "No additional information available.")
93
+
94
+ # Create styled HTML output
 
 
 
 
 
95
  output_message = f"""
96
  <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
97
  Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
98
  </div>
99
  """
100
+ output_message += f"""
101
+ <p style='font-size: 16px; color: #757575;'>
102
+ {detailed_response}
103
+ </p>
104
+ """
 
 
 
 
 
 
 
 
 
105
  return output_message
106
 
107
  # Create Gradio interface
108
  inputs = gr.Image(type="pil")
109
  outputs = gr.HTML()
110
 
 
 
111
  demo_app = gr.Interface(
112
  fn=predict_disease,
113
  inputs=inputs,
114
  outputs=outputs,
115
  title="Sugarcane Disease Detection",
116
+ live=True
 
 
117
  )
118
 
119
  demo_app.launch(debug=True)
120
 
121
+