xiaoyuxi commited on
Commit
cd14f82
·
1 Parent(s): 54c1d7b

support HubMixin

Browse files
.gitignore CHANGED
@@ -23,47 +23,30 @@ __pycache__/
23
  /**/**/__pycache__
24
  /**/__pycache__
25
 
26
- outputs
27
- scripts/lauch_exp/config
28
- scripts/lauch_exp/submit_job.log
29
- scripts/lauch_exp/hydra_output
30
- scripts/lauch_wulan
31
- scripts/custom_video
32
  # ignore the visualizer
33
  viser
34
  viser_result
35
  benchmark/results
36
  benchmark
37
 
38
- ossutil_output
39
-
40
  prev_version
41
  spat_ceres
42
  wandb
43
  *.log
44
  seg_target.py
45
 
46
- eval_davis.py
47
- eval_multiple_gpu.py
48
- eval_pose_scan.py
49
- eval_single_gpu.py
50
-
51
  infer_cam.py
52
  infer_stream.py
53
 
54
  *.egg-info/
55
  **/*.egg-info
56
 
57
- eval_kinectics.py
58
- models/SpaTrackV2/datasets
59
 
60
- scripts
61
  config/fix_2d.yaml
62
 
63
- models/SpaTrackV2/datasets
64
- scripts/
65
 
66
  models/**/build
67
  models/**/dist
68
 
69
- temp_local
 
 
23
  /**/**/__pycache__
24
  /**/__pycache__
25
 
 
 
 
 
 
 
26
  # ignore the visualizer
27
  viser
28
  viser_result
29
  benchmark/results
30
  benchmark
31
 
 
 
32
  prev_version
33
  spat_ceres
34
  wandb
35
  *.log
36
  seg_target.py
37
 
 
 
 
 
 
38
  infer_cam.py
39
  infer_stream.py
40
 
41
  *.egg-info/
42
  **/*.egg-info
43
 
 
 
44
 
 
45
  config/fix_2d.yaml
46
 
 
 
47
 
48
  models/**/build
49
  models/**/dist
50
 
51
+ temp_local
52
+ examples/results
README.md CHANGED
@@ -11,4 +11,4 @@ license: mit
11
  short_description: Official Space for SpatialTrackerV2
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
11
  short_description: Official Space for SpatialTrackerV2
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
_viz/viz_template.html CHANGED
@@ -671,6 +671,38 @@
671
  </div>
672
  </div>
673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
  <div class="settings-group">
675
  <div class="btn-group">
676
  <button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
@@ -739,7 +771,10 @@
739
  showCameraFrustum: document.getElementById('show-camera-frustum'),
740
  frustumSize: document.getElementById('frustum-size'),
741
  hideSettingsBtn: document.getElementById('hide-settings-btn'),
742
- showSettingsBtn: document.getElementById('show-settings-btn')
 
 
 
743
  };
744
 
745
  this.scene = null;
@@ -750,6 +785,12 @@
750
  this.trajectories = [];
751
  this.cameraFrustum = null;
752
 
 
 
 
 
 
 
753
  this.initThreeJS();
754
  this.loadDefaultSettings().then(() => {
755
  this.initEventListeners();
@@ -977,6 +1018,28 @@
977
  this.ui.showSettingsBtn.style.display = 'none';
978
  });
979
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  }
981
 
982
  makeElementDraggable(element) {
@@ -1296,6 +1359,9 @@
1296
 
1297
  this.updateTrajectories(frameIndex);
1298
 
 
 
 
1299
  const progress = (frameIndex + 1) / this.config.totalFrames;
1300
  this.ui.progress.style.width = `${progress * 100}%`;
1301
 
@@ -1752,15 +1818,286 @@
1752
  this.updateCameraFrustum(this.currentFrame);
1753
  }
1754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1755
  resetSettings() {
1756
  if (!this.defaultSettings) return;
1757
 
1758
  this.applyDefaultSettings();
1759
 
 
 
 
 
 
 
1760
  this.updatePointCloudSettings();
1761
  this.updateTrajectorySettings();
1762
  this.updateFrustumDimensions();
1763
 
 
 
 
1764
  this.ui.statusBar.textContent = "Settings reset to defaults";
1765
  this.ui.statusBar.classList.remove('hidden');
1766
 
 
671
  </div>
672
  </div>
673
 
674
+ <div class="settings-group">
675
+ <h3>Keep History</h3>
676
+ <div class="checkbox-container">
677
+ <label class="toggle-switch">
678
+ <input type="checkbox" id="enable-keep-history">
679
+ <span class="toggle-slider"></span>
680
+ </label>
681
+ <label for="enable-keep-history">Enable Keep History</label>
682
+ </div>
683
+ <div class="slider-container">
684
+ <label for="history-stride">Stride</label>
685
+ <select id="history-stride">
686
+ <option value="1">1</option>
687
+ <option value="2">2</option>
688
+ <option value="5" selected>5</option>
689
+ <option value="10">10</option>
690
+ <option value="20">20</option>
691
+ </select>
692
+ </div>
693
+ </div>
694
+
695
+ <div class="settings-group">
696
+ <h3>Background</h3>
697
+ <div class="checkbox-container">
698
+ <label class="toggle-switch">
699
+ <input type="checkbox" id="white-background">
700
+ <span class="toggle-slider"></span>
701
+ </label>
702
+ <label for="white-background">White Background</label>
703
+ </div>
704
+ </div>
705
+
706
  <div class="settings-group">
707
  <div class="btn-group">
708
  <button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
 
771
  showCameraFrustum: document.getElementById('show-camera-frustum'),
772
  frustumSize: document.getElementById('frustum-size'),
773
  hideSettingsBtn: document.getElementById('hide-settings-btn'),
774
+ showSettingsBtn: document.getElementById('show-settings-btn'),
775
+ enableKeepHistory: document.getElementById('enable-keep-history'),
776
+ historyStride: document.getElementById('history-stride'),
777
+ whiteBackground: document.getElementById('white-background')
778
  };
