Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,10 +8,11 @@ import os
|
|
| 8 |
import contextlib
|
| 9 |
from transformers import ViTForImageClassification, pipeline
|
| 10 |
|
| 11 |
-
# Suppress warnings
|
| 12 |
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
|
| 13 |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
|
| 14 |
|
|
|
|
| 15 |
@contextlib.contextmanager
|
| 16 |
def suppress_stdout():
|
| 17 |
with open(os.devnull, 'w') as devnull:
|
|
@@ -35,13 +36,13 @@ transform = transforms.Compose([
|
|
| 35 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 36 |
])
|
| 37 |
|
| 38 |
-
# Load the class names
|
| 39 |
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
|
| 40 |
|
| 41 |
-
# Load AI response generator
|
| 42 |
ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
|
| 43 |
|
| 44 |
-
# Knowledge base for sugarcane diseases
|
| 45 |
knowledge_base = {
|
| 46 |
'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
|
| 47 |
'Mosaic': "Mosaic disease results in streaked and mottled leaves, reducing photosynthesis. Use disease-resistant varieties and control aphids to prevent spread.",
|
|
@@ -51,7 +52,7 @@ knowledge_base = {
|
|
| 51 |
'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
|
| 52 |
}
|
| 53 |
|
| 54 |
-
# Update the predict_disease function
|
| 55 |
def predict_disease(image):
|
| 56 |
# Apply transformations to the image
|
| 57 |
img_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
|
@@ -59,20 +60,8 @@ def predict_disease(image):
|
|
| 59 |
# Make prediction
|
| 60 |
with torch.no_grad():
|
| 61 |
outputs = model(img_tensor)
|
| 62 |
-
|
| 63 |
-
max_prob, predicted_class = torch.max(probabilities, 1)
|
| 64 |
|
| 65 |
-
# Confidence threshold for non-sugarcane detection
|
| 66 |
-
confidence_threshold = 0.6 # Adjust based on experimentation
|
| 67 |
-
|
| 68 |
-
# Check if the confidence is below the threshold
|
| 69 |
-
if max_prob.item() < confidence_threshold:
|
| 70 |
-
return """
|
| 71 |
-
<div style='font-size: 18px; color: #FF5722; font-weight: bold;'>
|
| 72 |
-
The uploaded image does not belong to the sugarcane dataset.
|
| 73 |
-
</div>
|
| 74 |
-
"""
|
| 75 |
-
|
| 76 |
# Get the predicted label
|
| 77 |
predicted_label = class_names[predicted_class.item()]
|
| 78 |
|
|
@@ -108,7 +97,7 @@ def predict_disease(image):
|
|
| 108 |
|
| 109 |
# Create Gradio interface
|
| 110 |
inputs = gr.Image(type="pil")
|
| 111 |
-
outputs = gr.HTML()
|
| 112 |
|
| 113 |
EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]
|
| 114 |
|
|
@@ -124,4 +113,3 @@ demo_app = gr.Interface(
|
|
| 124 |
|
| 125 |
demo_app.launch(debug=True)
|
| 126 |
|
| 127 |
-
|
|
|
|
| 8 |
import contextlib
|
| 9 |
from transformers import ViTForImageClassification, pipeline
|
| 10 |
|
| 11 |
+
# Suppress warnings related to the model weights initialization
|
| 12 |
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
|
| 13 |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
|
| 14 |
|
| 15 |
+
# Suppress output for copying files and verbose model initialization messages
|
| 16 |
@contextlib.contextmanager
|
| 17 |
def suppress_stdout():
|
| 18 |
with open(os.devnull, 'w') as devnull:
|
|
|
|
| 36 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 37 |
])
|
| 38 |
|
| 39 |
+
# Load the class names (disease types)
|
| 40 |
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
|
| 41 |
|
| 42 |
+
# Load AI response generator (using a local GPT pipeline or OpenAI's GPT-3/4 API)
|
| 43 |
ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
|
| 44 |
|
| 45 |
+
# Knowledge base for sugarcane diseases (example data from the website)
|
| 46 |
knowledge_base = {
|
| 47 |
'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
|
| 48 |
'Mosaic': "Mosaic disease results in streaked and mottled leaves, reducing photosynthesis. Use disease-resistant varieties and control aphids to prevent spread.",
|
|
|
|
| 52 |
'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
|
| 53 |
}
|
| 54 |
|
| 55 |
+
# Update the predict_disease function
|
| 56 |
def predict_disease(image):
|
| 57 |
# Apply transformations to the image
|
| 58 |
img_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
|
|
|
| 60 |
# Make prediction
|
| 61 |
with torch.no_grad():
|
| 62 |
outputs = model(img_tensor)
|
| 63 |
+
_, predicted_class = torch.max(outputs.logits, 1)
|
|
|
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# Get the predicted label
|
| 66 |
predicted_label = class_names[predicted_class.item()]
|
| 67 |
|
|
|
|
| 97 |
|
| 98 |
# Create Gradio interface
|
| 99 |
inputs = gr.Image(type="pil")
|
| 100 |
+
outputs = gr.HTML() # Use HTML output for styled text
|
| 101 |
|
| 102 |
EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]
|
| 103 |
|
|
|
|
| 113 |
|
| 114 |
demo_app.launch(debug=True)
|
| 115 |
|
|
|