plant-detector / flutter_integration_example.dart
NimurAI's picture
Upload flutter_integration_example.dart with huggingface_hub
5a3f083 verified
import 'dart:io';
import 'dart:typed_data';
import 'dart:ui' as ui;
import 'package:flutter/services.dart';
import 'package:flutter_pytorch_lite/flutter_pytorch_lite.dart';
class PlantAnomalyDetector {
Module? _module;
static const double _threshold = 0.5687; // Your threshold from training
// Normalization values from your training data
static const List<double> _mean = [0.4682, 0.4865, 0.3050];
static const List<double> _std = [0.2064, 0.1995, 0.1961];
/// Initialize the model from assets
Future<void> loadModel() async {
try {
// Load model from assets
final filePath = '${Directory.systemTemp.path}/plant_anomaly_detector.ptl';
final modelBytes = await _getBuffer('assets/models/plant_anomaly_detector.ptl');
File(filePath).writeAsBytesSync(modelBytes);
_module = await FlutterPytorchLite.load(filePath);
print('Model loaded successfully');
} catch (e) {
print('Error loading model: $e');
rethrow;
}
}
/// Get byte buffer from assets
static Future<Uint8List> _getBuffer(String assetFileName) async {
ByteData rawAssetFile = await rootBundle.load(assetFileName);
final rawBytes = rawAssetFile.buffer.asUint8List();
return rawBytes;
}
/// Normalize tensor values using training statistics
List<double> _normalize(List<double> input) {
List<double> normalized = [];
int channels = 3;
int pixelsPerChannel = input.length ~/ channels;
for (int c = 0; c < channels; c++) {
for (int i = 0; i < pixelsPerChannel; i++) {
int idx = c * pixelsPerChannel + i;
double normalizedValue = (input[idx] - _mean[c]) / _std[c];
normalized.add(normalizedValue);
}
}
return normalized;
}
/// Calculate reconstruction error (MSE) between original and reconstructed
double _calculateReconstructionError(List<double> original, List<double> reconstructed) {
if (original.length != reconstructed.length) {
throw ArgumentError('Original and reconstructed tensors must have same length');
}
double sumSquaredError = 0.0;
for (int i = 0; i < original.length; i++) {
double diff = original[i] - reconstructed[i];
sumSquaredError += diff * diff;
}
return sumSquaredError / original.length;
}
/// Detect if an image is a plant or anomaly
Future<PlantDetectionResult> detectPlant(ui.Image image) async {
if (_module == null) {
throw StateError('Model not loaded. Call loadModel() first.');
}
try {
// Convert image to tensor
final inputShape = Int64List.fromList([1, 3, 224, 224]);
Tensor inputTensor = await TensorImageUtils.imageToFloat32Tensor(
image,
width: 224,
height: 224,
);
// Get original normalized values for reconstruction error calculation
List<double> originalValues = inputTensor.dataAsFloat32List;
List<double> normalizedOriginal = _normalize(originalValues);
// Forward pass through the model
IValue input = IValue.from(inputTensor);
IValue output = await _module!.forward([input]);
// Get reconstruction
Tensor reconstructionTensor = output.toTensor();
List<double> reconstruction = reconstructionTensor.dataAsFloat32List;
// Calculate reconstruction error
double reconstructionError = _calculateReconstructionError(
normalizedOriginal,
reconstruction
);
// Determine if it's an anomaly
bool isAnomaly = reconstructionError > _threshold;
double confidence = (reconstructionError - _threshold).abs() / _threshold;
return PlantDetectionResult(
isPlant: !isAnomaly,
reconstructionError: reconstructionError,
threshold: _threshold,
confidence: confidence,
);
} catch (e) {
print('Error during inference: $e');
rethrow;
}
}
/// Dispose the model
Future<void> dispose() async {
if (_module != null) {
await _module!.destroy();
_module = null;
}
}
}
/// Result class for plant detection
class PlantDetectionResult {
final bool isPlant;
final double reconstructionError;
final double threshold;
final double confidence;
PlantDetectionResult({
required this.isPlant,
required this.reconstructionError,
required this.threshold,
required this.confidence,
});
@override
String toString() {
return 'PlantDetectionResult('
'isPlant: $isPlant, '
'reconstructionError: ${reconstructionError.toStringAsFixed(4)}, '
'threshold: ${threshold.toStringAsFixed(4)}, '
'confidence: ${(confidence * 100).toStringAsFixed(2)}%'
')';
}
}
/// Example usage in a Flutter widget
class PlantDetectionWidget extends StatefulWidget {
@override
_PlantDetectionWidgetState createState() => _PlantDetectionWidgetState();
}
class _PlantDetectionWidgetState extends State<PlantDetectionWidget> {
final PlantAnomalyDetector _detector = PlantAnomalyDetector();
bool _isModelLoaded = false;
@override
void initState() {
super.initState();
_loadModel();
}
Future<void> _loadModel() async {
try {
await _detector.loadModel();
setState(() {
_isModelLoaded = true;
});
} catch (e) {
print('Failed to load model: $e');
}
}
Future<void> _detectFromAsset(String assetPath) async {
if (!_isModelLoaded) return;
try {
// Load image from assets
const assetImage = AssetImage('assets/images/test_plant.jpg');
final image = await TensorImageUtils.imageProviderToImage(assetImage);
// Run detection
final result = await _detector.detectPlant(image);
// Show result
print('Detection result: $result');
// You can update UI here with the result
showDialog(
context: context,
builder: (context) => AlertDialog(
title: Text(result.isPlant ? 'Plant Detected' : 'Anomaly Detected'),
content: Text(
'Reconstruction Error: ${result.reconstructionError.toStringAsFixed(4)}\n'
'Confidence: ${(result.confidence * 100).toStringAsFixed(2)}%'
),
actions: [
TextButton(
onPressed: () => Navigator.pop(context),
child: Text('OK'),
),
],
),
);
} catch (e) {
print('Error during detection: $e');
}
}
@override
void dispose() {
_detector.dispose();
super.dispose();
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(title: Text('Plant Anomaly Detection')),
body: Center(
child: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
if (!_isModelLoaded)
CircularProgressIndicator()
else
ElevatedButton(
onPressed: () => _detectFromAsset('assets/images/test_plant.jpg'),
child: Text('Detect Plant'),
),
],
),
),
);
}
}