779
 
780
  this.scene = null;
 
785
  this.trajectories = [];
786
  this.cameraFrustum = null;
787
 
788
+ // Keep History functionality
789
+ this.historyPointClouds = [];
790
+ this.historyTrajectories = [];
791
+ this.historyFrames = [];
792
+ this.maxHistoryFrames = 20;
793
+
794
  this.initThreeJS();
795
  this.loadDefaultSettings().then(() => {
796
  this.initEventListeners();
 
1018
  this.ui.showSettingsBtn.style.display = 'none';
1019
  });
1020
  }
1021
+
1022
+ // Keep History event listeners
1023
+ if (this.ui.enableKeepHistory) {
1024
+ this.ui.enableKeepHistory.addEventListener('change', () => {
1025
+ if (!this.ui.enableKeepHistory.checked) {
1026
+ this.clearHistory();
1027
+ }
1028
+ });
1029
+ }
1030
+
1031
+ if (this.ui.historyStride) {
1032
+ this.ui.historyStride.addEventListener('change', () => {
1033
+ this.clearHistory();
1034
+ });
1035
+ }
1036
+
1037
+ // Background toggle event listener
1038
+ if (this.ui.whiteBackground) {
1039
+ this.ui.whiteBackground.addEventListener('change', () => {
1040
+ this.toggleBackground();
1041
+ });
1042
+ }
1043
  }
1044
 
1045
  makeElementDraggable(element) {
 
1359
 
1360
  this.updateTrajectories(frameIndex);
1361
 
1362
+ // Keep History management
1363
+ this.updateHistory(frameIndex);
1364
+
1365
  const progress = (frameIndex + 1) / this.config.totalFrames;
1366
  this.ui.progress.style.width = `${progress * 100}%`;
1367
 
 
1818
  this.updateCameraFrustum(this.currentFrame);
1819
  }
1820
 
