Spaces:
Running
on
Zero
Running
on
Zero
xiaoyuxi
commited on
Commit
·
cd14f82
1
Parent(s):
54c1d7b
support HubMixin
Browse files- .gitignore +2 -19
- README.md +1 -1
- _viz/viz_template.html +338 -1
- app.py +29 -33
- app_3rd/spatrack_utils/infer_track.py +1 -1
- models/SpaTrackV2/models/SpaTrack.py +13 -79
- models/SpaTrackV2/models/predictor.py +6 -67
- models/SpaTrackV2/models/tracker3D/TrackRefiner.py +6 -55
- requirements.txt +2 -2
.gitignore
CHANGED
@@ -23,47 +23,30 @@ __pycache__/
|
|
23 |
/**/**/__pycache__
|
24 |
/**/__pycache__
|
25 |
|
26 |
-
outputs
|
27 |
-
scripts/lauch_exp/config
|
28 |
-
scripts/lauch_exp/submit_job.log
|
29 |
-
scripts/lauch_exp/hydra_output
|
30 |
-
scripts/lauch_wulan
|
31 |
-
scripts/custom_video
|
32 |
# ignore the visualizer
|
33 |
viser
|
34 |
viser_result
|
35 |
benchmark/results
|
36 |
benchmark
|
37 |
|
38 |
-
ossutil_output
|
39 |
-
|
40 |
prev_version
|
41 |
spat_ceres
|
42 |
wandb
|
43 |
*.log
|
44 |
seg_target.py
|
45 |
|
46 |
-
eval_davis.py
|
47 |
-
eval_multiple_gpu.py
|
48 |
-
eval_pose_scan.py
|
49 |
-
eval_single_gpu.py
|
50 |
-
|
51 |
infer_cam.py
|
52 |
infer_stream.py
|
53 |
|
54 |
*.egg-info/
|
55 |
**/*.egg-info
|
56 |
|
57 |
-
eval_kinectics.py
|
58 |
-
models/SpaTrackV2/datasets
|
59 |
|
60 |
-
scripts
|
61 |
config/fix_2d.yaml
|
62 |
|
63 |
-
models/SpaTrackV2/datasets
|
64 |
-
scripts/
|
65 |
|
66 |
models/**/build
|
67 |
models/**/dist
|
68 |
|
69 |
-
temp_local
|
|
|
|
23 |
/**/**/__pycache__
|
24 |
/**/__pycache__
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# ignore the visualizer
|
27 |
viser
|
28 |
viser_result
|
29 |
benchmark/results
|
30 |
benchmark
|
31 |
|
|
|
|
|
32 |
prev_version
|
33 |
spat_ceres
|
34 |
wandb
|
35 |
*.log
|
36 |
seg_target.py
|
37 |
|
|
|
|
|
|
|
|
|
|
|
38 |
infer_cam.py
|
39 |
infer_stream.py
|
40 |
|
41 |
*.egg-info/
|
42 |
**/*.egg-info
|
43 |
|
|
|
|
|
44 |
|
|
|
45 |
config/fix_2d.yaml
|
46 |
|
|
|
|
|
47 |
|
48 |
models/**/build
|
49 |
models/**/dist
|
50 |
|
51 |
+
temp_local
|
52 |
+
examples/results
|
README.md
CHANGED
@@ -11,4 +11,4 @@ license: mit
|
|
11 |
short_description: Official Space for SpatialTrackerV2
|
12 |
---
|
13 |
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
11 |
short_description: Official Space for SpatialTrackerV2
|
12 |
---
|
13 |
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
_viz/viz_template.html
CHANGED
@@ -671,6 +671,38 @@
|
|
671 |
</div>
|
672 |
</div>
|
673 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
<div class="settings-group">
|
675 |
<div class="btn-group">
|
676 |
<button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
|
@@ -739,7 +771,10 @@
|
|
739 |
showCameraFrustum: document.getElementById('show-camera-frustum'),
|
740 |
frustumSize: document.getElementById('frustum-size'),
|
741 |
hideSettingsBtn: document.getElementById('hide-settings-btn'),
|
742 |
-
showSettingsBtn: document.getElementById('show-settings-btn')
|
|
|
|
|
|
|
743 |
};
|
744 |
|
745 |
this.scene = null;
|
@@ -750,6 +785,12 @@
|
|
750 |
this.trajectories = [];
|
751 |
this.cameraFrustum = null;
|
752 |
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
this.initThreeJS();
|
754 |
this.loadDefaultSettings().then(() => {
|
755 |
this.initEventListeners();
|
@@ -977,6 +1018,28 @@
|
|
977 |
this.ui.showSettingsBtn.style.display = 'none';
|
978 |
});
|
979 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
980 |
}
|
981 |
|
982 |
makeElementDraggable(element) {
|
@@ -1296,6 +1359,9 @@
|
|
1296 |
|
1297 |
this.updateTrajectories(frameIndex);
|
1298 |
|
|
|
|
|
|
|
1299 |
const progress = (frameIndex + 1) / this.config.totalFrames;
|
1300 |
this.ui.progress.style.width = `${progress * 100}%`;
|
1301 |
|
@@ -1752,15 +1818,286 @@
|
|
1752 |
this.updateCameraFrustum(this.currentFrame);
|
1753 |
}
|
1754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1755 |
resetSettings() {
|
1756 |
if (!this.defaultSettings) return;
|
1757 |
|
1758 |
this.applyDefaultSettings();
|
1759 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1760 |
this.updatePointCloudSettings();
|
1761 |
this.updateTrajectorySettings();
|
1762 |
this.updateFrustumDimensions();
|
1763 |
|
|
|
|
|
|
|
1764 |
this.ui.statusBar.textContent = "Settings reset to defaults";
|
1765 |
this.ui.statusBar.classList.remove('hidden');
|
1766 |
|
|
|
671 |
</div>
|
672 |
</div>
|
673 |
|
674 |
+
<div class="settings-group">
|
675 |
+
<h3>Keep History</h3>
|
676 |
+
<div class="checkbox-container">
|
677 |
+
<label class="toggle-switch">
|
678 |
+
<input type="checkbox" id="enable-keep-history">
|
679 |
+
<span class="toggle-slider"></span>
|
680 |
+
</label>
|
681 |
+
<label for="enable-keep-history">Enable Keep History</label>
|
682 |
+
</div>
|
683 |
+
<div class="slider-container">
|
684 |
+
<label for="history-stride">Stride</label>
|
685 |
+
<select id="history-stride">
|
686 |
+
<option value="1">1</option>
|
687 |
+
<option value="2">2</option>
|
688 |
+
<option value="5" selected>5</option>
|
689 |
+
<option value="10">10</option>
|
690 |
+
<option value="20">20</option>
|
691 |
+
</select>
|
692 |
+
</div>
|
693 |
+
</div>
|
694 |
+
|
695 |
+
<div class="settings-group">
|
696 |
+
<h3>Background</h3>
|
697 |
+
<div class="checkbox-container">
|
698 |
+
<label class="toggle-switch">
|
699 |
+
<input type="checkbox" id="white-background">
|
700 |
+
<span class="toggle-slider"></span>
|
701 |
+
</label>
|
702 |
+
<label for="white-background">White Background</label>
|
703 |
+
</div>
|
704 |
+
</div>
|
705 |
+
|
706 |
<div class="settings-group">
|
707 |
<div class="btn-group">
|
708 |
<button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
|
|
|
771 |
showCameraFrustum: document.getElementById('show-camera-frustum'),
|
772 |
frustumSize: document.getElementById('frustum-size'),
|
773 |
hideSettingsBtn: document.getElementById('hide-settings-btn'),
|
774 |
+
showSettingsBtn: document.getElementById('show-settings-btn'),
|
775 |
+
enableKeepHistory: document.getElementById('enable-keep-history'),
|
776 |
+
historyStride: document.getElementById('history-stride'),
|
777 |
+
whiteBackground: document.getElementById('white-background')
|
778 |
};
|
779 |
|
780 |
this.scene = null;
|
|
|
785 |
this.trajectories = [];
|
786 |
this.cameraFrustum = null;
|
787 |
|
788 |
+
// Keep History functionality
|
789 |
+
this.historyPointClouds = [];
|
790 |
+
this.historyTrajectories = [];
|
791 |
+
this.historyFrames = [];
|
792 |
+
this.maxHistoryFrames = 20;
|
793 |
+
|
794 |
this.initThreeJS();
|
795 |
this.loadDefaultSettings().then(() => {
|
796 |
this.initEventListeners();
|
|
|
1018 |
this.ui.showSettingsBtn.style.display = 'none';
|
1019 |
});
|
1020 |
}
|
1021 |
+
|
1022 |
+
// Keep History event listeners
|
1023 |
+
if (this.ui.enableKeepHistory) {
|
1024 |
+
this.ui.enableKeepHistory.addEventListener('change', () => {
|
1025 |
+
if (!this.ui.enableKeepHistory.checked) {
|
1026 |
+
this.clearHistory();
|
1027 |
+
}
|
1028 |
+
});
|
1029 |
+
}
|
1030 |
+
|
1031 |
+
if (this.ui.historyStride) {
|
1032 |
+
this.ui.historyStride.addEventListener('change', () => {
|
1033 |
+
this.clearHistory();
|
1034 |
+
});
|
1035 |
+
}
|
1036 |
+
|
1037 |
+
// Background toggle event listener
|
1038 |
+
if (this.ui.whiteBackground) {
|
1039 |
+
this.ui.whiteBackground.addEventListener('change', () => {
|
1040 |
+
this.toggleBackground();
|
1041 |
+
});
|
1042 |
+
}
|
1043 |
}
|
1044 |
|
1045 |
makeElementDraggable(element) {
|
|
|
1359 |
|
1360 |
this.updateTrajectories(frameIndex);
|
1361 |
|
1362 |
+
// Keep History management
|
1363 |
+
this.updateHistory(frameIndex);
|
1364 |
+
|
1365 |
const progress = (frameIndex + 1) / this.config.totalFrames;
|
1366 |
this.ui.progress.style.width = `${progress * 100}%`;
|
1367 |
|
|
|
1818 |
this.updateCameraFrustum(this.currentFrame);
|
1819 |
}
|
1820 |
|
1821 |
+
// Keep History methods
|
1822 |
+
updateHistory(frameIndex) {
|
1823 |
+
if (!this.ui.enableKeepHistory.checked || !this.data) return;
|
1824 |
+
|
1825 |
+
const stride = parseInt(this.ui.historyStride.value);
|
1826 |
+
const newHistoryFrames = this.calculateHistoryFrames(frameIndex, stride);
|
1827 |
+
|
1828 |
+
// Check if history frames changed
|
1829 |
+
if (this.arraysEqual(this.historyFrames, newHistoryFrames)) return;
|
1830 |
+
|
1831 |
+
this.clearHistory();
|
1832 |
+
this.historyFrames = newHistoryFrames;
|
1833 |
+
|
1834 |
+
// Create history point clouds and trajectories
|
1835 |
+
this.historyFrames.forEach(historyFrame => {
|
1836 |
+
if (historyFrame !== frameIndex) {
|
1837 |
+
this.createHistoryPointCloud(historyFrame);
|
1838 |
+
this.createHistoryTrajectories(historyFrame);
|
1839 |
+
}
|
1840 |
+
});
|
1841 |
+
}
|
1842 |
+
|
1843 |
+
calculateHistoryFrames(currentFrame, stride) {
|
1844 |
+
const frames = [];
|
1845 |
+
let frame = 1; // Start from frame 1
|
1846 |
+
|
1847 |
+
while (frame <= currentFrame && frames.length < this.maxHistoryFrames) {
|
1848 |
+
frames.push(frame);
|
1849 |
+
frame += stride;
|
1850 |
+
}
|
1851 |
+
|
1852 |
+
// Always include current frame
|
1853 |
+
if (!frames.includes(currentFrame)) {
|
1854 |
+
frames.push(currentFrame);
|
1855 |
+
}
|
1856 |
+
|
1857 |
+
return frames.sort((a, b) => a - b);
|
1858 |
+
}
|
1859 |
+
|
1860 |
+
createHistoryPointCloud(frameIndex) {
|
1861 |
+
const numPoints = this.config.resolution[0] * this.config.resolution[1];
|
1862 |
+
const positions = new Float32Array(numPoints * 3);
|
1863 |
+
const colors = new Float32Array(numPoints * 3);
|
1864 |
+
|
1865 |
+
const geometry = new THREE.BufferGeometry();
|
1866 |
+
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
1867 |
+
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));
|
1868 |
+
|
1869 |
+
const material = new THREE.PointsMaterial({
|
1870 |
+
size: parseFloat(this.ui.pointSize.value),
|
1871 |
+
vertexColors: true,
|
1872 |
+
transparent: true,
|
1873 |
+
opacity: 0.5, // Transparent for history
|
1874 |
+
sizeAttenuation: true
|
1875 |
+
});
|
1876 |
+
|
1877 |
+
const historyPointCloud = new THREE.Points(geometry, material);
|
1878 |
+
this.scene.add(historyPointCloud);
|
1879 |
+
this.historyPointClouds.push(historyPointCloud);
|
1880 |
+
|
1881 |
+
// Update the history point cloud with data
|
1882 |
+
this.updateHistoryPointCloud(historyPointCloud, frameIndex);
|
1883 |
+
}
|
1884 |
+
|
1885 |
+
updateHistoryPointCloud(pointCloud, frameIndex) {
|
1886 |
+
const positions = pointCloud.geometry.attributes.position.array;
|
1887 |
+
const colors = pointCloud.geometry.attributes.color.array;
|
1888 |
+
|
1889 |
+
const rgbVideo = this.data.rgb_video;
|
1890 |
+
const depthsRgb = this.data.depths_rgb;
|
1891 |
+
const intrinsics = this.data.intrinsics;
|
1892 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
1893 |
+
|
1894 |
+
const width = this.config.resolution[0];
|
1895 |
+
const height = this.config.resolution[1];
|
1896 |
+
const numPoints = width * height;
|
1897 |
+
|
1898 |
+
const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
|
1899 |
+
const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
|
1900 |
+
|
1901 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
1902 |
+
const transform = this.getTransformElements(invExtrMat);
|
1903 |
+
|
1904 |
+
const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
|
1905 |
+
const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
|
1906 |
+
|
1907 |
+
const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
|
1908 |
+
|
1909 |
+
let validPointCount = 0;
|
1910 |
+
|
1911 |
+
for (let i = 0; i < numPoints; i++) {
|
1912 |
+
const xPix = i % width;
|
1913 |
+
const yPix = Math.floor(i / width);
|
1914 |
+
|
1915 |
+
const d0 = depthFrame[i * 3];
|
1916 |
+
const d1 = depthFrame[i * 3 + 1];
|
1917 |
+
const depthEncoded = d0 | (d1 << 8);
|
1918 |
+
const depthValue = (depthEncoded / ((1 << 16) - 1)) *
|
1919 |
+
(this.config.depthRange[1] - this.config.depthRange[0]) +
|
1920 |
+
this.config.depthRange[0];
|
1921 |
+
|
1922 |
+
if (depthValue === 0 || depthValue > maxDepth) {
|
1923 |
+
continue;
|
1924 |
+
}
|
1925 |
+
|
1926 |
+
const X = ((xPix - cx) * depthValue) / fx;
|
1927 |
+
const Y = ((yPix - cy) * depthValue) / fy;
|
1928 |
+
const Z = depthValue;
|
1929 |
+
|
1930 |
+
const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
|
1931 |
+
const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
|
1932 |
+
const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
|
1933 |
+
|
1934 |
+
const index = validPointCount * 3;
|
1935 |
+
positions[index] = tx;
|
1936 |
+
positions[index + 1] = -ty;
|
1937 |
+
positions[index + 2] = -tz;
|
1938 |
+
|
1939 |
+
colors[index] = rgbFrame[i * 3] / 255;
|
1940 |
+
colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
|
1941 |
+
colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
|
1942 |
+
|
1943 |
+
validPointCount++;
|
1944 |
+
}
|
1945 |
+
|
1946 |
+
pointCloud.geometry.setDrawRange(0, validPointCount);
|
1947 |
+
pointCloud.geometry.attributes.position.needsUpdate = true;
|
1948 |
+
pointCloud.geometry.attributes.color.needsUpdate = true;
|
1949 |
+
}
|
1950 |
+
|
1951 |
+
createHistoryTrajectories(frameIndex) {
|
1952 |
+
if (!this.data.trajectories) return;
|
1953 |
+
|
1954 |
+
const trajectoryData = this.data.trajectories.data;
|
1955 |
+
const [totalFrames, numTrajectories] = this.data.trajectories.shape;
|
1956 |
+
const palette = this.createColorPalette(numTrajectories);
|
1957 |
+
|
1958 |
+
const historyTrajectoryGroup = new THREE.Group();
|
1959 |
+
|
1960 |
+
for (let i = 0; i < numTrajectories; i++) {
|
1961 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
1962 |
+
const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
1963 |
+
const sphereMaterial = new THREE.MeshBasicMaterial({
|
1964 |
+
color: palette[i],
|
1965 |
+
transparent: true,
|
1966 |
+
opacity: 0.3 // Transparent for history
|
1967 |
+
});
|
1968 |
+
const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
|
1969 |
+
|
1970 |
+
const currentOffset = (frameIndex * numTrajectories + i) * 3;
|
1971 |
+
positionMarker.position.set(
|
1972 |
+
trajectoryData[currentOffset],
|
1973 |
+
-trajectoryData[currentOffset + 1],
|
1974 |
+
-trajectoryData[currentOffset + 2]
|
1975 |
+
);
|
1976 |
+
|
1977 |
+
historyTrajectoryGroup.add(positionMarker);
|
1978 |
+
}
|
1979 |
+
|
1980 |
+
this.scene.add(historyTrajectoryGroup);
|
1981 |
+
this.historyTrajectories.push(historyTrajectoryGroup);
|
1982 |
+
}
|
1983 |
+
|
1984 |
+
clearHistory() {
|
1985 |
+
// Clear history point clouds
|
1986 |
+
this.historyPointClouds.forEach(pointCloud => {
|
1987 |
+
if (pointCloud.geometry) pointCloud.geometry.dispose();
|
1988 |
+
if (pointCloud.material) pointCloud.material.dispose();
|
1989 |
+
this.scene.remove(pointCloud);
|
1990 |
+
});
|
1991 |
+
this.historyPointClouds = [];
|
1992 |
+
|
1993 |
+
// Clear history trajectories
|
1994 |
+
this.historyTrajectories.forEach(trajectoryGroup => {
|
1995 |
+
trajectoryGroup.children.forEach(child => {
|
1996 |
+
if (child.geometry) child.geometry.dispose();
|
1997 |
+
if (child.material) child.material.dispose();
|
1998 |
+
});
|
1999 |
+
this.scene.remove(trajectoryGroup);
|
2000 |
+
});
|
2001 |
+
this.historyTrajectories = [];
|
2002 |
+
|
2003 |
+
this.historyFrames = [];
|
2004 |
+
}
|
2005 |
+
|
2006 |
+
arraysEqual(a, b) {
|
2007 |
+
if (a.length !== b.length) return false;
|
2008 |
+
for (let i = 0; i < a.length; i++) {
|
2009 |
+
if (a[i] !== b[i]) return false;
|
2010 |
+
}
|
2011 |
+
return true;
|
2012 |
+
}
|
2013 |
+
|
2014 |
+
toggleBackground() {
|
2015 |
+
const isWhiteBackground = this.ui.whiteBackground.checked;
|
2016 |
+
|
2017 |
+
if (isWhiteBackground) {
|
2018 |
+
// Switch to white background
|
2019 |
+
document.body.style.backgroundColor = '#ffffff';
|
2020 |
+
this.scene.background = new THREE.Color(0xffffff);
|
2021 |
+
|
2022 |
+
// Update UI elements for white background
|
2023 |
+
document.documentElement.style.setProperty('--bg', '#ffffff');
|
2024 |
+
document.documentElement.style.setProperty('--text', '#333333');
|
2025 |
+
document.documentElement.style.setProperty('--text-secondary', '#666666');
|
2026 |
+
document.documentElement.style.setProperty('--border', '#cccccc');
|
2027 |
+
document.documentElement.style.setProperty('--surface', '#f5f5f5');
|
2028 |
+
document.documentElement.style.setProperty('--shadow', 'rgba(0, 0, 0, 0.1)');
|
2029 |
+
document.documentElement.style.setProperty('--shadow-hover', 'rgba(0, 0, 0, 0.2)');
|
2030 |
+
|
2031 |
+
// Update status bar and control panel backgrounds
|
2032 |
+
this.ui.statusBar.style.background = 'rgba(245, 245, 245, 0.9)';
|
2033 |
+
this.ui.statusBar.style.color = '#333333';
|
2034 |
+
|
2035 |
+
const controlPanel = document.getElementById('control-panel');
|
2036 |
+
if (controlPanel) {
|
2037 |
+
controlPanel.style.background = 'rgba(245, 245, 245, 0.95)';
|
2038 |
+
}
|
2039 |
+
|
2040 |
+
const settingsPanel = document.getElementById('settings-panel');
|
2041 |
+
if (settingsPanel) {
|
2042 |
+
settingsPanel.style.background = 'rgba(245, 245, 245, 0.98)';
|
2043 |
+
}
|
2044 |
+
|
2045 |
+
} else {
|
2046 |
+
// Switch back to dark background
|
2047 |
+
document.body.style.backgroundColor = '#1a1a1a';
|
2048 |
+
this.scene.background = new THREE.Color(0x1a1a1a);
|
2049 |
+
|
2050 |
+
// Restore original dark theme variables
|
2051 |
+
document.documentElement.style.setProperty('--bg', '#1a1a1a');
|
2052 |
+
document.documentElement.style.setProperty('--text', '#e0e0e0');
|
2053 |
+
document.documentElement.style.setProperty('--text-secondary', '#a0a0a0');
|
2054 |
+
document.documentElement.style.setProperty('--border', '#444444');
|
2055 |
+
document.documentElement.style.setProperty('--surface', '#2c2c2c');
|
2056 |
+
document.documentElement.style.setProperty('--shadow', 'rgba(0, 0, 0, 0.2)');
|
2057 |
+
document.documentElement.style.setProperty('--shadow-hover', 'rgba(0, 0, 0, 0.3)');
|
2058 |
+
|
2059 |
+
// Restore original UI backgrounds
|
2060 |
+
this.ui.statusBar.style.background = 'rgba(30, 30, 30, 0.9)';
|
2061 |
+
this.ui.statusBar.style.color = '#e0e0e0';
|
2062 |
+
|
2063 |
+
const controlPanel = document.getElementById('control-panel');
|
2064 |
+
if (controlPanel) {
|
2065 |
+
controlPanel.style.background = 'rgba(44, 44, 44, 0.95)';
|
2066 |
+
}
|
2067 |
+
|
2068 |
+
const settingsPanel = document.getElementById('settings-panel');
|
2069 |
+
if (settingsPanel) {
|
2070 |
+
settingsPanel.style.background = 'rgba(44, 44, 44, 0.98)';
|
2071 |
+
}
|
2072 |
+
}
|
2073 |
+
|
2074 |
+
// Show status message
|
2075 |
+
this.ui.statusBar.textContent = isWhiteBackground ? "Switched to white background" : "Switched to dark background";
|
2076 |
+
this.ui.statusBar.classList.remove('hidden');
|
2077 |
+
|
2078 |
+
setTimeout(() => {
|
2079 |
+
this.ui.statusBar.classList.add('hidden');
|
2080 |
+
}, 2000);
|
2081 |
+
}
|
2082 |
+
|
2083 |
resetSettings() {
|
2084 |
if (!this.defaultSettings) return;
|
2085 |
|
2086 |
this.applyDefaultSettings();
|
2087 |
|
2088 |
+
// Reset background to dark theme
|
2089 |
+
if (this.ui.whiteBackground) {
|
2090 |
+
this.ui.whiteBackground.checked = false;
|
2091 |
+
this.toggleBackground();
|
2092 |
+
}
|
2093 |
+
|
2094 |
this.updatePointCloudSettings();
|
2095 |
this.updateTrajectorySettings();
|
2096 |
this.updateFrustumDimensions();
|
2097 |
|
2098 |
+
// Clear history when resetting settings
|
2099 |
+
this.clearHistory();
|
2100 |
+
|
2101 |
this.ui.statusBar.textContent = "Settings reset to defaults";
|
2102 |
this.ui.statusBar.classList.remove('hidden');
|
2103 |
|
app.py
CHANGED
@@ -26,6 +26,9 @@ import logging
|
|
26 |
from concurrent.futures import ThreadPoolExecutor
|
27 |
import atexit
|
28 |
import uuid
|
|
|
|
|
|
|
29 |
|
30 |
# Configure logging
|
31 |
logging.basicConfig(level=logging.INFO)
|
@@ -78,20 +81,15 @@ def create_user_temp_dir():
|
|
78 |
return temp_dir
|
79 |
|
80 |
from huggingface_hub import hf_hub_download
|
81 |
-
# init the model
|
82 |
-
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
vggt_model = VGGT_MoE()
|
88 |
-
vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
|
89 |
-
vggt_model.eval()
|
90 |
-
vggt_model = vggt_model.to("cuda")
|
91 |
|
92 |
# Global model initialization
|
93 |
print("🚀 Initializing local models...")
|
94 |
-
tracker_model
|
|
|
95 |
predictor = get_sam_predictor()
|
96 |
print("✅ Models loaded successfully!")
|
97 |
|
@@ -131,7 +129,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|
131 |
print("Initializing tracker models inside GPU function...")
|
132 |
out_dir = os.path.join(temp_dir, "results")
|
133 |
os.makedirs(out_dir, exist_ok=True)
|
134 |
-
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
|
|
135 |
|
136 |
# Setup paths
|
137 |
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
@@ -161,25 +160,23 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|
161 |
data_npz_load = {}
|
162 |
|
163 |
# run vggt
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
with torch.
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
173 |
-
|
174 |
-
depth_tensor = depth_map.squeeze().cpu().numpy()
|
175 |
-
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
176 |
-
extrs = extrinsic.squeeze().cpu().numpy()
|
177 |
-
intrs = intrinsic.squeeze().cpu().numpy()
|
178 |
-
video_tensor = video_tensor.squeeze()
|
179 |
-
#NOTE: 20% of the depth is not reliable
|
180 |
-
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
181 |
-
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
# Load and process mask
|
184 |
if os.path.exists(mask_path):
|
185 |
mask = cv2.imread(mask_path)
|
@@ -201,7 +198,6 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|
201 |
|
202 |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
|
203 |
print(f"Query points shape: {query_xyt.shape}")
|
204 |
-
|
205 |
# Run model inference
|
206 |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
207 |
(
|
@@ -212,8 +208,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|
212 |
queries=query_xyt,
|
213 |
fps=1, full_point=False, iters_track=4,
|
214 |
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
|
215 |
-
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
216 |
-
|
217 |
# Resize results to avoid large I/O
|
218 |
max_size = 224
|
219 |
h, w = video.shape[2:]
|
@@ -1117,7 +1113,7 @@ if __name__ == "__main__":
|
|
1117 |
demo.launch(
|
1118 |
server_name="0.0.0.0",
|
1119 |
server_port=7860,
|
1120 |
-
share=
|
1121 |
debug=True,
|
1122 |
show_error=True
|
1123 |
)
|
|
|
26 |
from concurrent.futures import ThreadPoolExecutor
|
27 |
import atexit
|
28 |
import uuid
|
29 |
+
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
|
30 |
+
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
31 |
+
from models.SpaTrackV2.models.predictor import Predictor
|
32 |
|
33 |
# Configure logging
|
34 |
logging.basicConfig(level=logging.INFO)
|
|
|
81 |
return temp_dir
|
82 |
|
83 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
84 |
|
85 |
+
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
|
86 |
+
vggt4track_model.eval()
|
87 |
+
vggt4track_model = vggt4track_model.to("cuda")
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Global model initialization
|
90 |
print("🚀 Initializing local models...")
|
91 |
+
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
92 |
+
tracker_model.eval()
|
93 |
predictor = get_sam_predictor()
|
94 |
print("✅ Models loaded successfully!")
|
95 |
|
|
|
129 |
print("Initializing tracker models inside GPU function...")
|
130 |
out_dir = os.path.join(temp_dir, "results")
|
131 |
os.makedirs(out_dir, exist_ok=True)
|
132 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
133 |
+
tracker_model=tracker_model.cuda())
|
134 |
|
135 |
# Setup paths
|
136 |
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
|
|
160 |
data_npz_load = {}
|
161 |
|
162 |
# run vggt
|
163 |
+
# process the image tensor
|
164 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
165 |
+
with torch.no_grad():
|
166 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
167 |
+
# Predict attributes including cameras, depth maps, and point maps.
|
168 |
+
predictions = vggt4track_model(video_tensor.cuda()/255)
|
169 |
+
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
170 |
+
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
173 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
174 |
+
extrs = extrinsic.squeeze().cpu().numpy()
|
175 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
176 |
+
video_tensor = video_tensor.squeeze()
|
177 |
+
#NOTE: 20% of the depth is not reliable
|
178 |
+
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
179 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
180 |
# Load and process mask
|
181 |
if os.path.exists(mask_path):
|
182 |
mask = cv2.imread(mask_path)
|
|
|
198 |
|
199 |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
|
200 |
print(f"Query points shape: {query_xyt.shape}")
|
|
|
201 |
# Run model inference
|
202 |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
203 |
(
|
|
|
208 |
queries=query_xyt,
|
209 |
fps=1, full_point=False, iters_track=4,
|
210 |
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
|
211 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
212 |
+
|
213 |
# Resize results to avoid large I/O
|
214 |
max_size = 224
|
215 |
h, w = video.shape[2:]
|
|
|
1113 |
demo.launch(
|
1114 |
server_name="0.0.0.0",
|
1115 |
server_port=7860,
|
1116 |
+
share=False,
|
1117 |
debug=True,
|
1118 |
show_error=True
|
1119 |
)
|
app_3rd/spatrack_utils/infer_track.py
CHANGED
@@ -20,7 +20,7 @@ from huggingface_hub import hf_hub_download
|
|
20 |
|
21 |
config = {
|
22 |
"ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
|
23 |
-
"cfg_dir": "config/
|
24 |
}
|
25 |
|
26 |
def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
|
|
|
20 |
|
21 |
config = {
|
22 |
"ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
|
23 |
+
"cfg_dir": "config/magic_infer_offline.yaml",
|
24 |
}
|
25 |
|
26 |
def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
|
models/SpaTrackV2/models/SpaTrack.py
CHANGED
@@ -40,6 +40,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
|
40 |
resolution=518,
|
41 |
max_len=600, # the maximum video length we can preprocess,
|
42 |
track_num=768,
|
|
|
43 |
):
|
44 |
|
45 |
self.chunk_size = chunk_size
|
@@ -51,26 +52,29 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
|
51 |
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
|
52 |
|
53 |
super(SpaTrack2, self).__init__()
|
54 |
-
if
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
else:
|
57 |
-
|
58 |
-
base_model = MoGeModel(**checkpoint["model_config"])
|
59 |
-
base_model.load_state_dict(checkpoint['model'])
|
60 |
# avoid the base_model is a member of SpaTrack2
|
61 |
object.__setattr__(self, 'base_model', base_model)
|
62 |
|
63 |
# Tracker model
|
64 |
self.Track3D = TrackRefiner3D(Track_cfg)
|
65 |
-
track_base_ckpt_dir = Track_cfg
|
66 |
if os.path.exists(track_base_ckpt_dir):
|
67 |
track_pretrain = torch.load(track_base_ckpt_dir)
|
68 |
self.Track3D.load_state_dict(track_pretrain, strict=False)
|
69 |
|
70 |
# wrap the function of make lora trainable
|
71 |
self.make_paras_trainable = partial(self.make_paras_trainable,
|
72 |
-
mode=ft_cfg
|
73 |
-
paras_name=ft_cfg
|
74 |
self.track_num = track_num
|
75 |
|
76 |
def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
|
@@ -145,7 +149,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
|
145 |
):
|
146 |
# step 1 allocate the query points on the grid
|
147 |
T, C, H, W = video.shape
|
148 |
-
|
149 |
if annots_train is not None:
|
150 |
vis_gt = annots_train["vis"]
|
151 |
_, _, N = vis_gt.shape
|
@@ -296,39 +300,6 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
|
296 |
**kwargs, annots=annots)
|
297 |
if self.training:
|
298 |
loss += out["loss"].squeeze()
|
299 |
-
# from models.SpaTrackV2.utils.visualizer import Visualizer
|
300 |
-
# vis_track = Visualizer(grayscale=False,
|
301 |
-
# fps=10, pad_value=50, tracks_leave_trace=0)
|
302 |
-
# vis_track.visualize(video=segment,
|
303 |
-
# tracks=out["traj_est"][...,:2],
|
304 |
-
# visibility=out["vis_est"],
|
305 |
-
# save_video=True)
|
306 |
-
# # visualize 4d
|
307 |
-
# import os, json
|
308 |
-
# import os.path as osp
|
309 |
-
# viser4d_dir = os.path.join("viser_4d_results")
|
310 |
-
# os.makedirs(viser4d_dir, exist_ok=True)
|
311 |
-
# depth_est = annots["depth_gt"][0]
|
312 |
-
# unc_metric = out["unc_metric"]
|
313 |
-
# mask = (unc_metric > 0.5).squeeze(1)
|
314 |
-
# # pose_est = out["poses_pred"].squeeze(0)
|
315 |
-
# pose_est = annots["traj_mat"][0]
|
316 |
-
# rgb_tracks = out["rgb_tracks"].squeeze(0)
|
317 |
-
# intrinsics = out["intrs"].squeeze(0)
|
318 |
-
# for i_k in range(out["depth"].shape[0]):
|
319 |
-
# img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
|
320 |
-
# img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
|
321 |
-
# cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
|
322 |
-
# if stage == 1:
|
323 |
-
# depth = depth_est[i_k].squeeze().cpu().numpy()
|
324 |
-
# np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
|
325 |
-
# else:
|
326 |
-
# point_map_vis = out["points_map"][i_k].cpu().numpy()
|
327 |
-
# np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
|
328 |
-
# np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
|
329 |
-
# np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
|
330 |
-
# np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
|
331 |
-
# np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
|
332 |
|
333 |
queries_len = len(queries_new)
|
334 |
# update the track3d and track2d
|
@@ -720,40 +691,3 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
|
720 |
}
|
721 |
|
722 |
return ret
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
# three stages of training
|
728 |
-
|
729 |
-
# stage 1:
|
730 |
-
# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
|
731 |
-
# Tracking and Pose as well -> based on gt depth and intrinsics
|
732 |
-
# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
|
733 |
-
|
734 |
-
# stage 2: fixed 3D tracking
|
735 |
-
# Joint depth refiner
|
736 |
-
# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
|
737 |
-
# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
|
738 |
-
# ongoing two days
|
739 |
-
|
740 |
-
# stage 3: train multi windows by propagation
|
741 |
-
# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
|
742 |
-
|
743 |
-
# types of scenarioes:
|
744 |
-
# 1. auto driving (waymo open dataset)
|
745 |
-
# 2. robot
|
746 |
-
# 3. internet ego video
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
|
751 |
-
# Update Variables:
|
752 |
-
# 1. 3D tracks B T N 3 xyz.
|
753 |
-
# 2. 2D tracks B T N 2 x y.
|
754 |
-
# 3. Dynamic Mask B T H W.
|
755 |
-
# 4. Camera Pose B T 4 4.
|
756 |
-
# 5. Video Depth.
|
757 |
-
|
758 |
-
# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
|
759 |
-
# Campatiablity by product.
|
|
|
40 |
resolution=518,
|
41 |
max_len=600, # the maximum video length we can preprocess,
|
42 |
track_num=768,
|
43 |
+
moge_as_base=False,
|
44 |
):
|
45 |
|
46 |
self.chunk_size = chunk_size
|
|
|
52 |
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
|
53 |
|
54 |
super(SpaTrack2, self).__init__()
|
55 |
+
if moge_as_base:
|
56 |
+
if os.path.exists(backbone_ckpt_dir)==False:
|
57 |
+
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
|
58 |
+
else:
|
59 |
+
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
|
60 |
+
base_model = MoGeModel(**checkpoint["model_config"])
|
61 |
+
base_model.load_state_dict(checkpoint['model'])
|
62 |
else:
|
63 |
+
base_model = None
|
|
|
|
|
64 |
# avoid the base_model is a member of SpaTrack2
|
65 |
object.__setattr__(self, 'base_model', base_model)
|
66 |
|
67 |
# Tracker model
|
68 |
self.Track3D = TrackRefiner3D(Track_cfg)
|
69 |
+
track_base_ckpt_dir = Track_cfg["base_ckpt"]
|
70 |
if os.path.exists(track_base_ckpt_dir):
|
71 |
track_pretrain = torch.load(track_base_ckpt_dir)
|
72 |
self.Track3D.load_state_dict(track_pretrain, strict=False)
|
73 |
|
74 |
# wrap the function of make lora trainable
|
75 |
self.make_paras_trainable = partial(self.make_paras_trainable,
|
76 |
+
mode=ft_cfg["mode"],
|
77 |
+
paras_name=ft_cfg["paras_name"])
|
78 |
self.track_num = track_num
|
79 |
|
80 |
def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
|
|
|
149 |
):
|
150 |
# step 1 allocate the query points on the grid
|
151 |
T, C, H, W = video.shape
|
152 |
+
|
153 |
if annots_train is not None:
|
154 |
vis_gt = annots_train["vis"]
|
155 |
_, _, N = vis_gt.shape
|
|
|
300 |
**kwargs, annots=annots)
|
301 |
if self.training:
|
302 |
loss += out["loss"].squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
queries_len = len(queries_new)
|
305 |
# update the track3d and track2d
|
|
|
691 |
}
|
692 |
|
693 |
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SpaTrackV2/models/predictor.py
CHANGED
@@ -16,80 +16,20 @@ from typing import Union, Optional
|
|
16 |
import cv2
|
17 |
import os
|
18 |
import decord
|
|
|
19 |
|
20 |
-
class Predictor(torch.nn.Module):
|
21 |
def __init__(self, args=None):
|
22 |
super().__init__()
|
23 |
self.args = args
|
24 |
self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
|
25 |
-
self.S_wind = args
|
26 |
-
self.overlap = args
|
27 |
|
28 |
def to(self, device: Union[str, torch.device]):
|
29 |
self.spatrack.to(device)
|
30 |
-
self.spatrack.base_model
|
31 |
-
|
32 |
-
@classmethod
|
33 |
-
def from_pretrained(
|
34 |
-
cls,
|
35 |
-
pretrained_model_name_or_path: Union[str, Path],
|
36 |
-
*,
|
37 |
-
force_download: bool = False,
|
38 |
-
cache_dir: Optional[str] = None,
|
39 |
-
device: Optional[Union[str, torch.device]] = None,
|
40 |
-
model_cfg: Optional[dict] = None,
|
41 |
-
**kwargs,
|
42 |
-
) -> "SpaTrack2":
|
43 |
-
"""
|
44 |
-
Load a pretrained model from a local file or a remote repository.
|
45 |
-
|
46 |
-
Args:
|
47 |
-
pretrained_model_name_or_path (str or Path):
|
48 |
-
- Path to a local model file (e.g., `./model.pth`).
|
49 |
-
- HuggingFace Hub model ID (e.g., `username/model-name`).
|
50 |
-
force_download (bool, optional):
|
51 |
-
Whether to force re-download even if cached. Default: False.
|
52 |
-
cache_dir (str, optional):
|
53 |
-
Custom cache directory. Default: None (use default cache).
|
54 |
-
device (str or torch.device, optional):
|
55 |
-
Target device (e.g., "cuda", "cpu"). Default: None (keep original).
|
56 |
-
**kwargs:
|
57 |
-
Additional config overrides.
|
58 |
-
|
59 |
-
Returns:
|
60 |
-
SpaTrack2: Loaded pretrained model.
|
61 |
-
"""
|
62 |
-
# (1) check the path is local or remote
|
63 |
-
if isinstance(pretrained_model_name_or_path, Path):
|
64 |
-
model_path = str(pretrained_model_name_or_path)
|
65 |
-
else:
|
66 |
-
model_path = pretrained_model_name_or_path
|
67 |
-
# (2) if the path is remote, download it
|
68 |
-
if not os.path.exists(model_path):
|
69 |
-
raise NotImplementedError("Remote download not implemented yet. Use a local path.")
|
70 |
-
# (3) load the model weights
|
71 |
-
|
72 |
-
state_dict = torch.load(model_path, map_location="cpu")
|
73 |
-
# (4) initialize the model (can load config.json if exists)
|
74 |
-
config_path = os.path.join(os.path.dirname(model_path), "config.json")
|
75 |
-
config = {}
|
76 |
-
if os.path.exists(config_path):
|
77 |
-
import json
|
78 |
-
with open(config_path, "r") as f:
|
79 |
-
config.update(json.load(f))
|
80 |
-
config.update(kwargs) # allow override the config
|
81 |
-
if model_cfg is not None:
|
82 |
-
config = model_cfg
|
83 |
-
model = cls(config)
|
84 |
-
if "model" in state_dict:
|
85 |
-
model.spatrack.load_state_dict(state_dict["model"], strict=False)
|
86 |
-
else:
|
87 |
-
model.spatrack.load_state_dict(state_dict, strict=False)
|
88 |
-
# (5) device management
|
89 |
-
if device is not None:
|
90 |
-
model.to(device)
|
91 |
-
|
92 |
-
return model
|
93 |
|
94 |
def forward(self, video: str|torch.Tensor|np.ndarray,
|
95 |
depth: str|torch.Tensor|np.ndarray=None,
|
@@ -145,7 +85,6 @@ class Predictor(torch.nn.Module):
|
|
145 |
window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
|
146 |
fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
|
147 |
|
148 |
-
|
149 |
return ret
|
150 |
|
151 |
|
|
|
16 |
import cv2
|
17 |
import os
|
18 |
import decord
|
19 |
+
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
20 |
|
21 |
+
class Predictor(torch.nn.Module, PyTorchModelHubMixin):
|
22 |
def __init__(self, args=None):
|
23 |
super().__init__()
|
24 |
self.args = args
|
25 |
self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
|
26 |
+
self.S_wind = args["Track_cfg"]["s_wind"]
|
27 |
+
self.overlap = args["Track_cfg"]["overlap"]
|
28 |
|
29 |
def to(self, device: Union[str, torch.device]):
|
30 |
self.spatrack.to(device)
|
31 |
+
if self.spatrack.base_model is not None:
|
32 |
+
self.spatrack.base_model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def forward(self, video: str|torch.Tensor|np.ndarray,
|
35 |
depth: str|torch.Tensor|np.ndarray=None,
|
|
|
85 |
window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
|
86 |
fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
|
87 |
|
|
|
88 |
return ret
|
89 |
|
90 |
|
models/SpaTrackV2/models/tracker3D/TrackRefiner.py
CHANGED
@@ -24,14 +24,13 @@ from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import (
|
|
24 |
)
|
25 |
from models.SpaTrackV2.models.tracker3D.spatrack_modules.ba import extract_static_from_3DTracks, ba_pycolmap
|
26 |
from models.SpaTrackV2.models.tracker3D.spatrack_modules.pointmap_updator import PointMapUpdator
|
27 |
-
from models.SpaTrackV2.models.depth_refiner.depth_refiner import TrackStablizer
|
28 |
from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import affine_invariant_global_loss
|
29 |
from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import UpsampleTransformerAlibi
|
30 |
|
31 |
class TrackRefiner3D(CoTrackerThreeOffline):
|
32 |
|
33 |
def __init__(self, args=None):
|
34 |
-
super().__init__(**args
|
35 |
|
36 |
"""
|
37 |
This is 3D warpper from cotracker, which load the cotracker pretrain and
|
@@ -47,15 +46,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
|
|
47 |
self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
|
48 |
# get the anchor point's embedding, and init the pts refiner
|
49 |
update_pts = True
|
50 |
-
|
51 |
-
# CorrPointformer(
|
52 |
-
# dim=128,
|
53 |
-
# num_heads=8,
|
54 |
-
# head_dim=128 // 8,
|
55 |
-
# mlp_ratio=4.0,
|
56 |
-
# )
|
57 |
-
# for _ in range(self.corr_levels)
|
58 |
-
# ])
|
59 |
self.corr_transformer = nn.ModuleList([
|
60 |
CorrPointformer(
|
61 |
dim=128,
|
@@ -68,29 +59,11 @@ class TrackRefiner3D(CoTrackerThreeOffline):
|
|
68 |
self.fnet = BasicEncoder(input_dim=3,
|
69 |
output_dim=self.latent_dim, stride=self.stride)
|
70 |
self.corr3d_radius = 3
|
71 |
-
|
72 |
-
if args.stablizer:
|
73 |
-
self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
|
74 |
-
self.upsample_kernel_size = 5
|
75 |
-
self.residual_embedding = nn.Parameter(torch.randn(
|
76 |
-
self.latent_dim, self.model_resolution[0]//16,
|
77 |
-
self.model_resolution[1]//16, requires_grad=True))
|
78 |
-
self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
|
79 |
-
self.upsample_factor = 4
|
80 |
-
self.upsample_transformer = UpsampleTransformerAlibi(
|
81 |
-
kernel_size=self.upsample_kernel_size, # kernel_size=3, #
|
82 |
-
stride=self.stride,
|
83 |
-
latent_dim=self.latent_dim,
|
84 |
-
num_attn_blocks=2,
|
85 |
-
upsample_factor=4,
|
86 |
-
)
|
87 |
-
else:
|
88 |
-
self.update_pointmap = None
|
89 |
|
90 |
-
self.mode = args
|
91 |
if self.mode == "online":
|
92 |
-
self.s_wind = args
|
93 |
-
self.overlap = args
|
94 |
|
95 |
def upsample_with_mask(
|
96 |
self, inp: torch.Tensor, mask: torch.Tensor
|
@@ -1062,29 +1035,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
|
|
1062 |
vis_est = (vis_est>0.5).float()
|
1063 |
sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
|
1064 |
# coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
|
1065 |
-
|
1066 |
-
# import pdb; pdb.set_trace()
|
1067 |
-
|
1068 |
-
if False:
|
1069 |
-
point_map_resize = point_map.clone().view(B, T, 3, H, W)
|
1070 |
-
update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
|
1071 |
-
coords_append_resize = coords.clone().detach()
|
1072 |
-
coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
|
1073 |
-
update_track_input = self.norm_xyz(cam_pts_est)*5
|
1074 |
-
update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
|
1075 |
-
update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
|
1076 |
-
update = self.update_pointmap.stablizer(update_input,
|
1077 |
-
update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
|
1078 |
-
#NOTE: update the point map
|
1079 |
-
point_map_resize += update
|
1080 |
-
point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
|
1081 |
-
size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
|
1082 |
-
point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
|
1083 |
-
point_map_preds.append(self.denorm_xyz(point_map_refine_out))
|
1084 |
-
point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
|
1085 |
-
|
1086 |
-
# if torch.isnan(coords).sum()>0:
|
1087 |
-
# import pdb; pdb.set_trace()
|
1088 |
#NOTE: the 2d tracking + unproject depth
|
1089 |
fix_cam_est = coords_append.clone()
|
1090 |
fix_cam_est[...,2] = depth_unproj
|
|
|
24 |
)
|
25 |
from models.SpaTrackV2.models.tracker3D.spatrack_modules.ba import extract_static_from_3DTracks, ba_pycolmap
|
26 |
from models.SpaTrackV2.models.tracker3D.spatrack_modules.pointmap_updator import PointMapUpdator
|
|
|
27 |
from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import affine_invariant_global_loss
|
28 |
from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import UpsampleTransformerAlibi
|
29 |
|
30 |
class TrackRefiner3D(CoTrackerThreeOffline):
|
31 |
|
32 |
def __init__(self, args=None):
|
33 |
+
super().__init__(**args["base"])
|
34 |
|
35 |
"""
|
36 |
This is 3D warpper from cotracker, which load the cotracker pretrain and
|
|
|
46 |
self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
|
47 |
# get the anchor point's embedding, and init the pts refiner
|
48 |
update_pts = True
|
49 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
self.corr_transformer = nn.ModuleList([
|
51 |
CorrPointformer(
|
52 |
dim=128,
|
|
|
59 |
self.fnet = BasicEncoder(input_dim=3,
|
60 |
output_dim=self.latent_dim, stride=self.stride)
|
61 |
self.corr3d_radius = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
self.mode = args["mode"]
|
64 |
if self.mode == "online":
|
65 |
+
self.s_wind = args["s_wind"]
|
66 |
+
self.overlap = args["overlap"]
|
67 |
|
68 |
def upsample_with_mask(
|
69 |
self, inp: torch.Tensor, mask: torch.Tensor
|
|
|
1035 |
vis_est = (vis_est>0.5).float()
|
1036 |
sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
|
1037 |
# coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
|
1038 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1039 |
#NOTE: the 2d tracking + unproject depth
|
1040 |
fix_cam_est = coords_append.clone()
|
1041 |
fix_cam_est[...,2] = depth_unproj
|
requirements.txt
CHANGED
@@ -22,8 +22,8 @@ git+https://github.com/facebookresearch/segment-anything.git
|
|
22 |
git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d
|
23 |
huggingface_hub
|
24 |
pyceres
|
25 |
-
kornia
|
26 |
-
xformers
|
27 |
timm
|
28 |
PyJWT
|
29 |
gdown
|
|
|
22 |
git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d
|
23 |
huggingface_hub
|
24 |
pyceres
|
25 |
+
kornia==0.8.1
|
26 |
+
xformers==0.0.28
|
27 |
timm
|
28 |
PyJWT
|
29 |
gdown
|