recognitions = new ArrayList<>();
+ int recognitionsSize = min(pq.size(), MAX_RESULTS);
+ for (int i = 0; i < recognitionsSize; ++i) {
+ recognitions.add(pq.poll());
+ }
+ return recognitions;
+ }
+
+ /** Gets the name of the model file stored in Assets. */
+ protected abstract String getModelPath();
+
+ /** Gets the name of the label file stored in Assets. */
+ protected abstract String getLabelPath();
+
+ /** Gets the TensorOperator to nomalize the input image in preprocessing. */
+ protected abstract TensorOperator getPreprocessNormalizeOp();
+
+ /**
+ * Gets the TensorOperator to dequantize the output probability in post processing.
+ *
+ * For quantized model, we need de-quantize the prediction with NormalizeOp (as they are all
+ * essentially linear transformation). For float model, de-quantize is not required. But to
+ * uniform the API, de-quantize is added to float model too. Mean and std are set to 0.0f and
+ * 1.0f, respectively.
+ */
+ protected abstract TensorOperator getPostprocessNormalizeOp();
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java
new file mode 100644
index 0000000000000000000000000000000000000000..14dd027b26baefaedd979a8ac37f0bf984210ed4
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java
@@ -0,0 +1,71 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import android.app.Activity;
+import java.io.IOException;
+import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.common.ops.NormalizeOp;
+
+/** This TensorFlowLite classifier works with the float EfficientNet model. */
+public class ClassifierFloatEfficientNet extends Classifier {
+
+ private static final float IMAGE_MEAN = 115.0f; //127.0f;
+ private static final float IMAGE_STD = 58.0f; //128.0f;
+
+ /**
+ * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f
+ * and 1.0f, repectively, to bypass the normalization.
+ */
+ private static final float PROBABILITY_MEAN = 0.0f;
+
+ private static final float PROBABILITY_STD = 1.0f;
+
+ /**
+ * Initializes a {@code ClassifierFloatMobileNet}.
+ *
+ * @param activity
+ */
+ public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads)
+ throws IOException {
+ super(activity, device, numThreads);
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // see build.gradle for where to obtain this file. It should be auto
+ // downloaded into assets.
+ //return "efficientnet-lite0-fp32.tflite";
+ return "model_opt.tflite";
+ }
+
+ @Override
+ protected String getLabelPath() {
+ return "labels_without_background.txt";
+ }
+
+ @Override
+ protected TensorOperator getPreprocessNormalizeOp() {
+ return new NormalizeOp(IMAGE_MEAN, IMAGE_STD);
+ }
+
+ @Override
+ protected TensorOperator getPostprocessNormalizeOp() {
+ return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD);
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java
new file mode 100644
index 0000000000000000000000000000000000000000..40519de07cf5e887773250a4609a832b6060d684
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java
@@ -0,0 +1,72 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import android.app.Activity;
+import java.io.IOException;
+import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.common.ops.NormalizeOp;
+
+/** This TensorFlowLite classifier works with the float MobileNet model. */
+public class ClassifierFloatMobileNet extends Classifier {
+
+ /** Float MobileNet requires additional normalization of the used input. */
+ private static final float IMAGE_MEAN = 127.5f;
+
+ private static final float IMAGE_STD = 127.5f;
+
+ /**
+ * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f
+ * and 1.0f, repectively, to bypass the normalization.
+ */
+ private static final float PROBABILITY_MEAN = 0.0f;
+
+ private static final float PROBABILITY_STD = 1.0f;
+
+ /**
+ * Initializes a {@code ClassifierFloatMobileNet}.
+ *
+ * @param activity
+ */
+ public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads)
+ throws IOException {
+ super(activity, device, numThreads);
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // see build.gradle for where to obtain this file. It should be auto
+ // downloaded into assets.
+ return "model_0.tflite";
+ }
+
+ @Override
+ protected String getLabelPath() {
+ return "labels.txt";
+ }
+
+ @Override
+ protected TensorOperator getPreprocessNormalizeOp() {
+ return new NormalizeOp(IMAGE_MEAN, IMAGE_STD);
+ }
+
+ @Override
+ protected TensorOperator getPostprocessNormalizeOp() {
+ return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD);
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java
new file mode 100644
index 0000000000000000000000000000000000000000..d0d62f58d18333b6360ec30a4c85c9f1d38955ce
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import android.app.Activity;
+import java.io.IOException;
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.common.ops.NormalizeOp;
+
+/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */
+public class ClassifierQuantizedEfficientNet extends Classifier {
+
+ /**
+ * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to
+ * bypass the normalization.
+ */
+ private static final float IMAGE_MEAN = 0.0f;
+
+ private static final float IMAGE_STD = 1.0f;
+
+ /** Quantized MobileNet requires additional dequantization to the output probability. */
+ private static final float PROBABILITY_MEAN = 0.0f;
+
+ private static final float PROBABILITY_STD = 255.0f;
+
+ /**
+ * Initializes a {@code ClassifierQuantizedMobileNet}.
+ *
+ * @param activity
+ */
+ public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads)
+ throws IOException {
+ super(activity, device, numThreads);
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // see build.gradle for where to obtain this file. It should be auto
+ // downloaded into assets.
+ return "model_quant.tflite";
+ }
+
+ @Override
+ protected String getLabelPath() {
+ return "labels_without_background.txt";
+ }
+
+ @Override
+ protected TensorOperator getPreprocessNormalizeOp() {
+ return new NormalizeOp(IMAGE_MEAN, IMAGE_STD);
+ }
+
+ @Override
+ protected TensorOperator getPostprocessNormalizeOp() {
+ return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD);
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java
new file mode 100644
index 0000000000000000000000000000000000000000..94b06e3df659005c287733a8a37672863fdadd71
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import android.app.Activity;
+import java.io.IOException;
+import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.common.ops.NormalizeOp;
+
+/** This TensorFlow Lite classifier works with the quantized MobileNet model. */
+public class ClassifierQuantizedMobileNet extends Classifier {
+
+ /**
+ * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to
+ * bypass the normalization.
+ */
+ private static final float IMAGE_MEAN = 0.0f;
+
+ private static final float IMAGE_STD = 1.0f;
+
+ /** Quantized MobileNet requires additional dequantization to the output probability. */
+ private static final float PROBABILITY_MEAN = 0.0f;
+
+ private static final float PROBABILITY_STD = 255.0f;
+
+ /**
+ * Initializes a {@code ClassifierQuantizedMobileNet}.
+ *
+ * @param activity
+ */
+ public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads)
+ throws IOException {
+ super(activity, device, numThreads);
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // see build.gradle for where to obtain this file. It should be auto
+ // downloaded into assets.
+ return "model_quant_0.tflite";
+ }
+
+ @Override
+ protected String getLabelPath() {
+ return "labels.txt";
+ }
+
+ @Override
+ protected TensorOperator getPreprocessNormalizeOp() {
+ return new NormalizeOp(IMAGE_MEAN, IMAGE_STD);
+ }
+
+ @Override
+ protected TensorOperator getPostprocessNormalizeOp() {
+ return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD);
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle
new file mode 100644
index 0000000000000000000000000000000000000000..b5983986e3d56a77a41676b9195b0d0882b5fb96
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle
@@ -0,0 +1,47 @@
+apply plugin: 'com.android.library'
+
+android {
+ compileSdkVersion 28
+ buildToolsVersion "28.0.0"
+
+ defaultConfig {
+ minSdkVersion 21
+ targetSdkVersion 28
+ versionCode 1
+ versionName "1.0"
+
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
+
+ }
+
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
+ }
+ }
+ compileOptions {
+ sourceCompatibility = '1.8'
+ targetCompatibility = '1.8'
+ }
+ aaptOptions {
+ noCompress "tflite"
+ }
+
+ lintOptions {
+ checkReleaseBuilds false
+ // Or, if you prefer, you can continue to check for errors in release builds,
+ // but continue the build even when errors are found:
+ abortOnError false
+ }
+}
+
+dependencies {
+ implementation fileTree(dir: 'libs', include: ['*.jar'])
+ implementation project(":models")
+ implementation 'androidx.appcompat:appcompat:1.1.0'
+
+ // Build off of nightly TensorFlow Lite Task Library
+ implementation('org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly') { changing = true }
+ implementation('org.tensorflow:tensorflow-lite-metadata:0.0.0-nightly') { changing = true }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro
new file mode 100644
index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro
@@ -0,0 +1,21 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml
@@ -0,0 +1,3 @@
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java
new file mode 100644
index 0000000000000000000000000000000000000000..45da52a0d0dfa203255e0f2d44901ee0618e739f
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java
@@ -0,0 +1,278 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import static java.lang.Math.min;
+
+import android.app.Activity;
+import android.graphics.Bitmap;
+import android.graphics.Rect;
+import android.graphics.RectF;
+import android.os.SystemClock;
+import android.os.Trace;
+import android.util.Log;
+import java.io.IOException;
+import java.nio.MappedByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
+import org.tensorflow.lite.support.common.FileUtil;
+import org.tensorflow.lite.support.image.TensorImage;
+import org.tensorflow.lite.support.label.Category;
+import org.tensorflow.lite.support.metadata.MetadataExtractor;
+import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
+import org.tensorflow.lite.task.core.vision.ImageProcessingOptions.Orientation;
+import org.tensorflow.lite.task.vision.classifier.Classifications;
+import org.tensorflow.lite.task.vision.classifier.ImageClassifier;
+import org.tensorflow.lite.task.vision.classifier.ImageClassifier.ImageClassifierOptions;
+
+/** A classifier specialized to label images using TensorFlow Lite. */
+public abstract class Classifier {
+ public static final String TAG = "ClassifierWithTaskApi";
+
+ /** The model type used for classification. */
+ public enum Model {
+ FLOAT_MOBILENET,
+ QUANTIZED_MOBILENET,
+ FLOAT_EFFICIENTNET,
+ QUANTIZED_EFFICIENTNET
+ }
+
+ /** The runtime device type used for executing classification. */
+ public enum Device {
+ CPU,
+ NNAPI,
+ GPU
+ }
+
+ /** Number of results to show in the UI. */
+ private static final int MAX_RESULTS = 3;
+
+ /** Image size along the x axis. */
+ private final int imageSizeX;
+
+ /** Image size along the y axis. */
+ private final int imageSizeY;
+ /** An instance of the driver class to run model inference with Tensorflow Lite. */
+ protected final ImageClassifier imageClassifier;
+
+ /**
+ * Creates a classifier with the provided configuration.
+ *
+ * @param activity The current Activity.
+ * @param model The model to use for classification.
+ * @param device The device to use for classification.
+ * @param numThreads The number of threads to use for classification.
+ * @return A classifier with the desired configuration.
+ */
+ public static Classifier create(Activity activity, Model model, Device device, int numThreads)
+ throws IOException {
+ if (model == Model.QUANTIZED_MOBILENET) {
+ return new ClassifierQuantizedMobileNet(activity, device, numThreads);
+ } else if (model == Model.FLOAT_MOBILENET) {
+ return new ClassifierFloatMobileNet(activity, device, numThreads);
+ } else if (model == Model.FLOAT_EFFICIENTNET) {
+ return new ClassifierFloatEfficientNet(activity, device, numThreads);
+ } else if (model == Model.QUANTIZED_EFFICIENTNET) {
+ return new ClassifierQuantizedEfficientNet(activity, device, numThreads);
+ } else {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ /** An immutable result returned by a Classifier describing what was recognized. */
+ public static class Recognition {
+ /**
+ * A unique identifier for what has been recognized. Specific to the class, not the instance of
+ * the object.
+ */
+ private final String id;
+
+ /** Display name for the recognition. */
+ private final String title;
+
+ /**
+ * A sortable score for how good the recognition is relative to others. Higher should be better.
+ */
+ private final Float confidence;
+
+ /** Optional location within the source image for the location of the recognized object. */
+ private RectF location;
+
+ public Recognition(
+ final String id, final String title, final Float confidence, final RectF location) {
+ this.id = id;
+ this.title = title;
+ this.confidence = confidence;
+ this.location = location;
+ }
+
+ public String getId() {
+ return id;
+ }
+
+ public String getTitle() {
+ return title;
+ }
+
+ public Float getConfidence() {
+ return confidence;
+ }
+
+ public RectF getLocation() {
+ return new RectF(location);
+ }
+
+ public void setLocation(RectF location) {
+ this.location = location;
+ }
+
+ @Override
+ public String toString() {
+ String resultString = "";
+ if (id != null) {
+ resultString += "[" + id + "] ";
+ }
+
+ if (title != null) {
+ resultString += title + " ";
+ }
+
+ if (confidence != null) {
+ resultString += String.format("(%.1f%%) ", confidence * 100.0f);
+ }
+
+ if (location != null) {
+ resultString += location + " ";
+ }
+
+ return resultString.trim();
+ }
+ }
+
+ /** Initializes a {@code Classifier}. */
+ protected Classifier(Activity activity, Device device, int numThreads) throws IOException {
+ if (device != Device.CPU || numThreads != 1) {
+ throw new IllegalArgumentException(
+ "Manipulating the hardware accelerators and numbers of threads is not allowed in the Task"
+ + " library currently. Only CPU + single thread is allowed.");
+ }
+
+ // Create the ImageClassifier instance.
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build();
+ imageClassifier = ImageClassifier.createFromFileAndOptions(activity, getModelPath(), options);
+ Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
+
+ // Get the input image size information of the underlying tflite model.
+ MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath());
+ MetadataExtractor metadataExtractor = new MetadataExtractor(tfliteModel);
+ // Image shape is in the format of {1, height, width, 3}.
+ int[] imageShape = metadataExtractor.getInputTensorShape(/*inputIndex=*/ 0);
+ imageSizeY = imageShape[1];
+ imageSizeX = imageShape[2];
+ }
+
+ /** Runs inference and returns the classification results. */
+ public List recognizeImage(final Bitmap bitmap, int sensorOrientation) {
+ // Logs this method so that it can be analyzed with systrace.
+ Trace.beginSection("recognizeImage");
+
+ TensorImage inputImage = TensorImage.fromBitmap(bitmap);
+ int width = bitmap.getWidth();
+ int height = bitmap.getHeight();
+ int cropSize = min(width, height);
+ // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy.
+ // Task Library resize the images using bilinear interpolation, which is slightly different from
+ // the nearest neighbor sampling algorithm used in lib_support. See
+ // https://github.com/tensorflow/examples/blob/0ef3d93e2af95d325c70ef3bcbbd6844d0631e07/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java#L310.
+ ImageProcessingOptions imageOptions =
+ ImageProcessingOptions.builder()
+ .setOrientation(getOrientation(sensorOrientation))
+ // Set the ROI to the center of the image.
+ .setRoi(
+ new Rect(
+ /*left=*/ (width - cropSize) / 2,
+ /*top=*/ (height - cropSize) / 2,
+ /*right=*/ (width + cropSize) / 2,
+ /*bottom=*/ (height + cropSize) / 2))
+ .build();
+
+ // Runs the inference call.
+ Trace.beginSection("runInference");
+ long startTimeForReference = SystemClock.uptimeMillis();
+ List results = imageClassifier.classify(inputImage, imageOptions);
+ long endTimeForReference = SystemClock.uptimeMillis();
+ Trace.endSection();
+ Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference));
+
+ Trace.endSection();
+
+ return getRecognitions(results);
+ }
+
+ /** Closes the interpreter and model to release resources. */
+ public void close() {
+ if (imageClassifier != null) {
+ imageClassifier.close();
+ }
+ }
+
+ /** Get the image size along the x axis. */
+ public int getImageSizeX() {
+ return imageSizeX;
+ }
+
+ /** Get the image size along the y axis. */
+ public int getImageSizeY() {
+ return imageSizeY;
+ }
+
+ /**
+ * Converts a list of {@link Classifications} objects into a list of {@link Recognition} objects
+ * to match the interface of other inference method, such as using the TFLite
+ * Support Library. .
+ */
+ private static List getRecognitions(List classifications) {
+
+ final ArrayList recognitions = new ArrayList<>();
+ // All the demo models are single head models. Get the first Classifications in the results.
+ for (Category category : classifications.get(0).getCategories()) {
+ recognitions.add(
+ new Recognition(
+ "" + category.getLabel(), category.getLabel(), category.getScore(), null));
+ }
+ return recognitions;
+ }
+
+ /* Convert the camera orientation in degree into {@link ImageProcessingOptions#Orientation}.*/
+ private static Orientation getOrientation(int cameraOrientation) {
+ switch (cameraOrientation / 90) {
+ case 3:
+ return Orientation.BOTTOM_LEFT;
+ case 2:
+ return Orientation.BOTTOM_RIGHT;
+ case 1:
+ return Orientation.TOP_RIGHT;
+ default:
+ return Orientation.TOP_LEFT;
+ }
+ }
+
+ /** Gets the name of the model file stored in Assets. */
+ protected abstract String getModelPath();
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java
new file mode 100644
index 0000000000000000000000000000000000000000..250794cc12d0e603aa47502322dc646d50689848
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java
@@ -0,0 +1,45 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import android.app.Activity;
+import java.io.IOException;
+import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
+
+/** This TensorFlowLite classifier works with the float EfficientNet model. */
+public class ClassifierFloatEfficientNet extends Classifier {
+
+ /**
+ * Initializes a {@code ClassifierFloatMobileNet}.
+ *
+ * @param device a {@link Device} object to configure the hardware accelerator
+ * @param numThreads the number of threads during the inference
+ * @throws IOException if the model is not loaded correctly
+ */
+ public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads)
+ throws IOException {
+ super(activity, device, numThreads);
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // see build.gradle for where to obtain this file. It should be auto
+ // downloaded into assets.
+ //return "efficientnet-lite0-fp32.tflite";
+ return "model.tflite";
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java
new file mode 100644
index 0000000000000000000000000000000000000000..0707de98de41395eaf3ddcfd74d6e36229a63760
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java
@@ -0,0 +1,43 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import android.app.Activity;
+import java.io.IOException;
+import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
+
+/** This TensorFlowLite classifier works with the float MobileNet model. */
+public class ClassifierFloatMobileNet extends Classifier {
+ /**
+ * Initializes a {@code ClassifierFloatMobileNet}.
+ *
+ * @param device a {@link Device} object to configure the hardware accelerator
+ * @param numThreads the number of threads during the inference
+ * @throws IOException if the model is not loaded correctly
+ */
+ public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads)
+ throws IOException {
+ super(activity, device, numThreads);
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // see build.gradle for where to obtain this file. It should be auto
+ // downloaded into assets.
+ return "mobilenet_v1_1.0_224.tflite";
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java
new file mode 100644
index 0000000000000000000000000000000000000000..05ca4fa6c409d0274a396c9b26c3c39ca8a8194e
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import android.app.Activity;
+import java.io.IOException;
+
+/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */
+public class ClassifierQuantizedEfficientNet extends Classifier {
+
+ /**
+ * Initializes a {@code ClassifierQuantizedMobileNet}.
+ *
+ * @param device a {@link Device} object to configure the hardware accelerator
+ * @param numThreads the number of threads during the inference
+ * @throws IOException if the model is not loaded correctly
+ */
+ public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads)
+ throws IOException {
+ super(activity, device, numThreads);
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // see build.gradle for where to obtain this file. It should be auto
+ // downloaded into assets.
+ return "efficientnet-lite0-int8.tflite";
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java
new file mode 100644
index 0000000000000000000000000000000000000000..978b08eeaf52a23eede437d61045db08d1dff163
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.examples.classification.tflite;
+
+import android.app.Activity;
+import java.io.IOException;
+import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
+
+/** This TensorFlow Lite classifier works with the quantized MobileNet model. */
+public class ClassifierQuantizedMobileNet extends Classifier {
+
+ /**
+ * Initializes a {@code ClassifierQuantizedMobileNet}.
+ *
+ * @param device a {@link Device} object to configure the hardware accelerator
+ * @param numThreads the number of threads during the inference
+ * @throws IOException if the model is not loaded correctly
+ */
+ public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads)
+ throws IOException {
+ super(activity, device, numThreads);
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // see build.gradle for where to obtain this file. It should be auto
+ // downloaded into assets.
+ return "mobilenet_v1_1.0_224_quant.tflite";
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle
new file mode 100644
index 0000000000000000000000000000000000000000..8d825707af20cbbead6c4599f075599148e3511c
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle
@@ -0,0 +1,40 @@
+apply plugin: 'com.android.library'
+apply plugin: 'de.undercouch.download'
+
+android {
+ compileSdkVersion 28
+ buildToolsVersion "28.0.0"
+
+ defaultConfig {
+ minSdkVersion 21
+ targetSdkVersion 28
+ versionCode 1
+ versionName "1.0"
+
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
+
+ }
+
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
+ }
+ }
+
+ aaptOptions {
+ noCompress "tflite"
+ }
+
+ lintOptions {
+ checkReleaseBuilds false
+ // Or, if you prefer, you can continue to check for errors in release builds,
+ // but continue the build even when errors are found:
+ abortOnError false
+ }
+}
+
+// Download default models; if you wish to use your own models then
+// place them in the "assets" directory and comment out this line.
+project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'
+apply from:'download.gradle'
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle
new file mode 100644
index 0000000000000000000000000000000000000000..ce76974a2c3bc6f8214461028e0dfa6ebc25d588
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle
@@ -0,0 +1,10 @@
+def modelFloatDownloadUrl = "https://github.com/isl-org/MiDaS/releases/download/v2_1/model_opt.tflite"
+def modelFloatFile = "model_opt.tflite"
+
+task downloadModelFloat(type: Download) {
+ src "${modelFloatDownloadUrl}"
+ dest project.ext.ASSET_DIR + "/${modelFloatFile}"
+ overwrite false
+}
+
+preBuild.dependsOn downloadModelFloat
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro
new file mode 100644
index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro
@@ -0,0 +1,21 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000000000000000000000000000000000..42951a56497c5f947efe4aea6a07462019fb152c
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml
@@ -0,0 +1,3 @@
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle
new file mode 100644
index 0000000000000000000000000000000000000000..e86d89d2483f92b7e778589011fad60fbba3a318
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle
@@ -0,0 +1,2 @@
+rootProject.name = 'TFLite Image Classification Demo App'
+include ':app', ':lib_support', ':lib_task_api', ':models'
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..f1150e3379e4a38d31ca7bb46dc4f31d79f482c2
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore
@@ -0,0 +1,2 @@
+# ignore model file
+#*.tflite
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj
new file mode 100644
index 0000000000000000000000000000000000000000..4917371aa33a65fdfc66c02d914f05489c446430
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj
@@ -0,0 +1,538 @@
+// !$*UTF8*$!
+{
+ archiveVersion = 1;
+ classes = {
+ };
+ objectVersion = 50;
+ objects = {
+
+/* Begin PBXBuildFile section */
+ 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = A1CE41C09920CCEC31985547 /* Pods_Midas.framework */; };
+ 8402440123D9834600704ABD /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 8402440023D9834600704ABD /* README.md */; };
+ 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */ = {isa = PBXBuildFile; fileRef = 840ECB1F238BAA2300C7D88A /* InfoCell.swift */; };
+ 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */; };
+ 840EDD022341DE380017ED42 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDD002341DE380017ED42 /* Main.storyboard */; };
+ 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 842DDB6D2372A82000F6BB94 /* OverlayView.swift */; };
+ 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */; };
+ 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846BAF7523E7FE13006FC136 /* Constants.swift */; };
+ 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FEC82341D36E00377D34 /* PreviewView.swift */; };
+ 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FECA2341D39800377D34 /* CameraFeedManager.swift */; };
+ 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */; };
+ 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB82361874A0052C104 /* TFLiteExtension.swift */; };
+ 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CEE2326338300A11A08 /* AppDelegate.swift */; };
+ 84B67CF12326338300A11A08 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CF02326338300A11A08 /* ViewController.swift */; };
+ 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 84B67CF52326338400A11A08 /* Assets.xcassets */; };
+ 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */; };
+ 84F232D5254C831E0011862E /* model_opt.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 84F232D4254C831E0011862E /* model_opt.tflite */; };
+ 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */ = {isa = PBXBuildFile; fileRef = 84FCF5912387BD7900663812 /* tfl_logo.png */; };
+/* End PBXBuildFile section */
+
+/* Begin PBXFileReference section */
+ 8402440023D9834600704ABD /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; };
+ 840ECB1F238BAA2300C7D88A /* InfoCell.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InfoCell.swift; sourceTree = ""; };
+ 840EDCFC2341DDD30017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = "Base.lproj/Launch Screen.storyboard"; sourceTree = ""; };
+ 840EDD012341DE380017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; };
+ 842DDB6D2372A82000F6BB94 /* OverlayView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OverlayView.swift; sourceTree = ""; };
+ 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelDataHandler.swift; sourceTree = ""; };
+ 846BAF7523E7FE13006FC136 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; };
+ 8474FEC82341D36E00377D34 /* PreviewView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreviewView.swift; sourceTree = ""; };
+ 8474FECA2341D39800377D34 /* CameraFeedManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CameraFeedManager.swift; sourceTree = ""; };
+ 84884291236FF0A30043FC4C /* download_models.sh */ = {isa = PBXFileReference; lastKnownFileType = text.script.sh; path = download_models.sh; sourceTree = ""; };
+ 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CVPixelBufferExtension.swift; sourceTree = ""; };
+ 84952CB82361874A0052C104 /* TFLiteExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TFLiteExtension.swift; sourceTree = ""; };
+ 84B67CEB2326338300A11A08 /* Midas.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Midas.app; sourceTree = BUILT_PRODUCTS_DIR; };
+ 84B67CEE2326338300A11A08 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = ""; };
+ 84B67CF02326338300A11A08 /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = ""; };
+ 84B67CF52326338400A11A08 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; };
+ 84B67CFA2326338400A11A08 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; };
+ 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CGSizeExtension.swift; sourceTree = ""; };
+ 84F232D4254C831E0011862E /* model_opt.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = model_opt.tflite; sourceTree = ""; };
+ 84FCF5912387BD7900663812 /* tfl_logo.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; name = tfl_logo.png; path = Assets.xcassets/tfl_logo.png; sourceTree = ""; };
+ A1CE41C09920CCEC31985547 /* Pods_Midas.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_Midas.framework; sourceTree = BUILT_PRODUCTS_DIR; };
+ D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.release.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.release.xcconfig"; sourceTree = ""; };
+ FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.debug.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.debug.xcconfig"; sourceTree = ""; };
+/* End PBXFileReference section */
+
+/* Begin PBXFrameworksBuildPhase section */
+ 84B67CE82326338300A11A08 /* Frameworks */ = {
+ isa = PBXFrameworksBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXFrameworksBuildPhase section */
+
+/* Begin PBXGroup section */
+ 840ECB1E238BAA0D00C7D88A /* Cells */ = {
+ isa = PBXGroup;
+ children = (
+ 840ECB1F238BAA2300C7D88A /* InfoCell.swift */,
+ );
+ path = Cells;
+ sourceTree = "";
+ };
+ 842DDB6C2372A80E00F6BB94 /* Views */ = {
+ isa = PBXGroup;
+ children = (
+ 842DDB6D2372A82000F6BB94 /* OverlayView.swift */,
+ );
+ path = Views;
+ sourceTree = "";
+ };
+ 846499C0235DAAE7009CBBC7 /* ModelDataHandler */ = {
+ isa = PBXGroup;
+ children = (
+ 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */,
+ );
+ path = ModelDataHandler;
+ sourceTree = "";
+ };
+ 8474FEC62341D2BE00377D34 /* ViewControllers */ = {
+ isa = PBXGroup;
+ children = (
+ 84B67CF02326338300A11A08 /* ViewController.swift */,
+ );
+ path = ViewControllers;
+ sourceTree = "";
+ };
+ 8474FEC72341D35800377D34 /* Camera Feed */ = {
+ isa = PBXGroup;
+ children = (
+ 8474FEC82341D36E00377D34 /* PreviewView.swift */,
+ 8474FECA2341D39800377D34 /* CameraFeedManager.swift */,
+ );
+ path = "Camera Feed";
+ sourceTree = "";
+ };
+ 84884290236FF07F0043FC4C /* RunScripts */ = {
+ isa = PBXGroup;
+ children = (
+ 84884291236FF0A30043FC4C /* download_models.sh */,
+ );
+ path = RunScripts;
+ sourceTree = "";
+ };
+ 848842A22370180C0043FC4C /* Model */ = {
+ isa = PBXGroup;
+ children = (
+ 84F232D4254C831E0011862E /* model_opt.tflite */,
+ );
+ path = Model;
+ sourceTree = "";
+ };
+ 84952CB3236186A20052C104 /* Extensions */ = {
+ isa = PBXGroup;
+ children = (
+ 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */,
+ 84952CB82361874A0052C104 /* TFLiteExtension.swift */,
+ 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */,
+ );
+ path = Extensions;
+ sourceTree = "";
+ };
+ 84B67CE22326338300A11A08 = {
+ isa = PBXGroup;
+ children = (
+ 8402440023D9834600704ABD /* README.md */,
+ 84884290236FF07F0043FC4C /* RunScripts */,
+ 84B67CED2326338300A11A08 /* Midas */,
+ 84B67CEC2326338300A11A08 /* Products */,
+ B4DFDCC28443B641BC36251D /* Pods */,
+ A3DA804B8D3F6891E3A02852 /* Frameworks */,
+ );
+ sourceTree = "";
+ };
+ 84B67CEC2326338300A11A08 /* Products */ = {
+ isa = PBXGroup;
+ children = (
+ 84B67CEB2326338300A11A08 /* Midas.app */,
+ );
+ name = Products;
+ sourceTree = "";
+ };
+ 84B67CED2326338300A11A08 /* Midas */ = {
+ isa = PBXGroup;
+ children = (
+ 840ECB1E238BAA0D00C7D88A /* Cells */,
+ 842DDB6C2372A80E00F6BB94 /* Views */,
+ 848842A22370180C0043FC4C /* Model */,
+ 84952CB3236186A20052C104 /* Extensions */,
+ 846499C0235DAAE7009CBBC7 /* ModelDataHandler */,
+ 8474FEC72341D35800377D34 /* Camera Feed */,
+ 8474FEC62341D2BE00377D34 /* ViewControllers */,
+ 84B67D002326339000A11A08 /* Storyboards */,
+ 84B67CEE2326338300A11A08 /* AppDelegate.swift */,
+ 846BAF7523E7FE13006FC136 /* Constants.swift */,
+ 84B67CF52326338400A11A08 /* Assets.xcassets */,
+ 84FCF5912387BD7900663812 /* tfl_logo.png */,
+ 84B67CFA2326338400A11A08 /* Info.plist */,
+ );
+ path = Midas;
+ sourceTree = "";
+ };
+ 84B67D002326339000A11A08 /* Storyboards */ = {
+ isa = PBXGroup;
+ children = (
+ 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */,
+ 840EDD002341DE380017ED42 /* Main.storyboard */,
+ );
+ path = Storyboards;
+ sourceTree = "";
+ };
+ A3DA804B8D3F6891E3A02852 /* Frameworks */ = {
+ isa = PBXGroup;
+ children = (
+ A1CE41C09920CCEC31985547 /* Pods_Midas.framework */,
+ );
+ name = Frameworks;
+ sourceTree = "";
+ };
+ B4DFDCC28443B641BC36251D /* Pods */ = {
+ isa = PBXGroup;
+ children = (
+ FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */,
+ D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */,
+ );
+ path = Pods;
+ sourceTree = "";
+ };
+/* End PBXGroup section */
+
+/* Begin PBXNativeTarget section */
+ 84B67CEA2326338300A11A08 /* Midas */ = {
+ isa = PBXNativeTarget;
+ buildConfigurationList = 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */;
+ buildPhases = (
+ 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */,
+ 84884298237010B90043FC4C /* Download TensorFlow Lite model */,
+ 84B67CE72326338300A11A08 /* Sources */,
+ 84B67CE82326338300A11A08 /* Frameworks */,
+ 84B67CE92326338300A11A08 /* Resources */,
+ );
+ buildRules = (
+ );
+ dependencies = (
+ );
+ name = Midas;
+ productName = Midas;
+ productReference = 84B67CEB2326338300A11A08 /* Midas.app */;
+ productType = "com.apple.product-type.application";
+ };
+/* End PBXNativeTarget section */
+
+/* Begin PBXProject section */
+ 84B67CE32326338300A11A08 /* Project object */ = {
+ isa = PBXProject;
+ attributes = {
+ LastSwiftUpdateCheck = 1030;
+ LastUpgradeCheck = 1030;
+ ORGANIZATIONNAME = tensorflow;
+ TargetAttributes = {
+ 84B67CEA2326338300A11A08 = {
+ CreatedOnToolsVersion = 10.3;
+ };
+ };
+ };
+ buildConfigurationList = 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */;
+ compatibilityVersion = "Xcode 9.3";
+ developmentRegion = en;
+ hasScannedForEncodings = 0;
+ knownRegions = (
+ en,
+ Base,
+ );
+ mainGroup = 84B67CE22326338300A11A08;
+ productRefGroup = 84B67CEC2326338300A11A08 /* Products */;
+ projectDirPath = "";
+ projectRoot = "";
+ targets = (
+ 84B67CEA2326338300A11A08 /* Midas */,
+ );
+ };
+/* End PBXProject section */
+
+/* Begin PBXResourcesBuildPhase section */
+ 84B67CE92326338300A11A08 /* Resources */ = {
+ isa = PBXResourcesBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 8402440123D9834600704ABD /* README.md in Resources */,
+ 84F232D5254C831E0011862E /* model_opt.tflite in Resources */,
+ 840EDD022341DE380017ED42 /* Main.storyboard in Resources */,
+ 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */,
+ 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */,
+ 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXResourcesBuildPhase section */
+
+/* Begin PBXShellScriptBuildPhase section */
+ 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputFileListPaths = (
+ );
+ inputPaths = (
+ "${PODS_PODFILE_DIR_PATH}/Podfile.lock",
+ "${PODS_ROOT}/Manifest.lock",
+ );
+ name = "[CP] Check Pods Manifest.lock";
+ outputFileListPaths = (
+ );
+ outputPaths = (
+ "$(DERIVED_FILE_DIR)/Pods-Midas-checkManifestLockResult.txt",
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
+ showEnvVarsInLog = 0;
+ };
+ 84884298237010B90043FC4C /* Download TensorFlow Lite model */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputFileListPaths = (
+ );
+ inputPaths = (
+ );
+ name = "Download TensorFlow Lite model";
+ outputFileListPaths = (
+ );
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/bash;
+ shellScript = "\"$SRCROOT/RunScripts/download_models.sh\"\n";
+ };
+/* End PBXShellScriptBuildPhase section */
+
+/* Begin PBXSourcesBuildPhase section */
+ 84B67CE72326338300A11A08 /* Sources */ = {
+ isa = PBXSourcesBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */,
+ 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */,
+ 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */,
+ 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */,
+ 84B67CF12326338300A11A08 /* ViewController.swift in Sources */,
+ 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */,
+ 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */,
+ 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */,
+ 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */,
+ 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */,
+ 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXSourcesBuildPhase section */
+
+/* Begin PBXVariantGroup section */
+ 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */ = {
+ isa = PBXVariantGroup;
+ children = (
+ 840EDCFC2341DDD30017ED42 /* Base */,
+ );
+ name = "Launch Screen.storyboard";
+ sourceTree = "";
+ };
+ 840EDD002341DE380017ED42 /* Main.storyboard */ = {
+ isa = PBXVariantGroup;
+ children = (
+ 840EDD012341DE380017ED42 /* Base */,
+ );
+ name = Main.storyboard;
+ sourceTree = "";
+ };
+/* End PBXVariantGroup section */
+
+/* Begin XCBuildConfiguration section */
+ 84B67CFB2326338400A11A08 /* Debug */ = {
+ isa = XCBuildConfiguration;
+ buildSettings = {
+ ALWAYS_SEARCH_USER_PATHS = NO;
+ CLANG_ANALYZER_NONNULL = YES;
+ CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
+ CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
+ CLANG_CXX_LIBRARY = "libc++";
+ CLANG_ENABLE_MODULES = YES;
+ CLANG_ENABLE_OBJC_ARC = YES;
+ CLANG_ENABLE_OBJC_WEAK = YES;
+ CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
+ CLANG_WARN_BOOL_CONVERSION = YES;
+ CLANG_WARN_COMMA = YES;
+ CLANG_WARN_CONSTANT_CONVERSION = YES;
+ CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
+ CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
+ CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
+ CLANG_WARN_EMPTY_BODY = YES;
+ CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
+ CLANG_WARN_INT_CONVERSION = YES;
+ CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
+ CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
+ CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
+ CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
+ CLANG_WARN_STRICT_PROTOTYPES = YES;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
+ CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
+ CLANG_WARN_UNREACHABLE_CODE = YES;
+ CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
+ CODE_SIGN_IDENTITY = "iPhone Developer";
+ COPY_PHASE_STRIP = NO;
+ DEBUG_INFORMATION_FORMAT = dwarf;
+ ENABLE_STRICT_OBJC_MSGSEND = YES;
+ ENABLE_TESTABILITY = YES;
+ GCC_C_LANGUAGE_STANDARD = gnu11;
+ GCC_DYNAMIC_NO_PIC = NO;
+ GCC_NO_COMMON_BLOCKS = YES;
+ GCC_OPTIMIZATION_LEVEL = 0;
+ GCC_PREPROCESSOR_DEFINITIONS = (
+ "DEBUG=1",
+ "$(inherited)",
+ );
+ GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
+ GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
+ GCC_WARN_UNDECLARED_SELECTOR = YES;
+ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
+ GCC_WARN_UNUSED_FUNCTION = YES;
+ GCC_WARN_UNUSED_VARIABLE = YES;
+ IPHONEOS_DEPLOYMENT_TARGET = 12.4;
+ MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
+ MTL_FAST_MATH = YES;
+ ONLY_ACTIVE_ARCH = YES;
+ SDKROOT = iphoneos;
+ SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
+ SWIFT_OPTIMIZATION_LEVEL = "-Onone";
+ };
+ name = Debug;
+ };
+ 84B67CFC2326338400A11A08 /* Release */ = {
+ isa = XCBuildConfiguration;
+ buildSettings = {
+ ALWAYS_SEARCH_USER_PATHS = NO;
+ CLANG_ANALYZER_NONNULL = YES;
+ CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
+ CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
+ CLANG_CXX_LIBRARY = "libc++";
+ CLANG_ENABLE_MODULES = YES;
+ CLANG_ENABLE_OBJC_ARC = YES;
+ CLANG_ENABLE_OBJC_WEAK = YES;
+ CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
+ CLANG_WARN_BOOL_CONVERSION = YES;
+ CLANG_WARN_COMMA = YES;
+ CLANG_WARN_CONSTANT_CONVERSION = YES;
+ CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
+ CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
+ CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
+ CLANG_WARN_EMPTY_BODY = YES;
+ CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
+ CLANG_WARN_INT_CONVERSION = YES;
+ CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
+ CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
+ CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
+ CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
+ CLANG_WARN_STRICT_PROTOTYPES = YES;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
+ CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
+ CLANG_WARN_UNREACHABLE_CODE = YES;
+ CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
+ CODE_SIGN_IDENTITY = "iPhone Developer";
+ COPY_PHASE_STRIP = NO;
+ DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
+ ENABLE_NS_ASSERTIONS = NO;
+ ENABLE_STRICT_OBJC_MSGSEND = YES;
+ GCC_C_LANGUAGE_STANDARD = gnu11;
+ GCC_NO_COMMON_BLOCKS = YES;
+ GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
+ GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
+ GCC_WARN_UNDECLARED_SELECTOR = YES;
+ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
+ GCC_WARN_UNUSED_FUNCTION = YES;
+ GCC_WARN_UNUSED_VARIABLE = YES;
+ IPHONEOS_DEPLOYMENT_TARGET = 12.4;
+ MTL_ENABLE_DEBUG_INFO = NO;
+ MTL_FAST_MATH = YES;
+ SDKROOT = iphoneos;
+ SWIFT_COMPILATION_MODE = wholemodule;
+ SWIFT_OPTIMIZATION_LEVEL = "-O";
+ VALIDATE_PRODUCT = YES;
+ };
+ name = Release;
+ };
+ 84B67CFE2326338400A11A08 /* Debug */ = {
+ isa = XCBuildConfiguration;
+ baseConfigurationReference = FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */;
+ buildSettings = {
+ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
+ CODE_SIGN_IDENTITY = "iPhone Developer";
+ CODE_SIGN_STYLE = Automatic;
+ DEVELOPMENT_TEAM = BV6M48J3RX;
+ INFOPLIST_FILE = Midas/Info.plist;
+ LD_RUNPATH_SEARCH_PATHS = (
+ "$(inherited)",
+ "@executable_path/Frameworks",
+ );
+ PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu";
+ PRODUCT_NAME = Midas;
+ PROVISIONING_PROFILE_SPECIFIER = "";
+ SWIFT_VERSION = 5.0;
+ TARGETED_DEVICE_FAMILY = "1,2";
+ };
+ name = Debug;
+ };
+ 84B67CFF2326338400A11A08 /* Release */ = {
+ isa = XCBuildConfiguration;
+ baseConfigurationReference = D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */;
+ buildSettings = {
+ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
+ CODE_SIGN_IDENTITY = "iPhone Developer";
+ CODE_SIGN_STYLE = Automatic;
+ DEVELOPMENT_TEAM = BV6M48J3RX;
+ INFOPLIST_FILE = Midas/Info.plist;
+ LD_RUNPATH_SEARCH_PATHS = (
+ "$(inherited)",
+ "@executable_path/Frameworks",
+ );
+ PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu";
+ PRODUCT_NAME = Midas;
+ PROVISIONING_PROFILE_SPECIFIER = "";
+ SWIFT_VERSION = 5.0;
+ TARGETED_DEVICE_FAMILY = "1,2";
+ };
+ name = Release;
+ };
+/* End XCBuildConfiguration section */
+
+/* Begin XCConfigurationList section */
+ 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */ = {
+ isa = XCConfigurationList;
+ buildConfigurations = (
+ 84B67CFB2326338400A11A08 /* Debug */,
+ 84B67CFC2326338400A11A08 /* Release */,
+ );
+ defaultConfigurationIsVisible = 0;
+ defaultConfigurationName = Release;
+ };
+ 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */ = {
+ isa = XCConfigurationList;
+ buildConfigurations = (
+ 84B67CFE2326338400A11A08 /* Debug */,
+ 84B67CFF2326338400A11A08 /* Release */,
+ );
+ defaultConfigurationIsVisible = 0;
+ defaultConfigurationName = Release;
+ };
+/* End XCConfigurationList section */
+ };
+ rootObject = 84B67CE32326338300A11A08 /* Project object */;
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata
new file mode 100644
index 0000000000000000000000000000000000000000..919434a6254f0e9651f402737811be6634a03e9c
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata
@@ -0,0 +1,7 @@
+
+
+
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist
new file mode 100644
index 0000000000000000000000000000000000000000..18d981003d68d0546c4804ac2ff47dd97c6e7921
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist
@@ -0,0 +1,8 @@
+
+
+
+
+ IDEDidComputeMac32BitWarning
+
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate
new file mode 100644
index 0000000000000000000000000000000000000000..1d20756ee57b79e9f9f886453bdb7997ca2ee2d4
Binary files /dev/null and b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate differ
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist
new file mode 100644
index 0000000000000000000000000000000000000000..6093f6160eedfdfc20e96396247a7dbc9247cc55
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist
@@ -0,0 +1,14 @@
+
+
+
+
+ SchemeUserState
+
+ PoseNet.xcscheme_^#shared#^_
+
+ orderHint
+ 3
+
+
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift
new file mode 100644
index 0000000000000000000000000000000000000000..233f0291ab4f379067543bdad3cc198a2dc3ab0f
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift
@@ -0,0 +1,41 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import UIKit
+
+@UIApplicationMain
+class AppDelegate: UIResponder, UIApplicationDelegate {
+
+ var window: UIWindow?
+
+ func application(_ application: UIApplication, didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?) -> Bool {
+ return true
+ }
+
+ func applicationWillResignActive(_ application: UIApplication) {
+ }
+
+ func applicationDidEnterBackground(_ application: UIApplication) {
+ }
+
+ func applicationWillEnterForeground(_ application: UIApplication) {
+ }
+
+ func applicationDidBecomeActive(_ application: UIApplication) {
+ }
+
+ func applicationWillTerminate(_ application: UIApplication) {
+ }
+}
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json
new file mode 100644
index 0000000000000000000000000000000000000000..65b74d7ef11fa59fafa829e681ac90906f3ac8b2
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json
@@ -0,0 +1 @@
+{"images":[{"size":"60x60","expected-size":"180","filename":"180.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"40x40","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"60x60","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"57x57","expected-size":"57","filename":"57.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"87","filename":"87.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"57x57","expected-size":"114","filename":"114.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"60","filename":"60.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"1024x1024","filename":"1024.png","expected-size":"1024","idiom":"ios-marketing","folder":"Assets.xcassets/AppIcon.appiconset/","scale":"1x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"72x72","expected-size":"72","filename":"72.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"76x76","expected-size":"152","filename":"152.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"50x50","expected-size":"100","filename":"100.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"76x76","expected-size":"76","filename":"76.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"50x50","expected-size":"50","filename":"50.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"72x72","expected-size":"144","filename":"144.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"40x40","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"83.5x83.5","expected-size":"167","filename":"167.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"20x20","expected-size":"20","filename":"20.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"}]}
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json
new file mode 100644
index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json
@@ -0,0 +1,6 @@
+{
+ "info" : {
+ "version" : 1,
+ "author" : "xcode"
+ }
+}
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift
new file mode 100644
index 0000000000000000000000000000000000000000..48d65b88ee220e722fbad2570c8e879a431cd0f5
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift
@@ -0,0 +1,316 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import AVFoundation
+import UIKit
+import os
+
+// MARK: - CameraFeedManagerDelegate Declaration
+@objc protocol CameraFeedManagerDelegate: class {
+ /// This method delivers the pixel buffer of the current frame seen by the device's camera.
+ @objc optional func cameraFeedManager(
+ _ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer
+ )
+
+ /// This method initimates that a session runtime error occured.
+ func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager)
+
+ /// This method initimates that the session was interrupted.
+ func cameraFeedManager(
+ _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool
+ )
+
+ /// This method initimates that the session interruption has ended.
+ func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager)
+
+ /// This method initimates that there was an error in video configurtion.
+ func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager)
+
+ /// This method initimates that the camera permissions have been denied.
+ func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager)
+}
+
+/// This enum holds the state of the camera initialization.
+// MARK: - Camera Initialization State Enum
+enum CameraConfiguration {
+ case success
+ case failed
+ case permissionDenied
+}
+
+/// This class manages all camera related functionalities.
+// MARK: - Camera Related Functionalies Manager
+class CameraFeedManager: NSObject {
+ // MARK: Camera Related Instance Variables
+ private let session: AVCaptureSession = AVCaptureSession()
+
+ private let previewView: PreviewView
+ private let sessionQueue = DispatchQueue(label: "sessionQueue")
+ private var cameraConfiguration: CameraConfiguration = .failed
+ private lazy var videoDataOutput = AVCaptureVideoDataOutput()
+ private var isSessionRunning = false
+
+ // MARK: CameraFeedManagerDelegate
+ weak var delegate: CameraFeedManagerDelegate?
+
+ // MARK: Initializer
+ init(previewView: PreviewView) {
+ self.previewView = previewView
+ super.init()
+
+ // Initializes the session
+ session.sessionPreset = .high
+ self.previewView.session = session
+ self.previewView.previewLayer.connection?.videoOrientation = .portrait
+ self.previewView.previewLayer.videoGravity = .resizeAspectFill
+ self.attemptToConfigureSession()
+ }
+
+ // MARK: Session Start and End methods
+
+ /// This method starts an AVCaptureSession based on whether the camera configuration was successful.
+ func checkCameraConfigurationAndStartSession() {
+ sessionQueue.async {
+ switch self.cameraConfiguration {
+ case .success:
+ self.addObservers()
+ self.startSession()
+ case .failed:
+ DispatchQueue.main.async {
+ self.delegate?.presentVideoConfigurationErrorAlert(self)
+ }
+ case .permissionDenied:
+ DispatchQueue.main.async {
+ self.delegate?.presentCameraPermissionsDeniedAlert(self)
+ }
+ }
+ }
+ }
+
+ /// This method stops a running an AVCaptureSession.
+ func stopSession() {
+ self.removeObservers()
+ sessionQueue.async {
+ if self.session.isRunning {
+ self.session.stopRunning()
+ self.isSessionRunning = self.session.isRunning
+ }
+ }
+
+ }
+
+ /// This method resumes an interrupted AVCaptureSession.
+ func resumeInterruptedSession(withCompletion completion: @escaping (Bool) -> Void) {
+ sessionQueue.async {
+ self.startSession()
+
+ DispatchQueue.main.async {
+ completion(self.isSessionRunning)
+ }
+ }
+ }
+
+ /// This method starts the AVCaptureSession
+ private func startSession() {
+ self.session.startRunning()
+ self.isSessionRunning = self.session.isRunning
+ }
+
+ // MARK: Session Configuration Methods.
+ /// This method requests for camera permissions and handles the configuration of the session and stores the result of configuration.
+ private func attemptToConfigureSession() {
+ switch AVCaptureDevice.authorizationStatus(for: .video) {
+ case .authorized:
+ self.cameraConfiguration = .success
+ case .notDetermined:
+ self.sessionQueue.suspend()
+ self.requestCameraAccess(completion: { granted in
+ self.sessionQueue.resume()
+ })
+ case .denied:
+ self.cameraConfiguration = .permissionDenied
+ default:
+ break
+ }
+
+ self.sessionQueue.async {
+ self.configureSession()
+ }
+ }
+
+ /// This method requests for camera permissions.
+ private func requestCameraAccess(completion: @escaping (Bool) -> Void) {
+ AVCaptureDevice.requestAccess(for: .video) { (granted) in
+ if !granted {
+ self.cameraConfiguration = .permissionDenied
+ } else {
+ self.cameraConfiguration = .success
+ }
+ completion(granted)
+ }
+ }
+
+ /// This method handles all the steps to configure an AVCaptureSession.
+ private func configureSession() {
+ guard cameraConfiguration == .success else {
+ return
+ }
+ session.beginConfiguration()
+
+ // Tries to add an AVCaptureDeviceInput.
+ guard addVideoDeviceInput() == true else {
+ self.session.commitConfiguration()
+ self.cameraConfiguration = .failed
+ return
+ }
+
+ // Tries to add an AVCaptureVideoDataOutput.
+ guard addVideoDataOutput() else {
+ self.session.commitConfiguration()
+ self.cameraConfiguration = .failed
+ return
+ }
+
+ session.commitConfiguration()
+ self.cameraConfiguration = .success
+ }
+
+ /// This method tries to an AVCaptureDeviceInput to the current AVCaptureSession.
+ private func addVideoDeviceInput() -> Bool {
+ /// Tries to get the default back camera.
+ guard
+ let camera = AVCaptureDevice.default(.builtInWideAngleCamera, for: .video, position: .back)
+ else {
+ fatalError("Cannot find camera")
+ }
+
+ do {
+ let videoDeviceInput = try AVCaptureDeviceInput(device: camera)
+ if session.canAddInput(videoDeviceInput) {
+ session.addInput(videoDeviceInput)
+ return true
+ } else {
+ return false
+ }
+ } catch {
+ fatalError("Cannot create video device input")
+ }
+ }
+
+ /// This method tries to an AVCaptureVideoDataOutput to the current AVCaptureSession.
+ private func addVideoDataOutput() -> Bool {
+ let sampleBufferQueue = DispatchQueue(label: "sampleBufferQueue")
+ videoDataOutput.setSampleBufferDelegate(self, queue: sampleBufferQueue)
+ videoDataOutput.alwaysDiscardsLateVideoFrames = true
+ videoDataOutput.videoSettings = [
+ String(kCVPixelBufferPixelFormatTypeKey): kCMPixelFormat_32BGRA
+ ]
+
+ if session.canAddOutput(videoDataOutput) {
+ session.addOutput(videoDataOutput)
+ videoDataOutput.connection(with: .video)?.videoOrientation = .portrait
+ return true
+ }
+ return false
+ }
+
+ // MARK: Notification Observer Handling
+ private func addObservers() {
+ NotificationCenter.default.addObserver(
+ self, selector: #selector(CameraFeedManager.sessionRuntimeErrorOccured(notification:)),
+ name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session)
+ NotificationCenter.default.addObserver(
+ self, selector: #selector(CameraFeedManager.sessionWasInterrupted(notification:)),
+ name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session)
+ NotificationCenter.default.addObserver(
+ self, selector: #selector(CameraFeedManager.sessionInterruptionEnded),
+ name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session)
+ }
+
+ private func removeObservers() {
+ NotificationCenter.default.removeObserver(
+ self, name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session)
+ NotificationCenter.default.removeObserver(
+ self, name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session)
+ NotificationCenter.default.removeObserver(
+ self, name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session)
+ }
+
+ // MARK: Notification Observers
+ @objc func sessionWasInterrupted(notification: Notification) {
+ if let userInfoValue = notification.userInfo?[AVCaptureSessionInterruptionReasonKey]
+ as AnyObject?,
+ let reasonIntegerValue = userInfoValue.integerValue,
+ let reason = AVCaptureSession.InterruptionReason(rawValue: reasonIntegerValue)
+ {
+ os_log("Capture session was interrupted with reason: %s", type: .error, reason.rawValue)
+
+ var canResumeManually = false
+ if reason == .videoDeviceInUseByAnotherClient {
+ canResumeManually = true
+ } else if reason == .videoDeviceNotAvailableWithMultipleForegroundApps {
+ canResumeManually = false
+ }
+
+ delegate?.cameraFeedManager(self, sessionWasInterrupted: canResumeManually)
+
+ }
+ }
+
+ @objc func sessionInterruptionEnded(notification: Notification) {
+ delegate?.cameraFeedManagerDidEndSessionInterruption(self)
+ }
+
+ @objc func sessionRuntimeErrorOccured(notification: Notification) {
+ guard let error = notification.userInfo?[AVCaptureSessionErrorKey] as? AVError else {
+ return
+ }
+
+ os_log("Capture session runtime error: %s", type: .error, error.localizedDescription)
+
+ if error.code == .mediaServicesWereReset {
+ sessionQueue.async {
+ if self.isSessionRunning {
+ self.startSession()
+ } else {
+ DispatchQueue.main.async {
+ self.delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self)
+ }
+ }
+ }
+ } else {
+ delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self)
+ }
+ }
+}
+
+/// AVCaptureVideoDataOutputSampleBufferDelegate
+extension CameraFeedManager: AVCaptureVideoDataOutputSampleBufferDelegate {
+ /// This method delegates the CVPixelBuffer of the frame seen by the camera currently.
+ func captureOutput(
+ _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer,
+ from connection: AVCaptureConnection
+ ) {
+
+ // Converts the CMSampleBuffer to a CVPixelBuffer.
+ let pixelBuffer: CVPixelBuffer? = CMSampleBufferGetImageBuffer(sampleBuffer)
+
+ guard let imagePixelBuffer = pixelBuffer else {
+ return
+ }
+
+ // Delegates the pixel buffer to the ViewController.
+ delegate?.cameraFeedManager?(self, didOutput: imagePixelBuffer)
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift
new file mode 100644
index 0000000000000000000000000000000000000000..308c7ec54308af5c152ff6038670b26501a8e82c
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift
@@ -0,0 +1,39 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import UIKit
+import AVFoundation
+
+ /// The camera frame is displayed on this view.
+class PreviewView: UIView {
+ var previewLayer: AVCaptureVideoPreviewLayer {
+ guard let layer = layer as? AVCaptureVideoPreviewLayer else {
+ fatalError("Layer expected is of type VideoPreviewLayer")
+ }
+ return layer
+ }
+
+ var session: AVCaptureSession? {
+ get {
+ return previewLayer.session
+ }
+ set {
+ previewLayer.session = newValue
+ }
+ }
+
+ override class var layerClass: AnyClass {
+ return AVCaptureVideoPreviewLayer.self
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift
new file mode 100644
index 0000000000000000000000000000000000000000..c6be64af5678541ec09fc367b03c80155876f0ba
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift
@@ -0,0 +1,21 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import UIKit
+
+/// Table cell for inference result in bottom view.
+class InfoCell: UITableViewCell {
+ @IBOutlet weak var fieldNameLabel: UILabel!
+ @IBOutlet weak var infoLabel: UILabel!
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift
new file mode 100644
index 0000000000000000000000000000000000000000..b0789ee58a1ea373d441f05333d8ce8914adadb7
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift
@@ -0,0 +1,25 @@
+// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+enum Constants {
+ // MARK: - Constants related to the image processing
+ static let bgraPixel = (channels: 4, alphaComponent: 3, lastBgrComponent: 2)
+ static let rgbPixelChannels = 3
+ static let maxRGBValue: Float32 = 255.0
+
+ // MARK: - Constants related to the model interperter
+ static let defaultThreadCount = 2
+ static let defaultDelegate: Delegates = .CPU
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift
new file mode 100644
index 0000000000000000000000000000000000000000..031550ea0081963d18b5b83712854babaf7c0a34
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift
@@ -0,0 +1,45 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+import Accelerate
+import Foundation
+
+extension CGSize {
+ /// Returns `CGAfineTransform` to resize `self` to fit in destination size, keeping aspect ratio
+ /// of `self`. `self` image is resized to be inscribe to destination size and located in center of
+ /// destination.
+ ///
+ /// - Parameter toFitIn: destination size to be filled.
+ /// - Returns: `CGAffineTransform` to transform `self` image to `dest` image.
+ func transformKeepAspect(toFitIn dest: CGSize) -> CGAffineTransform {
+ let sourceRatio = self.height / self.width
+ let destRatio = dest.height / dest.width
+
+ // Calculates ratio `self` to `dest`.
+ var ratio: CGFloat
+ var x: CGFloat = 0
+ var y: CGFloat = 0
+ if sourceRatio > destRatio {
+ // Source size is taller than destination. Resized to fit in destination height, and find
+ // horizontal starting point to be centered.
+ ratio = dest.height / self.height
+ x = (dest.width - self.width * ratio) / 2
+ } else {
+ ratio = dest.width / self.width
+ y = (dest.height - self.height * ratio) / 2
+ }
+ return CGAffineTransform(a: ratio, b: 0, c: 0, d: ratio, tx: x, ty: y)
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift
new file mode 100644
index 0000000000000000000000000000000000000000..4899c76562a546c513736fbf4556629b08d2c929
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift
@@ -0,0 +1,172 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+import Accelerate
+import Foundation
+
+extension CVPixelBuffer {
+ var size: CGSize {
+ return CGSize(width: CVPixelBufferGetWidth(self), height: CVPixelBufferGetHeight(self))
+ }
+
+ /// Returns a new `CVPixelBuffer` created by taking the self area and resizing it to the
+ /// specified target size. Aspect ratios of source image and destination image are expected to be
+ /// same.
+ ///
+ /// - Parameters:
+ /// - from: Source area of image to be cropped and resized.
+ /// - to: Size to scale the image to(i.e. image size used while training the model).
+ /// - Returns: The cropped and resized image of itself.
+ func resize(from source: CGRect, to size: CGSize) -> CVPixelBuffer? {
+ let rect = CGRect(origin: CGPoint(x: 0, y: 0), size: self.size)
+ guard rect.contains(source) else {
+ os_log("Resizing Error: source area is out of index", type: .error)
+ return nil
+ }
+ guard rect.size.width / rect.size.height - source.size.width / source.size.height < 1e-5
+ else {
+ os_log(
+ "Resizing Error: source image ratio and destination image ratio is different",
+ type: .error)
+ return nil
+ }
+
+ let inputImageRowBytes = CVPixelBufferGetBytesPerRow(self)
+ let imageChannels = 4
+
+ CVPixelBufferLockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0))
+ defer { CVPixelBufferUnlockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) }
+
+ // Finds the address of the upper leftmost pixel of the source area.
+ guard
+ let inputBaseAddress = CVPixelBufferGetBaseAddress(self)?.advanced(
+ by: Int(source.minY) * inputImageRowBytes + Int(source.minX) * imageChannels)
+ else {
+ return nil
+ }
+
+ // Crops given area as vImage Buffer.
+ var croppedImage = vImage_Buffer(
+ data: inputBaseAddress, height: UInt(source.height), width: UInt(source.width),
+ rowBytes: inputImageRowBytes)
+
+ let resultRowBytes = Int(size.width) * imageChannels
+ guard let resultAddress = malloc(Int(size.height) * resultRowBytes) else {
+ return nil
+ }
+
+ // Allocates a vacant vImage buffer for resized image.
+ var resizedImage = vImage_Buffer(
+ data: resultAddress,
+ height: UInt(size.height), width: UInt(size.width),
+ rowBytes: resultRowBytes
+ )
+
+ // Performs the scale operation on cropped image and stores it in result image buffer.
+ guard vImageScale_ARGB8888(&croppedImage, &resizedImage, nil, vImage_Flags(0)) == kvImageNoError
+ else {
+ return nil
+ }
+
+ let releaseCallBack: CVPixelBufferReleaseBytesCallback = { mutablePointer, pointer in
+ if let pointer = pointer {
+ free(UnsafeMutableRawPointer(mutating: pointer))
+ }
+ }
+
+ var result: CVPixelBuffer?
+
+ // Converts the thumbnail vImage buffer to CVPixelBuffer
+ let conversionStatus = CVPixelBufferCreateWithBytes(
+ nil,
+ Int(size.width), Int(size.height),
+ CVPixelBufferGetPixelFormatType(self),
+ resultAddress,
+ resultRowBytes,
+ releaseCallBack,
+ nil,
+ nil,
+ &result
+ )
+
+ guard conversionStatus == kCVReturnSuccess else {
+ free(resultAddress)
+ return nil
+ }
+
+ return result
+ }
+
+ /// Returns the RGB `Data` representation of the given image buffer.
+ ///
+ /// - Parameters:
+ /// - isModelQuantized: Whether the model is quantized (i.e. fixed point values rather than
+ /// floating point values).
+ /// - Returns: The RGB data representation of the image buffer or `nil` if the buffer could not be
+ /// converted.
+ func rgbData(
+ isModelQuantized: Bool
+ ) -> Data? {
+ CVPixelBufferLockBaseAddress(self, .readOnly)
+ defer { CVPixelBufferUnlockBaseAddress(self, .readOnly) }
+ guard let sourceData = CVPixelBufferGetBaseAddress(self) else {
+ return nil
+ }
+
+ let width = CVPixelBufferGetWidth(self)
+ let height = CVPixelBufferGetHeight(self)
+ let sourceBytesPerRow = CVPixelBufferGetBytesPerRow(self)
+ let destinationBytesPerRow = Constants.rgbPixelChannels * width
+
+ // Assign input image to `sourceBuffer` to convert it.
+ var sourceBuffer = vImage_Buffer(
+ data: sourceData,
+ height: vImagePixelCount(height),
+ width: vImagePixelCount(width),
+ rowBytes: sourceBytesPerRow)
+
+ // Make `destinationBuffer` and `destinationData` for its data to be assigned.
+ guard let destinationData = malloc(height * destinationBytesPerRow) else {
+ os_log("Error: out of memory", type: .error)
+ return nil
+ }
+ defer { free(destinationData) }
+ var destinationBuffer = vImage_Buffer(
+ data: destinationData,
+ height: vImagePixelCount(height),
+ width: vImagePixelCount(width),
+ rowBytes: destinationBytesPerRow)
+
+ // Convert image type.
+ switch CVPixelBufferGetPixelFormatType(self) {
+ case kCVPixelFormatType_32BGRA:
+ vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags))
+ case kCVPixelFormatType_32ARGB:
+ vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags))
+ default:
+ os_log("The type of this image is not supported.", type: .error)
+ return nil
+ }
+
+ // Make `Data` with converted image.
+ let imageByteData = Data(
+ bytes: destinationBuffer.data, count: destinationBuffer.rowBytes * height)
+
+ if isModelQuantized { return imageByteData }
+
+ let imageBytes = [UInt8](imageByteData)
+ return Data(copyingBufferOf: imageBytes.map { Float($0) / Constants.maxRGBValue })
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift
new file mode 100644
index 0000000000000000000000000000000000000000..63f7ced786e2b550391c77af534d1d3c431522c6
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift
@@ -0,0 +1,75 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+import Accelerate
+import CoreImage
+import Foundation
+import TensorFlowLite
+
+// MARK: - Data
+extension Data {
+ /// Creates a new buffer by copying the buffer pointer of the given array.
+ ///
+ /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
+ /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
+ /// data from the resulting buffer has undefined behavior.
+ /// - Parameter array: An array with elements of type `T`.
+ init(copyingBufferOf array: [T]) {
+ self = array.withUnsafeBufferPointer(Data.init)
+ }
+
+ /// Convert a Data instance to Array representation.
+ func toArray(type: T.Type) -> [T] where T: AdditiveArithmetic {
+ var array = [T](repeating: T.zero, count: self.count / MemoryLayout.stride)
+ _ = array.withUnsafeMutableBytes { self.copyBytes(to: $0) }
+ return array
+ }
+}
+
+// MARK: - Wrappers
+/// Struct for handling multidimension `Data` in flat `Array`.
+struct FlatArray {
+ private var array: [Element]
+ var dimensions: [Int]
+
+ init(tensor: Tensor) {
+ dimensions = tensor.shape.dimensions
+ array = tensor.data.toArray(type: Element.self)
+ }
+
+ private func flatIndex(_ index: [Int]) -> Int {
+ guard index.count == dimensions.count else {
+ fatalError("Invalid index: got \(index.count) index(es) for \(dimensions.count) index(es).")
+ }
+
+ var result = 0
+ for i in 0.. index[i] else {
+ fatalError("Invalid index: \(index[i]) is bigger than \(dimensions[i])")
+ }
+ result = dimensions[i] * result + index[i]
+ }
+ return result
+ }
+
+ subscript(_ index: Int...) -> Element {
+ get {
+ return array[flatIndex(index)]
+ }
+ set(newValue) {
+ array[flatIndex(index)] = newValue
+ }
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist
new file mode 100644
index 0000000000000000000000000000000000000000..4330d9b33f31010549802febc6f6f2bc9fd9b950
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist
@@ -0,0 +1,42 @@
+
+
+
+
+ CFBundleDevelopmentRegion
+ $(DEVELOPMENT_LANGUAGE)
+ CFBundleExecutable
+ $(EXECUTABLE_NAME)
+ CFBundleIdentifier
+ $(PRODUCT_BUNDLE_IDENTIFIER)
+ CFBundleInfoDictionaryVersion
+ 6.0
+ CFBundleName
+ $(PRODUCT_NAME)
+ CFBundlePackageType
+ APPL
+ CFBundleShortVersionString
+ 1.0
+ CFBundleVersion
+ 1
+ LSRequiresIPhoneOS
+
+ NSCameraUsageDescription
+ This app will use camera to continuously estimate the depth map.
+ UILaunchStoryboardName
+ LaunchScreen
+ UIMainStoryboardFile
+ Main
+ UIRequiredDeviceCapabilities
+
+ armv7
+
+ UISupportedInterfaceOrientations
+
+ UIInterfaceOrientationPortrait
+
+ UISupportedInterfaceOrientations~ipad
+
+ UIInterfaceOrientationPortrait
+
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift
new file mode 100644
index 0000000000000000000000000000000000000000..144cfe1fa3a65af5adcb572237f2bf9718e570ae
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift
@@ -0,0 +1,464 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import Accelerate
+import CoreImage
+import Foundation
+import TensorFlowLite
+import UIKit
+
+/// This class handles all data preprocessing and makes calls to run inference on a given frame
+/// by invoking the `Interpreter`. It then formats the inferences obtained.
+class ModelDataHandler {
+ // MARK: - Private Properties
+
+ /// TensorFlow Lite `Interpreter` object for performing inference on a given model.
+ private var interpreter: Interpreter
+
+ /// TensorFlow lite `Tensor` of model input and output.
+ private var inputTensor: Tensor
+
+ //private var heatsTensor: Tensor
+ //private var offsetsTensor: Tensor
+ private var outputTensor: Tensor
+ // MARK: - Initialization
+
+ /// A failable initializer for `ModelDataHandler`. A new instance is created if the model is
+ /// successfully loaded from the app's main bundle. Default `threadCount` is 2.
+ init(
+ threadCount: Int = Constants.defaultThreadCount,
+ delegate: Delegates = Constants.defaultDelegate
+ ) throws {
+ // Construct the path to the model file.
+ guard
+ let modelPath = Bundle.main.path(
+ forResource: Model.file.name,
+ ofType: Model.file.extension
+ )
+ else {
+ fatalError("Failed to load the model file with name: \(Model.file.name).")
+ }
+
+ // Specify the options for the `Interpreter`.
+ var options = Interpreter.Options()
+ options.threadCount = threadCount
+
+ // Specify the delegates for the `Interpreter`.
+ var delegates: [Delegate]?
+ switch delegate {
+ case .Metal:
+ delegates = [MetalDelegate()]
+ case .CoreML:
+ if let coreMLDelegate = CoreMLDelegate() {
+ delegates = [coreMLDelegate]
+ } else {
+ delegates = nil
+ }
+ default:
+ delegates = nil
+ }
+
+ // Create the `Interpreter`.
+ interpreter = try Interpreter(modelPath: modelPath, options: options, delegates: delegates)
+
+ // Initialize input and output `Tensor`s.
+ // Allocate memory for the model's input `Tensor`s.
+ try interpreter.allocateTensors()
+
+ // Get allocated input and output `Tensor`s.
+ inputTensor = try interpreter.input(at: 0)
+ outputTensor = try interpreter.output(at: 0)
+ //heatsTensor = try interpreter.output(at: 0)
+ //offsetsTensor = try interpreter.output(at: 1)
+
+ /*
+ // Check if input and output `Tensor`s are in the expected formats.
+ guard (inputTensor.dataType == .uInt8) == Model.isQuantized else {
+ fatalError("Unexpected Model: quantization is \(!Model.isQuantized)")
+ }
+
+ guard inputTensor.shape.dimensions[0] == Model.input.batchSize,
+ inputTensor.shape.dimensions[1] == Model.input.height,
+ inputTensor.shape.dimensions[2] == Model.input.width,
+ inputTensor.shape.dimensions[3] == Model.input.channelSize
+ else {
+ fatalError("Unexpected Model: input shape")
+ }
+
+
+ guard heatsTensor.shape.dimensions[0] == Model.output.batchSize,
+ heatsTensor.shape.dimensions[1] == Model.output.height,
+ heatsTensor.shape.dimensions[2] == Model.output.width,
+ heatsTensor.shape.dimensions[3] == Model.output.keypointSize
+ else {
+ fatalError("Unexpected Model: heat tensor")
+ }
+
+ guard offsetsTensor.shape.dimensions[0] == Model.output.batchSize,
+ offsetsTensor.shape.dimensions[1] == Model.output.height,
+ offsetsTensor.shape.dimensions[2] == Model.output.width,
+ offsetsTensor.shape.dimensions[3] == Model.output.offsetSize
+ else {
+ fatalError("Unexpected Model: offset tensor")
+ }
+ */
+
+ }
+
+ /// Runs Midas model with given image with given source area to destination area.
+ ///
+ /// - Parameters:
+ /// - on: Input image to run the model.
+ /// - from: Range of input image to run the model.
+ /// - to: Size of view to render the result.
+ /// - Returns: Result of the inference and the times consumed in every steps.
+ func runMidas(on pixelbuffer: CVPixelBuffer, from source: CGRect, to dest: CGSize)
+ //-> (Result, Times)?
+ //-> (FlatArray, Times)?
+ -> ([Float], Int, Int, Times)?
+ {
+ // Start times of each process.
+ let preprocessingStartTime: Date
+ let inferenceStartTime: Date
+ let postprocessingStartTime: Date
+
+ // Processing times in miliseconds.
+ let preprocessingTime: TimeInterval
+ let inferenceTime: TimeInterval
+ let postprocessingTime: TimeInterval
+
+ preprocessingStartTime = Date()
+ guard let data = preprocess(of: pixelbuffer, from: source) else {
+ os_log("Preprocessing failed", type: .error)
+ return nil
+ }
+ preprocessingTime = Date().timeIntervalSince(preprocessingStartTime) * 1000
+
+ inferenceStartTime = Date()
+ inference(from: data)
+ inferenceTime = Date().timeIntervalSince(inferenceStartTime) * 1000
+
+ postprocessingStartTime = Date()
+ //guard let result = postprocess(to: dest) else {
+ // os_log("Postprocessing failed", type: .error)
+ // return nil
+ //}
+ postprocessingTime = Date().timeIntervalSince(postprocessingStartTime) * 1000
+
+
+ let results: [Float]
+ switch outputTensor.dataType {
+ case .uInt8:
+ guard let quantization = outputTensor.quantizationParameters else {
+ print("No results returned because the quantization values for the output tensor are nil.")
+ return nil
+ }
+ let quantizedResults = [UInt8](outputTensor.data)
+ results = quantizedResults.map {
+ quantization.scale * Float(Int($0) - quantization.zeroPoint)
+ }
+ case .float32:
+ results = [Float32](unsafeData: outputTensor.data) ?? []
+ default:
+ print("Output tensor data type \(outputTensor.dataType) is unsupported for this example app.")
+ return nil
+ }
+
+
+ let times = Times(
+ preprocessing: preprocessingTime,
+ inference: inferenceTime,
+ postprocessing: postprocessingTime)
+
+ return (results, Model.input.width, Model.input.height, times)
+ }
+
+ // MARK: - Private functions to run model
+ /// Preprocesses given rectangle image to be `Data` of disired size by croping and resizing it.
+ ///
+ /// - Parameters:
+ /// - of: Input image to crop and resize.
+ /// - from: Target area to be cropped and resized.
+ /// - Returns: The cropped and resized image. `nil` if it can not be processed.
+ private func preprocess(of pixelBuffer: CVPixelBuffer, from targetSquare: CGRect) -> Data? {
+ let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer)
+ assert(sourcePixelFormat == kCVPixelFormatType_32BGRA)
+
+ // Resize `targetSquare` of input image to `modelSize`.
+ let modelSize = CGSize(width: Model.input.width, height: Model.input.height)
+ guard let thumbnail = pixelBuffer.resize(from: targetSquare, to: modelSize)
+ else {
+ return nil
+ }
+
+ // Remove the alpha component from the image buffer to get the initialized `Data`.
+ let byteCount =
+ Model.input.batchSize
+ * Model.input.height * Model.input.width
+ * Model.input.channelSize
+ guard
+ let inputData = thumbnail.rgbData(
+ isModelQuantized: Model.isQuantized
+ )
+ else {
+ os_log("Failed to convert the image buffer to RGB data.", type: .error)
+ return nil
+ }
+
+ return inputData
+ }
+
+
+
+ /*
+ /// Postprocesses output `Tensor`s to `Result` with size of view to render the result.
+ ///
+ /// - Parameters:
+ /// - to: Size of view to be displaied.
+ /// - Returns: Postprocessed `Result`. `nil` if it can not be processed.
+ private func postprocess(to viewSize: CGSize) -> Result? {
+ // MARK: Formats output tensors
+ // Convert `Tensor` to `FlatArray`. As Midas is not quantized, convert them to Float type
+ // `FlatArray`.
+ let heats = FlatArray(tensor: heatsTensor)
+ let offsets = FlatArray(tensor: offsetsTensor)
+
+ // MARK: Find position of each key point
+ // Finds the (row, col) locations of where the keypoints are most likely to be. The highest
+ // `heats[0, row, col, keypoint]` value, the more likely `keypoint` being located in (`row`,
+ // `col`).
+ let keypointPositions = (0.. (Int, Int) in
+ var maxValue = heats[0, 0, 0, keypoint]
+ var maxRow = 0
+ var maxCol = 0
+ for row in 0.. maxValue {
+ maxValue = heats[0, row, col, keypoint]
+ maxRow = row
+ maxCol = col
+ }
+ }
+ }
+ return (maxRow, maxCol)
+ }
+
+ // MARK: Calculates total confidence score
+ // Calculates total confidence score of each key position.
+ let totalScoreSum = keypointPositions.enumerated().reduce(0.0) { accumulator, elem -> Float32 in
+ accumulator + sigmoid(heats[0, elem.element.0, elem.element.1, elem.offset])
+ }
+ let totalScore = totalScoreSum / Float32(Model.output.keypointSize)
+
+ // MARK: Calculate key point position on model input
+ // Calculates `KeyPoint` coordination model input image with `offsets` adjustment.
+ let coords = keypointPositions.enumerated().map { index, elem -> (y: Float32, x: Float32) in
+ let (y, x) = elem
+ let yCoord =
+ Float32(y) / Float32(Model.output.height - 1) * Float32(Model.input.height)
+ + offsets[0, y, x, index]
+ let xCoord =
+ Float32(x) / Float32(Model.output.width - 1) * Float32(Model.input.width)
+ + offsets[0, y, x, index + Model.output.keypointSize]
+ return (y: yCoord, x: xCoord)
+ }
+
+ // MARK: Transform key point position and make lines
+ // Make `Result` from `keypointPosition'. Each point is adjusted to `ViewSize` to be drawn.
+ var result = Result(dots: [], lines: [], score: totalScore)
+ var bodyPartToDotMap = [BodyPart: CGPoint]()
+ for (index, part) in BodyPart.allCases.enumerated() {
+ let position = CGPoint(
+ x: CGFloat(coords[index].x) * viewSize.width / CGFloat(Model.input.width),
+ y: CGFloat(coords[index].y) * viewSize.height / CGFloat(Model.input.height)
+ )
+ bodyPartToDotMap[part] = position
+ result.dots.append(position)
+ }
+
+ do {
+ try result.lines = BodyPart.lines.map { map throws -> Line in
+ guard let from = bodyPartToDotMap[map.from] else {
+ throw PostprocessError.missingBodyPart(of: map.from)
+ }
+ guard let to = bodyPartToDotMap[map.to] else {
+ throw PostprocessError.missingBodyPart(of: map.to)
+ }
+ return Line(from: from, to: to)
+ }
+ } catch PostprocessError.missingBodyPart(let missingPart) {
+ os_log("Postprocessing error: %s is missing.", type: .error, missingPart.rawValue)
+ return nil
+ } catch {
+ os_log("Postprocessing error: %s", type: .error, error.localizedDescription)
+ return nil
+ }
+
+ return result
+ }
+*/
+
+
+
+ /// Run inference with given `Data`
+ ///
+ /// Parameter `from`: `Data` of input image to run model.
+ private func inference(from data: Data) {
+ // Copy the initialized `Data` to the input `Tensor`.
+ do {
+ try interpreter.copy(data, toInputAt: 0)
+
+ // Run inference by invoking the `Interpreter`.
+ try interpreter.invoke()
+
+ // Get the output `Tensor` to process the inference results.
+ outputTensor = try interpreter.output(at: 0)
+ //heatsTensor = try interpreter.output(at: 0)
+ //offsetsTensor = try interpreter.output(at: 1)
+
+
+ } catch let error {
+ os_log(
+ "Failed to invoke the interpreter with error: %s", type: .error,
+ error.localizedDescription)
+ return
+ }
+ }
+
+ /// Returns value within [0,1].
+ private func sigmoid(_ x: Float32) -> Float32 {
+ return (1.0 / (1.0 + exp(-x)))
+ }
+}
+
+// MARK: - Data types for inference result
+struct KeyPoint {
+ var bodyPart: BodyPart = BodyPart.NOSE
+ var position: CGPoint = CGPoint()
+ var score: Float = 0.0
+}
+
+struct Line {
+ let from: CGPoint
+ let to: CGPoint
+}
+
+struct Times {
+ var preprocessing: Double
+ var inference: Double
+ var postprocessing: Double
+}
+
+struct Result {
+ var dots: [CGPoint]
+ var lines: [Line]
+ var score: Float
+}
+
+enum BodyPart: String, CaseIterable {
+ case NOSE = "nose"
+ case LEFT_EYE = "left eye"
+ case RIGHT_EYE = "right eye"
+ case LEFT_EAR = "left ear"
+ case RIGHT_EAR = "right ear"
+ case LEFT_SHOULDER = "left shoulder"
+ case RIGHT_SHOULDER = "right shoulder"
+ case LEFT_ELBOW = "left elbow"
+ case RIGHT_ELBOW = "right elbow"
+ case LEFT_WRIST = "left wrist"
+ case RIGHT_WRIST = "right wrist"
+ case LEFT_HIP = "left hip"
+ case RIGHT_HIP = "right hip"
+ case LEFT_KNEE = "left knee"
+ case RIGHT_KNEE = "right knee"
+ case LEFT_ANKLE = "left ankle"
+ case RIGHT_ANKLE = "right ankle"
+
+ /// List of lines connecting each part.
+ static let lines = [
+ (from: BodyPart.LEFT_WRIST, to: BodyPart.LEFT_ELBOW),
+ (from: BodyPart.LEFT_ELBOW, to: BodyPart.LEFT_SHOULDER),
+ (from: BodyPart.LEFT_SHOULDER, to: BodyPart.RIGHT_SHOULDER),
+ (from: BodyPart.RIGHT_SHOULDER, to: BodyPart.RIGHT_ELBOW),
+ (from: BodyPart.RIGHT_ELBOW, to: BodyPart.RIGHT_WRIST),
+ (from: BodyPart.LEFT_SHOULDER, to: BodyPart.LEFT_HIP),
+ (from: BodyPart.LEFT_HIP, to: BodyPart.RIGHT_HIP),
+ (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_SHOULDER),
+ (from: BodyPart.LEFT_HIP, to: BodyPart.LEFT_KNEE),
+ (from: BodyPart.LEFT_KNEE, to: BodyPart.LEFT_ANKLE),
+ (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_KNEE),
+ (from: BodyPart.RIGHT_KNEE, to: BodyPart.RIGHT_ANKLE),
+ ]
+}
+
+// MARK: - Delegates Enum
+enum Delegates: Int, CaseIterable {
+ case CPU
+ case Metal
+ case CoreML
+
+ var description: String {
+ switch self {
+ case .CPU:
+ return "CPU"
+ case .Metal:
+ return "GPU"
+ case .CoreML:
+ return "NPU"
+ }
+ }
+}
+
+// MARK: - Custom Errors
+enum PostprocessError: Error {
+ case missingBodyPart(of: BodyPart)
+}
+
+// MARK: - Information about the model file.
+typealias FileInfo = (name: String, extension: String)
+
+enum Model {
+ static let file: FileInfo = (
+ name: "model_opt", extension: "tflite"
+ )
+
+ static let input = (batchSize: 1, height: 256, width: 256, channelSize: 3)
+ static let output = (batchSize: 1, height: 256, width: 256, channelSize: 1)
+ static let isQuantized = false
+}
+
+
+extension Array {
+ /// Creates a new array from the bytes of the given unsafe data.
+ ///
+ /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
+ /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
+ /// the `unsafeData`'s buffer to a new array returns an unsafe copy.
+ /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
+ /// `MemoryLayout.stride`.
+ /// - Parameter unsafeData: The data containing the bytes to turn into an array.
+ init?(unsafeData: Data) {
+ guard unsafeData.count % MemoryLayout.stride == 0 else { return nil }
+ #if swift(>=5.0)
+ self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
+ #else
+ self = unsafeData.withUnsafeBytes {
+ .init(UnsafeBufferPointer(
+ start: $0,
+ count: unsafeData.count / MemoryLayout.stride
+ ))
+ }
+ #endif // swift(>=5.0)
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard
new file mode 100644
index 0000000000000000000000000000000000000000..a04c79f554777863bd0dc8287bfd60704ce28bf2
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard
new file mode 100644
index 0000000000000000000000000000000000000000..5f5623794bd35b9bb75efd7b7e249fd7357fdfbd
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard
@@ -0,0 +1,236 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift
new file mode 100644
index 0000000000000000000000000000000000000000..fbb51b5a303412c0bbd158d76d025cf88fee6f8f
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift
@@ -0,0 +1,489 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import AVFoundation
+import UIKit
+import os
+
+
+public struct PixelData {
+ var a: UInt8
+ var r: UInt8
+ var g: UInt8
+ var b: UInt8
+}
+
+extension UIImage {
+ convenience init?(pixels: [PixelData], width: Int, height: Int) {
+ guard width > 0 && height > 0, pixels.count == width * height else { return nil }
+ var data = pixels
+ guard let providerRef = CGDataProvider(data: Data(bytes: &data, count: data.count * MemoryLayout.size) as CFData)
+ else { return nil }
+ guard let cgim = CGImage(
+ width: width,
+ height: height,
+ bitsPerComponent: 8,
+ bitsPerPixel: 32,
+ bytesPerRow: width * MemoryLayout.size,
+ space: CGColorSpaceCreateDeviceRGB(),
+ bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue),
+ provider: providerRef,
+ decode: nil,
+ shouldInterpolate: false,
+ intent: .defaultIntent)
+ else { return nil }
+ self.init(cgImage: cgim)
+ }
+}
+
+
+class ViewController: UIViewController {
+ // MARK: Storyboards Connections
+ @IBOutlet weak var previewView: PreviewView!
+
+ //@IBOutlet weak var overlayView: OverlayView!
+ @IBOutlet weak var overlayView: UIImageView!
+
+ private var imageView : UIImageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400))
+
+ private var imageViewInitialized: Bool = false
+
+ @IBOutlet weak var resumeButton: UIButton!
+ @IBOutlet weak var cameraUnavailableLabel: UILabel!
+
+ @IBOutlet weak var tableView: UITableView!
+
+ @IBOutlet weak var threadCountLabel: UILabel!
+ @IBOutlet weak var threadCountStepper: UIStepper!
+
+ @IBOutlet weak var delegatesControl: UISegmentedControl!
+
+ // MARK: ModelDataHandler traits
+ var threadCount: Int = Constants.defaultThreadCount
+ var delegate: Delegates = Constants.defaultDelegate
+
+ // MARK: Result Variables
+ // Inferenced data to render.
+ private var inferencedData: InferencedData?
+
+ // Minimum score to render the result.
+ private let minimumScore: Float = 0.5
+
+ private var avg_latency: Double = 0.0
+
+ // Relative location of `overlayView` to `previewView`.
+ private var overlayViewFrame: CGRect?
+
+ private var previewViewFrame: CGRect?
+
+ // MARK: Controllers that manage functionality
+ // Handles all the camera related functionality
+ private lazy var cameraCapture = CameraFeedManager(previewView: previewView)
+
+ // Handles all data preprocessing and makes calls to run inference.
+ private var modelDataHandler: ModelDataHandler?
+
+ // MARK: View Handling Methods
+ override func viewDidLoad() {
+ super.viewDidLoad()
+
+ do {
+ modelDataHandler = try ModelDataHandler()
+ } catch let error {
+ fatalError(error.localizedDescription)
+ }
+
+ cameraCapture.delegate = self
+ tableView.delegate = self
+ tableView.dataSource = self
+
+ // MARK: UI Initialization
+ // Setup thread count stepper with white color.
+ // https://forums.developer.apple.com/thread/121495
+ threadCountStepper.setDecrementImage(
+ threadCountStepper.decrementImage(for: .normal), for: .normal)
+ threadCountStepper.setIncrementImage(
+ threadCountStepper.incrementImage(for: .normal), for: .normal)
+ // Setup initial stepper value and its label.
+ threadCountStepper.value = Double(Constants.defaultThreadCount)
+ threadCountLabel.text = Constants.defaultThreadCount.description
+
+ // Setup segmented controller's color.
+ delegatesControl.setTitleTextAttributes(
+ [NSAttributedString.Key.foregroundColor: UIColor.lightGray],
+ for: .normal)
+ delegatesControl.setTitleTextAttributes(
+ [NSAttributedString.Key.foregroundColor: UIColor.black],
+ for: .selected)
+ // Remove existing segments to initialize it with `Delegates` entries.
+ delegatesControl.removeAllSegments()
+ Delegates.allCases.forEach { delegate in
+ delegatesControl.insertSegment(
+ withTitle: delegate.description,
+ at: delegate.rawValue,
+ animated: false)
+ }
+ delegatesControl.selectedSegmentIndex = 0
+ }
+
+ override func viewWillAppear(_ animated: Bool) {
+ super.viewWillAppear(animated)
+
+ cameraCapture.checkCameraConfigurationAndStartSession()
+ }
+
+ override func viewWillDisappear(_ animated: Bool) {
+ cameraCapture.stopSession()
+ }
+
+ override func viewDidLayoutSubviews() {
+ overlayViewFrame = overlayView.frame
+ previewViewFrame = previewView.frame
+ }
+
+ // MARK: Button Actions
+ @IBAction func didChangeThreadCount(_ sender: UIStepper) {
+ let changedCount = Int(sender.value)
+ if threadCountLabel.text == changedCount.description {
+ return
+ }
+
+ do {
+ modelDataHandler = try ModelDataHandler(threadCount: changedCount, delegate: delegate)
+ } catch let error {
+ fatalError(error.localizedDescription)
+ }
+ threadCount = changedCount
+ threadCountLabel.text = changedCount.description
+ os_log("Thread count is changed to: %d", threadCount)
+ }
+
+ @IBAction func didChangeDelegate(_ sender: UISegmentedControl) {
+ guard let changedDelegate = Delegates(rawValue: delegatesControl.selectedSegmentIndex) else {
+ fatalError("Unexpected value from delegates segemented controller.")
+ }
+ do {
+ modelDataHandler = try ModelDataHandler(threadCount: threadCount, delegate: changedDelegate)
+ } catch let error {
+ fatalError(error.localizedDescription)
+ }
+ delegate = changedDelegate
+ os_log("Delegate is changed to: %s", delegate.description)
+ }
+
+ @IBAction func didTapResumeButton(_ sender: Any) {
+ cameraCapture.resumeInterruptedSession { complete in
+
+ if complete {
+ self.resumeButton.isHidden = true
+ self.cameraUnavailableLabel.isHidden = true
+ } else {
+ self.presentUnableToResumeSessionAlert()
+ }
+ }
+ }
+
+ func presentUnableToResumeSessionAlert() {
+ let alert = UIAlertController(
+ title: "Unable to Resume Session",
+ message: "There was an error while attempting to resume session.",
+ preferredStyle: .alert
+ )
+ alert.addAction(UIAlertAction(title: "OK", style: .default, handler: nil))
+
+ self.present(alert, animated: true)
+ }
+}
+
+// MARK: - CameraFeedManagerDelegate Methods
+extension ViewController: CameraFeedManagerDelegate {
+ func cameraFeedManager(_ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer) {
+ runModel(on: pixelBuffer)
+ }
+
+ // MARK: Session Handling Alerts
+ func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) {
+ // Handles session run time error by updating the UI and providing a button if session can be
+ // manually resumed.
+ self.resumeButton.isHidden = false
+ }
+
+ func cameraFeedManager(
+ _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool
+ ) {
+ // Updates the UI when session is interupted.
+ if canResumeManually {
+ self.resumeButton.isHidden = false
+ } else {
+ self.cameraUnavailableLabel.isHidden = false
+ }
+ }
+
+ func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) {
+ // Updates UI once session interruption has ended.
+ self.cameraUnavailableLabel.isHidden = true
+ self.resumeButton.isHidden = true
+ }
+
+ func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) {
+ let alertController = UIAlertController(
+ title: "Confirguration Failed", message: "Configuration of camera has failed.",
+ preferredStyle: .alert)
+ let okAction = UIAlertAction(title: "OK", style: .cancel, handler: nil)
+ alertController.addAction(okAction)
+
+ present(alertController, animated: true, completion: nil)
+ }
+
+ func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) {
+ let alertController = UIAlertController(
+ title: "Camera Permissions Denied",
+ message:
+ "Camera permissions have been denied for this app. You can change this by going to Settings",
+ preferredStyle: .alert)
+
+ let cancelAction = UIAlertAction(title: "Cancel", style: .cancel, handler: nil)
+ let settingsAction = UIAlertAction(title: "Settings", style: .default) { action in
+ if let url = URL.init(string: UIApplication.openSettingsURLString) {
+ UIApplication.shared.open(url, options: [:], completionHandler: nil)
+ }
+ }
+
+ alertController.addAction(cancelAction)
+ alertController.addAction(settingsAction)
+
+ present(alertController, animated: true, completion: nil)
+ }
+
+ @objc func runModel(on pixelBuffer: CVPixelBuffer) {
+ guard let overlayViewFrame = overlayViewFrame, let previewViewFrame = previewViewFrame
+ else {
+ return
+ }
+ // To put `overlayView` area as model input, transform `overlayViewFrame` following transform
+ // from `previewView` to `pixelBuffer`. `previewView` area is transformed to fit in
+ // `pixelBuffer`, because `pixelBuffer` as a camera output is resized to fill `previewView`.
+ // https://developer.apple.com/documentation/avfoundation/avlayervideogravity/1385607-resizeaspectfill
+ let modelInputRange = overlayViewFrame.applying(
+ previewViewFrame.size.transformKeepAspect(toFitIn: pixelBuffer.size))
+
+ // Run Midas model.
+ guard
+ let (result, width, height, times) = self.modelDataHandler?.runMidas(
+ on: pixelBuffer,
+ from: modelInputRange,
+ to: overlayViewFrame.size)
+ else {
+ os_log("Cannot get inference result.", type: .error)
+ return
+ }
+
+ if avg_latency == 0 {
+ avg_latency = times.inference
+ } else {
+ avg_latency = times.inference*0.1 + avg_latency*0.9
+ }
+
+ // Udpate `inferencedData` to render data in `tableView`.
+ inferencedData = InferencedData(score: Float(avg_latency), times: times)
+
+ //let height = 256
+ //let width = 256
+
+ let outputs = result
+ let outputs_size = width * height;
+
+ var multiplier : Float = 1.0;
+
+ let max_val : Float = outputs.max() ?? 0
+ let min_val : Float = outputs.min() ?? 0
+
+ if((max_val - min_val) > 0) {
+ multiplier = 255 / (max_val - min_val);
+ }
+
+ // Draw result.
+ DispatchQueue.main.async {
+ self.tableView.reloadData()
+
+ var pixels: [PixelData] = .init(repeating: .init(a: 255, r: 0, g: 0, b: 0), count: width * height)
+
+ for i in pixels.indices {
+ //if(i < 1000)
+ //{
+ let val = UInt8((outputs[i] - min_val) * multiplier)
+
+ pixels[i].r = val
+ pixels[i].g = val
+ pixels[i].b = val
+ //}
+ }
+
+
+ /*
+ pixels[i].a = 255
+ pixels[i].r = .random(in: 0...255)
+ pixels[i].g = .random(in: 0...255)
+ pixels[i].b = .random(in: 0...255)
+ }
+ */
+
+ DispatchQueue.main.async {
+ let image = UIImage(pixels: pixels, width: width, height: height)
+
+ self.imageView.image = image
+
+ if (self.imageViewInitialized == false) {
+ self.imageViewInitialized = true
+ self.overlayView.addSubview(self.imageView)
+ self.overlayView.setNeedsDisplay()
+ }
+ }
+
+ /*
+ let image = UIImage(pixels: pixels, width: width, height: height)
+
+ var imageView : UIImageView
+ imageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400));
+ imageView.image = image
+ self.overlayView.addSubview(imageView)
+ self.overlayView.setNeedsDisplay()
+ */
+ }
+ }
+/*
+ func drawResult(of result: Result) {
+ self.overlayView.dots = result.dots
+ self.overlayView.lines = result.lines
+ self.overlayView.setNeedsDisplay()
+ }
+
+ func clearResult() {
+ self.overlayView.clear()
+ self.overlayView.setNeedsDisplay()
+ }
+ */
+
+}
+
+
+// MARK: - TableViewDelegate, TableViewDataSource Methods
+extension ViewController: UITableViewDelegate, UITableViewDataSource {
+ func numberOfSections(in tableView: UITableView) -> Int {
+ return InferenceSections.allCases.count
+ }
+
+ func tableView(_ tableView: UITableView, numberOfRowsInSection section: Int) -> Int {
+ guard let section = InferenceSections(rawValue: section) else {
+ return 0
+ }
+
+ return section.subcaseCount
+ }
+
+ func tableView(_ tableView: UITableView, cellForRowAt indexPath: IndexPath) -> UITableViewCell {
+ let cell = tableView.dequeueReusableCell(withIdentifier: "InfoCell") as! InfoCell
+ guard let section = InferenceSections(rawValue: indexPath.section) else {
+ return cell
+ }
+ guard let data = inferencedData else { return cell }
+
+ var fieldName: String
+ var info: String
+
+ switch section {
+ case .Score:
+ fieldName = section.description
+ info = String(format: "%.3f", data.score)
+ case .Time:
+ guard let row = ProcessingTimes(rawValue: indexPath.row) else {
+ return cell
+ }
+ var time: Double
+ switch row {
+ case .InferenceTime:
+ time = data.times.inference
+ }
+ fieldName = row.description
+ info = String(format: "%.2fms", time)
+ }
+
+ cell.fieldNameLabel.text = fieldName
+ cell.infoLabel.text = info
+
+ return cell
+ }
+
+ func tableView(_ tableView: UITableView, heightForRowAt indexPath: IndexPath) -> CGFloat {
+ guard let section = InferenceSections(rawValue: indexPath.section) else {
+ return 0
+ }
+
+ var height = Traits.normalCellHeight
+ if indexPath.row == section.subcaseCount - 1 {
+ height = Traits.separatorCellHeight + Traits.bottomSpacing
+ }
+ return height
+ }
+
+}
+
+// MARK: - Private enums
+/// UI coinstraint values
+fileprivate enum Traits {
+ static let normalCellHeight: CGFloat = 35.0
+ static let separatorCellHeight: CGFloat = 25.0
+ static let bottomSpacing: CGFloat = 30.0
+}
+
+fileprivate struct InferencedData {
+ var score: Float
+ var times: Times
+}
+
+/// Type of sections in Info Cell
+fileprivate enum InferenceSections: Int, CaseIterable {
+ case Score
+ case Time
+
+ var description: String {
+ switch self {
+ case .Score:
+ return "Average"
+ case .Time:
+ return "Processing Time"
+ }
+ }
+
+ var subcaseCount: Int {
+ switch self {
+ case .Score:
+ return 1
+ case .Time:
+ return ProcessingTimes.allCases.count
+ }
+ }
+}
+
+/// Type of processing times in Time section in Info Cell
+fileprivate enum ProcessingTimes: Int, CaseIterable {
+ case InferenceTime
+
+ var description: String {
+ switch self {
+ case .InferenceTime:
+ return "Inference Time"
+ }
+ }
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift
new file mode 100644
index 0000000000000000000000000000000000000000..3b53910b57563b6a195fd53321fa2a24ebaf3d3f
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift
@@ -0,0 +1,63 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import UIKit
+
+/// UIView for rendering inference output.
+class OverlayView: UIView {
+
+ var dots = [CGPoint]()
+ var lines = [Line]()
+
+ override func draw(_ rect: CGRect) {
+ for dot in dots {
+ drawDot(of: dot)
+ }
+ for line in lines {
+ drawLine(of: line)
+ }
+ }
+
+ func drawDot(of dot: CGPoint) {
+ let dotRect = CGRect(
+ x: dot.x - Traits.dot.radius / 2, y: dot.y - Traits.dot.radius / 2,
+ width: Traits.dot.radius, height: Traits.dot.radius)
+ let dotPath = UIBezierPath(ovalIn: dotRect)
+
+ Traits.dot.color.setFill()
+ dotPath.fill()
+ }
+
+ func drawLine(of line: Line) {
+ let linePath = UIBezierPath()
+ linePath.move(to: CGPoint(x: line.from.x, y: line.from.y))
+ linePath.addLine(to: CGPoint(x: line.to.x, y: line.to.y))
+ linePath.close()
+
+ linePath.lineWidth = Traits.line.width
+ Traits.line.color.setStroke()
+
+ linePath.stroke()
+ }
+
+ func clear() {
+ self.dots = []
+ self.lines = []
+ }
+}
+
+private enum Traits {
+ static let dot = (radius: CGFloat(5), color: UIColor.orange)
+ static let line = (width: CGFloat(1.0), color: UIColor.orange)
+}
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile
new file mode 100644
index 0000000000000000000000000000000000000000..5e9461fc96dbbe3c22ca6bbf2bfd7df3981b9462
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile
@@ -0,0 +1,12 @@
+# Uncomment the next line to define a global platform for your project
+ platform :ios, '12.0'
+
+target 'Midas' do
+ # Comment the next line if you're not using Swift and don't want to use dynamic frameworks
+ use_frameworks!
+
+ # Pods for Midas
+ pod 'TensorFlowLiteSwift', '~> 0.0.1-nightly'
+ pod 'TensorFlowLiteSwift/CoreML', '~> 0.0.1-nightly'
+ pod 'TensorFlowLiteSwift/Metal', '~> 0.0.1-nightly'
+end
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7b8eb29feaa21e67814b035dbd5c5fb2c62a4151
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md
@@ -0,0 +1,105 @@
+# Tensorflow Lite MiDaS iOS Example
+
+### Requirements
+
+- XCode 11.0 or above
+- iOS 12.0 or above, [iOS 14 breaks the NPU Delegate](https://github.com/tensorflow/tensorflow/issues/43339)
+- TensorFlow 2.4.0, TensorFlowLiteSwift -> 0.0.1-nightly
+
+## Quick Start with a MiDaS Example
+
+MiDaS is a neural network to compute depth from a single image. It uses TensorFlowLiteSwift / C++ libraries on iOS. The code is written in Swift.
+
+Paper: https://arxiv.org/abs/1907.01341
+
+> Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
+> René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun
+
+### Install TensorFlow
+
+Set default python version to python3:
+
+```
+echo 'export PATH=/usr/local/opt/python/libexec/bin:$PATH' >> ~/.zshenv
+echo 'alias python=python3' >> ~/.zshenv
+echo 'alias pip=pip3' >> ~/.zshenv
+```
+
+Install TensorFlow
+
+```shell
+pip install tensorflow
+```
+
+### Install TensorFlowLiteSwift via Cocoapods
+
+Set required TensorFlowLiteSwift version in the file (`0.0.1-nightly` is recommended): https://github.com/isl-org/MiDaS/blob/master/mobile/ios/Podfile#L9
+
+Install: brew, ruby, cocoapods
+
+```
+ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
+brew install mc rbenv ruby-build
+sudo gem install cocoapods
+```
+
+
+The TensorFlowLiteSwift library is available in [Cocoapods](https://cocoapods.org/), to integrate it to our project, we can run in the root directory of the project:
+
+```ruby
+pod install
+```
+
+Now open the `Midas.xcworkspace` file in XCode, select your iPhone device (XCode->Product->Destination->iPhone) and launch it (cmd + R). If everything works well, you should see a real-time depth map from your camera.
+
+### Model
+
+The TensorFlow (TFlite) model `midas.tflite` is in the folder `/Midas/Model`
+
+
+To use another model, you should convert it from TensorFlow saved-model to TFlite model (so that it can be deployed):
+
+```python
+saved_model_export_dir = "./saved_model"
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_export_dir)
+tflite_model = converter.convert()
+open(model_tflite_name, "wb").write("model.tflite")
+```
+
+### Setup XCode
+
+* Open directory `.xcworkspace` from the XCode
+
+* Press on your ProjectName (left-top corner) -> change Bundle Identifier to `com.midas.tflite-npu` or something like this (it should be unique)
+
+* select your Developer Team (your should be signed-in by using your AppleID)
+
+* Connect your iPhone (if you want to run it on real device instead of simulator), select your iPhone device (XCode->Product->Destination->iPhone)
+
+* Click in the XCode: Product -> Run
+
+* On your iPhone device go to the: Settings -> General -> Device Management (or Profiles) -> Apple Development -> Trust Apple Development
+
+----
+
+Original repository: https://github.com/isl-org/MiDaS
+
+
+### Examples:
+
+|  |  |  |  |
+|---|---|---|---|
+
+## LICENSE
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d737b39d966278f5c6bc29802526ab86f8473de4
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+# Download TF Lite model from the internet if it does not exist.
+
+TFLITE_MODEL="model_opt.tflite"
+TFLITE_FILE="Midas/Model/${TFLITE_MODEL}"
+MODEL_SRC="https://github.com/isl-org/MiDaS/releases/download/v2/${TFLITE_MODEL}"
+
+if test -f "${TFLITE_FILE}"; then
+ echo "INFO: TF Lite model already exists. Skip downloading and use the local model."
+else
+ curl --create-dirs -o "${TFLITE_FILE}" -LJO "${MODEL_SRC}"
+ echo "INFO: Downloaded TensorFlow Lite model to ${TFLITE_FILE}."
+fi
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Alexey
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d43c2606767798ee46b34292e0483197424ec23
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md
@@ -0,0 +1,131 @@
+# MiDaS for ROS1 by using LibTorch in C++
+
+### Requirements
+
+- Ubuntu 17.10 / 18.04 / 20.04, Debian Stretch
+- ROS Melodic for Ubuntu (17.10 / 18.04) / Debian Stretch, ROS Noetic for Ubuntu 20.04
+- C++11
+- LibTorch >= 1.6
+
+## Quick Start with a MiDaS Example
+
+MiDaS is a neural network to compute depth from a single image.
+
+* input from `image_topic`: `sensor_msgs/Image` - `RGB8` image with any shape
+* output to `midas_topic`: `sensor_msgs/Image` - `TYPE_32FC1` inverse relative depth maps in range [0 - 255] with original size and channels=1
+
+### Install Dependecies
+
+* install ROS Melodic for Ubuntu 17.10 / 18.04:
+```bash
+wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_melodic_ubuntu_17_18.sh
+./install_ros_melodic_ubuntu_17_18.sh
+```
+
+or Noetic for Ubuntu 20.04:
+
+```bash
+wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_noetic_ubuntu_20.sh
+./install_ros_noetic_ubuntu_20.sh
+```
+
+
+* install LibTorch 1.7 with CUDA 11.0:
+
+On **Jetson (ARM)**:
+```bash
+wget https://nvidia.box.com/shared/static/wa34qwrwtk9njtyarwt5nvo6imenfy26.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl
+sudo apt-get install python3-pip libopenblas-base libopenmpi-dev
+pip3 install Cython
+pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl
+```
+Or compile LibTorch from source: https://github.com/pytorch/pytorch#from-source
+
+On **Linux (x86_64)**:
+```bash
+cd ~/
+wget https://download.pytorch.org/libtorch/cu110/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu110.zip
+unzip libtorch-cxx11-abi-shared-with-deps-1.7.0+cu110.zip
+```
+
+* create symlink for OpenCV:
+
+```bash
+sudo ln -s /usr/include/opencv4 /usr/include/opencv
+```
+
+* download and install MiDaS:
+
+```bash
+source ~/.bashrc
+cd ~/
+mkdir catkin_ws
+cd catkin_ws
+git clone https://github.com/isl-org/MiDaS
+mkdir src
+cp -r MiDaS/ros/* src
+
+chmod +x src/additions/*.sh
+chmod +x src/*.sh
+chmod +x src/midas_cpp/scripts/*.py
+cp src/additions/do_catkin_make.sh ./do_catkin_make.sh
+./do_catkin_make.sh
+./src/additions/downloads.sh
+```
+
+### Usage
+
+* run only `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh`
+
+#### Test
+
+* Test - capture video and show result in the window:
+ * place any `test.mp4` video file to the directory `~/catkin_ws/src/`
+ * run `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh`
+ * run test nodes in another terminal: `cd ~/catkin_ws/src && ./run_talker_listener_test.sh` and wait 30 seconds
+
+ (to use Python 2, run command `sed -i 's/python3/python2/' ~/catkin_ws/src/midas_cpp/scripts/*.py` )
+
+## Mobile version of MiDaS - Monocular Depth Estimation
+
+### Accuracy
+
+* MiDaS v2 small - ResNet50 default-decoder 384x384
+* MiDaS v2.1 small - EfficientNet-Lite3 small-decoder 256x256
+
+**Zero-shot error** (the lower - the better):
+
+| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 |
+|---|---|---|---|---|---|---|
+| MiDaS v2 small 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 |
+| MiDaS v2.1 small 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** |
+| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** |
+
+None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning.
+
+### Inference speed (FPS) on nVidia GPU
+
+Inference speed excluding pre and post processing, batch=1, **Frames Per Second** (the higher - the better):
+
+| Model | Jetson Nano, FPS | RTX 2080Ti, FPS |
+|---|---|---|
+| MiDaS v2 small 384x384 | 1.6 | 117 |
+| MiDaS v2.1 small 256x256 | 8.1 | 232 |
+| SpeedUp, X times | **5x** | **2x** |
+
+### Citation
+
+This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3):
+
+>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
+René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun
+
+Please cite our paper if you use this code or any of the models:
+```
+@article{Ranftl2020,
+ author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun},
+ title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer},
+ journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
+ year = {2020},
+}
+```
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0d416fc00282aab146326bbba12a9274e1ba29b8
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh
@@ -0,0 +1,5 @@
+mkdir src
+catkin_make
+source devel/setup.bash
+echo $ROS_PACKAGE_PATH
+chmod +x ./devel/setup.bash
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9c967d4e2dc7997da26399a063b5a54ecc314eb1
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh
@@ -0,0 +1,5 @@
+mkdir ~/.ros
+wget https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small-traced.pt
+cp ./model-small-traced.pt ~/.ros/model-small-traced.pt
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b868112631e9d9bc7bccb601407dfc857b8a99d5
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh
@@ -0,0 +1,34 @@
+#@title { display-mode: "code" }
+
+#from http://wiki.ros.org/indigo/Installation/Ubuntu
+
+#1.2 Setup sources.list
+sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list'
+
+# 1.3 Setup keys
+sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654
+sudo apt-key adv --keyserver 'hkp://ha.pool.sks-keyservers.net:80' --recv-key 421C365BD9FF1F717815A3895523BAEEB01FA116
+
+curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add -
+
+# 1.4 Installation
+sudo apt-get update
+sudo apt-get upgrade
+
+# Desktop-Full Install:
+sudo apt-get install ros-melodic-desktop-full
+
+printf "\nsource /opt/ros/melodic/setup.bash\n" >> ~/.bashrc
+
+# 1.5 Initialize rosdep
+sudo rosdep init
+rosdep update
+
+
+# 1.7 Getting rosinstall (python)
+sudo apt-get install python-rosinstall
+sudo apt-get install python-catkin-tools
+sudo apt-get install python-rospy
+sudo apt-get install python-rosdep
+sudo apt-get install python-roscd
+sudo apt-get install python-pip
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d73ea1a3d92359819167d735a92d2a650b9bc245
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh
@@ -0,0 +1,33 @@
+#@title { display-mode: "code" }
+
+#from http://wiki.ros.org/indigo/Installation/Ubuntu
+
+#1.2 Setup sources.list
+sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list'
+
+# 1.3 Setup keys
+sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654
+
+curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add -
+
+# 1.4 Installation
+sudo apt-get update
+sudo apt-get upgrade
+
+# Desktop-Full Install:
+sudo apt-get install ros-noetic-desktop-full
+
+printf "\nsource /opt/ros/noetic/setup.bash\n" >> ~/.bashrc
+
+# 1.5 Initialize rosdep
+sudo rosdep init
+rosdep update
+
+
+# 1.7 Getting rosinstall (python)
+sudo apt-get install python3-rosinstall
+sudo apt-get install python3-catkin-tools
+sudo apt-get install python3-rospy
+sudo apt-get install python3-rosdep
+sudo apt-get install python3-roscd
+sudo apt-get install python3-pip
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d0ef6073a9c9ce40744e1c81d557c1c68255b95e
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh
@@ -0,0 +1,16 @@
+cd ~/catkin_ws/src
+catkin_create_pkg midas_cpp std_msgs roscpp cv_bridge sensor_msgs image_transport
+cd ~/catkin_ws
+catkin_make
+
+chmod +x ~/catkin_ws/devel/setup.bash
+printf "\nsource ~/catkin_ws/devel/setup.bash" >> ~/.bashrc
+source ~/catkin_ws/devel/setup.bash
+
+
+sudo rosdep init
+rosdep update
+#rospack depends1 midas_cpp
+roscd midas_cpp
+#cat package.xml
+#rospack depends midas_cpp
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5a0d1583fffdc49216c625dfd07af2ae3b01a7a0
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh
@@ -0,0 +1,2 @@
+source ~/catkin_ws/devel/setup.bash
+roslaunch midas_cpp midas_cpp.launch model_name:="model-small-traced.pt" input_topic:="image_topic" output_topic:="midas_topic" out_orig_size:="true"
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..885341691d217f9c4c8fcb1e4ff568d87788c7b8
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt
@@ -0,0 +1,189 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(midas_cpp)
+
+## Compile as C++11, supported in ROS Kinetic and newer
+# add_compile_options(-std=c++11)
+
+## Find catkin macros and libraries
+## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
+## is used, also find other catkin packages
+find_package(catkin REQUIRED COMPONENTS
+ cv_bridge
+ image_transport
+ roscpp
+ rospy
+ sensor_msgs
+ std_msgs
+)
+
+## System dependencies are found with CMake's conventions
+# find_package(Boost REQUIRED COMPONENTS system)
+
+list(APPEND CMAKE_PREFIX_PATH "~/libtorch")
+list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python3.6/dist-packages/torch/lib")
+list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python2.7/dist-packages/torch/lib")
+
+if(NOT EXISTS "~/libtorch")
+ if (EXISTS "/usr/local/lib/python3.6/dist-packages/torch")
+ include_directories(/usr/local/include)
+ include_directories(/usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include)
+ include_directories(/usr/local/lib/python3.6/dist-packages/torch/include)
+
+ link_directories(/usr/local/lib)
+ link_directories(/usr/local/lib/python3.6/dist-packages/torch/lib)
+
+ set(CMAKE_PREFIX_PATH /usr/local/lib/python3.6/dist-packages/torch)
+ set(Boost_USE_MULTITHREADED ON)
+ set(Torch_DIR /usr/local/lib/python3.6/dist-packages/torch)
+
+ elseif (EXISTS "/usr/local/lib/python2.7/dist-packages/torch")
+
+ include_directories(/usr/local/include)
+ include_directories(/usr/local/lib/python2.7/dist-packages/torch/include/torch/csrc/api/include)
+ include_directories(/usr/local/lib/python2.7/dist-packages/torch/include)
+
+ link_directories(/usr/local/lib)
+ link_directories(/usr/local/lib/python2.7/dist-packages/torch/lib)
+
+ set(CMAKE_PREFIX_PATH /usr/local/lib/python2.7/dist-packages/torch)
+ set(Boost_USE_MULTITHREADED ON)
+ set(Torch_DIR /usr/local/lib/python2.7/dist-packages/torch)
+ endif()
+endif()
+
+
+
+find_package(Torch REQUIRED)
+find_package(OpenCV REQUIRED)
+include_directories( ${OpenCV_INCLUDE_DIRS} )
+
+add_executable(midas_cpp src/main.cpp)
+target_link_libraries(midas_cpp "${TORCH_LIBRARIES}" "${OpenCV_LIBS} ${catkin_LIBRARIES}")
+set_property(TARGET midas_cpp PROPERTY CXX_STANDARD 14)
+
+
+
+###################################
+## catkin specific configuration ##
+###################################
+## The catkin_package macro generates cmake config files for your package
+## Declare things to be passed to dependent projects
+## INCLUDE_DIRS: uncomment this if your package contains header files
+## LIBRARIES: libraries you create in this project that dependent projects also need
+## CATKIN_DEPENDS: catkin_packages dependent projects also need
+## DEPENDS: system dependencies of this project that dependent projects also need
+catkin_package(
+# INCLUDE_DIRS include
+# LIBRARIES midas_cpp
+# CATKIN_DEPENDS cv_bridge image_transport roscpp sensor_msgs std_msgs
+# DEPENDS system_lib
+)
+
+###########
+## Build ##
+###########
+
+## Specify additional locations of header files
+## Your package locations should be listed before other locations
+include_directories(
+# include
+ ${catkin_INCLUDE_DIRS}
+)
+
+## Declare a C++ library
+# add_library(${PROJECT_NAME}
+# src/${PROJECT_NAME}/midas_cpp.cpp
+# )
+
+## Add cmake target dependencies of the library
+## as an example, code may need to be generated before libraries
+## either from message generation or dynamic reconfigure
+# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Declare a C++ executable
+## With catkin_make all packages are built within a single CMake context
+## The recommended prefix ensures that target names across packages don't collide
+# add_executable(${PROJECT_NAME}_node src/midas_cpp_node.cpp)
+
+## Rename C++ executable without prefix
+## The above recommended prefix causes long target names, the following renames the
+## target back to the shorter version for ease of user use
+## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node"
+# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "")
+
+## Add cmake target dependencies of the executable
+## same as for the library above
+# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
+
+## Specify libraries to link a library or executable target against
+# target_link_libraries(${PROJECT_NAME}_node
+# ${catkin_LIBRARIES}
+# )
+
+#############
+## Install ##
+#############
+
+# all install targets should use catkin DESTINATION variables
+# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html
+
+## Mark executable scripts (Python etc.) for installation
+## in contrast to setup.py, you can choose the destination
+# catkin_install_python(PROGRAMS
+# scripts/my_python_script
+# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark executables for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html
+# install(TARGETS ${PROJECT_NAME}_node
+# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+# )
+
+## Mark libraries for installation
+## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html
+# install(TARGETS ${PROJECT_NAME}
+# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION}
+# )
+
+## Mark cpp header files for installation
+# install(DIRECTORY include/${PROJECT_NAME}/
+# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION}
+# FILES_MATCHING PATTERN "*.h"
+# PATTERN ".svn" EXCLUDE
+# )
+
+## Mark other files for installation (e.g. launch and bag files, etc.)
+# install(FILES
+# # myfile1
+# # myfile2
+# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+# )
+
+#############
+## Testing ##
+#############
+
+## Add gtest based cpp test target and link libraries
+# catkin_add_gtest(${PROJECT_NAME}-test test/test_midas_cpp.cpp)
+# if(TARGET ${PROJECT_NAME}-test)
+# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME})
+# endif()
+
+## Add folders to be run by python nosetests
+# catkin_add_nosetests(test)
+
+install(TARGETS ${PROJECT_NAME}
+ ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+ LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION}
+ RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
+)
+
+add_custom_command(
+ TARGET midas_cpp POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy
+ ${CMAKE_CURRENT_BINARY_DIR}/midas_cpp
+ ${CMAKE_SOURCE_DIR}/midas_cpp
+)
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch
new file mode 100644
index 0000000000000000000000000000000000000000..88e86f42f668e76ad4976ec6794a8cb0f20cac65
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch
new file mode 100644
index 0000000000000000000000000000000000000000..8817a4f4933c56986fe0edc0886b2fded3d3406d
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch
@@ -0,0 +1,23 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml
new file mode 100644
index 0000000000000000000000000000000000000000..9cac90eba75409bd170f73531c54c83c52ff047a
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml
@@ -0,0 +1,77 @@
+
+
+ midas_cpp
+ 0.1.0
+ The midas_cpp package
+
+ Alexey Bochkovskiy
+ MIT
+ https://github.com/isl-org/MiDaS/tree/master/ros
+
+
+
+
+
+
+ TODO
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ catkin
+ cv_bridge
+ image_transport
+ roscpp
+ rospy
+ sensor_msgs
+ std_msgs
+ cv_bridge
+ image_transport
+ roscpp
+ rospy
+ sensor_msgs
+ std_msgs
+ cv_bridge
+ image_transport
+ roscpp
+ rospy
+ sensor_msgs
+ std_msgs
+
+
+
+
+
+
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py
new file mode 100644
index 0000000000000000000000000000000000000000..6927ea7a83ac9309e5f883ee974a5dcfa8a2aa3b
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python3
+from __future__ import print_function
+
+import roslib
+#roslib.load_manifest('my_package')
+import sys
+import rospy
+import cv2
+import numpy as np
+from std_msgs.msg import String
+from sensor_msgs.msg import Image
+from cv_bridge import CvBridge, CvBridgeError
+
+class video_show:
+
+ def __init__(self):
+ self.show_output = rospy.get_param('~show_output', True)
+ self.save_output = rospy.get_param('~save_output', False)
+ self.output_video_file = rospy.get_param('~output_video_file','result.mp4')
+ # rospy.loginfo(f"Listener - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}")
+
+ self.bridge = CvBridge()
+ self.image_sub = rospy.Subscriber("midas_topic", Image, self.callback)
+
+ def callback(self, data):
+ try:
+ cv_image = self.bridge.imgmsg_to_cv2(data)
+ except CvBridgeError as e:
+ print(e)
+ return
+
+ if cv_image.size == 0:
+ return
+
+ rospy.loginfo("Listener: Received new frame")
+ cv_image = cv_image.astype("uint8")
+
+ if self.show_output==True:
+ cv2.imshow("video_show", cv_image)
+ cv2.waitKey(10)
+
+ if self.save_output==True:
+ if self.video_writer_init==False:
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
+ self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0]))
+
+ self.out.write(cv_image)
+
+
+
+def main(args):
+ rospy.init_node('listener', anonymous=True)
+ ic = video_show()
+ try:
+ rospy.spin()
+ except KeyboardInterrupt:
+ print("Shutting down")
+ cv2.destroyAllWindows()
+
+if __name__ == '__main__':
+ main(sys.argv)
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e235f6958d644b89383752ab18e9e2275f55e5
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python3
+from __future__ import print_function
+
+import roslib
+#roslib.load_manifest('my_package')
+import sys
+import rospy
+import cv2
+import numpy as np
+from std_msgs.msg import String
+from sensor_msgs.msg import Image
+from cv_bridge import CvBridge, CvBridgeError
+
+class video_show:
+
+ def __init__(self):
+ self.show_output = rospy.get_param('~show_output', True)
+ self.save_output = rospy.get_param('~save_output', False)
+ self.output_video_file = rospy.get_param('~output_video_file','result.mp4')
+ # rospy.loginfo(f"Listener original - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}")
+
+ self.bridge = CvBridge()
+ self.image_sub = rospy.Subscriber("image_topic", Image, self.callback)
+
+ def callback(self, data):
+ try:
+ cv_image = self.bridge.imgmsg_to_cv2(data)
+ except CvBridgeError as e:
+ print(e)
+ return
+
+ if cv_image.size == 0:
+ return
+
+ rospy.loginfo("Listener_original: Received new frame")
+ cv_image = cv_image.astype("uint8")
+
+ if self.show_output==True:
+ cv2.imshow("video_show_orig", cv_image)
+ cv2.waitKey(10)
+
+ if self.save_output==True:
+ if self.video_writer_init==False:
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
+ self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0]))
+
+ self.out.write(cv_image)
+
+
+
+def main(args):
+ rospy.init_node('listener_original', anonymous=True)
+ ic = video_show()
+ try:
+ rospy.spin()
+ except KeyboardInterrupt:
+ print("Shutting down")
+ cv2.destroyAllWindows()
+
+if __name__ == '__main__':
+ main(sys.argv)
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py
new file mode 100644
index 0000000000000000000000000000000000000000..8219cc8632484a2efd02984347c615efad6b78b2
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python3
+
+
+import roslib
+#roslib.load_manifest('my_package')
+import sys
+import rospy
+import cv2
+from std_msgs.msg import String
+from sensor_msgs.msg import Image
+from cv_bridge import CvBridge, CvBridgeError
+
+
+def talker():
+ rospy.init_node('talker', anonymous=True)
+
+ use_camera = rospy.get_param('~use_camera', False)
+ input_video_file = rospy.get_param('~input_video_file','test.mp4')
+ # rospy.loginfo(f"Talker - params: use_camera={use_camera}, input_video_file={input_video_file}")
+
+ # rospy.loginfo("Talker: Trying to open a video stream")
+ if use_camera == True:
+ cap = cv2.VideoCapture(0)
+ else:
+ cap = cv2.VideoCapture(input_video_file)
+
+ pub = rospy.Publisher('image_topic', Image, queue_size=1)
+ rate = rospy.Rate(30) # 30hz
+ bridge = CvBridge()
+
+ while not rospy.is_shutdown():
+ ret, cv_image = cap.read()
+ if ret==False:
+ print("Talker: Video is over")
+ rospy.loginfo("Video is over")
+ return
+
+ try:
+ image = bridge.cv2_to_imgmsg(cv_image, "bgr8")
+ except CvBridgeError as e:
+ rospy.logerr("Talker: cv2image conversion failed: ", e)
+ print(e)
+ continue
+
+ rospy.loginfo("Talker: Publishing frame")
+ pub.publish(image)
+ rate.sleep()
+
+if __name__ == '__main__':
+ try:
+ talker()
+ except rospy.ROSInterruptException:
+ pass
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e4fc72c6955f66af71c9cb1fc7a7b1f643129685
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp
@@ -0,0 +1,285 @@
+#include
+#include
+#include
+#include
+
+#include
+
+#include // One-stop header.
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+// includes for OpenCV >= 3.x
+#ifndef CV_VERSION_EPOCH
+#include
+#include
+#include
+#endif
+
+// OpenCV includes for OpenCV 2.x
+#ifdef CV_VERSION_EPOCH
+#include
+#include
+#include
+#include
+#endif
+
+static const std::string OPENCV_WINDOW = "Image window";
+
+class Midas
+{
+ ros::NodeHandle nh_;
+ image_transport::ImageTransport it_;
+ image_transport::Subscriber image_sub_;
+ image_transport::Publisher image_pub_;
+
+ torch::jit::script::Module module;
+ torch::Device device;
+
+ auto ToTensor(cv::Mat img, bool show_output = false, bool unsqueeze = false, int unsqueeze_dim = 0)
+ {
+ //std::cout << "image shape: " << img.size() << std::endl;
+ at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte);
+
+ if (unsqueeze)
+ {
+ tensor_image.unsqueeze_(unsqueeze_dim);
+ //std::cout << "tensors new shape: " << tensor_image.sizes() << std::endl;
+ }
+
+ if (show_output)
+ {
+ std::cout << tensor_image.slice(2, 0, 1) << std::endl;
+ }
+ //std::cout << "tenor shape: " << tensor_image.sizes() << std::endl;
+ return tensor_image;
+ }
+
+ auto ToInput(at::Tensor tensor_image)
+ {
+ // Create a vector of inputs.
+ return std::vector{tensor_image};
+ }
+
+ auto ToCvImage(at::Tensor tensor, int cv_type = CV_8UC3)
+ {
+ int width = tensor.sizes()[0];
+ int height = tensor.sizes()[1];
+ try
+ {
+ cv::Mat output_mat;
+ if (cv_type == CV_8UC4 || cv_type == CV_8UC3 || cv_type == CV_8UC2 || cv_type == CV_8UC1) {
+ cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr());
+ output_mat = cv_image;
+ }
+ else if (cv_type == CV_32FC4 || cv_type == CV_32FC3 || cv_type == CV_32FC2 || cv_type == CV_32FC1) {
+ cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr());
+ output_mat = cv_image;
+ }
+ else if (cv_type == CV_64FC4 || cv_type == CV_64FC3 || cv_type == CV_64FC2 || cv_type == CV_64FC1) {
+ cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr());
+ output_mat = cv_image;
+ }
+
+ //show_image(output_mat, "converted image from tensor");
+ return output_mat.clone();
+ }
+ catch (const c10::Error& e)
+ {
+ std::cout << "an error has occured : " << e.msg() << std::endl;
+ }
+ return cv::Mat(height, width, CV_8UC3);
+ }
+
+ std::string input_topic, output_topic, model_name;
+ bool out_orig_size;
+ int net_width, net_height;
+ torch::NoGradGuard guard;
+ at::Tensor mean, std;
+ at::Tensor output, tensor;
+
+public:
+ Midas()
+ : nh_(), it_(nh_), device(torch::Device(torch::kCPU))
+ {
+ ros::param::param("~input_topic", input_topic, "image_topic");
+ ros::param::param("~output_topic", output_topic, "midas_topic");
+ ros::param::param("~model_name", model_name, "model-small-traced.pt");
+ ros::param::param("~out_orig_size", out_orig_size, true);
+ ros::param::param("~net_width", net_width, 256);
+ ros::param::param("~net_height", net_height, 256);
+
+ std::cout << ", input_topic = " << input_topic <<
+ ", output_topic = " << output_topic <<
+ ", model_name = " << model_name <<
+ ", out_orig_size = " << out_orig_size <<
+ ", net_width = " << net_width <<
+ ", net_height = " << net_height <<
+ std::endl;
+
+ // Subscrive to input video feed and publish output video feed
+ image_sub_ = it_.subscribe(input_topic, 1, &Midas::imageCb, this);
+ image_pub_ = it_.advertise(output_topic, 1);
+
+ std::cout << "Try to load torchscript model \n";
+
+ try {
+ // Deserialize the ScriptModule from a file using torch::jit::load().
+ module = torch::jit::load(model_name);
+ }
+ catch (const c10::Error& e) {
+ std::cerr << "error loading the model\n";
+ exit(0);
+ }
+
+ std::cout << "ok\n";
+
+ try {
+ module.eval();
+ torch::jit::getProfilingMode() = false;
+ torch::jit::setGraphExecutorOptimize(true);
+
+ mean = torch::tensor({ 0.485, 0.456, 0.406 });
+ std = torch::tensor({ 0.229, 0.224, 0.225 });
+
+ if (torch::hasCUDA()) {
+ std::cout << "cuda is available" << std::endl;
+ at::globalContext().setBenchmarkCuDNN(true);
+ device = torch::Device(torch::kCUDA);
+ module.to(device);
+ mean = mean.to(device);
+ std = std.to(device);
+ }
+ }
+ catch (const c10::Error& e)
+ {
+ std::cerr << " module initialization: " << e.msg() << std::endl;
+ }
+ }
+
+ ~Midas()
+ {
+ }
+
+ void imageCb(const sensor_msgs::ImageConstPtr& msg)
+ {
+ cv_bridge::CvImagePtr cv_ptr;
+ try
+ {
+ // sensor_msgs::Image to cv::Mat
+ cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::RGB8);
+ }
+ catch (cv_bridge::Exception& e)
+ {
+ ROS_ERROR("cv_bridge exception: %s", e.what());
+ return;
+ }
+
+ // pre-processing
+ auto tensor_cpu = ToTensor(cv_ptr->image); // OpenCV-image -> Libtorch-tensor
+
+ try {
+ tensor = tensor_cpu.to(device); // move to device (CPU or GPU)
+
+ tensor = tensor.toType(c10::kFloat);
+ tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW
+ tensor = tensor.unsqueeze(0);
+ tensor = at::upsample_bilinear2d(tensor, { net_height, net_width }, true); // resize
+ tensor = tensor.squeeze(0);
+ tensor = tensor.permute({ 1, 2, 0 }); // CHW -> HWC
+
+ tensor = tensor.div(255).sub(mean).div(std); // normalization
+ tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW
+ tensor.unsqueeze_(0); // CHW -> NCHW
+ }
+ catch (const c10::Error& e)
+ {
+ std::cerr << " pre-processing exception: " << e.msg() << std::endl;
+ return;
+ }
+
+ auto input_to_net = ToInput(tensor); // input to the network
+
+ // inference
+ output;
+ try {
+ output = module.forward(input_to_net).toTensor(); // run inference
+ }
+ catch (const c10::Error& e)
+ {
+ std::cerr << " module.forward() exception: " << e.msg() << std::endl;
+ return;
+ }
+
+ output = output.detach().to(torch::kF32);
+
+ // move to CPU temporary
+ at::Tensor output_tmp = output;
+ output_tmp = output_tmp.to(torch::kCPU);
+
+ // normalization
+ float min_val = std::numeric_limits::max();
+ float max_val = std::numeric_limits::min();
+
+ for (int i = 0; i < net_width * net_height; ++i) {
+ float val = output_tmp.data_ptr()[i];
+ if (min_val > val) min_val = val;
+ if (max_val < val) max_val = val;
+ }
+ float range_val = max_val - min_val;
+
+ output = output.sub(min_val).div(range_val).mul(255.0F).clamp(0, 255).to(torch::kF32); // .to(torch::kU8);
+
+ // resize to the original size if required
+ if (out_orig_size) {
+ try {
+ output = at::upsample_bilinear2d(output.unsqueeze(0), { cv_ptr->image.size().height, cv_ptr->image.size().width }, true);
+ output = output.squeeze(0);
+ }
+ catch (const c10::Error& e)
+ {
+ std::cout << " upsample_bilinear2d() exception: " << e.msg() << std::endl;
+ return;
+ }
+ }
+ output = output.permute({ 1, 2, 0 }).to(torch::kCPU);
+
+ int cv_type = CV_32FC1; // CV_8UC1;
+ auto cv_img = ToCvImage(output, cv_type);
+
+ sensor_msgs::Image img_msg;
+
+ try {
+ // cv::Mat -> sensor_msgs::Image
+ std_msgs::Header header; // empty header
+ header.seq = 0; // user defined counter
+ header.stamp = ros::Time::now();// time
+ //cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::MONO8, cv_img);
+ cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::TYPE_32FC1, cv_img);
+
+ img_bridge.toImageMsg(img_msg); // cv_bridge -> sensor_msgs::Image
+ }
+ catch (cv_bridge::Exception& e)
+ {
+ ROS_ERROR("cv_bridge exception: %s", e.what());
+ return;
+ }
+
+ // Output modified video stream
+ image_pub_.publish(img_msg);
+ }
+};
+
+int main(int argc, char** argv)
+{
+ ros::init(argc, argv, "midas", ros::init_options::AnonymousName);
+ Midas ic;
+ ros::spin();
+ return 0;
+}
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a997c4261072d0d627598fe06a723fcc7522d347
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh
@@ -0,0 +1,16 @@
+# place any test.mp4 file near with this file
+
+# roscore
+# rosnode kill -a
+
+source ~/catkin_ws/devel/setup.bash
+
+roscore &
+P1=$!
+rosrun midas_cpp talker.py &
+P2=$!
+rosrun midas_cpp listener_original.py &
+P3=$!
+rosrun midas_cpp listener.py &
+P4=$!
+wait $P1 $P2 $P3 $P4
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..5696ef0547af093713ea416d18edd77d11879d0a
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py
@@ -0,0 +1,277 @@
+"""Compute depth maps for images in the input folder.
+"""
+import os
+import glob
+import torch
+import utils
+import cv2
+import argparse
+import time
+
+import numpy as np
+
+from imutils.video import VideoStream
+from midas.model_loader import default_models, load_model
+
+first_execution = True
+def process(device, model, model_type, image, input_size, target_size, optimize, use_camera):
+ """
+ Run the inference and interpolate.
+
+ Args:
+ device (torch.device): the torch device used
+ model: the model used for inference
+ model_type: the type of the model
+ image: the image fed into the neural network
+ input_size: the size (width, height) of the neural network input (for OpenVINO)
+ target_size: the size (width, height) the neural network output is interpolated to
+ optimize: optimize the model to half-floats on CUDA?
+ use_camera: is the camera used?
+
+ Returns:
+ the prediction
+ """
+ global first_execution
+
+ if "openvino" in model_type:
+ if first_execution or not use_camera:
+ print(f" Input resized to {input_size[0]}x{input_size[1]} before entering the encoder")
+ first_execution = False
+
+ sample = [np.reshape(image, (1, 3, *input_size))]
+ prediction = model(sample)[model.output(0)][0]
+ prediction = cv2.resize(prediction, dsize=target_size,
+ interpolation=cv2.INTER_CUBIC)
+ else:
+ sample = torch.from_numpy(image).to(device).unsqueeze(0)
+
+ if optimize and device == torch.device("cuda"):
+ if first_execution:
+ print(" Optimization to half-floats activated. Use with caution, because models like Swin require\n"
+ " float precision to work properly and may yield non-finite depth values to some extent for\n"
+ " half-floats.")
+ sample = sample.to(memory_format=torch.channels_last)
+ sample = sample.half()
+
+ if first_execution or not use_camera:
+ height, width = sample.shape[2:]
+ print(f" Input resized to {width}x{height} before entering the encoder")
+ first_execution = False
+
+ prediction = model.forward(sample)
+ prediction = (
+ torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=target_size[::-1],
+ mode="bicubic",
+ align_corners=False,
+ )
+ .squeeze()
+ .cpu()
+ .numpy()
+ )
+
+ return prediction
+
+
+def create_side_by_side(image, depth, grayscale):
+ """
+ Take an RGB image and depth map and place them side by side. This includes a proper normalization of the depth map
+ for better visibility.
+
+ Args:
+ image: the RGB image
+ depth: the depth map
+ grayscale: use a grayscale colormap?
+
+ Returns:
+ the image and depth map place side by side
+ """
+ depth_min = depth.min()
+ depth_max = depth.max()
+ normalized_depth = 255 * (depth - depth_min) / (depth_max - depth_min)
+ normalized_depth *= 3
+
+ right_side = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3
+ if not grayscale:
+ right_side = cv2.applyColorMap(np.uint8(right_side), cv2.COLORMAP_INFERNO)
+
+ if image is None:
+ return right_side
+ else:
+ return np.concatenate((image, right_side), axis=1)
+
+
+def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", optimize=False, side=False, height=None,
+ square=False, grayscale=False):
+ """Run MonoDepthNN to compute depth maps.
+
+ Args:
+ input_path (str): path to input folder
+ output_path (str): path to output folder
+ model_path (str): path to saved model
+ model_type (str): the model type
+ optimize (bool): optimize the model to half-floats on CUDA?
+ side (bool): RGB and depth side by side in output images?
+ height (int): inference encoder image height
+ square (bool): resize to a square resolution?
+ grayscale (bool): use a grayscale colormap?
+ """
+ print("Initialize")
+
+ # select device
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print("Device: %s" % device)
+
+ model, transform, net_w, net_h = load_model(device, model_path, model_type, optimize, height, square)
+
+ # get input
+ if input_path is not None:
+ image_names = glob.glob(os.path.join(input_path, "*"))
+ num_images = len(image_names)
+ else:
+ print("No input path specified. Grabbing images from camera.")
+
+ # create output folder
+ if output_path is not None:
+ os.makedirs(output_path, exist_ok=True)
+
+ print("Start processing")
+
+ if input_path is not None:
+ if output_path is None:
+ print("Warning: No output path specified. Images will be processed but not shown or stored anywhere.")
+ for index, image_name in enumerate(image_names):
+
+ print(" Processing {} ({}/{})".format(image_name, index + 1, num_images))
+
+ # input
+ original_image_rgb = utils.read_image(image_name) # in [0, 1]
+ image = transform({"image": original_image_rgb})["image"]
+
+ # compute
+ with torch.no_grad():
+ prediction = process(device, model, model_type, image, (net_w, net_h), original_image_rgb.shape[1::-1],
+ optimize, False)
+
+ # output
+ if output_path is not None:
+ filename = os.path.join(
+ output_path, os.path.splitext(os.path.basename(image_name))[0] + '-' + model_type
+ )
+ if not side:
+ utils.write_depth(filename, prediction, grayscale, bits=2)
+ else:
+ original_image_bgr = np.flip(original_image_rgb, 2)
+ content = create_side_by_side(original_image_bgr*255, prediction, grayscale)
+ cv2.imwrite(filename + ".png", content)
+ utils.write_pfm(filename + ".pfm", prediction.astype(np.float32))
+
+ else:
+ with torch.no_grad():
+ fps = 1
+ video = VideoStream(0).start()
+ time_start = time.time()
+ frame_index = 0
+ while True:
+ frame = video.read()
+ if frame is not None:
+ original_image_rgb = np.flip(frame, 2) # in [0, 255] (flip required to get RGB)
+ image = transform({"image": original_image_rgb/255})["image"]
+
+ prediction = process(device, model, model_type, image, (net_w, net_h),
+ original_image_rgb.shape[1::-1], optimize, True)
+
+ original_image_bgr = np.flip(original_image_rgb, 2) if side else None
+ content = create_side_by_side(original_image_bgr, prediction, grayscale)
+ cv2.imshow('MiDaS Depth Estimation - Press Escape to close window ', content/255)
+
+ if output_path is not None:
+ filename = os.path.join(output_path, 'Camera' + '-' + model_type + '_' + str(frame_index))
+ cv2.imwrite(filename + ".png", content)
+
+ alpha = 0.1
+ if time.time()-time_start > 0:
+ fps = (1 - alpha) * fps + alpha * 1 / (time.time()-time_start) # exponential moving average
+ time_start = time.time()
+ print(f"\rFPS: {round(fps,2)}", end="")
+
+ if cv2.waitKey(1) == 27: # Escape key
+ break
+
+ frame_index += 1
+ print()
+
+ print("Finished")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-i', '--input_path',
+ default=None,
+ help='Folder with input images (if no input path is specified, images are tried to be grabbed '
+ 'from camera)'
+ )
+
+ parser.add_argument('-o', '--output_path',
+ default=None,
+ help='Folder for output images'
+ )
+
+ parser.add_argument('-m', '--model_weights',
+ default=None,
+ help='Path to the trained weights of model'
+ )
+
+ parser.add_argument('-t', '--model_type',
+ default='dpt_beit_large_512',
+ help='Model type: '
+ 'dpt_beit_large_512, dpt_beit_large_384, dpt_beit_base_384, dpt_swin2_large_384, '
+ 'dpt_swin2_base_384, dpt_swin2_tiny_256, dpt_swin_large_384, dpt_next_vit_large_384, '
+ 'dpt_levit_224, dpt_large_384, dpt_hybrid_384, midas_v21_384, midas_v21_small_256 or '
+ 'openvino_midas_v21_small_256'
+ )
+
+ parser.add_argument('-s', '--side',
+ action='store_true',
+ help='Output images contain RGB and depth images side by side'
+ )
+
+ parser.add_argument('--optimize', dest='optimize', action='store_true', help='Use half-float optimization')
+ parser.set_defaults(optimize=False)
+
+ parser.add_argument('--height',
+ type=int, default=None,
+ help='Preferred height of images feed into the encoder during inference. Note that the '
+ 'preferred height may differ from the actual height, because an alignment to multiples of '
+ '32 takes place. Many models support only the height chosen during training, which is '
+ 'used automatically if this parameter is not set.'
+ )
+ parser.add_argument('--square',
+ action='store_true',
+ help='Option to resize images to a square resolution by changing their widths when images are '
+ 'fed into the encoder during inference. If this parameter is not set, the aspect ratio of '
+ 'images is tried to be preserved if supported by the model.'
+ )
+ parser.add_argument('--grayscale',
+ action='store_true',
+ help='Use a grayscale colormap instead of the inferno one. Although the inferno colormap, '
+ 'which is used by default, is better for visibility, it does not allow storing 16-bit '
+ 'depth values in PNGs but only 8-bit ones due to the precision limitation of this '
+ 'colormap.'
+ )
+
+ args = parser.parse_args()
+
+
+ if args.model_weights is None:
+ args.model_weights = default_models[args.model_type]
+
+ # set torch options
+ torch.backends.cudnn.enabled = True
+ torch.backends.cudnn.benchmark = True
+
+ # compute depth maps
+ run(args.input_path, args.output_path, args.model_weights, args.model_type, args.optimize, args.side, args.height,
+ args.square, args.grayscale)
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5b5fe0e63668eab45a55b140826cb3762862b17c
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md
@@ -0,0 +1,147 @@
+## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
+
+### TensorFlow inference using `.pb` and `.onnx` models
+
+1. [Run inference on TensorFlow-model by using TensorFlow](#run-inference-on-tensorflow-model-by-using-tensorFlow)
+
+2. [Run inference on ONNX-model by using TensorFlow](#run-inference-on-onnx-model-by-using-tensorflow)
+
+3. [Make ONNX model from downloaded Pytorch model file](#make-onnx-model-from-downloaded-pytorch-model-file)
+
+
+### Run inference on TensorFlow-model by using TensorFlow
+
+1) Download the model weights [model-f6b98070.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pb)
+and [model-small.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.pb) and place the
+file in the `/tf/` folder.
+
+2) Set up dependencies:
+
+```shell
+# install OpenCV
+pip install --upgrade pip
+pip install opencv-python
+
+# install TensorFlow
+pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0
+```
+
+#### Usage
+
+1) Place one or more input images in the folder `tf/input`.
+
+2) Run the model:
+
+ ```shell
+ python tf/run_pb.py
+ ```
+
+ Or run the small model:
+
+ ```shell
+ python tf/run_pb.py --model_weights model-small.pb --model_type small
+ ```
+
+3) The resulting inverse depth maps are written to the `tf/output` folder.
+
+
+### Run inference on ONNX-model by using ONNX-Runtime
+
+1) Download the model weights [model-f6b98070.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.onnx)
+and [model-small.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.onnx) and place the
+file in the `/tf/` folder.
+
+2) Set up dependencies:
+
+```shell
+# install OpenCV
+pip install --upgrade pip
+pip install opencv-python
+
+# install ONNX
+pip install onnx==1.7.0
+
+# install ONNX Runtime
+pip install onnxruntime==1.5.2
+```
+
+#### Usage
+
+1) Place one or more input images in the folder `tf/input`.
+
+2) Run the model:
+
+ ```shell
+ python tf/run_onnx.py
+ ```
+
+ Or run the small model:
+
+ ```shell
+ python tf/run_onnx.py --model_weights model-small.onnx --model_type small
+ ```
+
+3) The resulting inverse depth maps are written to the `tf/output` folder.
+
+
+
+### Make ONNX model from downloaded Pytorch model file
+
+1) Download the model weights [model-f6b98070.pt](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pt) and place the
+file in the root folder.
+
+2) Set up dependencies:
+
+```shell
+# install OpenCV
+pip install --upgrade pip
+pip install opencv-python
+
+# install PyTorch TorchVision
+pip install -I torch==1.7.0 torchvision==0.8.0
+
+# install TensorFlow
+pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0
+
+# install ONNX
+pip install onnx==1.7.0
+
+# install ONNX-TensorFlow
+git clone https://github.com/onnx/onnx-tensorflow.git
+cd onnx-tensorflow
+git checkout 095b51b88e35c4001d70f15f80f31014b592b81e
+pip install -e .
+```
+
+#### Usage
+
+1) Run the converter:
+
+ ```shell
+ python tf/make_onnx_model.py
+ ```
+
+2) The resulting `model-f6b98070.onnx` file is written to the `/tf/` folder.
+
+
+### Requirements
+
+ The code was tested with Python 3.6.9, PyTorch 1.5.1, TensorFlow 2.2.0, TensorFlow-addons 0.8.3, ONNX 1.7.0, ONNX-TensorFlow (GitHub-master-17.07.2020) and OpenCV 4.3.0.
+
+### Citation
+
+Please cite our paper if you use this code or any of the models:
+```
+@article{Ranftl2019,
+ author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun},
+ title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer},
+ journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
+ year = {2020},
+}
+```
+
+### License
+
+MIT License
+
+
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d14b0e4e1d2ea70fa315fd7ca7dfd72440a19376
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py
@@ -0,0 +1,112 @@
+"""Compute depth maps for images in the input folder.
+"""
+import os
+import ntpath
+import glob
+import torch
+import utils
+import cv2
+import numpy as np
+from torchvision.transforms import Compose, Normalize
+from torchvision import transforms
+
+from shutil import copyfile
+import fileinput
+import sys
+sys.path.append(os.getcwd() + '/..')
+
+def modify_file():
+ modify_filename = '../midas/blocks.py'
+ copyfile(modify_filename, modify_filename+'.bak')
+
+ with open(modify_filename, 'r') as file :
+ filedata = file.read()
+
+ filedata = filedata.replace('align_corners=True', 'align_corners=False')
+ filedata = filedata.replace('import torch.nn as nn', 'import torch.nn as nn\nimport torchvision.models as models')
+ filedata = filedata.replace('torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")', 'models.resnext101_32x8d()')
+
+ with open(modify_filename, 'w') as file:
+ file.write(filedata)
+
+def restore_file():
+ modify_filename = '../midas/blocks.py'
+ copyfile(modify_filename+'.bak', modify_filename)
+
+modify_file()
+
+from midas.midas_net import MidasNet
+from midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+restore_file()
+
+
+class MidasNet_preprocessing(MidasNet):
+ """Network for monocular depth estimation.
+ """
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ mean = torch.tensor([0.485, 0.456, 0.406])
+ std = torch.tensor([0.229, 0.224, 0.225])
+ x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
+
+ return MidasNet.forward(self, x)
+
+
+def run(model_path):
+ """Run MonoDepthNN to compute depth maps.
+
+ Args:
+ model_path (str): path to saved model
+ """
+ print("initialize")
+
+ # select device
+
+ # load network
+ #model = MidasNet(model_path, non_negative=True)
+ model = MidasNet_preprocessing(model_path, non_negative=True)
+
+ model.eval()
+
+ print("start processing")
+
+ # input
+ img_input = np.zeros((3, 384, 384), np.float32)
+
+ # compute
+ with torch.no_grad():
+ sample = torch.from_numpy(img_input).unsqueeze(0)
+ prediction = model.forward(sample)
+ prediction = (
+ torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=img_input.shape[:2],
+ mode="bicubic",
+ align_corners=False,
+ )
+ .squeeze()
+ .cpu()
+ .numpy()
+ )
+
+ torch.onnx.export(model, sample, ntpath.basename(model_path).rsplit('.', 1)[0]+'.onnx', opset_version=9)
+
+ print("finished")
+
+
+if __name__ == "__main__":
+ # set paths
+ # MODEL_PATH = "model.pt"
+ MODEL_PATH = "../model-f6b98070.pt"
+
+ # compute depth maps
+ run(MODEL_PATH)
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..7107b99969a127f951814f743d5c562a436b2430
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py
@@ -0,0 +1,119 @@
+"""Compute depth maps for images in the input folder.
+"""
+import os
+import glob
+import utils
+import cv2
+import sys
+import numpy as np
+import argparse
+
+import onnx
+import onnxruntime as rt
+
+from transforms import Resize, NormalizeImage, PrepareForNet
+
+
+def run(input_path, output_path, model_path, model_type="large"):
+ """Run MonoDepthNN to compute depth maps.
+
+ Args:
+ input_path (str): path to input folder
+ output_path (str): path to output folder
+ model_path (str): path to saved model
+ """
+ print("initialize")
+
+ # select device
+ device = "CUDA:0"
+ #device = "CPU"
+ print("device: %s" % device)
+
+ # network resolution
+ if model_type == "large":
+ net_w, net_h = 384, 384
+ elif model_type == "small":
+ net_w, net_h = 256, 256
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ # load network
+ print("loading model...")
+ model = rt.InferenceSession(model_path)
+ input_name = model.get_inputs()[0].name
+ output_name = model.get_outputs()[0].name
+
+ resize_image = Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=32,
+ resize_method="upper_bound",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ )
+
+ def compose2(f1, f2):
+ return lambda x: f2(f1(x))
+
+ transform = compose2(resize_image, PrepareForNet())
+
+ # get input
+ img_names = glob.glob(os.path.join(input_path, "*"))
+ num_images = len(img_names)
+
+ # create output folder
+ os.makedirs(output_path, exist_ok=True)
+
+ print("start processing")
+
+ for ind, img_name in enumerate(img_names):
+
+ print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
+
+ # input
+ img = utils.read_image(img_name)
+ img_input = transform({"image": img})["image"]
+
+ # compute
+ output = model.run([output_name], {input_name: img_input.reshape(1, 3, net_h, net_w).astype(np.float32)})[0]
+ prediction = np.array(output).reshape(net_h, net_w)
+ prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ # output
+ filename = os.path.join(
+ output_path, os.path.splitext(os.path.basename(img_name))[0]
+ )
+ utils.write_depth(filename, prediction, bits=2)
+
+ print("finished")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-i', '--input_path',
+ default='input',
+ help='folder with input images'
+ )
+
+ parser.add_argument('-o', '--output_path',
+ default='output',
+ help='folder for output images'
+ )
+
+ parser.add_argument('-m', '--model_weights',
+ default='model-f6b98070.onnx',
+ help='path to the trained weights of model'
+ )
+
+ parser.add_argument('-t', '--model_type',
+ default='large',
+ help='model type: large or small'
+ )
+
+ args = parser.parse_args()
+
+ # compute depth maps
+ run(args.input_path, args.output_path, args.model_weights, args.model_type)
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py
new file mode 100644
index 0000000000000000000000000000000000000000..e46254f7b37f72e7d87672d70fd4b2f393ad7658
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py
@@ -0,0 +1,135 @@
+"""Compute depth maps for images in the input folder.
+"""
+import os
+import glob
+import utils
+import cv2
+import argparse
+
+import tensorflow as tf
+
+from transforms import Resize, NormalizeImage, PrepareForNet
+
+def run(input_path, output_path, model_path, model_type="large"):
+ """Run MonoDepthNN to compute depth maps.
+
+ Args:
+ input_path (str): path to input folder
+ output_path (str): path to output folder
+ model_path (str): path to saved model
+ """
+ print("initialize")
+
+ # the runtime initialization will not allocate all memory on the device to avoid out of GPU memory
+ gpus = tf.config.experimental.list_physical_devices('GPU')
+ if gpus:
+ try:
+ for gpu in gpus:
+ #tf.config.experimental.set_memory_growth(gpu, True)
+ tf.config.experimental.set_virtual_device_configuration(gpu,
+ [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4000)])
+ except RuntimeError as e:
+ print(e)
+
+ # network resolution
+ if model_type == "large":
+ net_w, net_h = 384, 384
+ elif model_type == "small":
+ net_w, net_h = 256, 256
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ # load network
+ graph_def = tf.compat.v1.GraphDef()
+ with tf.io.gfile.GFile(model_path, 'rb') as f:
+ graph_def.ParseFromString(f.read())
+ tf.import_graph_def(graph_def, name='')
+
+
+ model_operations = tf.compat.v1.get_default_graph().get_operations()
+ input_node = '0:0'
+ output_layer = model_operations[len(model_operations) - 1].name + ':0'
+ print("Last layer name: ", output_layer)
+
+ resize_image = Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=32,
+ resize_method="upper_bound",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ )
+
+ def compose2(f1, f2):
+ return lambda x: f2(f1(x))
+
+ transform = compose2(resize_image, PrepareForNet())
+
+ # get input
+ img_names = glob.glob(os.path.join(input_path, "*"))
+ num_images = len(img_names)
+
+ # create output folder
+ os.makedirs(output_path, exist_ok=True)
+
+ print("start processing")
+
+ with tf.compat.v1.Session() as sess:
+ try:
+ # load images
+ for ind, img_name in enumerate(img_names):
+
+ print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
+
+ # input
+ img = utils.read_image(img_name)
+ img_input = transform({"image": img})["image"]
+
+ # compute
+ prob_tensor = sess.graph.get_tensor_by_name(output_layer)
+ prediction, = sess.run(prob_tensor, {input_node: [img_input] })
+ prediction = prediction.reshape(net_h, net_w)
+ prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ # output
+ filename = os.path.join(
+ output_path, os.path.splitext(os.path.basename(img_name))[0]
+ )
+ utils.write_depth(filename, prediction, bits=2)
+
+ except KeyError:
+ print ("Couldn't find input node: ' + input_node + ' or output layer: " + output_layer + ".")
+ exit(-1)
+
+ print("finished")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-i', '--input_path',
+ default='input',
+ help='folder with input images'
+ )
+
+ parser.add_argument('-o', '--output_path',
+ default='output',
+ help='folder for output images'
+ )
+
+ parser.add_argument('-m', '--model_weights',
+ default='model-f6b98070.pb',
+ help='path to the trained weights of model'
+ )
+
+ parser.add_argument('-t', '--model_type',
+ default='large',
+ help='model type: large or small'
+ )
+
+ args = parser.parse_args()
+
+ # compute depth maps
+ run(args.input_path, args.output_path, args.model_weights, args.model_type)
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff9a54bd55f5e31a90fad21242efbfda5a6cc1a7
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py
@@ -0,0 +1,82 @@
+import numpy as np
+import sys
+import cv2
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+ Args:
+ path (str): path to file
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = 0
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3976fd97dfe6a9dc7d4fa144be8fcb0b18b2db
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py
@@ -0,0 +1,199 @@
+"""Utils for monoDepth.
+"""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, grayscale, bits=1):
+ """Write depth map to png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ grayscale (bool): use a grayscale colormap?
+ """
+ if not grayscale:
+ bits = 1
+
+ if not np.isfinite(depth).all():
+ depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0)
+ print("WARNING: Non-finite depth values present")
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.dtype)
+
+ if not grayscale:
+ out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/src/flux/annotator/zoe/zoedepth/models/builder.py b/src/flux/annotator/zoe/zoedepth/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0818311b642561712a03a66655c638ce09a04cca
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/builder.py
@@ -0,0 +1,51 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+from importlib import import_module
+from .depth_model import DepthModel
+
+def build_model(config) -> DepthModel:
+ """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface.
+ This function should be used to construct models for training and evaluation.
+
+ Args:
+ config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder.
+
+ Returns:
+ torch.nn.Module: Model corresponding to name and version as specified in config
+ """
+ module_name = f"zoedepth.models.{config.model}"
+ try:
+ module = import_module(module_name)
+ except ModuleNotFoundError as e:
+ # print the original error message
+ print(e)
+ raise ValueError(
+ f"Model {config.model} not found. Refer above error for details.") from e
+ try:
+ get_version = getattr(module, "get_version")
+ except AttributeError as e:
+ raise ValueError(
+ f"Model {config.model} has no get_version function.") from e
+ return get_version(config.version_name).build_from_config(config)
diff --git a/src/flux/annotator/zoe/zoedepth/models/depth_model.py b/src/flux/annotator/zoe/zoedepth/models/depth_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc421c108ea3928c9add62b4c190500d9bd4eda1
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/depth_model.py
@@ -0,0 +1,152 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import transforms
+import PIL.Image
+from PIL import Image
+from typing import Union
+
+
+class DepthModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.device = 'cpu'
+
+ def to(self, device) -> nn.Module:
+ self.device = device
+ return super().to(device)
+
+ def forward(self, x, *args, **kwargs):
+ raise NotImplementedError
+
+ def _infer(self, x: torch.Tensor):
+ """
+ Inference interface for the model
+ Args:
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
+ Returns:
+ torch.Tensor: output tensor of shape (b, 1, h, w)
+ """
+ return self(x)['metric_depth']
+
+ def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor:
+ """
+ Inference interface for the model with padding augmentation
+ Padding augmentation fixes the boundary artifacts in the output depth map.
+ Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image.
+ This augmentation pads the input image and crops the prediction back to the original size / view.
+
+ Note: This augmentation is not required for the models trained with 'avoid_boundary'=True.
+ Args:
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
+ pad_input (bool, optional): whether to pad the input or not. Defaults to True.
+ fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3.
+ fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3.
+ upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'.
+ padding_mode (str, optional): padding mode. Defaults to "reflect".
+ Returns:
+ torch.Tensor: output tensor of shape (b, 1, h, w)
+ """
+ # assert x is nchw and c = 3
+ assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim())
+ assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1])
+
+ if pad_input:
+ assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0"
+ pad_h = int(np.sqrt(x.shape[2]/2) * fh)
+ pad_w = int(np.sqrt(x.shape[3]/2) * fw)
+ padding = [pad_w, pad_w]
+ if pad_h > 0:
+ padding += [pad_h, pad_h]
+
+ x = F.pad(x, padding, mode=padding_mode, **kwargs)
+ out = self._infer(x)
+ if out.shape[-2:] != x.shape[-2:]:
+ out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False)
+ if pad_input:
+ # crop to the original size, handling the case where pad_h and pad_w is 0
+ if pad_h > 0:
+ out = out[:, :, pad_h:-pad_h,:]
+ if pad_w > 0:
+ out = out[:, :, :, pad_w:-pad_w]
+ return out
+
+ def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor:
+ """
+ Inference interface for the model with horizontal flip augmentation
+ Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip.
+ Args:
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
+ Returns:
+ torch.Tensor: output tensor of shape (b, 1, h, w)
+ """
+ # infer with horizontal flip and average
+ out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs)
+ out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs)
+ out = (out + torch.flip(out_flip, dims=[3])) / 2
+ return out
+
+ def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor:
+ """
+ Inference interface for the model
+ Args:
+ x (torch.Tensor): input tensor of shape (b, c, h, w)
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
+ with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
+ Returns:
+ torch.Tensor: output tensor of shape (b, 1, h, w)
+ """
+ if with_flip_aug:
+ return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs)
+ else:
+ return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs)
+
+ @torch.no_grad()
+ def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]:
+ """
+ Inference interface for the model for PIL image
+ Args:
+ pil_img (PIL.Image.Image): input PIL image
+ pad_input (bool, optional): whether to use padding augmentation. Defaults to True.
+ with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True.
+ output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy".
+ """
+ x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device)
+ out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs)
+ if output_type == "numpy":
+ return out_tensor.squeeze().cpu().numpy()
+ elif output_type == "pil":
+ # uint16 is required for depth pil image
+ out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16)
+ return Image.fromarray(out_16bit_numpy)
+ elif output_type == "tensor":
+ return out_tensor.squeeze().cpu()
+ else:
+ raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'")
+
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py b/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a8efe645adea1d88a12e2ac5cc6bb2a251eef9d
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py
@@ -0,0 +1,208 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+
+
+@torch.jit.script
+def exp_attractor(dx, alpha: float = 300, gamma: int = 2):
+ """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor
+
+ Args:
+ dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
+ alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
+ gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
+
+ Returns:
+ torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc
+ """
+ return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx)
+
+
+@torch.jit.script
+def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
+ """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
+ This is the default one according to the accompanying paper.
+
+ Args:
+ dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
+ alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
+ gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
+
+ Returns:
+ torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
+ """
+ return dx.div(1+alpha*dx.pow(gamma))
+
+
+class AttractorLayer(nn.Module):
+ def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
+ alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
+ """
+ Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth)
+ """
+ super().__init__()
+
+ self.n_attractors = n_attractors
+ self.n_bins = n_bins
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.alpha = alpha
+ self.gamma = gamma
+ self.kind = kind
+ self.attractor_type = attractor_type
+ self.memory_efficient = memory_efficient
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
+ """
+ Args:
+ x (torch.Tensor) : feature block; shape - n, c, h, w
+ b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
+
+ Returns:
+ tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w
+ """
+ if prev_b_embedding is not None:
+ if interpolate:
+ prev_b_embedding = nn.functional.interpolate(
+ prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
+ x = x + prev_b_embedding
+
+ A = self._net(x)
+ eps = 1e-3
+ A = A + eps
+ n, c, h, w = A.shape
+ A = A.view(n, self.n_attractors, 2, h, w)
+ A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w
+ A_normed = A[:, :, 0, ...] # n, na, h, w
+
+ b_prev = nn.functional.interpolate(
+ b_prev, (h, w), mode='bilinear', align_corners=True)
+ b_centers = b_prev
+
+ if self.attractor_type == 'exp':
+ dist = exp_attractor
+ else:
+ dist = inv_attractor
+
+ if not self.memory_efficient:
+ func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
+ # .shape N, nbins, h, w
+ delta_c = func(dist(A_normed.unsqueeze(
+ 2) - b_centers.unsqueeze(1)), dim=1)
+ else:
+ delta_c = torch.zeros_like(b_centers, device=b_centers.device)
+ for i in range(self.n_attractors):
+ # .shape N, nbins, h, w
+ delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers)
+
+ if self.kind == 'mean':
+ delta_c = delta_c / self.n_attractors
+
+ b_new_centers = b_centers + delta_c
+ B_centers = (self.max_depth - self.min_depth) * \
+ b_new_centers + self.min_depth
+ B_centers, _ = torch.sort(B_centers, dim=1)
+ B_centers = torch.clip(B_centers, self.min_depth, self.max_depth)
+ return b_new_centers, B_centers
+
+
+class AttractorLayerUnnormed(nn.Module):
+ def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
+ alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
+ """
+ Attractor layer for bin centers. Bin centers are unbounded
+ """
+ super().__init__()
+
+ self.n_attractors = n_attractors
+ self.n_bins = n_bins
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.alpha = alpha
+ self.gamma = gamma
+ self.kind = kind
+ self.attractor_type = attractor_type
+ self.memory_efficient = memory_efficient
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0),
+ nn.Softplus()
+ )
+
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
+ """
+ Args:
+ x (torch.Tensor) : feature block; shape - n, c, h, w
+ b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
+
+ Returns:
+ tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version
+ """
+ if prev_b_embedding is not None:
+ if interpolate:
+ prev_b_embedding = nn.functional.interpolate(
+ prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
+ x = x + prev_b_embedding
+
+ A = self._net(x)
+ n, c, h, w = A.shape
+
+ b_prev = nn.functional.interpolate(
+ b_prev, (h, w), mode='bilinear', align_corners=True)
+ b_centers = b_prev
+
+ if self.attractor_type == 'exp':
+ dist = exp_attractor
+ else:
+ dist = inv_attractor
+
+ if not self.memory_efficient:
+ func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
+ # .shape N, nbins, h, w
+ delta_c = func(
+ dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1)
+ else:
+ delta_c = torch.zeros_like(b_centers, device=b_centers.device)
+ for i in range(self.n_attractors):
+ delta_c += dist(A[:, i, ...].unsqueeze(1) -
+ b_centers) # .shape N, nbins, h, w
+
+ if self.kind == 'mean':
+ delta_c = delta_c / self.n_attractors
+
+ b_new_centers = b_centers + delta_c
+ B_centers = b_new_centers
+
+ return b_new_centers, B_centers
diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py b/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3208405dfb78fdfc28d5765e5a6d5dbe31967a23
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py
@@ -0,0 +1,121 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+
+
+def log_binom(n, k, eps=1e-7):
+ """ log(nCk) using stirling approximation """
+ n = n + eps
+ k = k + eps
+ return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps)
+
+
+class LogBinomial(nn.Module):
+ def __init__(self, n_classes=256, act=torch.softmax):
+ """Compute log binomial distribution for n_classes
+
+ Args:
+ n_classes (int, optional): number of output classes. Defaults to 256.
+ """
+ super().__init__()
+ self.K = n_classes
+ self.act = act
+ self.register_buffer('k_idx', torch.arange(
+ 0, n_classes).view(1, -1, 1, 1))
+ self.register_buffer('K_minus_1', torch.Tensor(
+ [self.K-1]).view(1, -1, 1, 1))
+
+ def forward(self, x, t=1., eps=1e-4):
+ """Compute log binomial distribution for x
+
+ Args:
+ x (torch.Tensor - NCHW): probabilities
+ t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1..
+ eps (float, optional): Small number for numerical stability. Defaults to 1e-4.
+
+ Returns:
+ torch.Tensor -NCHW: log binomial distribution logbinomial(p;t)
+ """
+ if x.ndim == 3:
+ x = x.unsqueeze(1) # make it nchw
+
+ one_minus_x = torch.clamp(1 - x, eps, 1)
+ x = torch.clamp(x, eps, 1)
+ y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \
+ torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x)
+ return self.act(y/t, dim=1)
+
+
+class ConditionalLogBinomial(nn.Module):
+ def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax):
+ """Conditional Log Binomial distribution
+
+ Args:
+ in_features (int): number of input channels in main feature
+ condition_dim (int): number of input channels in condition feature
+ n_classes (int, optional): Number of classes. Defaults to 256.
+ bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2.
+ p_eps (float, optional): small eps value. Defaults to 1e-4.
+ max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50.
+ min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7.
+ """
+ super().__init__()
+ self.p_eps = p_eps
+ self.max_temp = max_temp
+ self.min_temp = min_temp
+ self.log_binomial_transform = LogBinomial(n_classes, act=act)
+ bottleneck = (in_features + condition_dim) // bottleneck_factor
+ self.mlp = nn.Sequential(
+ nn.Conv2d(in_features + condition_dim, bottleneck,
+ kernel_size=1, stride=1, padding=0),
+ nn.GELU(),
+ # 2 for p linear norm, 2 for t linear norm
+ nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0),
+ nn.Softplus()
+ )
+
+ def forward(self, x, cond):
+ """Forward pass
+
+ Args:
+ x (torch.Tensor - NCHW): Main feature
+ cond (torch.Tensor - NCHW): condition feature
+
+ Returns:
+ torch.Tensor: Output log binomial distribution
+ """
+ pt = self.mlp(torch.concat((x, cond), dim=1))
+ p, t = pt[:, :2, ...], pt[:, 2:, ...]
+
+ p = p + self.p_eps
+ p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...])
+
+ t = t + self.p_eps
+ t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...])
+ t = t.unsqueeze(1)
+ t = (self.max_temp - self.min_temp) * t + self.min_temp
+
+ return self.log_binomial_transform(p, t)
diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py b/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f94481605c3e6958ce50e73b2eb31d9f0c07dc67
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py
@@ -0,0 +1,169 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+
+
+class SeedBinRegressor(nn.Module):
+ def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
+ """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval.
+
+ Args:
+ in_features (int): input channels
+ n_bins (int, optional): Number of bin centers. Defaults to 16.
+ mlp_dim (int, optional): Hidden dimension. Defaults to 256.
+ min_depth (float, optional): Min depth value. Defaults to 1e-3.
+ max_depth (float, optional): Max depth value. Defaults to 10.
+ """
+ super().__init__()
+ self.version = "1_1"
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x):
+ """
+ Returns tensor of bin_width vectors (centers). One vector b for every pixel
+ """
+ B = self._net(x)
+ eps = 1e-3
+ B = B + eps
+ B_widths_normed = B / B.sum(dim=1, keepdim=True)
+ B_widths = (self.max_depth - self.min_depth) * \
+ B_widths_normed # .shape NCHW
+ # pad has the form (left, right, top, bottom, front, back)
+ B_widths = nn.functional.pad(
+ B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth)
+ B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
+
+ B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...])
+ return B_widths_normed, B_centers
+
+
+class SeedBinRegressorUnnormed(nn.Module):
+ def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
+ """Bin center regressor network. Bin centers are unbounded
+
+ Args:
+ in_features (int): input channels
+ n_bins (int, optional): Number of bin centers. Defaults to 16.
+ mlp_dim (int, optional): Hidden dimension. Defaults to 256.
+ min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
+ max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
+ """
+ super().__init__()
+ self.version = "1_1"
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
+ nn.Softplus()
+ )
+
+ def forward(self, x):
+ """
+ Returns tensor of bin_width vectors (centers). One vector b for every pixel
+ """
+ B_centers = self._net(x)
+ return B_centers, B_centers
+
+
+class Projector(nn.Module):
+ def __init__(self, in_features, out_features, mlp_dim=128):
+ """Projector MLP
+
+ Args:
+ in_features (int): input channels
+ out_features (int): output channels
+ mlp_dim (int, optional): hidden dimension. Defaults to 128.
+ """
+ super().__init__()
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mlp_dim, out_features, 1, 1, 0),
+ )
+
+ def forward(self, x):
+ return self._net(x)
+
+
+
+class LinearSplitter(nn.Module):
+ def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10):
+ super().__init__()
+
+ self.prev_nbins = prev_nbins
+ self.split_factor = split_factor
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+
+ self._net = nn.Sequential(
+ nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
+ nn.GELU(),
+ nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0),
+ nn.ReLU()
+ )
+
+ def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
+ """
+ x : feature block; shape - n, c, h, w
+ b_prev : previous bin widths normed; shape - n, prev_nbins, h, w
+ """
+ if prev_b_embedding is not None:
+ if interpolate:
+ prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
+ x = x + prev_b_embedding
+ S = self._net(x)
+ eps = 1e-3
+ S = S + eps
+ n, c, h, w = S.shape
+ S = S.view(n, self.prev_nbins, self.split_factor, h, w)
+ S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits
+
+ b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True)
+
+
+ b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees
+ # print(b_prev.shape, S_normed.shape)
+ # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat?
+ b = b_prev.unsqueeze(2) * S_normed
+ b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w
+
+ # calculate bin centers for loss calculation
+ B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W
+ # pad has the form (left, right, top, bottom, front, back)
+ B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth)
+ B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
+
+ B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...])
+ return b, B_centers
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py b/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..99d9e51a06b981bae45ce7dd64eaef19a4121991
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py
@@ -0,0 +1,91 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+
+
+class PatchTransformerEncoder(nn.Module):
+ def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False):
+ """ViT-like transformer block
+
+ Args:
+ in_channels (int): Input channels
+ patch_size (int, optional): patch size. Defaults to 10.
+ embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128.
+ num_heads (int, optional): number of attention heads. Defaults to 4.
+ use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False.
+ """
+ super(PatchTransformerEncoder, self).__init__()
+ self.use_class_token = use_class_token
+ encoder_layers = nn.TransformerEncoderLayer(
+ embedding_dim, num_heads, dim_feedforward=1024)
+ self.transformer_encoder = nn.TransformerEncoder(
+ encoder_layers, num_layers=4) # takes shape S,N,E
+
+ self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim,
+ kernel_size=patch_size, stride=patch_size, padding=0)
+
+ def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'):
+ """Generate positional encodings
+
+ Args:
+ sequence_length (int): Sequence length
+ embedding_dim (int): Embedding dimension
+
+ Returns:
+ torch.Tensor SBE: Positional encodings
+ """
+ position = torch.arange(
+ 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1)
+ index = torch.arange(
+ 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0)
+ div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim))
+ pos_encoding = position * div_term
+ pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1)
+ pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1)
+ return pos_encoding
+
+
+ def forward(self, x):
+ """Forward pass
+
+ Args:
+ x (torch.Tensor - NCHW): Input feature tensor
+
+ Returns:
+ torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim
+ """
+ embeddings = self.embedding_convPxP(x).flatten(
+ 2) # .shape = n,c,s = n, embedding_dim, s
+ if self.use_class_token:
+ # extra special token at start ?
+ embeddings = nn.functional.pad(embeddings, (1, 0))
+
+ # change to S,N,E format required by transformer
+ embeddings = embeddings.permute(2, 0, 1)
+ S, N, E = embeddings.shape
+ embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device)
+ x = self.transformer_encoder(embeddings) # .shape = S, N, E
+ return x
diff --git a/src/flux/annotator/zoe/zoedepth/models/model_io.py b/src/flux/annotator/zoe/zoedepth/models/model_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..78b6579631dd847ac76651238cb5a948b5a66286
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/model_io.py
@@ -0,0 +1,92 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+
+def load_state_dict(model, state_dict):
+ """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict.
+
+ DataParallel prefixes state_dict keys with 'module.' when saving.
+ If the model is not a DataParallel model but the state_dict is, then prefixes are removed.
+ If the model is a DataParallel model but the state_dict is not, then prefixes are added.
+ """
+ state_dict = state_dict.get('model', state_dict)
+ # if model is a DataParallel model, then state_dict keys are prefixed with 'module.'
+
+ do_prefix = isinstance(
+ model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
+ state = {}
+ for k, v in state_dict.items():
+ if k.startswith('module.') and not do_prefix:
+ k = k[7:]
+
+ if not k.startswith('module.') and do_prefix:
+ k = 'module.' + k
+
+ state[k] = v
+
+ model.load_state_dict(state)
+ print("Loaded successfully")
+ return model
+
+
+def load_wts(model, checkpoint_path):
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
+ return load_state_dict(model, ckpt)
+
+
+def load_state_dict_from_url(model, url, **kwargs):
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs)
+ return load_state_dict(model, state_dict)
+
+
+def load_state_from_resource(model, resource: str):
+ """Loads weights to the model from a given resource. A resource can be of following types:
+ 1. URL. Prefixed with "url::"
+ e.g. url::http(s)://url.resource.com/ckpt.pt
+
+ 2. Local path. Prefixed with "local::"
+ e.g. local::/path/to/ckpt.pt
+
+
+ Args:
+ model (torch.nn.Module): Model
+ resource (str): resource string
+
+ Returns:
+ torch.nn.Module: Model with loaded weights
+ """
+ print(f"Using pretrained resource {resource}")
+
+ if resource.startswith('url::'):
+ url = resource.split('url::')[1]
+ return load_state_dict_from_url(model, url, progress=True)
+
+ elif resource.startswith('local::'):
+ path = resource.split('local::')[1]
+ return load_wts(model, path)
+
+ else:
+ raise ValueError("Invalid resource type, only url:: and local:: are supported")
+
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py
@@ -0,0 +1,31 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+from .zoedepth_v1 import ZoeDepth
+
+all_versions = {
+ "v1": ZoeDepth,
+}
+
+get_version = lambda v : all_versions[v]
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json
new file mode 100644
index 0000000000000000000000000000000000000000..3112ed78c89f00e1d13f5d6e5be87cd3216b6dc7
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json
@@ -0,0 +1,58 @@
+{
+ "model": {
+ "name": "ZoeDepth",
+ "version_name": "v1",
+ "n_bins": 64,
+ "bin_embedding_dim": 128,
+ "bin_centers_type": "softplus",
+ "n_attractors":[16, 8, 4, 1],
+ "attractor_alpha": 1000,
+ "attractor_gamma": 2,
+ "attractor_kind" : "mean",
+ "attractor_type" : "inv",
+ "midas_model_type" : "DPT_BEiT_L_384",
+ "min_temp": 0.0212,
+ "max_temp": 50.0,
+ "output_distribution": "logbinomial",
+ "memory_efficient": true,
+ "inverse_midas": false,
+ "img_size": [384, 512]
+ },
+
+ "train": {
+ "train_midas": true,
+ "use_pretrained_midas": true,
+ "trainer": "zoedepth",
+ "epochs": 5,
+ "bs": 16,
+ "optim_kwargs": {"lr": 0.000161, "wd": 0.01},
+ "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true},
+ "same_lr": false,
+ "w_si": 1,
+ "w_domain": 0.2,
+ "w_reg": 0,
+ "w_grad": 0,
+ "avoid_boundary": false,
+ "random_crop": false,
+ "input_width": 640,
+ "input_height": 480,
+ "midas_lr_factor": 1,
+ "encoder_lr_factor":10,
+ "pos_enc_lr_factor":10,
+ "freeze_midas_bn": true
+
+ },
+
+ "infer":{
+ "train_midas": false,
+ "use_pretrained_midas": false,
+ "pretrained_resource" : null,
+ "force_keep_ar": true
+ },
+
+ "eval":{
+ "train_midas": false,
+ "use_pretrained_midas": false,
+ "pretrained_resource" : null
+ }
+}
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json
new file mode 100644
index 0000000000000000000000000000000000000000..b51802aa44b91c39e15aacaac4b5ab6bec884414
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json
@@ -0,0 +1,22 @@
+{
+ "model": {
+ "bin_centers_type": "normed",
+ "img_size": [384, 768]
+ },
+
+ "train": {
+ },
+
+ "infer":{
+ "train_midas": false,
+ "use_pretrained_midas": false,
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt",
+ "force_keep_ar": true
+ },
+
+ "eval":{
+ "train_midas": false,
+ "use_pretrained_midas": false,
+ "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt"
+ }
+}
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc931b059d6165c84e8ff4f09d5c62d19930cee9
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py
@@ -0,0 +1,250 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import itertools
+
+import torch
+import torch.nn as nn
+from ..depth_model import DepthModel
+from ..base_models.midas import MidasCore
+from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed
+from ..layers.dist_layers import ConditionalLogBinomial
+from ..layers.localbins_layers import (Projector, SeedBinRegressor,
+ SeedBinRegressorUnnormed)
+from ..model_io import load_state_from_resource
+
+
+class ZoeDepth(DepthModel):
+ def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10,
+ n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True,
+ midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs):
+ """ZoeDepth model. This is the version of ZoeDepth that has a single metric head
+
+ Args:
+ core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features
+ n_bins (int, optional): Number of bin centers. Defaults to 64.
+ bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers.
+ For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus".
+ bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128.
+ min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3.
+ max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10.
+ n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1].
+ attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300.
+ attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2.
+ attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'.
+ attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'.
+ min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5.
+ max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50.
+ train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True.
+ midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10.
+ encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10.
+ pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10.
+ """
+ super().__init__()
+
+ self.core = core
+ self.max_depth = max_depth
+ self.min_depth = min_depth
+ self.min_temp = min_temp
+ self.bin_centers_type = bin_centers_type
+
+ self.midas_lr_factor = midas_lr_factor
+ self.encoder_lr_factor = encoder_lr_factor
+ self.pos_enc_lr_factor = pos_enc_lr_factor
+ self.train_midas = train_midas
+ self.inverse_midas = inverse_midas
+
+ if self.encoder_lr_factor <= 0:
+ self.core.freeze_encoder(
+ freeze_rel_pos=self.pos_enc_lr_factor <= 0)
+
+ N_MIDAS_OUT = 32
+ btlnck_features = self.core.output_channels[0]
+ num_out_features = self.core.output_channels[1:]
+
+ self.conv2 = nn.Conv2d(btlnck_features, btlnck_features,
+ kernel_size=1, stride=1, padding=0) # btlnck conv
+
+ if bin_centers_type == "normed":
+ SeedBinRegressorLayer = SeedBinRegressor
+ Attractor = AttractorLayer
+ elif bin_centers_type == "softplus":
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
+ Attractor = AttractorLayerUnnormed
+ elif bin_centers_type == "hybrid1":
+ SeedBinRegressorLayer = SeedBinRegressor
+ Attractor = AttractorLayerUnnormed
+ elif bin_centers_type == "hybrid2":
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
+ Attractor = AttractorLayer
+ else:
+ raise ValueError(
+ "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'")
+
+ self.seed_bin_regressor = SeedBinRegressorLayer(
+ btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth)
+ self.seed_projector = Projector(btlnck_features, bin_embedding_dim)
+ self.projectors = nn.ModuleList([
+ Projector(num_out, bin_embedding_dim)
+ for num_out in num_out_features
+ ])
+ self.attractors = nn.ModuleList([
+ Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth,
+ alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type)
+ for i in range(len(num_out_features))
+ ])
+
+ last_in = N_MIDAS_OUT + 1 # +1 for relative depth
+
+ # use log binomial instead of softmax
+ self.conditional_log_binomial = ConditionalLogBinomial(
+ last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp)
+
+ def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs):
+ """
+ Args:
+ x (torch.Tensor): Input image tensor of shape (B, C, H, W)
+ return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False.
+ denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False.
+ return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False.
+
+ Returns:
+ dict: Dictionary containing the following keys:
+ - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W)
+ - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W)
+ - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True
+ - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True
+
+ """
+ b, c, h, w = x.shape
+ # print("input shape ", x.shape)
+ self.orig_input_width = w
+ self.orig_input_height = h
+ rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True)
+ # print("output shapes", rel_depth.shape, out.shape)
+
+ outconv_activation = out[0]
+ btlnck = out[1]
+ x_blocks = out[2:]
+
+ x_d0 = self.conv2(btlnck)
+ x = x_d0
+ _, seed_b_centers = self.seed_bin_regressor(x)
+
+ if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
+ b_prev = (seed_b_centers - self.min_depth) / \
+ (self.max_depth - self.min_depth)
+ else:
+ b_prev = seed_b_centers
+
+ prev_b_embedding = self.seed_projector(x)
+
+ # unroll this loop for better performance
+ for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks):
+ b_embedding = projector(x)
+ b, b_centers = attractor(
+ b_embedding, b_prev, prev_b_embedding, interpolate=True)
+ b_prev = b.clone()
+ prev_b_embedding = b_embedding.clone()
+
+ last = outconv_activation
+
+ if self.inverse_midas:
+ # invert depth followed by normalization
+ rel_depth = 1.0 / (rel_depth + 1e-6)
+ rel_depth = (rel_depth - rel_depth.min()) / \
+ (rel_depth.max() - rel_depth.min())
+ # concat rel depth with last. First interpolate rel depth to last size
+ rel_cond = rel_depth.unsqueeze(1)
+ rel_cond = nn.functional.interpolate(
+ rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True)
+ last = torch.cat([last, rel_cond], dim=1)
+
+ b_embedding = nn.functional.interpolate(
+ b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
+ x = self.conditional_log_binomial(last, b_embedding)
+
+ # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
+ # print(x.shape, b_centers.shape)
+ b_centers = nn.functional.interpolate(
+ b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
+ out = torch.sum(x * b_centers, dim=1, keepdim=True)
+
+ # Structure output dict
+ output = dict(metric_depth=out)
+ if return_final_centers or return_probs:
+ output['bin_centers'] = b_centers
+
+ if return_probs:
+ output['probs'] = x
+
+ return output
+
+ def get_lr_params(self, lr):
+ """
+ Learning rate configuration for different layers of the model
+ Args:
+ lr (float) : Base learning rate
+ Returns:
+ list : list of parameters to optimize and their learning rates, in the format required by torch optimizers.
+ """
+ param_conf = []
+ if self.train_midas:
+ if self.encoder_lr_factor > 0:
+ param_conf.append({'params': self.core.get_enc_params_except_rel_pos(
+ ), 'lr': lr / self.encoder_lr_factor})
+
+ if self.pos_enc_lr_factor > 0:
+ param_conf.append(
+ {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor})
+
+ midas_params = self.core.core.scratch.parameters()
+ midas_lr_factor = self.midas_lr_factor
+ param_conf.append(
+ {'params': midas_params, 'lr': lr / midas_lr_factor})
+
+ remaining_modules = []
+ for name, child in self.named_children():
+ if name != 'core':
+ remaining_modules.append(child)
+ remaining_params = itertools.chain(
+ *[child.parameters() for child in remaining_modules])
+
+ param_conf.append({'params': remaining_params, 'lr': lr})
+
+ return param_conf
+
+ @staticmethod
+ def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs):
+ core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas,
+ train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs)
+ model = ZoeDepth(core, **kwargs)
+ if pretrained_resource:
+ assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
+ model = load_state_from_resource(model, pretrained_resource)
+ return model
+
+ @staticmethod
+ def build_from_config(config):
+ return ZoeDepth.build(**config)
diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..513a278b939c10c010e3c0250ec73544d5663886
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py
@@ -0,0 +1,31 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+from .zoedepth_nk_v1 import ZoeDepthNK
+
+all_versions = {
+ "v1": ZoeDepthNK,
+}
+
+get_version = lambda v : all_versions[v]
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json
new file mode 100644
index 0000000000000000000000000000000000000000..42bab2a3ad159a09599a5aba270c491021a3cf1a
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json
@@ -0,0 +1,67 @@
+{
+ "model": {
+ "name": "ZoeDepthNK",
+ "version_name": "v1",
+ "bin_conf" : [
+ {
+ "name": "nyu",
+ "n_bins": 64,
+ "min_depth": 1e-3,
+ "max_depth": 10.0
+ },
+ {
+ "name": "kitti",
+ "n_bins": 64,
+ "min_depth": 1e-3,
+ "max_depth": 80.0
+ }
+ ],
+ "bin_embedding_dim": 128,
+ "bin_centers_type": "softplus",
+ "n_attractors":[16, 8, 4, 1],
+ "attractor_alpha": 1000,
+ "attractor_gamma": 2,
+ "attractor_kind" : "mean",
+ "attractor_type" : "inv",
+ "min_temp": 0.0212,
+ "max_temp": 50.0,
+ "memory_efficient": true,
+ "midas_model_type" : "DPT_BEiT_L_384",
+ "img_size": [384, 512]
+ },
+
+ "train": {
+ "train_midas": true,
+ "use_pretrained_midas": true,
+ "trainer": "zoedepth_nk",
+ "epochs": 5,
+ "bs": 16,
+ "optim_kwargs": {"lr": 0.0002512, "wd": 0.01},
+ "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true},
+ "same_lr": false,
+ "w_si": 1,
+ "w_domain": 100,
+ "avoid_boundary": false,
+ "random_crop": false,
+ "input_width": 640,
+ "input_height": 480,
+ "w_grad": 0,
+ "w_reg": 0,
+ "midas_lr_factor": 10,
+ "encoder_lr_factor":10,
+ "pos_enc_lr_factor":10
+ },
+
+ "infer": {
+ "train_midas": false,
+ "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt",
+ "use_pretrained_midas": false,
+ "force_keep_ar": true
+ },
+
+ "eval": {
+ "train_midas": false,
+ "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt",
+ "use_pretrained_midas": false
+ }
+}
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..7368ae8031188a9f946d9d3f29633c96e791e68e
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py
@@ -0,0 +1,333 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import itertools
+
+import torch
+import torch.nn as nn
+
+from zoedepth.models.depth_model import DepthModel
+from zoedepth.models.base_models.midas import MidasCore
+from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed
+from zoedepth.models.layers.dist_layers import ConditionalLogBinomial
+from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor,
+ SeedBinRegressorUnnormed)
+from zoedepth.models.layers.patch_transformer import PatchTransformerEncoder
+from zoedepth.models.model_io import load_state_from_resource
+
+
+class ZoeDepthNK(DepthModel):
+ def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128,
+ n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp',
+ min_temp=5, max_temp=50,
+ memory_efficient=False, train_midas=True,
+ is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs):
+ """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts.
+
+ Args:
+ core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features
+
+ bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys:
+ "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float)
+
+ The length of this list determines the number of metric heads.
+ bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers.
+ For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed".
+ bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128.
+
+ n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1].
+ attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300.
+ attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2.
+ attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'.
+ attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'.
+
+ min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5.
+ max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50.
+
+ memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False.
+
+ train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True.
+ is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True.
+ midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10.
+ encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10.
+ pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10.
+
+ """
+
+ super().__init__()
+
+ self.core = core
+ self.bin_conf = bin_conf
+ self.min_temp = min_temp
+ self.max_temp = max_temp
+ self.memory_efficient = memory_efficient
+ self.train_midas = train_midas
+ self.is_midas_pretrained = is_midas_pretrained
+ self.midas_lr_factor = midas_lr_factor
+ self.encoder_lr_factor = encoder_lr_factor
+ self.pos_enc_lr_factor = pos_enc_lr_factor
+ self.inverse_midas = inverse_midas
+
+ N_MIDAS_OUT = 32
+ btlnck_features = self.core.output_channels[0]
+ num_out_features = self.core.output_channels[1:]
+ # self.scales = [16, 8, 4, 2] # spatial scale factors
+
+ self.conv2 = nn.Conv2d(
+ btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0)
+
+ # Transformer classifier on the bottleneck
+ self.patch_transformer = PatchTransformerEncoder(
+ btlnck_features, 1, 128, use_class_token=True)
+ self.mlp_classifier = nn.Sequential(
+ nn.Linear(128, 128),
+ nn.ReLU(),
+ nn.Linear(128, 2)
+ )
+
+ if bin_centers_type == "normed":
+ SeedBinRegressorLayer = SeedBinRegressor
+ Attractor = AttractorLayer
+ elif bin_centers_type == "softplus":
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
+ Attractor = AttractorLayerUnnormed
+ elif bin_centers_type == "hybrid1":
+ SeedBinRegressorLayer = SeedBinRegressor
+ Attractor = AttractorLayerUnnormed
+ elif bin_centers_type == "hybrid2":
+ SeedBinRegressorLayer = SeedBinRegressorUnnormed
+ Attractor = AttractorLayer
+ else:
+ raise ValueError(
+ "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'")
+ self.bin_centers_type = bin_centers_type
+ # We have bins for each bin conf.
+ # Create a map (ModuleDict) of 'name' -> seed_bin_regressor
+ self.seed_bin_regressors = nn.ModuleDict(
+ {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"])
+ for conf in bin_conf}
+ )
+
+ self.seed_projector = Projector(
+ btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2)
+ self.projectors = nn.ModuleList([
+ Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2)
+ for num_out in num_out_features
+ ])
+
+ # Create a map (ModuleDict) of 'name' -> attractors (ModuleList)
+ self.attractors = nn.ModuleDict(
+ {conf['name']: nn.ModuleList([
+ Attractor(bin_embedding_dim, n_attractors[i],
+ mlp_dim=bin_embedding_dim, alpha=attractor_alpha,
+ gamma=attractor_gamma, kind=attractor_kind,
+ attractor_type=attractor_type, memory_efficient=memory_efficient,
+ min_depth=conf["min_depth"], max_depth=conf["max_depth"])
+ for i in range(len(n_attractors))
+ ])
+ for conf in bin_conf}
+ )
+
+ last_in = N_MIDAS_OUT
+ # conditional log binomial for each bin conf
+ self.conditional_log_binomial = nn.ModuleDict(
+ {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp)
+ for conf in bin_conf}
+ )
+
+ def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs):
+ """
+ Args:
+ x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain.
+ return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False.
+ denorm (bool, optional): Whether to denormalize the input image. Defaults to False.
+ return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False.
+
+ Returns:
+ dict: Dictionary of outputs with keys:
+ - "rel_depth": Relative depth map of shape (B, 1, H, W)
+ - "metric_depth": Metric depth map of shape (B, 1, H, W)
+ - "domain_logits": Domain logits of shape (B, 2)
+ - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True
+ - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True
+ """
+ b, c, h, w = x.shape
+ self.orig_input_width = w
+ self.orig_input_height = h
+ rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True)
+
+ outconv_activation = out[0]
+ btlnck = out[1]
+ x_blocks = out[2:]
+
+ x_d0 = self.conv2(btlnck)
+ x = x_d0
+
+ # Predict which path to take
+ embedding = self.patch_transformer(x)[0] # N, E
+ domain_logits = self.mlp_classifier(embedding) # N, 2
+ domain_vote = torch.softmax(domain_logits.sum(
+ dim=0, keepdim=True), dim=-1) # 1, 2
+
+ # Get the path
+ bin_conf_name = ["nyu", "kitti"][torch.argmax(
+ domain_vote, dim=-1).squeeze().item()]
+
+ try:
+ conf = [c for c in self.bin_conf if c.name == bin_conf_name][0]
+ except IndexError:
+ raise ValueError(
+ f"bin_conf_name {bin_conf_name} not found in bin_confs")
+
+ min_depth = conf['min_depth']
+ max_depth = conf['max_depth']
+
+ seed_bin_regressor = self.seed_bin_regressors[bin_conf_name]
+ _, seed_b_centers = seed_bin_regressor(x)
+ if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
+ b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth)
+ else:
+ b_prev = seed_b_centers
+ prev_b_embedding = self.seed_projector(x)
+
+ attractors = self.attractors[bin_conf_name]
+ for projector, attractor, x in zip(self.projectors, attractors, x_blocks):
+ b_embedding = projector(x)
+ b, b_centers = attractor(
+ b_embedding, b_prev, prev_b_embedding, interpolate=True)
+ b_prev = b
+ prev_b_embedding = b_embedding
+
+ last = outconv_activation
+
+ b_centers = nn.functional.interpolate(
+ b_centers, last.shape[-2:], mode='bilinear', align_corners=True)
+ b_embedding = nn.functional.interpolate(
+ b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
+
+ clb = self.conditional_log_binomial[bin_conf_name]
+ x = clb(last, b_embedding)
+
+ # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
+ # print(x.shape, b_centers.shape)
+ # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
+ out = torch.sum(x * b_centers, dim=1, keepdim=True)
+
+ output = dict(domain_logits=domain_logits, metric_depth=out)
+ if return_final_centers or return_probs:
+ output['bin_centers'] = b_centers
+
+ if return_probs:
+ output['probs'] = x
+ return output
+
+ def get_lr_params(self, lr):
+ """
+ Learning rate configuration for different layers of the model
+
+ Args:
+ lr (float) : Base learning rate
+ Returns:
+ list : list of parameters to optimize and their learning rates, in the format required by torch optimizers.
+ """
+ param_conf = []
+ if self.train_midas:
+ def get_rel_pos_params():
+ for name, p in self.core.core.pretrained.named_parameters():
+ if "relative_position" in name:
+ yield p
+
+ def get_enc_params_except_rel_pos():
+ for name, p in self.core.core.pretrained.named_parameters():
+ if "relative_position" not in name:
+ yield p
+
+ encoder_params = get_enc_params_except_rel_pos()
+ rel_pos_params = get_rel_pos_params()
+ midas_params = self.core.core.scratch.parameters()
+ midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0
+ param_conf.extend([
+ {'params': encoder_params, 'lr': lr / self.encoder_lr_factor},
+ {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor},
+ {'params': midas_params, 'lr': lr / midas_lr_factor}
+ ])
+
+ remaining_modules = []
+ for name, child in self.named_children():
+ if name != 'core':
+ remaining_modules.append(child)
+ remaining_params = itertools.chain(
+ *[child.parameters() for child in remaining_modules])
+ param_conf.append({'params': remaining_params, 'lr': lr})
+ return param_conf
+
+ def get_conf_parameters(self, conf_name):
+ """
+ Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
+ """
+ params = []
+ for name, child in self.named_children():
+ if isinstance(child, nn.ModuleDict):
+ for bin_conf_name, module in child.items():
+ if bin_conf_name == conf_name:
+ params += list(module.parameters())
+ return params
+
+ def freeze_conf(self, conf_name):
+ """
+ Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
+ """
+ for p in self.get_conf_parameters(conf_name):
+ p.requires_grad = False
+
+ def unfreeze_conf(self, conf_name):
+ """
+ Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration
+ """
+ for p in self.get_conf_parameters(conf_name):
+ p.requires_grad = True
+
+ def freeze_all_confs(self):
+ """
+ Freezes all the parameters of all the ModuleDicts children
+ """
+ for name, child in self.named_children():
+ if isinstance(child, nn.ModuleDict):
+ for bin_conf_name, module in child.items():
+ for p in module.parameters():
+ p.requires_grad = False
+
+ @staticmethod
+ def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs):
+ core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas,
+ train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs)
+ model = ZoeDepthNK(core, **kwargs)
+ if pretrained_resource:
+ assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
+ model = load_state_from_resource(model, pretrained_resource)
+ return model
+
+ @staticmethod
+ def build_from_config(config):
+ return ZoeDepthNK.build(**config)
diff --git a/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..33fbbea3a7d49efe11b005adb5127f441eabfaf6
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py
@@ -0,0 +1,326 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import os
+import uuid
+import warnings
+from datetime import datetime as dt
+from typing import Dict
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.optim as optim
+import wandb
+from tqdm import tqdm
+
+from zoedepth.utils.config import flatten
+from zoedepth.utils.misc import RunningAverageDict, colorize, colors
+
+
+def is_rank_zero(args):
+ return args.rank == 0
+
+
+class BaseTrainer:
+ def __init__(self, config, model, train_loader, test_loader=None, device=None):
+ """ Base Trainer class for training a model."""
+
+ self.config = config
+ self.metric_criterion = "abs_rel"
+ if device is None:
+ device = torch.device(
+ 'cuda') if torch.cuda.is_available() else torch.device('cpu')
+ self.device = device
+ self.model = model
+ self.train_loader = train_loader
+ self.test_loader = test_loader
+ self.optimizer = self.init_optimizer()
+ self.scheduler = self.init_scheduler()
+
+ def resize_to_target(self, prediction, target):
+ if prediction.shape[2:] != target.shape[-2:]:
+ prediction = nn.functional.interpolate(
+ prediction, size=target.shape[-2:], mode="bilinear", align_corners=True
+ )
+ return prediction
+
+ def load_ckpt(self, checkpoint_dir="./checkpoints", ckpt_type="best"):
+ import glob
+ import os
+
+ from zoedepth.models.model_io import load_wts
+
+ if hasattr(self.config, "checkpoint"):
+ checkpoint = self.config.checkpoint
+ elif hasattr(self.config, "ckpt_pattern"):
+ pattern = self.config.ckpt_pattern
+ matches = glob.glob(os.path.join(
+ checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
+ if not (len(matches) > 0):
+ raise ValueError(f"No matches found for the pattern {pattern}")
+ checkpoint = matches[0]
+ else:
+ return
+ model = load_wts(self.model, checkpoint)
+ # TODO : Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.
+ print("Loaded weights from {0}".format(checkpoint))
+ warnings.warn(
+ "Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.")
+ self.model = model
+
+ def init_optimizer(self):
+ m = self.model.module if self.config.multigpu else self.model
+
+ if self.config.same_lr:
+ print("Using same LR")
+ if hasattr(m, 'core'):
+ m.core.unfreeze()
+ params = self.model.parameters()
+ else:
+ print("Using diff LR")
+ if not hasattr(m, 'get_lr_params'):
+ raise NotImplementedError(
+ f"Model {m.__class__.__name__} does not implement get_lr_params. Please implement it or use the same LR for all parameters.")
+
+ params = m.get_lr_params(self.config.lr)
+
+ return optim.AdamW(params, lr=self.config.lr, weight_decay=self.config.wd)
+
+ def init_scheduler(self):
+ lrs = [l['lr'] for l in self.optimizer.param_groups]
+ return optim.lr_scheduler.OneCycleLR(self.optimizer, lrs, epochs=self.config.epochs, steps_per_epoch=len(self.train_loader),
+ cycle_momentum=self.config.cycle_momentum,
+ base_momentum=0.85, max_momentum=0.95, div_factor=self.config.div_factor, final_div_factor=self.config.final_div_factor, pct_start=self.config.pct_start, three_phase=self.config.three_phase)
+
+ def train_on_batch(self, batch, train_step):
+ raise NotImplementedError
+
+ def validate_on_batch(self, batch, val_step):
+ raise NotImplementedError
+
+ def raise_if_nan(self, losses):
+ for key, value in losses.items():
+ if torch.isnan(value):
+ raise ValueError(f"{key} is NaN, Stopping training")
+
+ @property
+ def iters_per_epoch(self):
+ return len(self.train_loader)
+
+ @property
+ def total_iters(self):
+ return self.config.epochs * self.iters_per_epoch
+
+ def should_early_stop(self):
+ if self.config.get('early_stop', False) and self.step > self.config.early_stop:
+ return True
+
+ def train(self):
+ print(f"Training {self.config.name}")
+ if self.config.uid is None:
+ self.config.uid = str(uuid.uuid4()).split('-')[-1]
+ run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-{self.config.uid}"
+ self.config.run_id = run_id
+ self.config.experiment_id = f"{self.config.name}{self.config.version_name}_{run_id}"
+ self.should_write = ((not self.config.distributed)
+ or self.config.rank == 0)
+ self.should_log = self.should_write # and logging
+ if self.should_log:
+ tags = self.config.tags.split(
+ ',') if self.config.tags != '' else None
+ wandb.init(project=self.config.project, name=self.config.experiment_id, config=flatten(self.config), dir=self.config.root,
+ tags=tags, notes=self.config.notes, settings=wandb.Settings(start_method="fork"))
+
+ self.model.train()
+ self.step = 0
+ best_loss = np.inf
+ validate_every = int(self.config.validate_every * self.iters_per_epoch)
+
+
+ if self.config.prefetch:
+
+ for i, batch in tqdm(enumerate(self.train_loader), desc=f"Prefetching...",
+ total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader):
+ pass
+
+ losses = {}
+ def stringify_losses(L): return "; ".join(map(
+ lambda kv: f"{colors.fg.purple}{kv[0]}{colors.reset}: {round(kv[1].item(),3):.4e}", L.items()))
+ for epoch in range(self.config.epochs):
+ if self.should_early_stop():
+ break
+
+ self.epoch = epoch
+ ################################# Train loop ##########################################################
+ if self.should_log:
+ wandb.log({"Epoch": epoch}, step=self.step)
+ pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train",
+ total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader)
+ for i, batch in pbar:
+ if self.should_early_stop():
+ print("Early stopping")
+ break
+ # print(f"Batch {self.step+1} on rank {self.config.rank}")
+ losses = self.train_on_batch(batch, i)
+ # print(f"trained batch {self.step+1} on rank {self.config.rank}")
+
+ self.raise_if_nan(losses)
+ if is_rank_zero(self.config) and self.config.print_losses:
+ pbar.set_description(
+ f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train. Losses: {stringify_losses(losses)}")
+ self.scheduler.step()
+
+ if self.should_log and self.step % 50 == 0:
+ wandb.log({f"Train/{name}": loss.item()
+ for name, loss in losses.items()}, step=self.step)
+
+ self.step += 1
+
+ ########################################################################################################
+
+ if self.test_loader:
+ if (self.step % validate_every) == 0:
+ self.model.eval()
+ if self.should_write:
+ self.save_checkpoint(
+ f"{self.config.experiment_id}_latest.pt")
+
+ ################################# Validation loop ##################################################
+ # validate on the entire validation set in every process but save only from rank 0, I know, inefficient, but avoids divergence of processes
+ metrics, test_losses = self.validate()
+ # print("Validated: {}".format(metrics))
+ if self.should_log:
+ wandb.log(
+ {f"Test/{name}": tloss for name, tloss in test_losses.items()}, step=self.step)
+
+ wandb.log({f"Metrics/{k}": v for k,
+ v in metrics.items()}, step=self.step)
+
+ if (metrics[self.metric_criterion] < best_loss) and self.should_write:
+ self.save_checkpoint(
+ f"{self.config.experiment_id}_best.pt")
+ best_loss = metrics[self.metric_criterion]
+
+ self.model.train()
+
+ if self.config.distributed:
+ dist.barrier()
+ # print(f"Validated: {metrics} on device {self.config.rank}")
+
+ # print(f"Finished step {self.step} on device {self.config.rank}")
+ #################################################################################################
+
+ # Save / validate at the end
+ self.step += 1 # log as final point
+ self.model.eval()
+ self.save_checkpoint(f"{self.config.experiment_id}_latest.pt")
+ if self.test_loader:
+
+ ################################# Validation loop ##################################################
+ metrics, test_losses = self.validate()
+ # print("Validated: {}".format(metrics))
+ if self.should_log:
+ wandb.log({f"Test/{name}": tloss for name,
+ tloss in test_losses.items()}, step=self.step)
+ wandb.log({f"Metrics/{k}": v for k,
+ v in metrics.items()}, step=self.step)
+
+ if (metrics[self.metric_criterion] < best_loss) and self.should_write:
+ self.save_checkpoint(
+ f"{self.config.experiment_id}_best.pt")
+ best_loss = metrics[self.metric_criterion]
+
+ self.model.train()
+
+ def validate(self):
+ with torch.no_grad():
+ losses_avg = RunningAverageDict()
+ metrics_avg = RunningAverageDict()
+ for i, batch in tqdm(enumerate(self.test_loader), desc=f"Epoch: {self.epoch + 1}/{self.config.epochs}. Loop: Validation", total=len(self.test_loader), disable=not is_rank_zero(self.config)):
+ metrics, losses = self.validate_on_batch(batch, val_step=i)
+
+ if losses:
+ losses_avg.update(losses)
+ if metrics:
+ metrics_avg.update(metrics)
+
+ return metrics_avg.get_value(), losses_avg.get_value()
+
+ def save_checkpoint(self, filename):
+ if not self.should_write:
+ return
+ root = self.config.save_dir
+ if not os.path.isdir(root):
+ os.makedirs(root)
+
+ fpath = os.path.join(root, filename)
+ m = self.model.module if self.config.multigpu else self.model
+ torch.save(
+ {
+ "model": m.state_dict(),
+ "optimizer": None, # TODO : Change to self.optimizer.state_dict() if resume support is needed, currently None to reduce file size
+ "epoch": self.epoch
+ }, fpath)
+
+ def log_images(self, rgb: Dict[str, list] = {}, depth: Dict[str, list] = {}, scalar_field: Dict[str, list] = {}, prefix="", scalar_cmap="jet", min_depth=None, max_depth=None):
+ if not self.should_log:
+ return
+
+ if min_depth is None:
+ try:
+ min_depth = self.config.min_depth
+ max_depth = self.config.max_depth
+ except AttributeError:
+ min_depth = None
+ max_depth = None
+
+ depth = {k: colorize(v, vmin=min_depth, vmax=max_depth)
+ for k, v in depth.items()}
+ scalar_field = {k: colorize(
+ v, vmin=None, vmax=None, cmap=scalar_cmap) for k, v in scalar_field.items()}
+ images = {**rgb, **depth, **scalar_field}
+ wimages = {
+ prefix+"Predictions": [wandb.Image(v, caption=k) for k, v in images.items()]}
+ wandb.log(wimages, step=self.step)
+
+ def log_line_plot(self, data):
+ if not self.should_log:
+ return
+
+ plt.plot(data)
+ plt.ylabel("Scale factors")
+ wandb.log({"Scale factors": wandb.Image(plt)}, step=self.step)
+ plt.close()
+
+ def log_bar_plot(self, title, labels, values):
+ if not self.should_log:
+ return
+
+ data = [[label, val] for (label, val) in zip(labels, values)]
+ table = wandb.Table(data=data, columns=["label", "value"])
+ wandb.log({title: wandb.plot.bar(table, "label",
+ "value", title=title)}, step=self.step)
diff --git a/src/flux/annotator/zoe/zoedepth/trainers/builder.py b/src/flux/annotator/zoe/zoedepth/trainers/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a663541b08912ebedce21a68c7599ce4c06e85d0
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/trainers/builder.py
@@ -0,0 +1,48 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+from importlib import import_module
+
+
+def get_trainer(config):
+ """Builds and returns a trainer based on the config.
+
+ Args:
+ config (dict): the config dict (typically constructed using utils.config.get_config)
+ config.trainer (str): the name of the trainer to use. The module named "{config.trainer}_trainer" must exist in trainers root module
+
+ Raises:
+ ValueError: If the specified trainer does not exist under trainers/ folder
+
+ Returns:
+ Trainer (inherited from zoedepth.trainers.BaseTrainer): The Trainer object
+ """
+ assert "trainer" in config and config.trainer is not None and config.trainer != '', "Trainer not specified. Config: {0}".format(
+ config)
+ try:
+ Trainer = getattr(import_module(
+ f"zoedepth.trainers.{config.trainer}_trainer"), 'Trainer')
+ except ModuleNotFoundError as e:
+ raise ValueError(f"Trainer {config.trainer}_trainer not found.") from e
+ return Trainer
diff --git a/src/flux/annotator/zoe/zoedepth/trainers/loss.py b/src/flux/annotator/zoe/zoedepth/trainers/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c5a1c15cdf5628c1474c566fdc6e58159d7f5ab
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/trainers/loss.py
@@ -0,0 +1,316 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.cuda.amp as amp
+import numpy as np
+
+
+KEY_OUTPUT = 'metric_depth'
+
+
+def extract_key(prediction, key):
+ if isinstance(prediction, dict):
+ return prediction[key]
+ return prediction
+
+
+# Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7)
+class SILogLoss(nn.Module):
+ """SILog loss (pixel-wise)"""
+ def __init__(self, beta=0.15):
+ super(SILogLoss, self).__init__()
+ self.name = 'SILog'
+ self.beta = beta
+
+ def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False):
+ input = extract_key(input, KEY_OUTPUT)
+ if input.shape[-1] != target.shape[-1] and interpolate:
+ input = nn.functional.interpolate(
+ input, target.shape[-2:], mode='bilinear', align_corners=True)
+ intr_input = input
+ else:
+ intr_input = input
+
+ if target.ndim == 3:
+ target = target.unsqueeze(1)
+
+ if mask is not None:
+ if mask.ndim == 3:
+ mask = mask.unsqueeze(1)
+
+ input = input[mask]
+ target = target[mask]
+
+ with amp.autocast(enabled=False): # amp causes NaNs in this loss function
+ alpha = 1e-7
+ g = torch.log(input + alpha) - torch.log(target + alpha)
+
+ # n, c, h, w = g.shape
+ # norm = 1/(h*w)
+ # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2
+
+ Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2)
+
+ loss = 10 * torch.sqrt(Dg)
+
+ if torch.isnan(loss):
+ print("Nan SILog loss")
+ print("input:", input.shape)
+ print("target:", target.shape)
+ print("G", torch.sum(torch.isnan(g)))
+ print("Input min max", torch.min(input), torch.max(input))
+ print("Target min max", torch.min(target), torch.max(target))
+ print("Dg", torch.isnan(Dg))
+ print("loss", torch.isnan(loss))
+
+ if not return_interpolated:
+ return loss
+
+ return loss, intr_input
+
+
+def grad(x):
+ # x.shape : n, c, h, w
+ diff_x = x[..., 1:, 1:] - x[..., 1:, :-1]
+ diff_y = x[..., 1:, 1:] - x[..., :-1, 1:]
+ mag = diff_x**2 + diff_y**2
+ # angle_ratio
+ angle = torch.atan(diff_y / (diff_x + 1e-10))
+ return mag, angle
+
+
+def grad_mask(mask):
+ return mask[..., 1:, 1:] & mask[..., 1:, :-1] & mask[..., :-1, 1:]
+
+
+class GradL1Loss(nn.Module):
+ """Gradient loss"""
+ def __init__(self):
+ super(GradL1Loss, self).__init__()
+ self.name = 'GradL1'
+
+ def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False):
+ input = extract_key(input, KEY_OUTPUT)
+ if input.shape[-1] != target.shape[-1] and interpolate:
+ input = nn.functional.interpolate(
+ input, target.shape[-2:], mode='bilinear', align_corners=True)
+ intr_input = input
+ else:
+ intr_input = input
+
+ grad_gt = grad(target)
+ grad_pred = grad(input)
+ mask_g = grad_mask(mask)
+
+ loss = nn.functional.l1_loss(grad_pred[0][mask_g], grad_gt[0][mask_g])
+ loss = loss + \
+ nn.functional.l1_loss(grad_pred[1][mask_g], grad_gt[1][mask_g])
+ if not return_interpolated:
+ return loss
+ return loss, intr_input
+
+
+class OrdinalRegressionLoss(object):
+
+ def __init__(self, ord_num, beta, discretization="SID"):
+ self.ord_num = ord_num
+ self.beta = beta
+ self.discretization = discretization
+
+ def _create_ord_label(self, gt):
+ N,one, H, W = gt.shape
+ # print("gt shape:", gt.shape)
+
+ ord_c0 = torch.ones(N, self.ord_num, H, W).to(gt.device)
+ if self.discretization == "SID":
+ label = self.ord_num * torch.log(gt) / np.log(self.beta)
+ else:
+ label = self.ord_num * (gt - 1.0) / (self.beta - 1.0)
+ label = label.long()
+ mask = torch.linspace(0, self.ord_num - 1, self.ord_num, requires_grad=False) \
+ .view(1, self.ord_num, 1, 1).to(gt.device)
+ mask = mask.repeat(N, 1, H, W).contiguous().long()
+ mask = (mask > label)
+ ord_c0[mask] = 0
+ ord_c1 = 1 - ord_c0
+ # implementation according to the paper.
+ # ord_label = torch.ones(N, self.ord_num * 2, H, W).to(gt.device)
+ # ord_label[:, 0::2, :, :] = ord_c0
+ # ord_label[:, 1::2, :, :] = ord_c1
+ # reimplementation for fast speed.
+ ord_label = torch.cat((ord_c0, ord_c1), dim=1)
+ return ord_label, mask
+
+ def __call__(self, prob, gt):
+ """
+ :param prob: ordinal regression probability, N x 2*Ord Num x H x W, torch.Tensor
+ :param gt: depth ground truth, NXHxW, torch.Tensor
+ :return: loss: loss value, torch.float
+ """
+ # N, C, H, W = prob.shape
+ valid_mask = gt > 0.
+ ord_label, mask = self._create_ord_label(gt)
+ # print("prob shape: {}, ord label shape: {}".format(prob.shape, ord_label.shape))
+ entropy = -prob * ord_label
+ loss = torch.sum(entropy, dim=1)[valid_mask.squeeze(1)]
+ return loss.mean()
+
+
+class DiscreteNLLLoss(nn.Module):
+ """Cross entropy loss"""
+ def __init__(self, min_depth=1e-3, max_depth=10, depth_bins=64):
+ super(DiscreteNLLLoss, self).__init__()
+ self.name = 'CrossEntropy'
+ self.ignore_index = -(depth_bins + 1)
+ # self._loss_func = nn.NLLLoss(ignore_index=self.ignore_index)
+ self._loss_func = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.depth_bins = depth_bins
+ self.alpha = 1
+ self.zeta = 1 - min_depth
+ self.beta = max_depth + self.zeta
+
+ def quantize_depth(self, depth):
+ # depth : N1HW
+ # output : NCHW
+
+ # Quantize depth log-uniformly on [1, self.beta] into self.depth_bins bins
+ depth = torch.log(depth / self.alpha) / np.log(self.beta / self.alpha)
+ depth = depth * (self.depth_bins - 1)
+ depth = torch.round(depth)
+ depth = depth.long()
+ return depth
+
+
+
+ def _dequantize_depth(self, depth):
+ """
+ Inverse of quantization
+ depth : NCHW -> N1HW
+ """
+ # Get the center of the bin
+
+
+
+
+ def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False):
+ input = extract_key(input, KEY_OUTPUT)
+ # assert torch.all(input <= 0), "Input should be negative"
+
+ if input.shape[-1] != target.shape[-1] and interpolate:
+ input = nn.functional.interpolate(
+ input, target.shape[-2:], mode='bilinear', align_corners=True)
+ intr_input = input
+ else:
+ intr_input = input
+
+ # assert torch.all(input)<=1)
+ if target.ndim == 3:
+ target = target.unsqueeze(1)
+
+ target = self.quantize_depth(target)
+ if mask is not None:
+ if mask.ndim == 3:
+ mask = mask.unsqueeze(1)
+
+ # Set the mask to ignore_index
+ mask = mask.long()
+ input = input * mask + (1 - mask) * self.ignore_index
+ target = target * mask + (1 - mask) * self.ignore_index
+
+
+
+ input = input.flatten(2) # N, nbins, H*W
+ target = target.flatten(1) # N, H*W
+ loss = self._loss_func(input, target)
+
+ if not return_interpolated:
+ return loss
+ return loss, intr_input
+
+
+
+
+def compute_scale_and_shift(prediction, target, mask):
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
+ a_00 = torch.sum(mask * prediction * prediction, (1, 2))
+ a_01 = torch.sum(mask * prediction, (1, 2))
+ a_11 = torch.sum(mask, (1, 2))
+
+ # right hand side: b = [b_0, b_1]
+ b_0 = torch.sum(mask * prediction * target, (1, 2))
+ b_1 = torch.sum(mask * target, (1, 2))
+
+ # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
+ x_0 = torch.zeros_like(b_0)
+ x_1 = torch.zeros_like(b_1)
+
+ det = a_00 * a_11 - a_01 * a_01
+ # A needs to be a positive definite matrix.
+ valid = det > 0
+
+ x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
+ x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
+
+ return x_0, x_1
+class ScaleAndShiftInvariantLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.name = "SSILoss"
+
+ def forward(self, prediction, target, mask, interpolate=True, return_interpolated=False):
+
+ if prediction.shape[-1] != target.shape[-1] and interpolate:
+ prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
+ intr_input = prediction
+ else:
+ intr_input = prediction
+
+
+ prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
+ assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
+
+ scale, shift = compute_scale_and_shift(prediction, target, mask)
+
+ scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
+
+ loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask])
+ if not return_interpolated:
+ return loss
+ return loss, intr_input
+
+
+
+
+if __name__ == '__main__':
+ # Tests for DiscreteNLLLoss
+ celoss = DiscreteNLLLoss()
+ print(celoss(torch.rand(4, 64, 26, 32)*10, torch.rand(4, 1, 26, 32)*10, ))
+
+ d = torch.Tensor([6.59, 3.8, 10.0])
+ print(celoss.dequantize_depth(celoss.quantize_depth(d)))
diff --git a/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d528ae126f1c51b2f25fd31f94a39591ceb2f43a
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py
@@ -0,0 +1,143 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+
+from zoedepth.trainers.loss import GradL1Loss, SILogLoss
+from zoedepth.utils.config import DATASETS_CONFIG
+from zoedepth.utils.misc import compute_metrics
+
+from .base_trainer import BaseTrainer
+
+
+class Trainer(BaseTrainer):
+ def __init__(self, config, model, train_loader, test_loader=None, device=None):
+ super().__init__(config, model, train_loader,
+ test_loader=test_loader, device=device)
+ self.device = device
+ self.silog_loss = SILogLoss()
+ self.grad_loss = GradL1Loss()
+ self.domain_classifier_loss = nn.CrossEntropyLoss()
+
+ self.scaler = amp.GradScaler(enabled=self.config.use_amp)
+
+ def train_on_batch(self, batch, train_step):
+ """
+ Expects a batch of images and depth as input
+ batch["image"].shape : batch_size, c, h, w
+ batch["depth"].shape : batch_size, 1, h, w
+
+ Assumes all images in a batch are from the same dataset
+ """
+
+ images, depths_gt = batch['image'].to(
+ self.device), batch['depth'].to(self.device)
+ # batch['dataset'] is a tensor strings all valued either 'nyu' or 'kitti'. labels nyu -> 0, kitti -> 1
+ dataset = batch['dataset'][0]
+ # Convert to 0s or 1s
+ domain_labels = torch.Tensor([dataset == 'kitti' for _ in range(
+ images.size(0))]).to(torch.long).to(self.device)
+
+ # m = self.model.module if self.config.multigpu else self.model
+
+ b, c, h, w = images.size()
+ mask = batch["mask"].to(self.device).to(torch.bool)
+
+ losses = {}
+
+ with amp.autocast(enabled=self.config.use_amp):
+ output = self.model(images)
+ pred_depths = output['metric_depth']
+ domain_logits = output['domain_logits']
+
+ l_si, pred = self.silog_loss(
+ pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True)
+ loss = self.config.w_si * l_si
+ losses[self.silog_loss.name] = l_si
+
+ if self.config.w_grad > 0:
+ l_grad = self.grad_loss(pred, depths_gt, mask=mask)
+ loss = loss + self.config.w_grad * l_grad
+ losses[self.grad_loss.name] = l_grad
+ else:
+ l_grad = torch.Tensor([0])
+
+ if self.config.w_domain > 0:
+ l_domain = self.domain_classifier_loss(
+ domain_logits, domain_labels)
+ loss = loss + self.config.w_domain * l_domain
+ losses["DomainLoss"] = l_domain
+ else:
+ l_domain = torch.Tensor([0.])
+
+ self.scaler.scale(loss).backward()
+
+ if self.config.clip_grad > 0:
+ self.scaler.unscale_(self.optimizer)
+ nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.config.clip_grad)
+
+ self.scaler.step(self.optimizer)
+
+ if self.should_log and self.step > 1 and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0:
+ depths_gt[torch.logical_not(mask)] = -99
+ self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train",
+ min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth'])
+
+ self.scaler.update()
+ self.optimizer.zero_grad(set_to_none=True)
+
+ return losses
+
+ def validate_on_batch(self, batch, val_step):
+ images = batch['image'].to(self.device)
+ depths_gt = batch['depth'].to(self.device)
+ dataset = batch['dataset'][0]
+ if 'has_valid_depth' in batch:
+ if not batch['has_valid_depth']:
+ return None, None
+
+ depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0)
+ with amp.autocast(enabled=self.config.use_amp):
+ m = self.model.module if self.config.multigpu else self.model
+ pred_depths = m(images)["metric_depth"]
+ pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0)
+
+ mask = torch.logical_and(
+ depths_gt > self.config.min_depth, depths_gt < self.config.max_depth)
+ with amp.autocast(enabled=self.config.use_amp):
+ l_depth = self.silog_loss(
+ pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True)
+
+ metrics = compute_metrics(depths_gt, pred_depths, **self.config)
+ losses = {f"{self.silog_loss.name}": l_depth.item()}
+
+ if val_step == 1 and self.should_log:
+ depths_gt[torch.logical_not(mask)] = -99
+ self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test",
+ min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth'])
+
+ return metrics, losses
diff --git a/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ac1c24c0512c1c1b191670a7c24abb4fca47ba1
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py
@@ -0,0 +1,177 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+
+from zoedepth.trainers.loss import GradL1Loss, SILogLoss
+from zoedepth.utils.config import DATASETS_CONFIG
+from zoedepth.utils.misc import compute_metrics
+from zoedepth.data.preprocess import get_black_border
+
+from .base_trainer import BaseTrainer
+from torchvision import transforms
+from PIL import Image
+import numpy as np
+
+class Trainer(BaseTrainer):
+ def __init__(self, config, model, train_loader, test_loader=None, device=None):
+ super().__init__(config, model, train_loader,
+ test_loader=test_loader, device=device)
+ self.device = device
+ self.silog_loss = SILogLoss()
+ self.grad_loss = GradL1Loss()
+ self.scaler = amp.GradScaler(enabled=self.config.use_amp)
+
+ def train_on_batch(self, batch, train_step):
+ """
+ Expects a batch of images and depth as input
+ batch["image"].shape : batch_size, c, h, w
+ batch["depth"].shape : batch_size, 1, h, w
+ """
+
+ images, depths_gt = batch['image'].to(
+ self.device), batch['depth'].to(self.device)
+ dataset = batch['dataset'][0]
+
+ b, c, h, w = images.size()
+ mask = batch["mask"].to(self.device).to(torch.bool)
+
+ losses = {}
+
+ with amp.autocast(enabled=self.config.use_amp):
+
+ output = self.model(images)
+ pred_depths = output['metric_depth']
+
+ l_si, pred = self.silog_loss(
+ pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True)
+ loss = self.config.w_si * l_si
+ losses[self.silog_loss.name] = l_si
+
+ if self.config.w_grad > 0:
+ l_grad = self.grad_loss(pred, depths_gt, mask=mask)
+ loss = loss + self.config.w_grad * l_grad
+ losses[self.grad_loss.name] = l_grad
+ else:
+ l_grad = torch.Tensor([0])
+
+ self.scaler.scale(loss).backward()
+
+ if self.config.clip_grad > 0:
+ self.scaler.unscale_(self.optimizer)
+ nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.config.clip_grad)
+
+ self.scaler.step(self.optimizer)
+
+ if self.should_log and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0:
+ # -99 is treated as invalid depth in the log_images function and is colored grey.
+ depths_gt[torch.logical_not(mask)] = -99
+
+ self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train",
+ min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth'])
+
+ if self.config.get("log_rel", False):
+ self.log_images(
+ scalar_field={"RelPred": output["relative_depth"][0]}, prefix="TrainRel")
+
+ self.scaler.update()
+ self.optimizer.zero_grad()
+
+ return losses
+
+ @torch.no_grad()
+ def eval_infer(self, x):
+ with amp.autocast(enabled=self.config.use_amp):
+ m = self.model.module if self.config.multigpu else self.model
+ pred_depths = m(x)['metric_depth']
+ return pred_depths
+
+ @torch.no_grad()
+ def crop_aware_infer(self, x):
+ # if we are not avoiding the black border, we can just use the normal inference
+ if not self.config.get("avoid_boundary", False):
+ return self.eval_infer(x)
+
+ # otherwise, we need to crop the image to avoid the black border
+ # For now, this may be a bit slow due to converting to numpy and back
+ # We assume no normalization is done on the input image
+
+ # get the black border
+ assert x.shape[0] == 1, "Only batch size 1 is supported for now"
+ x_pil = transforms.ToPILImage()(x[0].cpu())
+ x_np = np.array(x_pil, dtype=np.uint8)
+ black_border_params = get_black_border(x_np)
+ top, bottom, left, right = black_border_params.top, black_border_params.bottom, black_border_params.left, black_border_params.right
+ x_np_cropped = x_np[top:bottom, left:right, :]
+ x_cropped = transforms.ToTensor()(Image.fromarray(x_np_cropped))
+
+ # run inference on the cropped image
+ pred_depths_cropped = self.eval_infer(x_cropped.unsqueeze(0).to(self.device))
+
+ # resize the prediction to x_np_cropped's size
+ pred_depths_cropped = nn.functional.interpolate(
+ pred_depths_cropped, size=(x_np_cropped.shape[0], x_np_cropped.shape[1]), mode="bilinear", align_corners=False)
+
+
+ # pad the prediction back to the original size
+ pred_depths = torch.zeros((1, 1, x_np.shape[0], x_np.shape[1]), device=pred_depths_cropped.device, dtype=pred_depths_cropped.dtype)
+ pred_depths[:, :, top:bottom, left:right] = pred_depths_cropped
+
+ return pred_depths
+
+
+
+ def validate_on_batch(self, batch, val_step):
+ images = batch['image'].to(self.device)
+ depths_gt = batch['depth'].to(self.device)
+ dataset = batch['dataset'][0]
+ mask = batch["mask"].to(self.device)
+ if 'has_valid_depth' in batch:
+ if not batch['has_valid_depth']:
+ return None, None
+
+ depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0)
+ mask = mask.squeeze().unsqueeze(0).unsqueeze(0)
+ if dataset == 'nyu':
+ pred_depths = self.crop_aware_infer(images)
+ else:
+ pred_depths = self.eval_infer(images)
+ pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0)
+
+ with amp.autocast(enabled=self.config.use_amp):
+ l_depth = self.silog_loss(
+ pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True)
+
+ metrics = compute_metrics(depths_gt, pred_depths, **self.config)
+ losses = {f"{self.silog_loss.name}": l_depth.item()}
+
+ if val_step == 1 and self.should_log:
+ depths_gt[torch.logical_not(mask)] = -99
+ self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test",
+ min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth'])
+
+ return metrics, losses
diff --git a/src/flux/annotator/zoe/zoedepth/utils/__init__.py b/src/flux/annotator/zoe/zoedepth/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/utils/__init__.py
@@ -0,0 +1,24 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
diff --git a/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py b/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py
@@ -0,0 +1,33 @@
+
+
+def infer_type(x): # hacky way to infer type from string args
+ if not isinstance(x, str):
+ return x
+
+ try:
+ x = int(x)
+ return x
+ except ValueError:
+ pass
+
+ try:
+ x = float(x)
+ return x
+ except ValueError:
+ pass
+
+ return x
+
+
+def parse_unknown(unknown_args):
+ clean = []
+ for a in unknown_args:
+ if "=" in a:
+ k, v = a.split("=")
+ clean.extend([k, v])
+ else:
+ clean.append(a)
+
+ keys = clean[::2]
+ values = clean[1::2]
+ return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)}
diff --git a/src/flux/annotator/zoe/zoedepth/utils/config.py b/src/flux/annotator/zoe/zoedepth/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..84996564663dadf0e720de2a68ef8c53106ed666
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/utils/config.py
@@ -0,0 +1,437 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import json
+import os
+
+from .easydict import EasyDict as edict
+from .arg_utils import infer_type
+
+import pathlib
+import platform
+
+ROOT = pathlib.Path(__file__).parent.parent.resolve()
+
+HOME_DIR = os.path.expanduser("~")
+
+COMMON_CONFIG = {
+ "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"),
+ "project": "ZoeDepth",
+ "tags": '',
+ "notes": "",
+ "gpu": None,
+ "root": ".",
+ "uid": None,
+ "print_losses": False
+}
+
+DATASETS_CONFIG = {
+ "kitti": {
+ "dataset": "kitti",
+ "min_depth": 0.001,
+ "max_depth": 80,
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
+ "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
+ "input_height": 352,
+ "input_width": 1216, # 704
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
+ "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",
+
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+
+ "do_random_rotate": True,
+ "degree": 1.0,
+ "do_kb_crop": True,
+ "garg_crop": True,
+ "eigen_crop": False,
+ "use_right": False
+ },
+ "kitti_test": {
+ "dataset": "kitti",
+ "min_depth": 0.001,
+ "max_depth": 80,
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
+ "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
+ "input_height": 352,
+ "input_width": 1216,
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"),
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"),
+ "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",
+
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+
+ "do_random_rotate": False,
+ "degree": 1.0,
+ "do_kb_crop": True,
+ "garg_crop": True,
+ "eigen_crop": False,
+ "use_right": False
+ },
+ "nyu": {
+ "dataset": "nyu",
+ "avoid_boundary": False,
+ "min_depth": 1e-3, # originally 0.1
+ "max_depth": 10,
+ "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"),
+ "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"),
+ "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt",
+ "input_height": 480,
+ "input_width": 640,
+ "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"),
+ "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"),
+ "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt",
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 10,
+ "min_depth_diff": -10,
+ "max_depth_diff": 10,
+
+ "do_random_rotate": True,
+ "degree": 1.0,
+ "do_kb_crop": False,
+ "garg_crop": False,
+ "eigen_crop": True
+ },
+ "ibims": {
+ "dataset": "ibims",
+ "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 0,
+ "max_depth_eval": 10,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "sunrgbd": {
+ "dataset": "sunrgbd",
+ "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 0,
+ "max_depth_eval": 8,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "diml_indoor": {
+ "dataset": "diml_indoor",
+ "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 0,
+ "max_depth_eval": 10,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "diml_outdoor": {
+ "dataset": "diml_outdoor",
+ "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": False,
+ "min_depth_eval": 2,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80
+ },
+ "diode_indoor": {
+ "dataset": "diode_indoor",
+ "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 10,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "diode_outdoor": {
+ "dataset": "diode_outdoor",
+ "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": False,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80
+ },
+ "hypersim_test": {
+ "dataset": "hypersim_test",
+ "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"),
+ "eigen_crop": True,
+ "garg_crop": False,
+ "do_kb_crop": False,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 10
+ },
+ "vkitti": {
+ "dataset": "vkitti",
+ "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": True,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80
+ },
+ "vkitti2": {
+ "dataset": "vkitti2",
+ "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": True,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80,
+ },
+ "ddad": {
+ "dataset": "ddad",
+ "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"),
+ "eigen_crop": False,
+ "garg_crop": True,
+ "do_kb_crop": True,
+ "min_depth_eval": 1e-3,
+ "max_depth_eval": 80,
+ "min_depth": 1e-3,
+ "max_depth": 80,
+ },
+}
+
+ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"]
+ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"]
+ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR
+
+COMMON_TRAINING_CONFIG = {
+ "dataset": "nyu",
+ "distributed": True,
+ "workers": 16,
+ "clip_grad": 0.1,
+ "use_shared_dict": False,
+ "shared_dict": None,
+ "use_amp": False,
+
+ "aug": True,
+ "random_crop": False,
+ "random_translate": False,
+ "translate_prob": 0.2,
+ "max_translation": 100,
+
+ "validate_every": 0.25,
+ "log_images_every": 0.1,
+ "prefetch": False,
+}
+
+
+def flatten(config, except_keys=('bin_conf')):
+ def recurse(inp):
+ if isinstance(inp, dict):
+ for key, value in inp.items():
+ if key in except_keys:
+ yield (key, value)
+ if isinstance(value, dict):
+ yield from recurse(value)
+ else:
+ yield (key, value)
+
+ return dict(list(recurse(config)))
+
+
+def split_combined_args(kwargs):
+ """Splits the arguments that are combined with '__' into multiple arguments.
+ Combined arguments should have equal number of keys and values.
+ Keys are separated by '__' and Values are separated with ';'.
+ For example, '__n_bins__lr=256;0.001'
+
+ Args:
+ kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format.
+
+ Returns:
+ dict: Parsed dict with the combined arguments split into individual key-value pairs.
+ """
+ new_kwargs = dict(kwargs)
+ for key, value in kwargs.items():
+ if key.startswith("__"):
+ keys = key.split("__")[1:]
+ values = value.split(";")
+ assert len(keys) == len(
+ values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})"
+ for k, v in zip(keys, values):
+ new_kwargs[k] = v
+ return new_kwargs
+
+
+def parse_list(config, key, dtype=int):
+ """Parse a list of values for the key if the value is a string. The values are separated by a comma.
+ Modifies the config in place.
+ """
+ if key in config:
+ if isinstance(config[key], str):
+ config[key] = list(map(dtype, config[key].split(',')))
+ assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]]
+ ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}."
+
+
+def get_model_config(model_name, model_version=None):
+ """Find and parse the .json config file for the model.
+
+ Args:
+ model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory.
+ model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None.
+
+ Returns:
+ easydict: the config dictionary for the model.
+ """
+ config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json"
+ config_file = os.path.join(ROOT, "models", model_name, config_fname)
+ if not os.path.exists(config_file):
+ return None
+
+ with open(config_file, "r") as f:
+ config = edict(json.load(f))
+
+ # handle dictionary inheritance
+ # only training config is supported for inheritance
+ if "inherit" in config.train and config.train.inherit is not None:
+ inherit_config = get_model_config(config.train["inherit"]).train
+ for key, value in inherit_config.items():
+ if key not in config.train:
+ config.train[key] = value
+ return edict(config)
+
+
+def update_model_config(config, mode, model_name, model_version=None, strict=False):
+ model_config = get_model_config(model_name, model_version)
+ if model_config is not None:
+ config = {**config, **
+ flatten({**model_config.model, **model_config[mode]})}
+ elif strict:
+ raise ValueError(f"Config file for model {model_name} not found.")
+ return config
+
+
+def check_choices(name, value, choices):
+ # return # No checks in dev branch
+ if value not in choices:
+ raise ValueError(f"{name} {value} not in supported choices {choices}")
+
+
+KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase",
+ "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1
+
+
+def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs):
+ """Main entry point to get the config for the model.
+
+ Args:
+ model_name (str): name of the desired model.
+ mode (str, optional): "train" or "infer". Defaults to 'train'.
+ dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None.
+
+ Keyword Args: key-value pairs of arguments to overwrite the default config.
+
+ The order of precedence for overwriting the config is (Higher precedence first):
+ # 1. overwrite_kwargs
+ # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json
+ # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json
+ # 4. common_config: Default config for all models specified in COMMON_CONFIG
+
+ Returns:
+ easydict: The config dictionary for the model.
+ """
+
+
+ check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"])
+ check_choices("Mode", mode, ["train", "infer", "eval"])
+ if mode == "train":
+ check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None])
+
+ config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG})
+ config = update_model_config(config, mode, model_name)
+
+ # update with model version specific config
+ version_name = overwrite_kwargs.get("version_name", config["version_name"])
+ config = update_model_config(config, mode, model_name, version_name)
+
+ # update with config version if specified
+ config_version = overwrite_kwargs.get("config_version", None)
+ if config_version is not None:
+ print("Overwriting config with config_version", config_version)
+ config = update_model_config(config, mode, model_name, config_version)
+
+ # update with overwrite_kwargs
+ # Combined args are useful for hyperparameter search
+ overwrite_kwargs = split_combined_args(overwrite_kwargs)
+ config = {**config, **overwrite_kwargs}
+
+ # Casting to bool # TODO: Not necessary. Remove and test
+ for key in KEYS_TYPE_BOOL:
+ if key in config:
+ config[key] = bool(config[key])
+
+ # Model specific post processing of config
+ parse_list(config, "n_attractors")
+
+ # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs
+ if 'bin_conf' in config and 'n_bins' in overwrite_kwargs:
+ bin_conf = config['bin_conf'] # list of dicts
+ n_bins = overwrite_kwargs['n_bins']
+ new_bin_conf = []
+ for conf in bin_conf:
+ conf['n_bins'] = n_bins
+ new_bin_conf.append(conf)
+ config['bin_conf'] = new_bin_conf
+
+ if mode == "train":
+ orig_dataset = dataset
+ if dataset == "mix":
+ dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader
+ if dataset is not None:
+ config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb
+
+ if dataset is not None:
+ config['dataset'] = dataset
+ config = {**DATASETS_CONFIG[dataset], **config}
+
+
+ config['model'] = model_name
+ typed_config = {k: infer_type(v) for k, v in config.items()}
+ # add hostname to config
+ config['hostname'] = platform.node()
+ return edict(typed_config)
+
+
+def change_dataset(config, new_dataset):
+ config.update(DATASETS_CONFIG[new_dataset])
+ return config
diff --git a/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py b/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py
@@ -0,0 +1,158 @@
+"""
+EasyDict
+Copy/pasted from https://github.com/makinacorpus/easydict
+Original author: Mathieu Leplatre
+"""
+
+class EasyDict(dict):
+ """
+ Get attributes
+
+ >>> d = EasyDict({'foo':3})
+ >>> d['foo']
+ 3
+ >>> d.foo
+ 3
+ >>> d.bar
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'bar'
+
+ Works recursively
+
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
+ >>> isinstance(d.bar, dict)
+ True
+ >>> d.bar.x
+ 1
+
+ Bullet-proof
+
+ >>> EasyDict({})
+ {}
+ >>> EasyDict(d={})
+ {}
+ >>> EasyDict(None)
+ {}
+ >>> d = {'a': 1}
+ >>> EasyDict(**d)
+ {'a': 1}
+ >>> EasyDict((('a', 1), ('b', 2)))
+ {'a': 1, 'b': 2}
+
+ Set attributes
+
+ >>> d = EasyDict()
+ >>> d.foo = 3
+ >>> d.foo
+ 3
+ >>> d.bar = {'prop': 'value'}
+ >>> d.bar.prop
+ 'value'
+ >>> d
+ {'foo': 3, 'bar': {'prop': 'value'}}
+ >>> d.bar.prop = 'newer'
+ >>> d.bar.prop
+ 'newer'
+
+
+ Values extraction
+
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
+ >>> isinstance(d.bar, list)
+ True
+ >>> from operator import attrgetter
+ >>> list(map(attrgetter('x'), d.bar))
+ [1, 3]
+ >>> list(map(attrgetter('y'), d.bar))
+ [2, 4]
+ >>> d = EasyDict()
+ >>> list(d.keys())
+ []
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
+ >>> d.foo
+ 3
+ >>> d.bar.x
+ 1
+
+ Still like a dict though
+
+ >>> o = EasyDict({'clean':True})
+ >>> list(o.items())
+ [('clean', True)]
+
+ And like a class
+
+ >>> class Flower(EasyDict):
+ ... power = 1
+ ...
+ >>> f = Flower()
+ >>> f.power
+ 1
+ >>> f = Flower({'height': 12})
+ >>> f.height
+ 12
+ >>> f['power']
+ 1
+ >>> sorted(f.keys())
+ ['height', 'power']
+
+ update and pop items
+ >>> d = EasyDict(a=1, b='2')
+ >>> e = EasyDict(c=3.0, a=9.0)
+ >>> d.update(e)
+ >>> d.c
+ 3.0
+ >>> d['c']
+ 3.0
+ >>> d.get('c')
+ 3.0
+ >>> d.update(a=4, b=4)
+ >>> d.b
+ 4
+ >>> d.pop('a')
+ 4
+ >>> d.a
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'a'
+ """
+ def __init__(self, d=None, **kwargs):
+ if d is None:
+ d = {}
+ else:
+ d = dict(d)
+ if kwargs:
+ d.update(**kwargs)
+ for k, v in d.items():
+ setattr(self, k, v)
+ # Class attributes
+ for k in self.__class__.__dict__.keys():
+ if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
+ setattr(self, k, getattr(self, k))
+
+ def __setattr__(self, name, value):
+ if isinstance(value, (list, tuple)):
+ value = [self.__class__(x)
+ if isinstance(x, dict) else x for x in value]
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
+ value = self.__class__(value)
+ super(EasyDict, self).__setattr__(name, value)
+ super(EasyDict, self).__setitem__(name, value)
+
+ __setitem__ = __setattr__
+
+ def update(self, e=None, **f):
+ d = e or dict()
+ d.update(f)
+ for k in d:
+ setattr(self, k, d[k])
+
+ def pop(self, k, d=None):
+ delattr(self, k)
+ return super(EasyDict, self).pop(k, d)
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
\ No newline at end of file
diff --git a/src/flux/annotator/zoe/zoedepth/utils/geometry.py b/src/flux/annotator/zoe/zoedepth/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3da8c75b5a8e39b4b58a4dcd827b84d79b9115c
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/utils/geometry.py
@@ -0,0 +1,98 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+import numpy as np
+
+def get_intrinsics(H,W):
+ """
+ Intrinsics for a pinhole camera model.
+ Assume fov of 55 degrees and central principal point.
+ """
+ f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0)
+ cx = 0.5 * W
+ cy = 0.5 * H
+ return np.array([[f, 0, cx],
+ [0, f, cy],
+ [0, 0, 1]])
+
+def depth_to_points(depth, R=None, t=None):
+
+ K = get_intrinsics(depth.shape[1], depth.shape[2])
+ Kinv = np.linalg.inv(K)
+ if R is None:
+ R = np.eye(3)
+ if t is None:
+ t = np.zeros(3)
+
+ # M converts from your coordinate to PyTorch3D's coordinate system
+ M = np.eye(3)
+ M[0, 0] = -1.0
+ M[1, 1] = -1.0
+
+ height, width = depth.shape[1:3]
+
+ x = np.arange(width)
+ y = np.arange(height)
+ coord = np.stack(np.meshgrid(x, y), -1)
+ coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1
+ coord = coord.astype(np.float32)
+ # coord = torch.as_tensor(coord, dtype=torch.float32, device=device)
+ coord = coord[None] # bs, h, w, 3
+
+ D = depth[:, :, :, None, None]
+ # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape )
+ pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None]
+ # pts3D_1 live in your coordinate system. Convert them to Py3D's
+ pts3D_1 = M[None, None, None, ...] @ pts3D_1
+ # from reference to targe tviewpoint
+ pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None]
+ # pts3D_2 = pts3D_1
+ # depth_2 = pts3D_2[:, :, :, 2, :] # b,1,h,w
+ return pts3D_2[:, :, :, :3, 0][0]
+
+
+def create_triangles(h, w, mask=None):
+ """
+ Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68
+ Creates mesh triangle indices from a given pixel grid size.
+ This function is not and need not be differentiable as triangle indices are
+ fixed.
+ Args:
+ h: (int) denoting the height of the image.
+ w: (int) denoting the width of the image.
+ Returns:
+ triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3)
+ """
+ x, y = np.meshgrid(range(w - 1), range(h - 1))
+ tl = y * w + x
+ tr = y * w + x + 1
+ bl = (y + 1) * w + x
+ br = (y + 1) * w + x + 1
+ triangles = np.array([tl, bl, tr, br, tr, bl])
+ triangles = np.transpose(triangles, (1, 2, 0)).reshape(
+ ((w - 1) * (h - 1) * 2, 3))
+ if mask is not None:
+ mask = mask.reshape(-1)
+ triangles = triangles[mask[triangles].all(1)]
+ return triangles
diff --git a/src/flux/annotator/zoe/zoedepth/utils/misc.py b/src/flux/annotator/zoe/zoedepth/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bbe403d3669829eecdf658458c76aa5e87e2b33
--- /dev/null
+++ b/src/flux/annotator/zoe/zoedepth/utils/misc.py
@@ -0,0 +1,368 @@
+# MIT License
+
+# Copyright (c) 2022 Intelligent Systems Lab Org
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# File author: Shariq Farooq Bhat
+
+"""Miscellaneous utility functions."""
+
+from scipy import ndimage
+
+import base64
+import math
+import re
+from io import BytesIO
+
+import matplotlib
+import matplotlib.cm
+import numpy as np
+import requests
+import torch
+import torch.distributed as dist
+import torch.nn
+import torch.nn as nn
+import torch.utils.data.distributed
+from PIL import Image
+from torchvision.transforms import ToTensor
+
+
+class RunningAverage:
+ def __init__(self):
+ self.avg = 0
+ self.count = 0
+
+ def append(self, value):
+ self.avg = (value + self.count * self.avg) / (self.count + 1)
+ self.count += 1
+
+ def get_value(self):
+ return self.avg
+
+
+def denormalize(x):
+ """Reverses the imagenet normalization applied to the input.
+
+ Args:
+ x (torch.Tensor - shape(N,3,H,W)): input tensor
+
+ Returns:
+ torch.Tensor - shape(N,3,H,W): Denormalized input
+ """
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
+ return x * std + mean
+
+
+class RunningAverageDict:
+ """A dictionary of running averages."""
+ def __init__(self):
+ self._dict = None
+
+ def update(self, new_dict):
+ if new_dict is None:
+ return
+
+ if self._dict is None:
+ self._dict = dict()
+ for key, value in new_dict.items():
+ self._dict[key] = RunningAverage()
+
+ for key, value in new_dict.items():
+ self._dict[key].append(value)
+
+ def get_value(self):
+ if self._dict is None:
+ return None
+ return {key: value.get_value() for key, value in self._dict.items()}
+
+
+def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
+ """Converts a depth map to a color image.
+
+ Args:
+ value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
+ vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
+ vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
+ cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
+ invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
+ invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
+ background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
+ gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
+ value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
+
+ Returns:
+ numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
+ """
+ if isinstance(value, torch.Tensor):
+ value = value.detach().cpu().numpy()
+
+ value = value.squeeze()
+ if invalid_mask is None:
+ invalid_mask = value == invalid_val
+ mask = np.logical_not(invalid_mask)
+
+ # normalize
+ vmin = np.percentile(value[mask],2) if vmin is None else vmin
+ vmax = np.percentile(value[mask],85) if vmax is None else vmax
+ if vmin != vmax:
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
+ else:
+ # Avoid 0-division
+ value = value * 0.
+
+ # squeeze last dim if it exists
+ # grey out the invalid values
+
+ value[invalid_mask] = np.nan
+ cmapper = matplotlib.cm.get_cmap(cmap)
+ if value_transform:
+ value = value_transform(value)
+ # value = value / value.max()
+ value = cmapper(value, bytes=True) # (nxmx4)
+
+ # img = value[:, :, :]
+ img = value[...]
+ img[invalid_mask] = background_color
+
+ # return img.transpose((2, 0, 1))
+ if gamma_corrected:
+ # gamma correction
+ img = img / 255
+ img = np.power(img, 2.2)
+ img = img * 255
+ img = img.astype(np.uint8)
+ return img
+
+
+def count_parameters(model, include_all=False):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all)
+
+
+def compute_errors(gt, pred):
+ """Compute metrics for 'pred' compared to 'gt'
+
+ Args:
+ gt (numpy.ndarray): Ground truth values
+ pred (numpy.ndarray): Predicted values
+
+ gt.shape should be equal to pred.shape
+
+ Returns:
+ dict: Dictionary containing the following metrics:
+ 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25
+ 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2
+ 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3
+ 'abs_rel': Absolute relative error
+ 'rmse': Root mean squared error
+ 'log_10': Absolute log10 error
+ 'sq_rel': Squared relative error
+ 'rmse_log': Root mean squared error on the log scale
+ 'silog': Scale invariant log error
+ """
+ thresh = np.maximum((gt / pred), (pred / gt))
+ a1 = (thresh < 1.25).mean()
+ a2 = (thresh < 1.25 ** 2).mean()
+ a3 = (thresh < 1.25 ** 3).mean()
+
+ abs_rel = np.mean(np.abs(gt - pred) / gt)
+ sq_rel = np.mean(((gt - pred) ** 2) / gt)
+
+ rmse = (gt - pred) ** 2
+ rmse = np.sqrt(rmse.mean())
+
+ rmse_log = (np.log(gt) - np.log(pred)) ** 2
+ rmse_log = np.sqrt(rmse_log.mean())
+
+ err = np.log(pred) - np.log(gt)
+ silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
+
+ log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean()
+ return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log,
+ silog=silog, sq_rel=sq_rel)
+
+
+def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs):
+ """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics.
+ """
+ if 'config' in kwargs:
+ config = kwargs['config']
+ garg_crop = config.garg_crop
+ eigen_crop = config.eigen_crop
+ min_depth_eval = config.min_depth_eval
+ max_depth_eval = config.max_depth_eval
+
+ if gt.shape[-2:] != pred.shape[-2:] and interpolate:
+ pred = nn.functional.interpolate(
+ pred, gt.shape[-2:], mode='bilinear', align_corners=True)
+
+ pred = pred.squeeze().cpu().numpy()
+ pred[pred < min_depth_eval] = min_depth_eval
+ pred[pred > max_depth_eval] = max_depth_eval
+ pred[np.isinf(pred)] = max_depth_eval
+ pred[np.isnan(pred)] = min_depth_eval
+
+ gt_depth = gt.squeeze().cpu().numpy()
+ valid_mask = np.logical_and(
+ gt_depth > min_depth_eval, gt_depth < max_depth_eval)
+
+ if garg_crop or eigen_crop:
+ gt_height, gt_width = gt_depth.shape
+ eval_mask = np.zeros(valid_mask.shape)
+
+ if garg_crop:
+ eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height),
+ int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
+
+ elif eigen_crop:
+ # print("-"*10, " EIGEN CROP ", "-"*10)
+ if dataset == 'kitti':
+ eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height),
+ int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
+ else:
+ # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images"
+ eval_mask[45:471, 41:601] = 1
+ else:
+ eval_mask = np.ones(valid_mask.shape)
+ valid_mask = np.logical_and(valid_mask, eval_mask)
+ return compute_errors(gt_depth[valid_mask], pred[valid_mask])
+
+
+#################################### Model uilts ################################################
+
+
+def parallelize(config, model, find_unused_parameters=True):
+
+ if config.gpu is not None:
+ torch.cuda.set_device(config.gpu)
+ model = model.cuda(config.gpu)
+
+ config.multigpu = False
+ if config.distributed:
+ # Use DDP
+ config.multigpu = True
+ config.rank = config.rank * config.ngpus_per_node + config.gpu
+ dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
+ world_size=config.world_size, rank=config.rank)
+ config.batch_size = int(config.batch_size / config.ngpus_per_node)
+ # config.batch_size = 8
+ config.workers = int(
+ (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node)
+ print("Device", config.gpu, "Rank", config.rank, "batch size",
+ config.batch_size, "Workers", config.workers)
+ torch.cuda.set_device(config.gpu)
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = model.cuda(config.gpu)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu,
+ find_unused_parameters=find_unused_parameters)
+
+ elif config.gpu is None:
+ # Use DP
+ config.multigpu = True
+ model = model.cuda()
+ model = torch.nn.DataParallel(model)
+
+ return model
+
+
+#################################################################################################
+
+
+#####################################################################################################
+
+
+class colors:
+ '''Colors class:
+ Reset all colors with colors.reset
+ Two subclasses fg for foreground and bg for background.
+ Use as colors.subclass.colorname.
+ i.e. colors.fg.red or colors.bg.green
+ Also, the generic bold, disable, underline, reverse, strikethrough,
+ and invisible work with the main class
+ i.e. colors.bold
+ '''
+ reset = '\033[0m'
+ bold = '\033[01m'
+ disable = '\033[02m'
+ underline = '\033[04m'
+ reverse = '\033[07m'
+ strikethrough = '\033[09m'
+ invisible = '\033[08m'
+
+ class fg:
+ black = '\033[30m'
+ red = '\033[31m'
+ green = '\033[32m'
+ orange = '\033[33m'
+ blue = '\033[34m'
+ purple = '\033[35m'
+ cyan = '\033[36m'
+ lightgrey = '\033[37m'
+ darkgrey = '\033[90m'
+ lightred = '\033[91m'
+ lightgreen = '\033[92m'
+ yellow = '\033[93m'
+ lightblue = '\033[94m'
+ pink = '\033[95m'
+ lightcyan = '\033[96m'
+
+ class bg:
+ black = '\033[40m'
+ red = '\033[41m'
+ green = '\033[42m'
+ orange = '\033[43m'
+ blue = '\033[44m'
+ purple = '\033[45m'
+ cyan = '\033[46m'
+ lightgrey = '\033[47m'
+
+
+def printc(text, color):
+ print(f"{color}{text}{colors.reset}")
+
+############################################
+
+def get_image_from_url(url):
+ response = requests.get(url)
+ img = Image.open(BytesIO(response.content)).convert("RGB")
+ return img
+
+def url_to_torch(url, size=(384, 384)):
+ img = get_image_from_url(url)
+ img = img.resize(size, Image.ANTIALIAS)
+ img = torch.from_numpy(np.asarray(img)).float()
+ img = img.permute(2, 0, 1)
+ img.div_(255)
+ return img
+
+def pil_to_batched_tensor(img):
+ return ToTensor()(img).unsqueeze(0)
+
+def save_raw_16bit(depth, fpath="raw.png"):
+ if isinstance(depth, torch.Tensor):
+ depth = depth.squeeze().cpu().numpy()
+
+ assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
+ assert depth.ndim == 2, "Depth must be 2D"
+ depth = depth * 256 # scale for 16-bit png
+ depth = depth.astype(np.uint16)
+ depth = Image.fromarray(depth)
+ depth.save(fpath)
+ print("Saved raw depth to", fpath)
\ No newline at end of file
diff --git a/src/flux/api.py b/src/flux/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08202adb35d2ffae320bb9b47f567e538837836
--- /dev/null
+++ b/src/flux/api.py
@@ -0,0 +1,194 @@
+import io
+import os
+import time
+from pathlib import Path
+
+import requests
+from PIL import Image
+
+API_ENDPOINT = "https://api.bfl.ml"
+
+
+class ApiException(Exception):
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
+ super().__init__()
+ self.detail = detail
+ self.status_code = status_code
+
+ def __str__(self) -> str:
+ return self.__repr__()
+
+ def __repr__(self) -> str:
+ if self.detail is None:
+ message = None
+ elif isinstance(self.detail, str):
+ message = self.detail
+ else:
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
+
+
+class ImageRequest:
+ def __init__(
+ self,
+ prompt: str,
+ width: int = 1024,
+ height: int = 1024,
+ name: str = "flux.1-pro",
+ num_steps: int = 50,
+ prompt_upsampling: bool = False,
+ seed: int | None = None,
+ validate: bool = True,
+ launch: bool = True,
+ api_key: str | None = None,
+ ):
+ """
+ Manages an image generation request to the API.
+
+ Args:
+ prompt: Prompt to sample
+ width: Width of the image in pixel
+ height: Height of the image in pixel
+ name: Name of the model
+ num_steps: Number of network evaluations
+ prompt_upsampling: Use prompt upsampling
+ seed: Fix the generation seed
+ validate: Run input validation
+ launch: Directly launches request
+ api_key: Your API key if not provided by the environment
+
+ Raises:
+ ValueError: For invalid input
+ ApiException: For errors raised from the API
+ """
+ if validate:
+ if name not in ["flux.1-pro"]:
+ raise ValueError(f"Invalid model {name}")
+ elif width % 32 != 0:
+ raise ValueError(f"width must be divisible by 32, got {width}")
+ elif not (256 <= width <= 1440):
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
+ elif height % 32 != 0:
+ raise ValueError(f"height must be divisible by 32, got {height}")
+ elif not (256 <= height <= 1440):
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
+ elif not (1 <= num_steps <= 50):
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
+
+ self.request_json = {
+ "prompt": prompt,
+ "width": width,
+ "height": height,
+ "variant": name,
+ "steps": num_steps,
+ "prompt_upsampling": prompt_upsampling,
+ }
+ if seed is not None:
+ self.request_json["seed"] = seed
+
+ self.request_id: str | None = None
+ self.result: dict | None = None
+ self._image_bytes: bytes | None = None
+ self._url: str | None = None
+ if api_key is None:
+ self.api_key = os.environ.get("BFL_API_KEY")
+ else:
+ self.api_key = api_key
+
+ if launch:
+ self.request()
+
+ def request(self):
+ """
+ Request to generate the image.
+ """
+ if self.request_id is not None:
+ return
+ response = requests.post(
+ f"{API_ENDPOINT}/v1/image",
+ headers={
+ "accept": "application/json",
+ "x-key": self.api_key,
+ "Content-Type": "application/json",
+ },
+ json=self.request_json,
+ )
+ result = response.json()
+ if response.status_code != 200:
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
+ self.request_id = response.json()["id"]
+
+ def retrieve(self) -> dict:
+ """
+ Wait for the generation to finish and retrieve response.
+ """
+ if self.request_id is None:
+ self.request()
+ while self.result is None:
+ response = requests.get(
+ f"{API_ENDPOINT}/v1/get_result",
+ headers={
+ "accept": "application/json",
+ "x-key": self.api_key,
+ },
+ params={
+ "id": self.request_id,
+ },
+ )
+ result = response.json()
+ if "status" not in result:
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
+ elif result["status"] == "Ready":
+ self.result = result["result"]
+ elif result["status"] == "Pending":
+ time.sleep(0.5)
+ else:
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
+ return self.result
+
+ @property
+ def bytes(self) -> bytes:
+ """
+ Generated image as bytes.
+ """
+ if self._image_bytes is None:
+ response = requests.get(self.url)
+ if response.status_code == 200:
+ self._image_bytes = response.content
+ else:
+ raise ApiException(status_code=response.status_code)
+ return self._image_bytes
+
+ @property
+ def url(self) -> str:
+ """
+ Public url to retrieve the image from
+ """
+ if self._url is None:
+ result = self.retrieve()
+ self._url = result["sample"]
+ return self._url
+
+ @property
+ def image(self) -> Image.Image:
+ """
+ Load the image as a PIL Image
+ """
+ return Image.open(io.BytesIO(self.bytes))
+
+ def save(self, path: str):
+ """
+ Save the generated image to a local path
+ """
+ suffix = Path(self.url).suffix
+ if not path.endswith(suffix):
+ path = path + suffix
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "wb") as file:
+ file.write(self.bytes)
+
+
+if __name__ == "__main__":
+ from fire import Fire
+
+ Fire(ImageRequest)
diff --git a/src/flux/cli.py b/src/flux/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3624bc6c387f359162e68f46995b12ce341970a
--- /dev/null
+++ b/src/flux/cli.py
@@ -0,0 +1,254 @@
+import os
+import re
+import time
+from dataclasses import dataclass
+from glob import iglob
+
+import torch
+from einops import rearrange
+from fire import Fire
+from PIL import ExifTags, Image
+
+from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
+from flux.util import (configs, embed_watermark, load_ae, load_clip,
+ load_flow_model, load_t5)
+from transformers import pipeline
+
+NSFW_THRESHOLD = 0.85
+
+@dataclass
+class SamplingOptions:
+ prompt: str
+ width: int
+ height: int
+ num_steps: int
+ guidance: float
+ seed: int | None
+
+
+def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
+ usage = (
+ "Usage: Either write your prompt directly, leave this field empty "
+ "to repeat the prompt or write a command starting with a slash:\n"
+ "- '/w ' will set the width of the generated image\n"
+ "- '/h ' will set the height of the generated image\n"
+ "- '/s ' sets the next seed\n"
+ "- '/g ' sets the guidance (flux-dev only)\n"
+ "- '/n ' sets the number of steps\n"
+ "- '/q' to quit"
+ )
+
+ while (prompt := input(user_question)).startswith("/"):
+ if prompt.startswith("/w"):
+ if prompt.count(" ") != 1:
+ print(f"Got invalid command '{prompt}'\n{usage}")
+ continue
+ _, width = prompt.split()
+ options.width = 16 * (int(width) // 16)
+ print(
+ f"Setting resolution to {options.width} x {options.height} "
+ f"({options.height *options.width/1e6:.2f}MP)"
+ )
+ elif prompt.startswith("/h"):
+ if prompt.count(" ") != 1:
+ print(f"Got invalid command '{prompt}'\n{usage}")
+ continue
+ _, height = prompt.split()
+ options.height = 16 * (int(height) // 16)
+ print(
+ f"Setting resolution to {options.width} x {options.height} "
+ f"({options.height *options.width/1e6:.2f}MP)"
+ )
+ elif prompt.startswith("/g"):
+ if prompt.count(" ") != 1:
+ print(f"Got invalid command '{prompt}'\n{usage}")
+ continue
+ _, guidance = prompt.split()
+ options.guidance = float(guidance)
+ print(f"Setting guidance to {options.guidance}")
+ elif prompt.startswith("/s"):
+ if prompt.count(" ") != 1:
+ print(f"Got invalid command '{prompt}'\n{usage}")
+ continue
+ _, seed = prompt.split()
+ options.seed = int(seed)
+ print(f"Setting seed to {options.seed}")
+ elif prompt.startswith("/n"):
+ if prompt.count(" ") != 1:
+ print(f"Got invalid command '{prompt}'\n{usage}")
+ continue
+ _, steps = prompt.split()
+ options.num_steps = int(steps)
+ print(f"Setting seed to {options.num_steps}")
+ elif prompt.startswith("/q"):
+ print("Quitting")
+ return None
+ else:
+ if not prompt.startswith("/h"):
+ print(f"Got invalid command '{prompt}'\n{usage}")
+ print(usage)
+ if prompt != "":
+ options.prompt = prompt
+ return options
+
+
+@torch.inference_mode()
+def main(
+ name: str = "flux-schnell",
+ width: int = 1360,
+ height: int = 768,
+ seed: int | None = None,
+ prompt: str = (
+ "a photo of a forest with mist swirling around the tree trunks. The word "
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
+ ),
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ num_steps: int | None = None,
+ loop: bool = False,
+ guidance: float = 3.5,
+ offload: bool = False,
+ output_dir: str = "output",
+ add_sampling_metadata: bool = True,
+):
+ """
+ Sample the flux model. Either interactively (set `--loop`) or run for a
+ single image.
+
+ Args:
+ name: Name of the model to load
+ height: height of the sample in pixels (should be a multiple of 16)
+ width: width of the sample in pixels (should be a multiple of 16)
+ seed: Set a seed for sampling
+ output_name: where to save the output image, `{idx}` will be replaced
+ by the index of the sample
+ prompt: Prompt used for sampling
+ device: Pytorch device
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
+ loop: start an interactive session and sample multiple times
+ guidance: guidance value used for guidance distillation
+ add_sampling_metadata: Add the prompt to the image Exif metadata
+ """
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
+
+ if name not in configs:
+ available = ", ".join(configs.keys())
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
+
+ torch_device = torch.device(device)
+ if num_steps is None:
+ num_steps = 4 if name == "flux-schnell" else 50
+
+ # allow for packing and conversion to latent space
+ height = 16 * (height // 16)
+ width = 16 * (width // 16)
+
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ idx = 0
+ else:
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
+ if len(fns) > 0:
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
+ else:
+ idx = 0
+
+ # init all components
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
+ clip = load_clip(torch_device)
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
+ ae = load_ae(name, device="cpu" if offload else torch_device)
+
+ rng = torch.Generator(device="cpu")
+ opts = SamplingOptions(
+ prompt=prompt,
+ width=width,
+ height=height,
+ num_steps=num_steps,
+ guidance=guidance,
+ seed=seed,
+ )
+
+ if loop:
+ opts = parse_prompt(opts)
+
+ while opts is not None:
+ if opts.seed is None:
+ opts.seed = rng.seed()
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
+ t0 = time.perf_counter()
+
+ # prepare input
+ x = get_noise(
+ 1,
+ opts.height,
+ opts.width,
+ device=torch_device,
+ dtype=torch.bfloat16,
+ seed=opts.seed,
+ )
+ opts.seed = None
+ if offload:
+ ae = ae.cpu()
+ torch.cuda.empty_cache()
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
+ inp = prepare(t5, clip, x, prompt=opts.prompt)
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
+
+ # offload TEs to CPU, load model to gpu
+ if offload:
+ t5, clip = t5.cpu(), clip.cpu()
+ torch.cuda.empty_cache()
+ model = model.to(torch_device)
+
+ # denoise initial noise
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
+
+ # offload model, load autoencoder to gpu
+ if offload:
+ model.cpu()
+ torch.cuda.empty_cache()
+ ae.decoder.to(x.device)
+
+ # decode latents to pixel space
+ x = unpack(x.float(), opts.height, opts.width)
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
+ x = ae.decode(x)
+ t1 = time.perf_counter()
+
+ fn = output_name.format(idx=idx)
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
+ # bring into PIL format and save
+ x = x.clamp(-1, 1)
+ x = embed_watermark(x.float())
+ x = rearrange(x[0], "c h w -> h w c")
+
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
+
+ if nsfw_score < NSFW_THRESHOLD:
+ exif_data = Image.Exif()
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
+ exif_data[ExifTags.Base.Model] = name
+ if add_sampling_metadata:
+ exif_data[ExifTags.Base.ImageDescription] = prompt
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
+ idx += 1
+ else:
+ print("Your generated image may contain NSFW content.")
+
+ if loop:
+ print("-" * 80)
+ opts = parse_prompt(opts)
+ else:
+ opts = None
+
+
+def app():
+ Fire(main)
+
+
+if __name__ == "__main__":
+ app()
diff --git a/src/flux/controlnet.py b/src/flux/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5a04cc0234b2b726a550cbe62d027943f6bbcbb
--- /dev/null
+++ b/src/flux/controlnet.py
@@ -0,0 +1,222 @@
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor, nn
+from einops import rearrange
+
+from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
+ MLPEmbedder, SingleStreamBlock,
+ timestep_embedding)
+
+
+@dataclass
+class FluxParams:
+ in_channels: int
+ vec_in_dim: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list[int]
+ theta: int
+ qkv_bias: bool
+ guidance_embed: bool
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+
+class ControlNetFlux(nn.Module):
+ """
+ Transformer model for flow matching on sequences.
+ """
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, params: FluxParams, controlnet_depth=2):
+ super().__init__()
+
+ self.params = params
+ self.in_channels = params.in_channels
+ self.out_channels = self.in_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
+ self.guidance_in = (
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
+ )
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
+
+ self.double_blocks = nn.ModuleList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ )
+ for _ in range(controlnet_depth)
+ ]
+ )
+
+ # add ControlNet blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(controlnet_depth):
+ controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_blocks.append(controlnet_block)
+ self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.gradient_checkpointing = False
+ self.input_hint_block = nn.Sequential(
+ nn.Conv2d(3, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ zero_module(nn.Conv2d(16, 16, 3, padding=1))
+ )
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+
+ @property
+ def attn_processors(self):
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def forward(
+ self,
+ img: Tensor,
+ img_ids: Tensor,
+ controlnet_cond: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ timesteps: Tensor,
+ y: Tensor,
+ guidance: Tensor | None = None,
+ ) -> Tensor:
+ if img.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ # running on sequences img
+ img = self.img_in(img)
+ controlnet_cond = self.input_hint_block(controlnet_cond)
+ controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
+ controlnet_cond = self.pos_embed_input(controlnet_cond)
+ img = img + controlnet_cond
+ vec = self.time_in(timestep_embedding(timesteps, 256))
+ if self.params.guidance_embed:
+ if guidance is None:
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
+ vec = vec + self.vector_in(y)
+ txt = self.txt_in(txt)
+
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ pe = self.pe_embedder(ids)
+
+ block_res_samples = ()
+
+ for block in self.double_blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ img,
+ txt,
+ vec,
+ pe,
+ )
+ else:
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
+
+ block_res_samples = block_res_samples + (img,)
+
+ controlnet_block_res_samples = ()
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
+ block_res_sample = controlnet_block(block_res_sample)
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
+
+ return controlnet_block_res_samples
diff --git a/src/flux/math.py b/src/flux/math.py
new file mode 100644
index 0000000000000000000000000000000000000000..0156bb6a205dec340e029f0c87cf70ae8709ae12
--- /dev/null
+++ b/src/flux/math.py
@@ -0,0 +1,30 @@
+import torch
+from einops import rearrange
+from torch import Tensor
+
+
+def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
+ q, k = apply_rope(q, k, pe)
+
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "B H L D -> B L (H D)")
+
+ return x
+
+
+def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
+ assert dim % 2 == 0
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+ out = torch.einsum("...n,d->...nd", pos, omega)
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
+ return out.float()
+
+
+def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
diff --git a/src/flux/model.py b/src/flux/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e811491dd9861d92e00c5859ba7420ab20394e79
--- /dev/null
+++ b/src/flux/model.py
@@ -0,0 +1,329 @@
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor, nn
+from einops import rearrange
+
+from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
+ MLPEmbedder, SingleStreamBlock,
+ timestep_embedding)
+
+
+@dataclass
+class FluxParams:
+ in_channels: int
+ vec_in_dim: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list[int]
+ theta: int
+ qkv_bias: bool
+ guidance_embed: bool
+
+
+class Flux(nn.Module):
+ """
+ Transformer model for flow matching on sequences.
+ """
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, params: FluxParams):
+ super().__init__()
+
+ self.params = params
+ self.in_channels = params.in_channels
+ self.out_channels = self.in_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
+ self.guidance_in = (
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
+ )
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
+
+ self.double_blocks = nn.ModuleList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ )
+ for _ in range(params.depth)
+ ]
+ )
+
+ self.single_blocks = nn.ModuleList(
+ [
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
+ for _ in range(params.depth_single_blocks)
+ ]
+ )
+
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
+ self.gradient_checkpointing = True # False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @property
+ def attn_processors(self):
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def forward(
+ self,
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ timesteps: Tensor,
+ y: Tensor, # clip
+ block_controlnet_hidden_states=None,
+ guidance: Tensor | None = None,
+ image_proj: Tensor | None = None,
+ ip_scale: Tensor | float = 1.0,
+ use_share_weight_referencenet=False,
+ single_img_ids: Tensor | None = None,
+ single_block_refnet=False,
+ double_block_refnet=False,
+ ) -> Tensor:
+ if single_block_refnet or double_block_refnet:
+ assert use_share_weight_referencenet == True
+ if img.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ # running on sequences img
+ img = self.img_in(img)
+ vec = self.time_in(timestep_embedding(timesteps, 256))
+ # print("vec shape 1:", vec.shape)
+ # print("y shape 1:", y.shape)
+ if self.params.guidance_embed:
+ if guidance is None:
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
+ # print("vec shape 1.5:", vec.shape)
+ vec = vec + self.vector_in(y)
+ # print("vec shape 2:", vec.shape)
+ txt = self.txt_in(txt)
+
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ pe = self.pe_embedder(ids)
+ if use_share_weight_referencenet:
+ # print("In img shape:", img.shape)
+ img_latent_length = img.shape[1]
+ single_ids = torch.cat((txt_ids, single_img_ids), dim=1)
+ single_pe = self.pe_embedder(single_ids)
+ if double_block_refnet and (not single_block_refnet):
+ double_block_pe = pe
+ double_block_img = img
+ single_block_pe = single_pe
+
+ elif single_block_refnet and (not double_block_refnet):
+ double_block_pe = single_pe
+ double_block_img = img[:, img_latent_length//2:, :]
+ single_block_pe = pe
+ ref_img_latent = img[:, :img_latent_length//2, :]
+ else:
+ print("RefNet only support either double blocks or single blocks. If you want to turn on all blocks for RefNet, please use Spatial Condition.")
+ raise NotImplementedError
+
+ if block_controlnet_hidden_states is not None:
+ controlnet_depth = len(block_controlnet_hidden_states)
+ for index_block, block in enumerate(self.double_blocks):
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+ if not use_share_weight_referencenet:
+ img, txt = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ img,
+ txt,
+ vec,
+ pe,
+ image_proj,
+ ip_scale,
+ use_reentrant=True,
+ )
+ else:
+ double_block_img, txt = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ double_block_img,
+ txt,
+ vec,
+ double_block_pe,
+ image_proj,
+ ip_scale,
+ use_reentrant=True,
+ )
+ else:
+ if not use_share_weight_referencenet:
+ img, txt = block(
+ img=img,
+ txt=txt,
+ vec=vec,
+ pe=pe,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ )
+ else:
+ double_block_img, txt = block(
+ img=double_block_img,
+ txt=txt,
+ vec=vec,
+ pe=double_block_pe,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ )
+ # controlnet residual
+ if block_controlnet_hidden_states is not None:
+ if not use_share_weight_referencenet:
+ img = img + block_controlnet_hidden_states[index_block % 2]
+ else:
+ double_block_img = double_block_img + block_controlnet_hidden_states[index_block % 2]
+
+ if use_share_weight_referencenet:
+ mid_img = double_block_img
+ # print("After double blocks img shape:",mid_img.shape)
+ if double_block_refnet and (not single_block_refnet):
+ single_block_img = mid_img[:, img_latent_length//2:, :]
+ elif single_block_refnet and (not double_block_refnet):
+ single_block_img = torch.cat([ref_img_latent, mid_img], dim=1)
+ single_block_img = torch.cat((txt, single_block_img), 1)
+ else:
+ img = torch.cat((txt, img), 1)
+ # print("single block input img shape:", single_block_img.shape)
+ for block in self.single_blocks:
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+ if not use_share_weight_referencenet:
+ img = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ img,
+ vec,
+ pe,
+ use_reentrant=True,
+ )
+ else:
+ single_block_img = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ single_block_img,
+ vec,
+ single_block_pe,
+ use_reentrant=True,
+ )
+ else:
+ if not use_share_weight_referencenet:
+ img = block(
+ img,
+ vec=vec,
+ pe=pe,
+ )
+ else:
+ single_block_img = block(
+ single_block_img,
+ vec=vec,
+ pe=single_block_pe,
+ )
+ if use_share_weight_referencenet:
+ out_img = single_block_img
+ if double_block_refnet and (not single_block_refnet):
+ out_img = out_img[:, txt.shape[1]:, ...]
+ elif single_block_refnet and (not double_block_refnet):
+ out_img = out_img[:, txt.shape[1]:, ...]
+ out_img = out_img[:, img_latent_length//2:, :]
+ img = out_img
+ # print("output img shape:", img.shape)
+ else:
+ img = img[:, txt.shape[1] :, ...]
+
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
+ return img
+
+
+
+# In img shape: torch.Size([1, 2048, 3072])
+# After double blocks img shape: torch.Size([1, 1024, 3072])
+# single block input img shape: torch.Size([1, 2560, 3072])
+# output img shape: torch.Size([1, 1024, 3072])
+#
+# In img shape: torch.Size([1, 2048, 3072])
+# After double blocks img shape: torch.Size([1, 2048, 3072]) [78/1966]
+# single block input img shape: torch.Size([1, 1536, 3072])
+# output img shape: torch.Size([1, 1024, 3072])
\ No newline at end of file
diff --git a/src/flux/modules/autoencoder.py b/src/flux/modules/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..75159f711f65f064107a1a1b9be6f09fc9872028
--- /dev/null
+++ b/src/flux/modules/autoencoder.py
@@ -0,0 +1,312 @@
+from dataclasses import dataclass
+
+import torch
+from einops import rearrange
+from torch import Tensor, nn
+
+
+@dataclass
+class AutoEncoderParams:
+ resolution: int
+ in_channels: int
+ ch: int
+ out_ch: int
+ ch_mult: list[int]
+ num_res_blocks: int
+ z_channels: int
+ scale_factor: float
+ shift_factor: float
+
+
+def swish(x: Tensor) -> Tensor:
+ return x * torch.sigmoid(x)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+
+ def attention(self, h_: Tensor) -> Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x + self.proj_out(self.attention(x))
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = swish(h)
+ h = self.conv1(h)
+
+ h = self.norm2(h)
+ h = swish(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x: Tensor):
+ pad = (0, 1, 0, 1)
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x: Tensor):
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ # downsampling
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ block_in = self.ch
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x: Tensor) -> Tensor:
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+ # end
+ h = self.norm_out(h)
+ h = swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ ch: int,
+ out_ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ in_channels: int,
+ resolution: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.ffactor = 2 ** (self.num_resolutions - 1)
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, z: Tensor) -> Tensor:
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class DiagonalGaussian(nn.Module):
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
+ super().__init__()
+ self.sample = sample
+ self.chunk_dim = chunk_dim
+
+ def forward(self, z: Tensor) -> Tensor:
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
+ if self.sample:
+ std = torch.exp(0.5 * logvar)
+ return mean + std * torch.randn_like(mean)
+ else:
+ return mean
+
+
+class AutoEncoder(nn.Module):
+ def __init__(self, params: AutoEncoderParams):
+ super().__init__()
+ self.encoder = Encoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.decoder = Decoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ out_ch=params.out_ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.reg = DiagonalGaussian()
+
+ self.scale_factor = params.scale_factor
+ self.shift_factor = params.shift_factor
+
+ def encode(self, x: Tensor) -> Tensor:
+ z = self.reg(self.encoder(x))
+ z = self.scale_factor * (z - self.shift_factor)
+ return z
+
+ def decode(self, z: Tensor) -> Tensor:
+ z = z / self.scale_factor + self.shift_factor
+ return self.decoder(z)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.decode(self.encode(x))
diff --git a/src/flux/modules/conditioner.py b/src/flux/modules/conditioner.py
new file mode 100644
index 0000000000000000000000000000000000000000..207a1693b2dd9e57b534b43e35f19ee100573aaa
--- /dev/null
+++ b/src/flux/modules/conditioner.py
@@ -0,0 +1,39 @@
+from torch import Tensor, nn
+from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
+ T5Tokenizer)
+
+
+class HFEmbedder(nn.Module):
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
+ super().__init__()
+ # self.is_clip = version.startswith("openai")
+ self.is_clip = "openai" in version
+ self.max_length = max_length
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
+
+ if self.is_clip:
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
+ else:
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
+
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
+
+ def forward(self, text: list[str]) -> Tensor:
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+
+ outputs = self.hf_module(
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
+ attention_mask=None,
+ output_hidden_states=False,
+ )
+ return outputs[self.output_key]
diff --git a/src/flux/modules/layers.py b/src/flux/modules/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9e745832b842c311bc19c951ccb410039a91cfa
--- /dev/null
+++ b/src/flux/modules/layers.py
@@ -0,0 +1,625 @@
+import math
+from dataclasses import dataclass
+
+import torch
+from einops import rearrange
+from torch import Tensor, nn
+
+from ..math import attention, rope
+import torch.nn.functional as F
+
+class EmbedND(nn.Module):
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: Tensor) -> Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+
+ return emb.unsqueeze(1)
+
+
+def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ t = time_factor * t
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
+ t.device
+ )
+
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ if torch.is_floating_point(t):
+ embedding = embedding.to(t)
+ return embedding
+
+
+class MLPEmbedder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.scale = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x: Tensor):
+ x_dtype = x.dtype
+ x = x.float()
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
+ return (x * rrms).to(dtype=x_dtype) * self.scale
+
+
+class QKNorm(torch.nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.query_norm = RMSNorm(dim)
+ self.key_norm = RMSNorm(dim)
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
+ q = self.query_norm(q)
+ k = self.key_norm(k)
+ return q.to(v), k.to(v)
+
+class LoRALinearLayer(nn.Module):
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
+ super().__init__()
+
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ self.network_alpha = network_alpha
+ self.rank = rank
+
+ nn.init.normal_(self.down.weight, std=1 / rank)
+ nn.init.zeros_(self.up.weight)
+
+ def forward(self, hidden_states):
+ orig_dtype = hidden_states.dtype
+ dtype = self.down.weight.dtype
+
+ down_hidden_states = self.down(hidden_states.to(dtype))
+ up_hidden_states = self.up(down_hidden_states)
+
+ if self.network_alpha is not None:
+ up_hidden_states *= self.network_alpha / self.rank
+
+ return up_hidden_states.to(orig_dtype)
+
+class FLuxSelfAttnProcessor:
+ def __call__(self, attn, x, pe, **attention_kwargs):
+ print('2' * 30)
+
+ qkv = attn.qkv(x)
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k = attn.norm(q, k, v)
+ x = attention(q, k, v, pe=pe)
+ x = attn.proj(x)
+ return x
+
+class LoraFluxAttnProcessor(nn.Module):
+
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
+ super().__init__()
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
+ self.lora_weight = lora_weight
+
+
+ def __call__(self, attn, x, pe, **attention_kwargs):
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k = attn.norm(q, k, v)
+ x = attention(q, k, v, pe=pe)
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
+ print('1' * 30)
+ print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm')
+ return x
+
+class SelfAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.norm = QKNorm(head_dim)
+ self.proj = nn.Linear(dim, dim)
+ def forward():
+ pass
+
+
+@dataclass
+class ModulationOut:
+ shift: Tensor
+ scale: Tensor
+ gate: Tensor
+
+
+class Modulation(nn.Module):
+ def __init__(self, dim: int, double: bool):
+ super().__init__()
+ self.is_double = double
+ self.multiplier = 6 if double else 3
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
+
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
+
+ return (
+ ModulationOut(*out[:3]),
+ ModulationOut(*out[3:]) if self.is_double else None,
+ )
+
+class DoubleStreamBlockLoraProcessor(nn.Module):
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
+ super().__init__()
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
+ self.lora_weight = lora_weight
+
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
+ img_mod1, img_mod2 = attn.img_mod(vec)
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
+
+ img_modulated = attn.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
+
+ # prepare txt for attention
+ txt_modulated = attn.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
+
+ # run actual attention
+ q = torch.cat((txt_q, img_q), dim=2)
+ k = torch.cat((txt_k, img_k), dim=2)
+ v = torch.cat((txt_v, img_v), dim=2)
+
+ attn1 = attention(q, k, v, pe=pe)
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
+
+ # calculate the img bloks
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
+
+ # calculate the txt bloks
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
+ return img, txt
+
+class IPDoubleStreamBlockProcessor(nn.Module):
+ """Attention processor for handling IP-adapter with double stream block."""
+
+ def __init__(self, context_dim, hidden_dim):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch."
+ )
+
+ # Ensure context_dim matches the dimension of image_proj
+ self.context_dim = context_dim
+ self.hidden_dim = hidden_dim
+
+ # Initialize projections for IP-adapter
+ self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
+ self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
+
+ nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
+ nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
+
+ nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
+ nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
+
+ def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs):
+
+ # Prepare image for attention
+ img_mod1, img_mod2 = attn.img_mod(vec)
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
+
+ img_modulated = attn.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = attn.img_attn.qkv(img_modulated)
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
+
+ txt_modulated = attn.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
+
+ q = torch.cat((txt_q, img_q), dim=2)
+ k = torch.cat((txt_k, img_k), dim=2)
+ v = torch.cat((txt_v, img_v), dim=2)
+
+ attn1 = attention(q, k, v, pe=pe)
+ txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:]
+
+ # print(f"txt_attn shape: {txt_attn.size()}")
+ # print(f"img_attn shape: {img_attn.size()}")
+
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
+
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
+
+
+ # IP-adapter processing
+ ip_query = img_q # latent sample query
+ ip_key = self.ip_adapter_double_stream_k_proj(image_proj)
+ ip_value = self.ip_adapter_double_stream_v_proj(image_proj)
+
+ # Reshape projections for multi-head attention
+ ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
+ ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
+
+ # Compute attention between IP projections and the latent query
+ ip_attention = F.scaled_dot_product_attention(
+ ip_query,
+ ip_key,
+ ip_value,
+ dropout_p=0.0,
+ is_causal=False
+ )
+ ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim)
+
+ img = img + ip_scale * ip_attention
+
+ return img, txt
+
+# TODO: ReferenceNet Block Processor
+class DoubleStreamBlockProcessor:
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
+ # img: latent; txt: text embedding; vec: timestep embedding; pe: postion embedding;
+ # prepare image for attention
+ img_mod1, img_mod2 = attn.img_mod(vec)
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
+ # prepare image for attention
+ img_modulated = attn.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ # if use_share_weight_referencenet:
+ # real_image_batch_size = 1
+ # ref_img_modulated = img_modulated[:1, ...]
+ # print("ref latent shape:", ref_img_modulated.shape)
+ # noisy_img_modulated = img_modulated[1:, ...]
+ # print("noise latent shape:", noisy_img_modulated.shape)
+ # img_modulated = torch.cat([noisy_img_modulated, ref_img_modulated], dim=1)
+ # print("input latent shape:", img_modulated.shape)
+ img_qkv = attn.img_attn.qkv(img_modulated)
+ # print("img_qkv shape:", img_qkv.shape)
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", B=1, K=3, H=attn.num_heads, D=attn.head_dim)
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
+
+ # prepare txt for attention
+ txt_modulated = attn.txt_norm1(txt)
+
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
+ # print("txt_qkv shape:", txt_qkv.shape)
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", B=1, K=3, H=attn.num_heads, D=attn.head_dim)
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
+ # print("txt_q shape:", txt_q.shape)
+ # print("img_q shape:", img_q.shape)
+ # print("txt_k shape:", txt_k.shape)
+ # print("img_k shape:", img_k.shape)
+ # print("txt_v shape:", txt_v.shape)
+ # print("img_v shape:", img_v.shape)
+ # run actual attention
+ q = torch.cat((txt_q, img_q), dim=2)
+ k = torch.cat((txt_k, img_k), dim=2)
+ v = torch.cat((txt_v, img_v), dim=2)
+ attn1 = attention(q, k, v, pe=pe)
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
+ # print(" txt_attn shape:", txt_attn.shape)
+ # print(" img_attn shape:", img_attn.shape)
+ # print(" img shape:", img.shape)
+ # calculate the img bloks
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
+
+ # calculate the txt bloks
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
+ # raise NotImplementedError
+ # img: torch.Size([1, 1024, 3072])
+ # txt: torch.Size([1, 512, 3072])
+ # txt_q shape: torch.Size([1, 24, 512, 128])
+ # img_q shape: torch.Size([1, 24, 1024, 128])
+ # txt_attn shape: torch.Size([1, 512, 3072])
+ # img_attn shape: torch.Size([1, 1024, 3072])
+ # img shape: torch.Size([1, 1024, 3072])
+ return img, txt
+
+class DoubleStreamBlock(nn.Module):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
+ super().__init__()
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.head_dim = hidden_size // num_heads
+
+ self.img_mod = Modulation(hidden_size, double=True)
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
+
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ self.txt_mod = Modulation(hidden_size, double=True)
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
+
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.txt_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+ processor = DoubleStreamBlockProcessor()
+ self.set_processor(processor)
+
+ def set_processor(self, processor) -> None:
+ self.processor = processor
+
+ def get_processor(self):
+ return self.processor
+
+ def forward(
+ self,
+ img: Tensor,
+ txt: Tensor,
+ vec: Tensor,
+ pe: Tensor,
+ image_proj: Tensor = None,
+ ip_scale: float =1.0,
+ use_share_weight_referencenet=False,
+ ) -> tuple[Tensor, Tensor]:
+ if image_proj is None and use_share_weight_referencenet:
+ return self.processor(self, img, txt, vec, pe, use_share_weight_referencenet)
+ elif image_proj is None:
+ return self.processor(self, img, txt, vec, pe)
+ else:
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
+
+class IPSingleStreamBlockProcessor(nn.Module):
+ """Attention processor for handling IP-adapter with single stream block."""
+ def __init__(self, context_dim, hidden_dim):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch."
+ )
+
+ # Ensure context_dim matches the dimension of image_proj
+ self.context_dim = context_dim
+ self.hidden_dim = hidden_dim
+
+ # Initialize projections for IP-adapter
+ self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
+ self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
+
+ nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight)
+ nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight)
+
+ def __call__(
+ self,
+ attn: nn.Module,
+ x: Tensor,
+ vec: Tensor,
+ pe: Tensor,
+ image_proj: Tensor | None = None,
+ ip_scale: float = 1.0
+ ) -> Tensor:
+
+ mod, _ = attn.modulation(vec)
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
+ q, k = attn.norm(q, k, v)
+
+ # compute attention
+ attn_1 = attention(q, k, v, pe=pe)
+
+ # IP-adapter processing
+ ip_query = q
+ ip_key = self.ip_adapter_single_stream_k_proj(image_proj)
+ ip_value = self.ip_adapter_single_stream_v_proj(image_proj)
+
+ # Reshape projections for multi-head attention
+ ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
+ ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
+
+
+ # Compute attention between IP projections and the latent query
+ ip_attention = F.scaled_dot_product_attention(
+ ip_query,
+ ip_key,
+ ip_value
+ )
+ ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)")
+
+ attn_out = attn_1 + ip_scale * ip_attention
+
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2))
+ out = x + mod.gate * output
+
+ return out
+
+
+class SingleStreamBlockLoraProcessor(nn.Module):
+ def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
+ super().__init__()
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
+ self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
+ self.lora_weight = lora_weight
+
+ def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
+
+ mod, _ = attn.modulation(vec)
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
+ q, k = attn.norm(q, k, v)
+
+ # compute attention
+ attn_1 = attention(q, k, v, pe=pe)
+
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
+ output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
+ output = x + mod.gate * output
+ return output
+
+# TODO: ReferenceNet Block Processor
+class SingleStreamBlockProcessor:
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
+
+ mod, _ = attn.modulation(vec)
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
+ q, k = attn.norm(q, k, v)
+
+ # compute attention
+ attn_1 = attention(q, k, v, pe=pe)
+
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
+ output = x + mod.gate * output
+ return output
+
+class SingleStreamBlock(nn.Module):
+ """
+ A DiT block with parallel linear layers as described in
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: float | None = None,
+ ):
+ super().__init__()
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ # qkv and mlp_in
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
+ # proj and mlp_out
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
+
+ self.norm = QKNorm(self.head_dim)
+
+ self.hidden_size = hidden_size
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.mlp_act = nn.GELU(approximate="tanh")
+ self.modulation = Modulation(hidden_size, double=False)
+
+ processor = SingleStreamBlockProcessor()
+ self.set_processor(processor)
+
+
+ def set_processor(self, processor) -> None:
+ self.processor = processor
+
+ def get_processor(self):
+ return self.processor
+
+ def forward(
+ self,
+ x: Tensor,
+ vec: Tensor,
+ pe: Tensor,
+ image_proj: Tensor | None = None,
+ ip_scale: float = 1.0
+ ) -> Tensor:
+ if image_proj is None:
+ return self.processor(self, x, vec, pe)
+ else:
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
+
+
+
+class LastLayer(nn.Module):
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
+
+class ImageProjModel(torch.nn.Module):
+ """Projection Model
+ https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28
+ """
+
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
+ super().__init__()
+
+ self.generator = None
+ self.cross_attention_dim = cross_attention_dim
+ self.clip_extra_context_tokens = clip_extra_context_tokens
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds):
+ embeds = image_embeds
+ clip_extra_context_tokens = self.proj(embeds).reshape(
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
+ )
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+ return clip_extra_context_tokens
+
diff --git a/src/flux/sampling.py b/src/flux/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..875927d025ddc5618e93a4c85992aad6ffb3846d
--- /dev/null
+++ b/src/flux/sampling.py
@@ -0,0 +1,304 @@
+import math
+from typing import Callable
+
+import torch
+from einops import rearrange, repeat
+from torch import Tensor
+
+from .model import Flux
+from .modules.conditioner import HFEmbedder
+
+
+def get_noise(
+ num_samples: int,
+ height: int,
+ width: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ seed: int,
+):
+ return torch.randn(
+ num_samples,
+ 16,
+ # allow for packing
+ 2 * math.ceil(height / 16),
+ 2 * math.ceil(width / 16),
+ device=device,
+ dtype=dtype,
+ generator=torch.Generator(device=device).manual_seed(seed),
+ )
+
+
+def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str], use_spatial_condition=False, use_share_weight_referencenet=False, share_position_embedding=False) -> dict[str, Tensor]:
+ bs, c, h, w = img.shape
+ if bs == 1 and not isinstance(prompt, str):
+ bs = len(prompt)
+
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
+ if img.shape[0] == 1 and bs > 1:
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
+
+ if use_spatial_condition:
+ if share_position_embedding:
+ img_ids = torch.zeros(h // 2, w // 2, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+ img_ids = torch.cat([img_ids, img_ids], dim=1)
+ else:
+ img_ids = torch.zeros(h // 2, w, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+ elif use_share_weight_referencenet:
+ if share_position_embedding:
+ single_img_ids = torch.zeros(h // 2, w // 2, 3)
+ single_img_ids[..., 1] = single_img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ single_img_ids[..., 2] = single_img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ single_img_ids = repeat(single_img_ids, "h w c -> b (h w) c", b=bs)
+ img_ids = torch.cat([single_img_ids, single_img_ids], dim=1)
+ else:
+ # single_img_position_embedding
+ single_img_ids = torch.zeros(h // 2, w // 2, 3)
+ single_img_ids[..., 1] = single_img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ single_img_ids[..., 2] = single_img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ single_img_ids = repeat(single_img_ids, "h w c -> b (h w) c", b=bs)
+ # ref_and_noise_img_position_embedding
+ img_ids = torch.zeros(h // 2, w, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+ else:
+ img_ids = torch.zeros(h // 2, w // 2, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ txt = t5(prompt)
+ if txt.shape[0] == 1 and bs > 1:
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
+
+ vec = clip(prompt)
+ if vec.shape[0] == 1 and bs > 1:
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
+ if not use_share_weight_referencenet:
+ return {
+ "img": img,
+ "img_ids": img_ids.to(img.device),
+ "txt": txt.to(img.device),
+ "txt_ids": txt_ids.to(img.device),
+ "vec": vec.to(img.device),
+ }
+ else:
+ return {
+ "img": img,
+ "img_ids": img_ids.to(img.device),
+ "txt": txt.to(img.device),
+ "txt_ids": txt_ids.to(img.device),
+ "vec": vec.to(img.device),
+ "single_img_ids": single_img_ids.to(img.device),
+ }
+
+
+def time_shift(mu: float, sigma: float, t: Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+
+def get_lin_function(
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
+) -> Callable[[float], float]:
+ m = (y2 - y1) / (x2 - x1)
+ b = y1 - m * x1
+ return lambda x: m * x + b
+
+
+def get_schedule(
+ num_steps: int,
+ image_seq_len: int,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+ shift: bool = True,
+) -> list[float]:
+ # extra step for zero
+ timesteps = torch.linspace(1, 0, num_steps + 1)
+
+ # shifting the schedule to favor high timesteps for higher signal images
+ if shift:
+ # eastimate mu based on linear estimation between two points
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
+ timesteps = time_shift(mu, 1.0, timesteps)
+
+ return timesteps.tolist()
+
+
+def denoise(
+ model: Flux,
+ # model input
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ vec: Tensor,
+ neg_txt: Tensor,
+ neg_txt_ids: Tensor,
+ neg_vec: Tensor,
+ # sampling parameters
+ timesteps: list[float],
+ guidance: float = 4.0,
+ true_gs = 1,
+ timestep_to_start_cfg=0,
+ # ip-adapter parameters
+ image_proj: Tensor=None,
+ neg_image_proj: Tensor=None,
+ ip_scale: Tensor | float = 1.0,
+ neg_ip_scale: Tensor | float = 1.0,
+ source_image: Tensor=None,
+ use_share_weight_referencenet=False,
+ single_img_ids=None,
+ neg_single_img_ids=None,
+ single_block_refnet=False,
+ double_block_refnet=False,
+):
+ i = 0
+ # this is ignored for schnell
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
+ if source_image is not None: # spatial condition or refnet
+ img = torch.cat([source_image, img],dim=-2)
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ use_share_weight_referencenet=use_share_weight_referencenet,
+ single_img_ids=single_img_ids,
+ single_block_refnet=single_block_refnet,
+ double_block_refnet=double_block_refnet,
+ )
+ if i >= timestep_to_start_cfg:
+ neg_pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=neg_txt,
+ txt_ids=neg_txt_ids,
+ y=neg_vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ image_proj=neg_image_proj,
+ ip_scale=neg_ip_scale,
+ use_share_weight_referencenet=use_share_weight_referencenet,
+ single_img_ids=neg_single_img_ids,
+ single_block_refnet=single_block_refnet,
+ double_block_refnet=double_block_refnet,
+ )
+ pred = neg_pred + true_gs * (pred - neg_pred)
+ if use_share_weight_referencenet:
+ zero_buffer = torch.zeros_like(pred)
+ pred = torch.cat([zero_buffer, pred], dim=1)
+ img = img + (t_prev - t_curr) * pred
+ if (source_image is not None): # spatial condition or refnet
+ latent_length = img.shape[-2] // 2
+ img = img[:,latent_length:,:]
+ i += 1
+ return img
+
+def denoise_controlnet(
+ model: Flux,
+ controlnet:None,
+ # model input
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ vec: Tensor,
+ neg_txt: Tensor,
+ neg_txt_ids: Tensor,
+ neg_vec: Tensor,
+ controlnet_cond,
+ # sampling parameters
+ timesteps: list[float],
+ guidance: float = 4.0,
+ true_gs = 1,
+ controlnet_gs=0.7,
+ timestep_to_start_cfg=0,
+ # ip-adapter parameters
+ image_proj: Tensor=None,
+ neg_image_proj: Tensor=None,
+ ip_scale: Tensor | float = 1,
+ neg_ip_scale: Tensor | float = 1,
+):
+ # this is ignored for schnell
+ i = 0
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
+ block_res_samples = controlnet(
+ img=img,
+ img_ids=img_ids,
+ controlnet_cond=controlnet_cond,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ )
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples],
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ )
+ if i >= timestep_to_start_cfg:
+ neg_block_res_samples = controlnet(
+ img=img,
+ img_ids=img_ids,
+ controlnet_cond=controlnet_cond,
+ txt=neg_txt,
+ txt_ids=neg_txt_ids,
+ y=neg_vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ )
+ neg_pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=neg_txt,
+ txt_ids=neg_txt_ids,
+ y=neg_vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples],
+ image_proj=neg_image_proj,
+ ip_scale=neg_ip_scale,
+ )
+ pred = neg_pred + true_gs * (pred - neg_pred)
+
+ img = img + (t_prev - t_curr) * pred
+
+ i += 1
+ return img
+
+def unpack(x: Tensor, height: int, width: int) -> Tensor:
+ return rearrange(
+ x,
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
+ h=math.ceil(height / 16),
+ w=math.ceil(width / 16),
+ ph=2,
+ pw=2,
+ )
diff --git a/src/flux/util.py b/src/flux/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..186422b1b3a244697966b58e26fa3c1152d493cc
--- /dev/null
+++ b/src/flux/util.py
@@ -0,0 +1,456 @@
+import os
+from dataclasses import dataclass
+
+import torch
+import json
+import cv2
+import numpy as np
+from PIL import Image
+from huggingface_hub import hf_hub_download
+from safetensors import safe_open
+from safetensors.torch import load_file as load_sft
+
+from optimum.quanto import requantize
+
+from .model import Flux, FluxParams
+from .controlnet import ControlNetFlux
+from .modules.autoencoder import AutoEncoder, AutoEncoderParams
+from .modules.conditioner import HFEmbedder
+from .annotator.dwpose import DWposeDetector
+from .annotator.mlsd import MLSDdetector
+from .annotator.canny import CannyDetector
+from .annotator.midas import MidasDetector
+from .annotator.hed import HEDdetector
+from .annotator.tile import TileDetector
+from .annotator.zoe import ZoeDetector
+
+def tensor_to_pil_image(in_image):
+ tensor = in_image.squeeze(0)
+ tensor = (tensor + 1) / 2
+ tensor = tensor * 255
+ numpy_array = tensor.permute(1, 2, 0).byte().numpy()
+ pil_image = Image.fromarray(numpy_array)
+ return pil_image
+
+def save_image(in_image, output_path):
+ tensor = in_image.squeeze(0)
+ tensor = (tensor + 1) / 2
+ tensor = tensor * 255
+ numpy_array = tensor.permute(1, 2, 0).byte().numpy()
+ image = Image.fromarray(numpy_array)
+ image.save(output_path)
+
+def load_safetensors(path):
+ tensors = {}
+ with safe_open(path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ tensors[key] = f.get_tensor(key)
+ return tensors
+
+def get_lora_rank(checkpoint):
+ for k in checkpoint.keys():
+ if k.endswith(".down.weight"):
+ return checkpoint[k].shape[0]
+
+def load_checkpoint(local_path, repo_id, name):
+ if local_path is not None:
+ if '.safetensors' in local_path:
+ print(f"Loading .safetensors checkpoint from {local_path}")
+ checkpoint = load_safetensors(local_path)
+ else:
+ print(f"Loading checkpoint from {local_path}")
+ checkpoint = torch.load(local_path, map_location='cpu')
+ elif repo_id is not None and name is not None:
+ print(f"Loading checkpoint {name} from repo id {repo_id}")
+ checkpoint = load_from_repo_id(repo_id, name)
+ else:
+ raise ValueError(
+ "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
+ )
+ return checkpoint
+
+
+def c_crop(image):
+ width, height = image.size
+ new_size = min(width, height)
+ left = (width - new_size) / 2
+ top = (height - new_size) / 2
+ right = (width + new_size) / 2
+ bottom = (height + new_size) / 2
+ return image.crop((left, top, right, bottom))
+
+def pad64(x):
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+def safer_memory(x):
+ # Fix many MAC/AMD problems
+ return np.ascontiguousarray(x.copy()).copy()
+
+#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17
+#Added upscale_method, mode params
+def resize_image_with_pad(input_image, resolution, skip_hwc3=False, mode='edge'):
+ if skip_hwc3:
+ img = input_image
+ else:
+ img = HWC3(input_image)
+ H_raw, W_raw, _ = img.shape
+ if resolution == 0:
+ return img, lambda x: x
+ k = float(resolution) / float(min(H_raw, W_raw))
+ H_target = int(np.round(float(H_raw) * k))
+ W_target = int(np.round(float(W_raw) * k))
+ img = cv2.resize(img, (W_target, H_target), interpolation=cv2.INTER_AREA)
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
+
+ def remove_pad(x):
+ return safer_memory(x[:H_target, :W_target, ...])
+
+ return safer_memory(img_padded), remove_pad
+
+class Annotator:
+ def __init__(self, name: str, device: str):
+ if name == "canny":
+ processor = CannyDetector()
+ elif name == "openpose":
+ processor = DWposeDetector(device)
+ elif name == "depth":
+ processor = MidasDetector()
+ elif name == "hed":
+ processor = HEDdetector()
+ elif name == "hough":
+ processor = MLSDdetector()
+ elif name == "tile":
+ processor = TileDetector()
+ elif name == "zoe":
+ processor = ZoeDetector()
+ self.name = name
+ self.processor = processor
+
+ def __call__(self, image: Image, width: int, height: int):
+ image = np.array(image)
+ detect_resolution = max(width, height)
+ image, remove_pad = resize_image_with_pad(image, detect_resolution)
+
+ image = np.array(image)
+ if self.name == "canny":
+ result = self.processor(image, low_threshold=100, high_threshold=200)
+ elif self.name == "hough":
+ result = self.processor(image, thr_v=0.05, thr_d=5)
+ elif self.name == "depth":
+ result = self.processor(image)
+ result, _ = result
+ else:
+ result = self.processor(image)
+
+ result = HWC3(remove_pad(result))
+ result = cv2.resize(result, (width, height))
+ return result
+
+
+@dataclass
+class ModelSpec:
+ params: FluxParams
+ ae_params: AutoEncoderParams
+ ckpt_path: str | None
+ ae_path: str | None
+ repo_id: str | None
+ repo_flow: str | None
+ repo_ae: str | None
+ repo_id_ae: str | None
+
+
+configs = {
+ "flux-dev": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-dev",
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux1-dev.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_DEV"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+ "flux-dev-fp8": ModelSpec(
+ repo_id="XLabs-AI/flux-dev-fp8",
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux-dev-fp8.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_DEV_FP8"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+ "flux-schnell": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-schnell",
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux1-schnell.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=False,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+}
+
+
+def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
+ if len(missing) > 0 and len(unexpected) > 0:
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
+ print("\n" + "-" * 79 + "\n")
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
+ elif len(missing) > 0:
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
+ elif len(unexpected) > 0:
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
+
+def load_from_repo_id(repo_id, checkpoint_name):
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
+ sd = load_sft(ckpt_path, device='cpu')
+ return sd
+
+def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
+ # Loading Flux
+ print("Init model")
+ ckpt_path = configs[name].ckpt_path
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_flow is not None
+ and hf_download
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
+
+ with torch.device("meta" if ckpt_path is not None else device):
+ model = Flux(configs[name].params).to(torch.bfloat16)
+
+ if ckpt_path is not None:
+ print("Loading checkpoint")
+ # load_sft doesn't support torch.device
+ sd = load_sft(ckpt_path, device=str(device))
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
+ print_load_warning(missing, unexpected)
+ return model
+
+def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
+ # Loading Flux
+ print("Init model")
+ ckpt_path = configs[name].ckpt_path
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_flow is not None
+ and hf_download
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
+
+ with torch.device("meta" if ckpt_path is not None else device):
+ model = Flux(configs[name].params)
+
+ if ckpt_path is not None:
+ print("Loading checkpoint")
+ # load_sft doesn't support torch.device
+ sd = load_sft(ckpt_path, device=str(device))
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
+ print_load_warning(missing, unexpected)
+ return model
+
+def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
+ # Loading Flux
+ print("Init model")
+ ckpt_path = configs[name].ckpt_path
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_flow is not None
+ and hf_download
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
+ json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
+
+
+ model = Flux(configs[name].params).to(torch.bfloat16)
+
+ print("Loading checkpoint")
+ # load_sft doesn't support torch.device
+ sd = load_sft(ckpt_path, device='cpu')
+ with open(json_path, "r") as f:
+ quantization_map = json.load(f)
+ print("Start a quantization process...")
+ requantize(model, sd, quantization_map, device=device)
+ print("Model is quantized!")
+ return model
+
+def load_controlnet(name, device, transformer=None):
+ with torch.device(device):
+ controlnet = ControlNetFlux(configs[name].params)
+ if transformer is not None:
+ controlnet.load_state_dict(transformer.state_dict(), strict=False)
+ return controlnet
+
+def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
+ t5_path = os.getenv("T5")
+ if t5_path is None:
+ return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
+ else:
+ return HFEmbedder(t5_path, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
+
+def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
+ clip_path = os.getenv("CLIP_VIT")
+ if clip_path is None:
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
+ else:
+ return HFEmbedder(clip_path, max_length=77, torch_dtype=torch.bfloat16).to(device)
+
+
+def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
+ ckpt_path = configs[name].ae_path
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_ae is not None
+ and hf_download
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
+
+ # Loading the autoencoder
+ print("Init AE")
+ with torch.device("meta" if ckpt_path is not None else device):
+ ae = AutoEncoder(configs[name].ae_params)
+
+ if ckpt_path is not None:
+ sd = load_sft(ckpt_path, device=str(device))
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
+ print_load_warning(missing, unexpected)
+ return ae
+
+
+class WatermarkEmbedder:
+ def __init__(self, watermark):
+ self.watermark = watermark
+ self.num_bits = len(WATERMARK_BITS)
+ self.encoder = WatermarkEncoder()
+ self.encoder.set_watermark("bits", self.watermark)
+
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Adds a predefined watermark to the input image
+
+ Args:
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
+
+ Returns:
+ same as input but watermarked
+ """
+ image = 0.5 * image + 0.5
+ squeeze = len(image.shape) == 4
+ if squeeze:
+ image = image[None, ...]
+ n = image.shape[0]
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
+ # watermarking libary expects input as cv2 BGR format
+ for k in range(image_np.shape[0]):
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
+ image.device
+ )
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
+ if squeeze:
+ image = image[0]
+ image = 2 * image - 1
+ return image
+
+
+# A fixed 48-bit message that was choosen at random
+WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
+# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
+WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..cda3a073e84966dfac7742bab40720a3d5145a23
--- /dev/null
+++ b/src/flux/xflux_pipeline.py
@@ -0,0 +1,407 @@
+from PIL import Image, ExifTags
+import numpy as np
+import torch
+from torch import Tensor
+
+from einops import rearrange
+import uuid
+import os
+
+from src.flux.modules.layers import (
+ SingleStreamBlockProcessor,
+ DoubleStreamBlockProcessor,
+ SingleStreamBlockLoraProcessor,
+ DoubleStreamBlockLoraProcessor,
+ IPDoubleStreamBlockProcessor,
+ ImageProjModel,
+)
+from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack
+from src.flux.util import (
+ load_ae,
+ load_clip,
+ load_flow_model,
+ load_t5,
+ load_controlnet,
+ load_flow_model_quintized,
+ Annotator,
+ get_lora_rank,
+ load_checkpoint
+)
+
+from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
+
+class XFluxPipeline:
+ def __init__(self, model_type, device, offload: bool = False):
+ self.device = torch.device(device)
+ self.offload = offload
+ self.model_type = model_type
+
+ self.clip = load_clip(self.device)
+ self.t5 = load_t5(self.device, max_length=512)
+ self.ae = load_ae(model_type, device="cpu" if offload else self.device)
+ if "fp8" in model_type:
+ self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
+ else:
+ self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
+
+ self.image_encoder_path = "openai/clip-vit-large-patch14"
+ self.hf_lora_collection = "XLabs-AI/flux-lora-collection"
+ self.lora_types_to_names = {
+ "realism": "lora.safetensors",
+ }
+ self.controlnet_loaded = False
+ self.ip_loaded = False
+ self.spatial_condition = False
+ self.share_position_embedding = False
+ self.use_share_weight_referencenet = False
+ self.single_block_refnet = False
+ self.double_block_refnet = False
+
+ def set_ip(self, local_path: str = None, repo_id = None, name: str = None):
+ self.model.to(self.device)
+
+ # unpack checkpoint
+ checkpoint = load_checkpoint(local_path, repo_id, name)
+ prefix = "double_blocks."
+ blocks = {}
+ proj = {}
+
+ for key, value in checkpoint.items():
+ if key.startswith(prefix):
+ blocks[key[len(prefix):].replace('.processor.', '.')] = value
+ if key.startswith("ip_adapter_proj_model"):
+ proj[key[len("ip_adapter_proj_model."):]] = value
+
+ # load image encoder
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
+ self.device, dtype=torch.float16
+ )
+ self.clip_image_processor = CLIPImageProcessor()
+
+ # setup image embedding projection model
+ self.improj = ImageProjModel(4096, 768, 4)
+ self.improj.load_state_dict(proj)
+ self.improj = self.improj.to(self.device, dtype=torch.bfloat16)
+
+ ip_attn_procs = {}
+
+ for name, _ in self.model.attn_processors.items():
+ ip_state_dict = {}
+ for k in checkpoint.keys():
+ if name in k:
+ ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k]
+ if ip_state_dict:
+ ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
+ ip_attn_procs[name].load_state_dict(ip_state_dict)
+ ip_attn_procs[name].to(self.device, dtype=torch.bfloat16)
+ else:
+ ip_attn_procs[name] = self.model.attn_processors[name]
+
+ self.model.set_attn_processor(ip_attn_procs)
+ self.ip_loaded = True
+
+ def set_lora(self, local_path: str = None, repo_id: str = None,
+ name: str = None, lora_weight: int = 0.7):
+ checkpoint = load_checkpoint(local_path, repo_id, name)
+ self.update_model_with_lora(checkpoint, lora_weight)
+
+ def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
+ checkpoint = load_checkpoint(
+ None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
+ )
+ self.update_model_with_lora(checkpoint, lora_weight)
+
+ def update_model_with_lora(self, checkpoint, lora_weight):
+ rank = get_lora_rank(checkpoint)
+ lora_attn_procs = {}
+
+ for name, _ in self.model.attn_processors.items():
+ lora_state_dict = {}
+ for k in checkpoint.keys():
+ if name in k:
+ lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
+
+ if len(lora_state_dict):
+ if name.startswith("single_blocks"):
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
+ else:
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
+ lora_attn_procs[name].load_state_dict(lora_state_dict)
+ lora_attn_procs[name].to(self.device)
+ else:
+ if name.startswith("single_blocks"):
+ lora_attn_procs[name] = SingleStreamBlockProcessor()
+ else:
+ lora_attn_procs[name] = DoubleStreamBlockProcessor()
+
+ self.model.set_attn_processor(lora_attn_procs)
+
+ def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
+ self.model.to(self.device)
+ self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16)
+
+ checkpoint = load_checkpoint(local_path, repo_id, name)
+ self.controlnet.load_state_dict(checkpoint, strict=False)
+ self.annotator = Annotator(control_type, self.device)
+ self.controlnet_loaded = True
+ self.control_type = control_type
+
+ def get_image_proj(
+ self,
+ image_prompt: Tensor,
+ ):
+ # encode image-prompt embeds
+ image_prompt = self.clip_image_processor(
+ images=image_prompt,
+ return_tensors="pt"
+ ).pixel_values
+ image_prompt = image_prompt.to(self.image_encoder.device)
+ image_prompt_embeds = self.image_encoder(
+ image_prompt
+ ).image_embeds.to(
+ device=self.device, dtype=torch.bfloat16,
+ )
+ # encode image
+ image_proj = self.improj(image_prompt_embeds)
+ return image_proj
+
+ def __call__(self,
+ prompt: str,
+ image_prompt: Image = None,
+ source_image: Tensor = None,
+ controlnet_image: Image = None,
+ width: int = 512,
+ height: int = 512,
+ guidance: float = 4,
+ num_steps: int = 50,
+ seed: int = 123456789,
+ true_gs: float = 3.5, # 3
+ control_weight: float = 0.9,
+ ip_scale: float = 1.0,
+ neg_ip_scale: float = 1.0,
+ neg_prompt: str = '',
+ neg_image_prompt: Image = None,
+ timestep_to_start_cfg: int = 1, # 0
+ ):
+ width = 16 * (width // 16)
+ height = 16 * (height // 16)
+ image_proj = None
+ neg_image_proj = None
+ if not (image_prompt is None and neg_image_prompt is None) :
+ assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input'
+
+ if image_prompt is None:
+ image_prompt = np.zeros((width, height, 3), dtype=np.uint8)
+ if neg_image_prompt is None:
+ neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8)
+
+ image_proj = self.get_image_proj(image_prompt)
+ neg_image_proj = self.get_image_proj(neg_image_prompt)
+
+ if self.controlnet_loaded:
+ controlnet_image = self.annotator(controlnet_image, width, height)
+ controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
+ controlnet_image = controlnet_image.permute(
+ 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)
+
+ return self.forward(
+ prompt,
+ width,
+ height,
+ guidance,
+ num_steps,
+ seed,
+ controlnet_image,
+ timestep_to_start_cfg=timestep_to_start_cfg,
+ true_gs=true_gs,
+ control_weight=control_weight,
+ neg_prompt=neg_prompt,
+ image_proj=image_proj,
+ neg_image_proj=neg_image_proj,
+ ip_scale=ip_scale,
+ neg_ip_scale=neg_ip_scale,
+ spatial_condition=self.spatial_condition,
+ source_image=source_image,
+ share_position_embedding=self.share_position_embedding
+ )
+
+ @torch.inference_mode()
+ def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance,
+ num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
+ neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
+ lora_weight, local_path, lora_local_path, ip_local_path):
+ if controlnet_image is not None:
+ controlnet_image = Image.fromarray(controlnet_image)
+ if ((self.controlnet_loaded and control_type != self.control_type)
+ or not self.controlnet_loaded):
+ if local_path is not None:
+ self.set_controlnet(control_type, local_path=local_path)
+ else:
+ self.set_controlnet(control_type, local_path=None,
+ repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3",
+ name=f"flux-{control_type}-controlnet-v3.safetensors")
+ if lora_local_path is not None:
+ self.set_lora(local_path=lora_local_path, lora_weight=lora_weight)
+ if image_prompt is not None:
+ image_prompt = Image.fromarray(image_prompt)
+ if neg_image_prompt is not None:
+ neg_image_prompt = Image.fromarray(neg_image_prompt)
+ if not self.ip_loaded:
+ if ip_local_path is not None:
+ self.set_ip(local_path=ip_local_path)
+ else:
+ self.set_ip(repo_id="xlabs-ai/flux-ip-adapter",
+ name="flux-ip-adapter.safetensors")
+ seed = int(seed)
+ if seed == -1:
+ seed = torch.Generator(device="cpu").seed()
+
+ img = self(prompt, image_prompt, controlnet_image, width, height, guidance,
+ num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt,
+ neg_image_prompt, timestep_to_start_cfg)
+
+ filename = f"output/gradio/{uuid.uuid4()}.jpg"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ exif_data = Image.Exif()
+ exif_data[ExifTags.Base.Make] = "XLabs AI"
+ exif_data[ExifTags.Base.Model] = self.model_type
+ img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
+ return img, filename
+
+ def forward(
+ self,
+ prompt,
+ width,
+ height,
+ guidance,
+ num_steps,
+ seed,
+ controlnet_image = None,
+ timestep_to_start_cfg = 0,
+ true_gs = 3.5,
+ control_weight = 0.9,
+ neg_prompt="",
+ image_proj=None,
+ neg_image_proj=None,
+ ip_scale=1.0,
+ neg_ip_scale=1.0,
+ spatial_condition=False,
+ source_image=None,
+ share_position_embedding=False
+ ):
+ x = get_noise(
+ 1, height, width, device=self.device,
+ dtype=torch.bfloat16, seed=seed
+ )
+ timesteps = get_schedule(
+ num_steps,
+ (width // 8) * (height // 8) // (16 * 16),
+ shift=True,
+ )
+ torch.manual_seed(seed)
+ with torch.no_grad():
+ if self.offload:
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
+ # print("x noise shape:", x.shape)
+ inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt, use_spatial_condition=spatial_condition, share_position_embedding=share_position_embedding, use_share_weight_referencenet=self.use_share_weight_referencenet)
+ # print("input img noise shape:", inp_cond['img'].shape)
+ neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt, use_spatial_condition=spatial_condition, share_position_embedding=share_position_embedding, use_share_weight_referencenet=self.use_share_weight_referencenet)
+ if spatial_condition or self.use_share_weight_referencenet:
+ # TODO here:
+ source_image = self.ae.encode(source_image.to(self.device).to(torch.float32))
+ # print("ae source image shape:", source_image.shape)
+ source_image = rearrange(source_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(inp_cond['img'].dtype)
+
+ # print("rearrange ae source image shape:", source_image.shape)
+
+ if self.offload:
+ self.offload_model_to_cpu(self.t5, self.clip)
+ self.model = self.model.to(self.device)
+ if self.controlnet_loaded:
+ x = denoise_controlnet(
+ self.model,
+ img=inp_cond['img'],
+ img_ids=inp_cond['img_ids'],
+ txt=inp_cond['txt'],
+ txt_ids=inp_cond['txt_ids'],
+ vec=inp_cond['vec'],
+ controlnet=self.controlnet,
+ timesteps=timesteps,
+ guidance=guidance,
+ controlnet_cond=controlnet_image,
+ timestep_to_start_cfg=timestep_to_start_cfg,
+ neg_txt=neg_inp_cond['txt'],
+ neg_txt_ids=neg_inp_cond['txt_ids'],
+ neg_vec=neg_inp_cond['vec'],
+ true_gs=true_gs,
+ controlnet_gs=control_weight,
+ image_proj=image_proj,
+ neg_image_proj=neg_image_proj,
+ ip_scale=ip_scale,
+ neg_ip_scale=neg_ip_scale,
+ )
+ else:
+ x = denoise(
+ self.model,
+ img=inp_cond['img'],
+ img_ids=inp_cond['img_ids'],
+ txt=inp_cond['txt'],
+ txt_ids=inp_cond['txt_ids'],
+ vec=inp_cond['vec'],
+ timesteps=timesteps,
+ guidance=guidance,
+ timestep_to_start_cfg=timestep_to_start_cfg,
+ neg_txt=neg_inp_cond['txt'],
+ neg_txt_ids=neg_inp_cond['txt_ids'],
+ neg_vec=neg_inp_cond['vec'],
+ true_gs=true_gs,
+ image_proj=image_proj,
+ neg_image_proj=neg_image_proj,
+ ip_scale=ip_scale,
+ neg_ip_scale=neg_ip_scale,
+ source_image=source_image, # spatial_condition source image
+ use_share_weight_referencenet=self.use_share_weight_referencenet,
+ single_img_ids=inp_cond['single_img_ids'] if self.use_share_weight_referencenet else None,
+ neg_single_img_ids=neg_inp_cond['single_img_ids'] if self.use_share_weight_referencenet else None,
+ single_block_refnet=self.single_block_refnet,
+ double_block_refnet=self.double_block_refnet,
+ )
+
+ if self.offload:
+ self.offload_model_to_cpu(self.model)
+ self.ae.decoder.to(x.device)
+ x = unpack(x.float(), height, width)
+ x = self.ae.decode(x)
+ self.offload_model_to_cpu(self.ae.decoder)
+
+ x1 = x.clamp(-1, 1)
+ x1 = rearrange(x1[-1], "c h w -> h w c")
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
+ return output_img
+
+ def offload_model_to_cpu(self, *models):
+ if not self.offload: return
+ for model in models:
+ model.cpu()
+ torch.cuda.empty_cache()
+
+
+class XFluxSampler(XFluxPipeline):
+ def __init__(self, clip, t5, ae, model, device,controlnet_loaded=False,ip_loaded=False, spatial_condition=False, offload=False, clip_image_processor=None, image_encoder=None, improj=None, share_position_embedding=False, use_share_weight_referencenet=False, single_block_refnet=False, double_block_refnet=False):
+ self.clip = clip
+ self.t5 = t5
+ self.ae = ae
+ self.model = model
+ self.model.eval()
+ self.device = device
+ self.controlnet_loaded = controlnet_loaded
+ self.ip_loaded = ip_loaded
+ self.offload = offload
+ self.clip_image_processor = clip_image_processor
+ self.image_encoder = image_encoder
+ self.improj = improj
+ self.spatial_condition = spatial_condition
+ self.share_position_embedding = share_position_embedding
+ self.use_share_weight_referencenet = use_share_weight_referencenet
+ self.single_block_refnet = single_block_refnet
+ self.double_block_refnet = double_block_refnet
\ No newline at end of file