1821
+ // Keep History methods
1822
+ updateHistory(frameIndex) {
1823
+ if (!this.ui.enableKeepHistory.checked || !this.data) return;
1824
+
1825
+ const stride = parseInt(this.ui.historyStride.value);
1826
+ const newHistoryFrames = this.calculateHistoryFrames(frameIndex, stride);
1827
+
1828
+ // Check if history frames changed
1829
+ if (this.arraysEqual(this.historyFrames, newHistoryFrames)) return;
1830
+
1831
+ this.clearHistory();
1832
+ this.historyFrames = newHistoryFrames;
1833
+
1834
+ // Create history point clouds and trajectories
1835
+ this.historyFrames.forEach(historyFrame => {
1836
+ if (historyFrame !== frameIndex) {
1837
+ this.createHistoryPointCloud(historyFrame);
1838
+ this.createHistoryTrajectories(historyFrame);
1839
+ }
1840
+ });
1841
+ }
1842
+
1843
+ calculateHistoryFrames(currentFrame, stride) {
1844
+ const frames = [];
1845
+ let frame = 1; // Start from frame 1
1846
+
1847
+ while (frame <= currentFrame && frames.length < this.maxHistoryFrames) {
1848
+ frames.push(frame);
1849
+ frame += stride;
1850
+ }
1851
+
1852
+ // Always include current frame
1853
+ if (!frames.includes(currentFrame)) {
1854
+ frames.push(currentFrame);
1855
+ }
1856
+
1857
+ return frames.sort((a, b) => a - b);
1858
+ }
1859
+
1860
+ createHistoryPointCloud(frameIndex) {
1861
+ const numPoints = this.config.resolution[0] * this.config.resolution[1];
1862
+ const positions = new Float32Array(numPoints * 3);
1863
+ const colors = new Float32Array(numPoints * 3);
1864
+
1865
+ const geometry = new THREE.BufferGeometry();
1866
+ geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3));
1867
+ geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));
1868
+
1869
+ const material = new THREE.PointsMaterial({
1870
+ size: parseFloat(this.ui.pointSize.value),
1871
+ vertexColors: true,
1872
+ transparent: true,
1873
+ opacity: 0.5, // Transparent for history
1874
+ sizeAttenuation: true
1875
+ });
1876
+
1877
+ const historyPointCloud = new THREE.Points(geometry, material);
1878
+ this.scene.add(historyPointCloud);
1879
+ this.historyPointClouds.push(historyPointCloud);
1880
+
1881
+ // Update the history point cloud with data
1882
+ this.updateHistoryPointCloud(historyPointCloud, frameIndex);
1883
+ }
1884
+
1885
+ updateHistoryPointCloud(pointCloud, frameIndex) {
1886
+ const positions = pointCloud.geometry.attributes.position.array;
1887
+ const colors = pointCloud.geometry.attributes.color.array;
1888
+
1889
+ const rgbVideo = this.data.rgb_video;
1890
+ const depthsRgb = this.data.depths_rgb;
1891
+ const intrinsics = this.data.intrinsics;
1892
+ const invExtrinsics = this.data.inv_extrinsics;
1893
+
1894
+ const width = this.config.resolution[0];
1895
+ const height = this.config.resolution[1];
1896
+ const numPoints = width * height;
1897
+
1898
+ const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
1899
+ const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
1900
+
1901
+ const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
1902
+ const transform = this.getTransformElements(invExtrMat);
1903
+
1904
+ const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
1905
+ const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
1906
+
1907
+ const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
1908
+
1909
+ let validPointCount = 0;
1910
+
1911
+ for (let i = 0; i < numPoints; i++) {
1912
+ const xPix = i % width;
1913
+ const yPix = Math.floor(i / width);
1914
+
1915
+ const d0 = depthFrame[i * 3];
1916
+ const d1 = depthFrame[i * 3 + 1];
1917
+ const depthEncoded = d0 | (d1 << 8);
1918
+ const depthValue = (depthEncoded / ((1 << 16) - 1)) *
1919
+ (this.config.depthRange[1] - this.config.depthRange[0]) +
1920
+ this.config.depthRange[0];
1921
+
1922
+ if (depthValue === 0 || depthValue > maxDepth) {
1923
+ continue;
1924
+ }
1925
+
1926
+ const X = ((xPix - cx) * depthValue) / fx;
1927
+ const Y = ((yPix - cy) * depthValue) / fy;
1928
+ const Z = depthValue;
1929
+
1930
+ const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
1931
+ const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
1932
+ const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
1933
+
1934
+ const index = validPointCount * 3;
1935
+ positions[index] = tx;
1936
+ positions[index + 1] = -ty;
1937
+ positions[index + 2] = -tz;
1938
+
1939
+ colors[index] = rgbFrame[i * 3] / 255;
1940
+ colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
1941
+ colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
1942
+
1943
+ validPointCount++;
1944
+ }
1945
+
1946
+ pointCloud.geometry.setDrawRange(0, validPointCount);
1947
+ pointCloud.geometry.attributes.position.needsUpdate = true;
1948
+ pointCloud.geometry.attributes.color.needsUpdate = true;
1949
+ }
1950
+
1951
+ createHistoryTrajectories(frameIndex) {
1952
+ if (!this.data.trajectories) return;
1953
+
1954
+ const trajectoryData = this.data.trajectories.data;
1955
+ const [totalFrames, numTrajectories] = this.data.trajectories.shape;
1956
+ const palette = this.createColorPalette(numTrajectories);
1957
+
1958
+ const historyTrajectoryGroup = new THREE.Group();
1959
+
1960
+ for (let i = 0; i < numTrajectories; i++) {
1961
+ const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
1962
+ const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
1963
+ const sphereMaterial = new THREE.MeshBasicMaterial({
1964
+ color: palette[i],
1965
+ transparent: true,
1966
+ opacity: 0.3 // Transparent for history
1967
+ });
1968
+ const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
1969
+
1970
+ const currentOffset = (frameIndex * numTrajectories + i) * 3;
1971
+ positionMarker.position.set(
1972
+ trajectoryData[currentOffset],
1973
+ -trajectoryData[currentOffset + 1],
1974
+ -trajectoryData[currentOffset + 2]
1975
+ );
1976
+
1977
+ historyTrajectoryGroup.add(positionMarker);
1978
+ }
1979
+
1980
+ this.scene.add(historyTrajectoryGroup);
1981
+ this.historyTrajectories.push(historyTrajectoryGroup);
1982
+ }
1983
+
1984
+ clearHistory() {
1985
+ // Clear history point clouds
1986
+ this.historyPointClouds.forEach(pointCloud => {
1987
+ if (pointCloud.geometry) pointCloud.geometry.dispose();
1988
+ if (pointCloud.material) pointCloud.material.dispose();
1989
+ this.scene.remove(pointCloud);
1990
+ });
1991
+ this.historyPointClouds = [];
1992
+
1993
+ // Clear history trajectories
1994
+ this.historyTrajectories.forEach(trajectoryGroup => {
1995
+ trajectoryGroup.children.forEach(child => {
1996
+ if (child.geometry) child.geometry.dispose();
1997
+ if (child.material) child.material.dispose();
1998
+ });
1999
+ this.scene.remove(trajectoryGroup);
2000
+ });
2001
+ this.historyTrajectories = [];
2002
+
2003
+ this.historyFrames = [];
2004
+ }
2005
+
2006
+ arraysEqual(a, b) {
2007
+ if (a.length !== b.length) return false;
2008
+ for (let i = 0; i < a.length; i++) {
2009
+ if (a[i] !== b[i]) return false;
2010
+ }
2011
+ return true;
2012
+ }
2013
+
2014
+ toggleBackground() {
2015
+ const isWhiteBackground = this.ui.whiteBackground.checked;
2016
+
2017
+ if (isWhiteBackground) {
2018
+ // Switch to white background
2019
+ document.body.style.backgroundColor = '#ffffff';
2020
+ this.scene.background = new THREE.Color(0xffffff);
2021
+
2022
+ // Update UI elements for white background
2023
+ document.documentElement.style.setProperty('--bg', '#ffffff');
2024
+ document.documentElement.style.setProperty('--text', '#333333');
2025
+ document.documentElement.style.setProperty('--text-secondary', '#666666');
2026
+ document.documentElement.style.setProperty('--border', '#cccccc');
2027
+ document.documentElement.style.setProperty('--surface', '#f5f5f5');
2028
+ document.documentElement.style.setProperty('--shadow', 'rgba(0, 0, 0, 0.1)');
2029
+ document.documentElement.style.setProperty('--shadow-hover', 'rgba(0, 0, 0, 0.2)');
2030
+
2031
+ // Update status bar and control panel backgrounds
2032
+ this.ui.statusBar.style.background = 'rgba(245, 245, 245, 0.9)';
2033
+ this.ui.statusBar.style.color = '#333333';
2034
+
2035
+ const controlPanel = document.getElementById('control-panel');
2036
+ if (controlPanel) {
2037
+ controlPanel.style.background = 'rgba(245, 245, 245, 0.95)';
2038
+ }
2039
+
2040
+ const settingsPanel = document.getElementById('settings-panel');
2041
+ if (settingsPanel) {
2042
+ settingsPanel.style.background = 'rgba(245, 245, 245, 0.98)';
2043
+ }
2044
+
2045
+ } else {
2046
+ // Switch back to dark background
2047
+ document.body.style.backgroundColor = '#1a1a1a';
2048
+ this.scene.background = new THREE.Color(0x1a1a1a);
2049
+
2050
+ // Restore original dark theme variables
2051
+ document.documentElement.style.setProperty('--bg', '#1a1a1a');
2052
+ document.documentElement.style.setProperty('--text', '#e0e0e0');
2053
+ document.documentElement.style.setProperty('--text-secondary', '#a0a0a0');
2054
+ document.documentElement.style.setProperty('--border', '#444444');
2055
+ document.documentElement.style.setProperty('--surface', '#2c2c2c');
2056
+ document.documentElement.style.setProperty('--shadow', 'rgba(0, 0, 0, 0.2)');
2057
+ document.documentElement.style.setProperty('--shadow-hover', 'rgba(0, 0, 0, 0.3)');
2058
+
2059
+ // Restore original UI backgrounds
2060
+ this.ui.statusBar.style.background = 'rgba(30, 30, 30, 0.9)';
2061
+ this.ui.statusBar.style.color = '#e0e0e0';
2062
+
2063
+ const controlPanel = document.getElementById('control-panel');
2064
+ if (controlPanel) {
2065
+ controlPanel.style.background = 'rgba(44, 44, 44, 0.95)';
2066
+ }
2067
+
2068
+ const settingsPanel = document.getElementById('settings-panel');
2069
+ if (settingsPanel) {
2070
+ settingsPanel.style.background = 'rgba(44, 44, 44, 0.98)';
2071
+ }
2072
+ }
2073
+
2074
+ // Show status message
2075
+ this.ui.statusBar.textContent = isWhiteBackground ? "Switched to white background" : "Switched to dark background";
2076
+ this.ui.statusBar.classList.remove('hidden');
2077
+
2078
+ setTimeout(() => {
2079
+ this.ui.statusBar.classList.add('hidden');
2080
+ }, 2000);
2081
+ }
2082
+
2083
  resetSettings() {
2084
  if (!this.defaultSettings) return;
2085
 
2086
  this.applyDefaultSettings();
2087
 
2088
+ // Reset background to dark theme
2089
+ if (this.ui.whiteBackground) {
2090
+ this.ui.whiteBackground.checked = false;
2091
+ this.toggleBackground();
2092
+ }
2093
+
2094
  this.updatePointCloudSettings();
2095
  this.updateTrajectorySettings();
2096
  this.updateFrustumDimensions();
2097
 
2098
+ // Clear history when resetting settings
2099
+ this.clearHistory();
2100
+
2101
  this.ui.statusBar.textContent = "Settings reset to defaults";
2102
  this.ui.statusBar.classList.remove('hidden');
2103
 
app.py CHANGED
@@ -26,6 +26,9 @@ import logging
26
  from concurrent.futures import ThreadPoolExecutor
27
  import atexit
28
  import uuid
 
 
 
29
 
30
  # Configure logging
31
  logging.basicConfig(level=logging.INFO)
@@ -78,20 +81,15 @@ def create_user_temp_dir():
78
  return temp_dir
79
 
80
  from huggingface_hub import hf_hub_download
81
- # init the model
82
- os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
83
 
84
- if os.environ.get("VGGT_DIR", None) is not None:
85
- from models.vggt.vggt.models.vggt_moe import VGGT_MoE
86
- from models.vggt.vggt.utils.load_fn import preprocess_image
87
- vggt_model = VGGT_MoE()
88
- vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
89
- vggt_model.eval()
90
- vggt_model = vggt_model.to("cuda")
91
 
92
  # Global model initialization
93
  print("🚀 Initializing local models...")
94
- tracker_model, _ = get_tracker_predictor(".", vo_points=756)
 
95
  predictor = get_sam_predictor()
96
  print("✅ Models loaded successfully!")
97
 
@@ -131,7 +129,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
131
  print("Initializing tracker models inside GPU function...")
132
  out_dir = os.path.join(temp_dir, "results")
133
  os.makedirs(out_dir, exist_ok=True)
134
- tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
 
135
 
136
  # Setup paths
137
  video_path = os.path.join(temp_dir, f"{video_name}.mp4")
@@ -161,25 +160,23 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
161
  data_npz_load = {}
162
 
163
  # run vggt
164
- if os.environ.get("VGGT_DIR", None) is not None:
165
- # process the image tensor
166
- video_tensor = preprocess_image(video_tensor)[None]
167
- with torch.no_grad():
168
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
169
- # Predict attributes including cameras, depth maps, and point maps.
170
- predictions = vggt_model(video_tensor.cuda()/255)
171
- extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
172
- depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
173
-
174
- depth_tensor = depth_map.squeeze().cpu().numpy()
175
- extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
176
- extrs = extrinsic.squeeze().cpu().numpy()
177
- intrs = intrinsic.squeeze().cpu().numpy()
178
- video_tensor = video_tensor.squeeze()
179
- #NOTE: 20% of the depth is not reliable
180
- # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
181
- unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
182
 
 
 
 
 
 
 
 
 
183
  # Load and process mask
184
  if os.path.exists(mask_path):
185
  mask = cv2.imread(mask_path)
@@ -201,7 +198,6 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
201
 
202
  query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
203
  print(f"Query points shape: {query_xyt.shape}")
204
-
205
  # Run model inference
206
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
207
  (
@@ -212,8 +208,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
212
  queries=query_xyt,
213
  fps=1, full_point=False, iters_track=4,
214
  query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
215
- support_frame=len(video_tensor)-1, replace_ratio=0.2)
216
-
217
  # Resize results to avoid large I/O
218
  max_size = 224
219
  h, w = video.shape[2:]
@@ -1117,7 +1113,7 @@ if __name__ == "__main__":
1117
  demo.launch(
1118
  server_name="0.0.0.0",
1119
  server_port=7860,
1120
- share=True,
1121
  debug=True,
1122
  show_error=True
1123
  )
 
26
  from concurrent.futures import ThreadPoolExecutor
27
  import atexit
28
  import uuid
29
+ from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
30
+ from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
31
+ from models.SpaTrackV2.models.predictor import Predictor
32
 
33
  # Configure logging
34
  logging.basicConfig(level=logging.INFO)
 
81
  return temp_dir
82
 
83
  from huggingface_hub import hf_hub_download
 
 
84
 
85
+ vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
86
+ vggt4track_model.eval()
87
+ vggt4track_model = vggt4track_model.to("cuda")
 
 
 
 
88
 
89
  # Global model initialization
90
  print("🚀 Initializing local models...")
91
+ tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
92
+ tracker_model.eval()
93
  predictor = get_sam_predictor()
94
  print("✅ Models loaded successfully!")
95
 
 
129
  print("Initializing tracker models inside GPU function...")
130
  out_dir = os.path.join(temp_dir, "results")
131
  os.makedirs(out_dir, exist_ok=True)
132
+ tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
133
+ tracker_model=tracker_model.cuda())
134
 
