import 'dart:io'; import 'dart:typed_data'; import 'dart:ui' as ui; import 'package:flutter/services.dart'; import 'package:flutter_pytorch_lite/flutter_pytorch_lite.dart'; class PlantAnomalyDetector { Module? _module; static const double _threshold = 0.5687; // Your threshold from training // Normalization values from your training data static const List _mean = [0.4682, 0.4865, 0.3050]; static const List _std = [0.2064, 0.1995, 0.1961]; /// Initialize the model from assets Future loadModel() async { try { // Load model from assets final filePath = '${Directory.systemTemp.path}/plant_anomaly_detector.ptl'; final modelBytes = await _getBuffer('assets/models/plant_anomaly_detector.ptl'); File(filePath).writeAsBytesSync(modelBytes); _module = await FlutterPytorchLite.load(filePath); print('Model loaded successfully'); } catch (e) { print('Error loading model: $e'); rethrow; } } /// Get byte buffer from assets static Future _getBuffer(String assetFileName) async { ByteData rawAssetFile = await rootBundle.load(assetFileName); final rawBytes = rawAssetFile.buffer.asUint8List(); return rawBytes; } /// Normalize tensor values using training statistics List _normalize(List input) { List normalized = []; int channels = 3; int pixelsPerChannel = input.length ~/ channels; for (int c = 0; c < channels; c++) { for (int i = 0; i < pixelsPerChannel; i++) { int idx = c * pixelsPerChannel + i; double normalizedValue = (input[idx] - _mean[c]) / _std[c]; normalized.add(normalizedValue); } } return normalized; } /// Calculate reconstruction error (MSE) between original and reconstructed double _calculateReconstructionError(List original, List reconstructed) { if (original.length != reconstructed.length) { throw ArgumentError('Original and reconstructed tensors must have same length'); } double sumSquaredError = 0.0; for (int i = 0; i < original.length; i++) { double diff = original[i] - reconstructed[i]; sumSquaredError += diff * diff; } return sumSquaredError / original.length; } /// Detect if an image is a plant or anomaly Future detectPlant(ui.Image image) async { if (_module == null) { throw StateError('Model not loaded. Call loadModel() first.'); } try { // Convert image to tensor final inputShape = Int64List.fromList([1, 3, 224, 224]); Tensor inputTensor = await TensorImageUtils.imageToFloat32Tensor( image, width: 224, height: 224, ); // Get original normalized values for reconstruction error calculation List originalValues = inputTensor.dataAsFloat32List; List normalizedOriginal = _normalize(originalValues); // Forward pass through the model IValue input = IValue.from(inputTensor); IValue output = await _module!.forward([input]); // Get reconstruction Tensor reconstructionTensor = output.toTensor(); List reconstruction = reconstructionTensor.dataAsFloat32List; // Calculate reconstruction error double reconstructionError = _calculateReconstructionError( normalizedOriginal, reconstruction ); // Determine if it's an anomaly bool isAnomaly = reconstructionError > _threshold; double confidence = (reconstructionError - _threshold).abs() / _threshold; return PlantDetectionResult( isPlant: !isAnomaly, reconstructionError: reconstructionError, threshold: _threshold, confidence: confidence, ); } catch (e) { print('Error during inference: $e'); rethrow; } } /// Dispose the model Future dispose() async { if (_module != null) { await _module!.destroy(); _module = null; } } } /// Result class for plant detection class PlantDetectionResult { final bool isPlant; final double reconstructionError; final double threshold; final double confidence; PlantDetectionResult({ required this.isPlant, required this.reconstructionError, required this.threshold, required this.confidence, }); @override String toString() { return 'PlantDetectionResult(' 'isPlant: $isPlant, ' 'reconstructionError: ${reconstructionError.toStringAsFixed(4)}, ' 'threshold: ${threshold.toStringAsFixed(4)}, ' 'confidence: ${(confidence * 100).toStringAsFixed(2)}%' ')'; } } /// Example usage in a Flutter widget class PlantDetectionWidget extends StatefulWidget { @override _PlantDetectionWidgetState createState() => _PlantDetectionWidgetState(); } class _PlantDetectionWidgetState extends State { final PlantAnomalyDetector _detector = PlantAnomalyDetector(); bool _isModelLoaded = false; @override void initState() { super.initState(); _loadModel(); } Future _loadModel() async { try { await _detector.loadModel(); setState(() { _isModelLoaded = true; }); } catch (e) { print('Failed to load model: $e'); } } Future _detectFromAsset(String assetPath) async { if (!_isModelLoaded) return; try { // Load image from assets const assetImage = AssetImage('assets/images/test_plant.jpg'); final image = await TensorImageUtils.imageProviderToImage(assetImage); // Run detection final result = await _detector.detectPlant(image); // Show result print('Detection result: $result'); // You can update UI here with the result showDialog( context: context, builder: (context) => AlertDialog( title: Text(result.isPlant ? 'Plant Detected' : 'Anomaly Detected'), content: Text( 'Reconstruction Error: ${result.reconstructionError.toStringAsFixed(4)}\n' 'Confidence: ${(result.confidence * 100).toStringAsFixed(2)}%' ), actions: [ TextButton( onPressed: () => Navigator.pop(context), child: Text('OK'), ), ], ), ); } catch (e) { print('Error during detection: $e'); } } @override void dispose() { _detector.dispose(); super.dispose(); } @override Widget build(BuildContext context) { return Scaffold( appBar: AppBar(title: Text('Plant Anomaly Detection')), body: Center( child: Column( mainAxisAlignment: MainAxisAlignment.center, children: [ if (!_isModelLoaded) CircularProgressIndicator() else ElevatedButton( onPressed: () => _detectFromAsset('assets/images/test_plant.jpg'), child: Text('Detect Plant'), ), ], ), ), ); } }