135
  # Setup paths
136
  video_path = os.path.join(temp_dir, f"{video_name}.mp4")
 
160
  data_npz_load = {}
161
 
162
  # run vggt
163
+ # process the image tensor
164
+ video_tensor = preprocess_image(video_tensor)[None]
165
+ with torch.no_grad():
166
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
167
+ # Predict attributes including cameras, depth maps, and point maps.
168
+ predictions = vggt4track_model(video_tensor.cuda()/255)
169
+ extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
170
+ depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
 
 
 
 
 
 
 
 
 
 
171
 
172
+ depth_tensor = depth_map.squeeze().cpu().numpy()
173
+ extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
174
+ extrs = extrinsic.squeeze().cpu().numpy()
175
+ intrs = intrinsic.squeeze().cpu().numpy()
176
+ video_tensor = video_tensor.squeeze()
177
+ #NOTE: 20% of the depth is not reliable
178
+ # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
179
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
180
  # Load and process mask
181
  if os.path.exists(mask_path):
182
  mask = cv2.imread(mask_path)
 
198
 
199
  query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
200
  print(f"Query points shape: {query_xyt.shape}")
 
201
  # Run model inference
202
  with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
203
  (
 
208
  queries=query_xyt,
209
  fps=1, full_point=False, iters_track=4,
210
  query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
211
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
212
+
213
  # Resize results to avoid large I/O
214
  max_size = 224
215
  h, w = video.shape[2:]
 
1113
  demo.launch(
1114
  server_name="0.0.0.0",
1115
  server_port=7860,
1116
+ share=False,
1117
  debug=True,
1118
  show_error=True
1119
  )
app_3rd/spatrack_utils/infer_track.py CHANGED
@@ -20,7 +20,7 @@ from huggingface_hub import hf_hub_download
20
 
21
  config = {
22
  "ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
23
- "cfg_dir": "config/magic_infer_moge.yaml",
24
  }
25
 
26
  def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
 
20
 
21
  config = {
22
  "ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
23
+ "cfg_dir": "config/magic_infer_offline.yaml",
24
  }
25
 
26
  def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
models/SpaTrackV2/models/SpaTrack.py CHANGED
@@ -40,6 +40,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
40
  resolution=518,
41
  max_len=600, # the maximum video length we can preprocess,
42
  track_num=768,
 
43
  ):
44
 
45
  self.chunk_size = chunk_size
@@ -51,26 +52,29 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
51
  backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
52
 
53
  super(SpaTrack2, self).__init__()
54
- if os.path.exists(backbone_ckpt_dir)==False:
55
- base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
 
 
 
 
 
56
  else:
57
- checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
58
- base_model = MoGeModel(**checkpoint["model_config"])
59
- base_model.load_state_dict(checkpoint['model'])
60
  # avoid the base_model is a member of SpaTrack2
61
  object.__setattr__(self, 'base_model', base_model)
62
 
63
  # Tracker model
64
  self.Track3D = TrackRefiner3D(Track_cfg)
65
- track_base_ckpt_dir = Track_cfg.base_ckpt
66
  if os.path.exists(track_base_ckpt_dir):
67
  track_pretrain = torch.load(track_base_ckpt_dir)
68
  self.Track3D.load_state_dict(track_pretrain, strict=False)
69
 
70
  # wrap the function of make lora trainable
71
  self.make_paras_trainable = partial(self.make_paras_trainable,
72
- mode=ft_cfg.mode,
73
- paras_name=ft_cfg.paras_name)
74
  self.track_num = track_num
75
 
76
  def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
@@ -145,7 +149,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
145
  ):
146
  # step 1 allocate the query points on the grid
147
  T, C, H, W = video.shape
148
-
149
  if annots_train is not None:
150
  vis_gt = annots_train["vis"]
151
  _, _, N = vis_gt.shape
@@ -296,39 +300,6 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
296
  **kwargs, annots=annots)
297
  if self.training:
298
  loss += out["loss"].squeeze()
299
- # from models.SpaTrackV2.utils.visualizer import Visualizer
300
- # vis_track = Visualizer(grayscale=False,
301
- # fps=10, pad_value=50, tracks_leave_trace=0)
302
- # vis_track.visualize(video=segment,
303
- # tracks=out["traj_est"][...,:2],
304
- # visibility=out["vis_est"],
305
- # save_video=True)
306
- # # visualize 4d
307
- # import os, json
308
- # import os.path as osp
309
- # viser4d_dir = os.path.join("viser_4d_results")
310
- # os.makedirs(viser4d_dir, exist_ok=True)
311
- # depth_est = annots["depth_gt"][0]
312
- # unc_metric = out["unc_metric"]
313
- # mask = (unc_metric > 0.5).squeeze(1)
314
- # # pose_est = out["poses_pred"].squeeze(0)
315
- # pose_est = annots["traj_mat"][0]
316
- # rgb_tracks = out["rgb_tracks"].squeeze(0)
317
- # intrinsics = out["intrs"].squeeze(0)
318
- # for i_k in range(out["depth"].shape[0]):
319
- # img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
320
- # img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
321
- # cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
322
- # if stage == 1:
323
- # depth = depth_est[i_k].squeeze().cpu().numpy()
324
- # np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
325
- # else:
326
- # point_map_vis = out["points_map"][i_k].cpu().numpy()
327
- # np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
328
- # np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
329
- # np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
330
- # np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
331
- # np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
332
 
333
  queries_len = len(queries_new)
334
  # update the track3d and track2d
@@ -720,40 +691,3 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
720
  }
721
 
722
  return ret
723
-
724
-
725
-
726
-
727
- # three stages of training
728
-
729
- # stage 1:
730
- # gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
731
- # Tracking and Pose as well -> based on gt depth and intrinsics
732
- # (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
733
-
734
- # stage 2: fixed 3D tracking
735
- # Joint depth refiner
736
- # input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
737
- # estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
738
- # ongoing two days
739
-
740
- # stage 3: train multi windows by propagation
741
- # 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
742
-
743
- # types of scenarioes:
744
- # 1. auto driving (waymo open dataset)
745
- # 2. robot
746
- # 3. internet ego video
747
-
748
-
749
-
750
- # Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
751
- # Update Variables:
752
- # 1. 3D tracks B T N 3 xyz.
753
- # 2. 2D tracks B T N 2 x y.
754
- # 3. Dynamic Mask B T H W.
755
- # 4. Camera Pose B T 4 4.
756
- # 5. Video Depth.
757
-
758
- # (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
759
- # Campatiablity by product.
 
40
  resolution=518,
41
  max_len=600, # the maximum video length we can preprocess,
42
  track_num=768,
43
+ moge_as_base=False,
44
  ):
45
 
46
  self.chunk_size = chunk_size
 
52
  backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
53
 
54
  super(SpaTrack2, self).__init__()
55
+ if moge_as_base:
56
+ if os.path.exists(backbone_ckpt_dir)==False:
57
+ base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
58
+ else:
59
+ checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
60
+ base_model = MoGeModel(**checkpoint["model_config"])
61
+ base_model.load_state_dict(checkpoint['model'])
62
  else:
63
+ base_model = None
 
 
64
  # avoid the base_model is a member of SpaTrack2
65
  object.__setattr__(self, 'base_model', base_model)
66
 
67
  # Tracker model
68
  self.Track3D = TrackRefiner3D(Track_cfg)
69
+ track_base_ckpt_dir = Track_cfg["base_ckpt"]
70
  if os.path.exists(track_base_ckpt_dir):
71
  track_pretrain = torch.load(track_base_ckpt_dir)
72
  self.Track3D.load_state_dict(track_pretrain, strict=False)
73
 
74
  # wrap the function of make lora trainable
75
  self.make_paras_trainable = partial(self.make_paras_trainable,
76
+ mode=ft_cfg["mode"],
77
+ paras_name=ft_cfg["paras_name"])
78
  self.track_num = track_num
79
 
80
  def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
 
149
  ):
150
  # step 1 allocate the query points on the grid
151
  T, C, H, W = video.shape
152
+
153
  if annots_train is not None:
154
  vis_gt = annots_train["vis"]
155
  _, _, N = vis_gt.shape
 
300
  **kwargs, annots=annots)
301
  if self.training:
302
  loss += out["loss"].squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  queries_len = len(queries_new)
305
  # update the track3d and track2d
 
691
  }
692
 
693
  return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/SpaTrackV2/models/predictor.py CHANGED
@@ -16,80 +16,20 @@ from typing import Union, Optional
16
  import cv2
17
  import os
18
  import decord
 
19
 
20
- class Predictor(torch.nn.Module):
21
  def __init__(self, args=None):
22
  super().__init__()
23
  self.args = args
24
  self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
25
- self.S_wind = args.Track_cfg.s_wind
26
- self.overlap = args.Track_cfg.overlap
27
 
28
  def to(self, device: Union[str, torch.device]):
29
  self.spatrack.to(device)
30
- self.spatrack.base_model.to(device)
31
-
32
- @classmethod
33
- def from_pretrained(
34
- cls,
35
- pretrained_model_name_or_path: Union[str, Path],
36
- *,
37
- force_download: bool = False,
38
- cache_dir: Optional[str] = None,
39
- device: Optional[Union[str, torch.device]] = None,
40
- model_cfg: Optional[dict] = None,
41
- **kwargs,
42
- ) -> "SpaTrack2":
43
- """
44
- Load a pretrained model from a local file or a remote repository.
45
-
46
- Args:
47
- pretrained_model_name_or_path (str or Path):
48
- - Path to a local model file (e.g., `./model.pth`).
49
- - HuggingFace Hub model ID (e.g., `username/model-name`).
50
- force_download (bool, optional):
51
- Whether to force re-download even if cached. Default: False.
52
- cache_dir (str, optional):
53
- Custom cache directory. Default: None (use default cache).
54
- device (str or torch.device, optional):
55
- Target device (e.g., "cuda", "cpu"). Default: None (keep original).
56
- **kwargs:
57
- Additional config overrides.
58
-
59
- Returns:
60
- SpaTrack2: Loaded pretrained model.
61
- """
62
- # (1) check the path is local or remote
63
- if isinstance(pretrained_model_name_or_path, Path):
64
- model_path = str(pretrained_model_name_or_path)
65
- else:
66
- model_path = pretrained_model_name_or_path
67
- # (2) if the path is remote, download it
68
- if not os.path.exists(model_path):
69
- raise NotImplementedError("Remote download not implemented yet. Use a local path.")
70
- # (3) load the model weights
71
-
72
- state_dict = torch.load(model_path, map_location="cpu")
73
- # (4) initialize the model (can load config.json if exists)
74
- config_path = os.path.join(os.path.dirname(model_path), "config.json")
75
- config = {}
76
- if os.path.exists(config_path):
77
- import json
78
- with open(config_path, "r") as f:
79
- config.update(json.load(f))
80
- config.update(kwargs) # allow override the config
81
- if model_cfg is not None:
82
- config = model_cfg
83
- model = cls(config)
84
- if "model" in state_dict:
85
- model.spatrack.load_state_dict(state_dict["model"], strict=False)
86
- else:
87
- model.spatrack.load_state_dict(state_dict, strict=False)
88
- # (5) device management
89
- if device is not None:
90
- model.to(device)
91
-
92
- return model
93
 
94
  def forward(self, video: str|torch.Tensor|np.ndarray,
95
  depth: str|torch.Tensor|np.ndarray=None,
@@ -145,7 +85,6 @@ class Predictor(torch.nn.Module):
145
  window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
146
  fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
147
 
148
-
149
  return ret
150
 
151
 
 
16
  import cv2
17
  import os
18
  import decord
19
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
20
 
21
+ class Predictor(torch.nn.Module, PyTorchModelHubMixin):
22
  def __init__(self, args=None):
23
  super().__init__()
24
  self.args = args
25
  self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
26
+ self.S_wind = args["Track_cfg"]["s_wind"]
27
+ self.overlap = args["Track_cfg"]["overlap"]
28
 
29
  def to(self, device: Union[str, torch.device]):
30
  self.spatrack.to(device)
31
+ if self.spatrack.base_model is not None:
32
+ self.spatrack.base_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def forward(self, video: str|torch.Tensor|np.ndarray,
35
  depth: str|torch.Tensor|np.ndarray=None,
 
85
  window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
86
  fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
87
 
 
88
  return ret
89
 
90
 
models/SpaTrackV2/models/tracker3D/TrackRefiner.py CHANGED
@@ -24,14 +24,13 @@ from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import (
24
  )
25
  from models.SpaTrackV2.models.tracker3D.spatrack_modules.ba import extract_static_from_3DTracks, ba_pycolmap
26
  from models.SpaTrackV2.models.tracker3D.spatrack_modules.pointmap_updator import PointMapUpdator
27
- from models.SpaTrackV2.models.depth_refiner.depth_refiner import TrackStablizer
28
  from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import affine_invariant_global_loss
29
  from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import UpsampleTransformerAlibi
30
 
31
  class TrackRefiner3D(CoTrackerThreeOffline):
32
 
33
  def __init__(self, args=None):
34
- super().__init__(**args.base)
35
 
36
  """
37
  This is 3D warpper from cotracker, which load the cotracker pretrain and
@@ -47,15 +46,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
47
  self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
48
  # get the anchor point's embedding, and init the pts refiner
49
  update_pts = True
50
- # self.corr_transformer = nn.ModuleList([
51
- # CorrPointformer(
52
- # dim=128,
53
- # num_heads=8,
54
- # head_dim=128 // 8,
55
- # mlp_ratio=4.0,
56
- # )
57
- # for _ in range(self.corr_levels)
58
- # ])
59
  self.corr_transformer = nn.ModuleList([
60
  CorrPointformer(
61
  dim=128,
@@ -68,29 +59,11 @@ class TrackRefiner3D(CoTrackerThreeOffline):
68
  self.fnet = BasicEncoder(input_dim=3,
69
  output_dim=self.latent_dim, stride=self.stride)
70
  self.corr3d_radius = 3
71
-
72
- if args.stablizer:
73
- self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
74
- self.upsample_kernel_size = 5
75
- self.residual_embedding = nn.Parameter(torch.randn(
76
- self.latent_dim, self.model_resolution[0]//16,
77
- self.model_resolution[1]//16, requires_grad=True))
78
- self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
79
- self.upsample_factor = 4
80
- self.upsample_transformer = UpsampleTransformerAlibi(
81
- kernel_size=self.upsample_kernel_size, # kernel_size=3, #
82
- stride=self.stride,
83
- latent_dim=self.latent_dim,
84
- num_attn_blocks=2,
85
- upsample_factor=4,
86
- )
87
- else:
88
- self.update_pointmap = None
89
 
90
- self.mode = args.mode
91
  if self.mode == "online":
92
- self.s_wind = args.s_wind
93
- self.overlap = args.overlap
94
 
95
  def upsample_with_mask(
96
  self, inp: torch.Tensor, mask: torch.Tensor
@@ -1062,29 +1035,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
1062
  vis_est = (vis_est>0.5).float()
1063
  sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
1064
  # coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
1065
- # if torch.isnan(coords_proj_curr).sum()>0:
1066
- # import pdb; pdb.set_trace()
1067
-
1068
- if False:
1069
- point_map_resize = point_map.clone().view(B, T, 3, H, W)
1070
- update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
1071
- coords_append_resize = coords.clone().detach()
1072
- coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
1073
- update_track_input = self.norm_xyz(cam_pts_est)*5
1074
- update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
1075
- update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
1076
- update = self.update_pointmap.stablizer(update_input,
1077
- update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
1078
- #NOTE: update the point map
1079
- point_map_resize += update
1080
- point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
1081
- size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
1082
- point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
1083
- point_map_preds.append(self.denorm_xyz(point_map_refine_out))
1084
- point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
1085
-
1086
- # if torch.isnan(coords).sum()>0:
1087
- # import pdb; pdb.set_trace()
1088
  #NOTE: the 2d tracking + unproject depth
1089
  fix_cam_est = coords_append.clone()
1090
  fix_cam_est[...,2] = depth_unproj
 
24
  )
25
  from models.SpaTrackV2.models.tracker3D.spatrack_modules.ba import extract_static_from_3DTracks, ba_pycolmap
26
  from models.SpaTrackV2.models.tracker3D.spatrack_modules.pointmap_updator import PointMapUpdator
 
27
  from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import affine_invariant_global_loss
28
  from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import UpsampleTransformerAlibi
29
 
30
  class TrackRefiner3D(CoTrackerThreeOffline):
31
 
32
  def __init__(self, args=None):
33
+ super().__init__(**args["base"])
34
 
35
  """
36
  This is 3D warpper from cotracker, which load the cotracker pretrain and
 
46
  self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
47
  # get the anchor point's embedding, and init the pts refiner
48
  update_pts = True
49
+
 
 
 
 
 
 
 
 
50
  self.corr_transformer = nn.ModuleList([
51
  CorrPointformer(
52
  dim=128,
 
59
  self.fnet = BasicEncoder(input_dim=3,
60
  output_dim=self.latent_dim, stride=self.stride)
61
  self.corr3d_radius = 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ self.mode = args["mode"]
64
  if self.mode == "online":
65
+ self.s_wind = args["s_wind"]
66
+ self.overlap = args["overlap"]
67
 
68
  def upsample_with_mask(
69
  self, inp: torch.Tensor, mask: torch.Tensor
 
1035
  vis_est = (vis_est>0.5).float()
1036
  sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
1037
  # coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
1038
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1039
  #NOTE: the 2d tracking + unproject depth
1040
  fix_cam_est = coords_append.clone()
1041
  fix_cam_est[...,2] = depth_unproj
requirements.txt CHANGED
@@ -22,8 +22,8 @@ git+https://github.com/facebookresearch/segment-anything.git
22
  git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d
23
  huggingface_hub
24
  pyceres
25
- kornia
26
- xformers
27
  timm
28
  PyJWT
29
  gdown
 
22
  git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d
23
  huggingface_hub
24
  pyceres
25
+ kornia==0.8.1
26
+ xformers==0.0.28
27
  timm
28
  PyJWT
29
  gdown