Spaces:
Runtime error
Runtime error
/** | |
* Copyright (c) Meta Platforms, Inc. and affiliates. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
import {generateThumbnail} from '@/common/components/video/editor/VideoEditorUtils'; | |
import VideoWorkerContext from '@/common/components/video/VideoWorkerContext'; | |
import Logger from '@/common/logger/Logger'; | |
import { | |
SAM2ModelAddNewPointsMutation, | |
SAM2ModelAddNewPointsMutation$data, | |
} from '@/common/tracker/__generated__/SAM2ModelAddNewPointsMutation.graphql'; | |
import {SAM2ModelCancelPropagateInVideoMutation} from '@/common/tracker/__generated__/SAM2ModelCancelPropagateInVideoMutation.graphql'; | |
import {SAM2ModelClearPointsInFrameMutation} from '@/common/tracker/__generated__/SAM2ModelClearPointsInFrameMutation.graphql'; | |
import {SAM2ModelClearPointsInVideoMutation} from '@/common/tracker/__generated__/SAM2ModelClearPointsInVideoMutation.graphql'; | |
import {SAM2ModelCloseSessionMutation} from '@/common/tracker/__generated__/SAM2ModelCloseSessionMutation.graphql'; | |
import {SAM2ModelRemoveObjectMutation} from '@/common/tracker/__generated__/SAM2ModelRemoveObjectMutation.graphql'; | |
import {SAM2ModelStartSessionMutation} from '@/common/tracker/__generated__/SAM2ModelStartSessionMutation.graphql'; | |
import { | |
BaseTracklet, | |
Mask, | |
SegmentationPoint, | |
StreamingState, | |
Tracker, | |
Tracklet, | |
} from '@/common/tracker/Tracker'; | |
import {TrackerOptions} from '@/common/tracker/Trackers'; | |
import { | |
ClearPointsInVideoResponse, | |
SessionStartFailedResponse, | |
SessionStartedResponse, | |
StreamingCompletedResponse, | |
StreamingStartedResponse, | |
StreamingStateUpdateResponse, | |
TrackletCreatedResponse, | |
TrackletDeletedResponse, | |
TrackletsUpdatedResponse, | |
} from '@/common/tracker/TrackerTypes'; | |
import {convertMaskToRGBA} from '@/common/utils/MaskUtils'; | |
import multipartStream from '@/common/utils/MultipartStream'; | |
import {Stats} from '@/debug/stats/Stats'; | |
import {INFERENCE_API_ENDPOINT} from '@/demo/DemoConfig'; | |
import {createEnvironment} from '@/graphql/RelayEnvironment'; | |
import { | |
DataArray, | |
Masks, | |
RLEObject, | |
decode, | |
encode, | |
toBbox, | |
} from '@/jscocotools/mask'; | |
import {THEME_COLORS} from '@/theme/colors'; | |
import invariant from 'invariant'; | |
import {IEnvironment, commitMutation, graphql} from 'relay-runtime'; | |
type Options = Pick<TrackerOptions, 'inferenceEndpoint'>; | |
type Session = { | |
id: string | null; | |
tracklets: {[id: number]: Tracklet}; | |
}; | |
type StreamMasksResult = { | |
frameIndex: number; | |
rleMaskList: Array<{ | |
objectId: number; | |
rleMask: RLEObject; | |
}>; | |
}; | |
type StreamMasksAbortResult = { | |
aborted: boolean; | |
}; | |
export class SAM2Model extends Tracker { | |
private _endpoint: string; | |
private _environment: IEnvironment; | |
private abortController: AbortController | null = null; | |
private _session: Session = { | |
id: null, | |
tracklets: {}, | |
}; | |
private _streamingState: StreamingState = 'none'; | |
private _emptyMask: RLEObject | null = null; | |
private _maskCanvas: OffscreenCanvas; | |
private _maskCtx: OffscreenCanvasRenderingContext2D; | |
private _stats?: Stats; | |
constructor( | |
context: VideoWorkerContext, | |
options: Options = { | |
inferenceEndpoint: INFERENCE_API_ENDPOINT, | |
}, | |
) { | |
super(context); | |
this._endpoint = options.inferenceEndpoint; | |
this._environment = createEnvironment(options.inferenceEndpoint); | |
this._maskCanvas = new OffscreenCanvas(0, 0); | |
const maskCtx = this._maskCanvas.getContext('2d'); | |
invariant(maskCtx != null, 'context cannot be null'); | |
this._maskCtx = maskCtx; | |
} | |
public startSession(videoPath: string): Promise<void> { | |
// Reset streaming state. Force update with the true flag to make sure the | |
// UI updates its state. | |
this._updateStreamingState('none', true); | |
return new Promise(resolve => { | |
try { | |
commitMutation<SAM2ModelStartSessionMutation>(this._environment, { | |
mutation: graphql` | |
mutation SAM2ModelStartSessionMutation($input: StartSessionInput!) { | |
startSession(input: $input) { | |
sessionId | |
} | |
} | |
`, | |
variables: { | |
input: { | |
path: videoPath, | |
}, | |
}, | |
onCompleted: response => { | |
const {sessionId} = response.startSession; | |
this._session.id = sessionId; | |
this._sendResponse<SessionStartedResponse>('sessionStarted', { | |
sessionId, | |
}); | |
// Clear any tracklets from the previous session when | |
// a new session is started | |
this._clearTracklets(); | |
// Make an empty tracklet | |
this.createTracklet(); | |
resolve(); | |
}, | |
onError: error => { | |
Logger.error(error); | |
this._sendResponse<SessionStartFailedResponse>( | |
'sessionStartFailed', | |
); | |
resolve(); | |
}, | |
}); | |
} catch (error) { | |
Logger.error(error); | |
this._sendResponse<SessionStartFailedResponse>('sessionStartFailed'); | |
resolve(); | |
} | |
}); | |
} | |
public closeSession(): Promise<void> { | |
const sessionId = this._session.id; | |
// Do not call cleanup before retrieving the session id because cleanup | |
// will reset the session id. If the order would be changed, it would | |
// never execute the closeSession mutation. | |
this._cleanup(); | |
if (sessionId === null) { | |
return Promise.resolve(); | |
} | |
return new Promise((resolve, reject) => { | |
commitMutation<SAM2ModelCloseSessionMutation>(this._environment, { | |
mutation: graphql` | |
mutation SAM2ModelCloseSessionMutation($input: CloseSessionInput!) { | |
closeSession(input: $input) { | |
success | |
} | |
} | |
`, | |
variables: { | |
input: { | |
sessionId, | |
}, | |
}, | |
onCompleted: response => { | |
const {success} = response.closeSession; | |
if (success === false) { | |
reject(new Error('Failed to close session')); | |
return; | |
} | |
resolve(); | |
}, | |
onError: error => { | |
Logger.error(error); | |
reject(error); | |
}, | |
}); | |
}); | |
} | |
public createTracklet(): void { | |
// This will return 0 for for empty tracklets and otherwise the next | |
// largest number. | |
const nextId = | |
Object.values(this._session.tracklets).reduce( | |
(prev, curr) => Math.max(prev, curr.id), | |
-1, | |
) + 1; | |
const newTracklet = { | |
id: nextId, | |
color: THEME_COLORS[nextId % THEME_COLORS.length], | |
thumbnail: null, | |
points: [], | |
masks: [], | |
isInitialized: false, | |
}; | |
this._session.tracklets[nextId] = newTracklet; | |
// Notify the main thread | |
this._updateTracklets(); | |
this._sendResponse<TrackletCreatedResponse>('trackletCreated', { | |
tracklet: newTracklet, | |
}); | |
} | |
public deleteTracklet(trackletId: number): Promise<void> { | |
const sessionId = this._session.id; | |
if (sessionId === null) { | |
return Promise.reject('No active session'); | |
} | |
const tracklet = this._session.tracklets[trackletId]; | |
invariant( | |
tracklet != null, | |
'tracklet for tracklet id %s not initialized', | |
trackletId, | |
); | |
return new Promise((resolve, reject) => { | |
commitMutation<SAM2ModelRemoveObjectMutation>(this._environment, { | |
mutation: graphql` | |
mutation SAM2ModelRemoveObjectMutation($input: RemoveObjectInput!) { | |
removeObject(input: $input) { | |
frameIndex | |
rleMaskList { | |
objectId | |
rleMask { | |
counts | |
size | |
} | |
} | |
} | |
} | |
`, | |
variables: { | |
input: {objectId: trackletId, sessionId}, | |
}, | |
onCompleted: response => { | |
const trackletUpdates = response.removeObject; | |
this._sendResponse<TrackletDeletedResponse>('trackletDeleted', { | |
isSuccessful: true, | |
}); | |
for (const trackletUpdate of trackletUpdates) { | |
this._updateTrackletMasks( | |
trackletUpdate, | |
trackletUpdate.frameIndex === this._context.frameIndex, | |
false, // shouldGoToFrame | |
); | |
} | |
this._removeTrackletMasks(tracklet); | |
resolve(); | |
}, | |
onError: error => { | |
this._sendResponse<TrackletDeletedResponse>('trackletDeleted', { | |
isSuccessful: false, | |
}); | |
Logger.error(error); | |
reject(error); | |
}, | |
}); | |
}); | |
} | |
public updatePoints( | |
frameIndex: number, | |
objectId: number, | |
points: SegmentationPoint[], | |
): Promise<void> { | |
const sessionId = this._session.id; | |
if (sessionId === null) { | |
return Promise.reject('No active session'); | |
} | |
// TODO: This is not the right place to initialize the empty mask. | |
// Move this into the constructor and listen to events on the context. | |
// Note, the initial context.width and context.height is 0, so it needs | |
// to happen based on an event, so when the video is initialized, it needs | |
// to notify the tracker to update the empty mask. | |
if (this._emptyMask === null) { | |
// We need to round the height/width to the nearest integer since | |
// Masks.toTensor() expects an integer value for the height/width. | |
const tensor = new Masks( | |
Math.trunc(this._context.height), | |
Math.trunc(this._context.width), | |
1, | |
).toDataArray(); | |
this._emptyMask = encode(tensor)[0]; | |
} | |
const tracklet = this._session.tracklets[objectId]; | |
invariant( | |
tracklet != null, | |
'tracklet for object id %s not initialized', | |
objectId, | |
); | |
// Mark session needing propagation when point is set | |
this._updateStreamingState('required'); | |
// Clear all points in frame if no points are provided. | |
if (points.length === 0) { | |
return this.clearPointsInFrame(frameIndex, objectId); | |
} | |
return new Promise((resolve, reject) => { | |
const normalizedPoints = points.map(p => [ | |
p[0] / this._context.width, | |
p[1] / this._context.height, | |
]); | |
const labels = points.map(p => p[2]); | |
commitMutation<SAM2ModelAddNewPointsMutation>(this._environment, { | |
mutation: graphql` | |
mutation SAM2ModelAddNewPointsMutation($input: AddPointsInput!) { | |
addPoints(input: $input) { | |
frameIndex | |
rleMaskList { | |
objectId | |
rleMask { | |
counts | |
size | |
} | |
} | |
} | |
} | |
`, | |
variables: { | |
input: { | |
sessionId, | |
frameIndex, | |
objectId, | |
labels: labels, | |
points: normalizedPoints, | |
clearOldPoints: true, | |
}, | |
}, | |
onCompleted: response => { | |
tracklet.points[frameIndex] = points; | |
tracklet.isInitialized = true; | |
this._updateTrackletMasks(response.addPoints, true); | |
resolve(); | |
}, | |
onError: error => { | |
Logger.error(error); | |
reject(error); | |
}, | |
}); | |
}); | |
} | |
public clearPointsInFrame( | |
frameIndex: number, | |
objectId: number, | |
): Promise<void> { | |
const sessionId = this._session.id; | |
if (sessionId === null) { | |
return Promise.reject('No active session'); | |
} | |
const tracklet = this._session.tracklets[objectId]; | |
invariant( | |
tracklet != null, | |
'tracklet for object id %s not initialized', | |
objectId, | |
); | |
// Mark session needing propagation when point is set | |
this._updateStreamingState('required'); | |
return new Promise((resolve, reject) => { | |
commitMutation<SAM2ModelClearPointsInFrameMutation>(this._environment, { | |
mutation: graphql` | |
mutation SAM2ModelClearPointsInFrameMutation( | |
$input: ClearPointsInFrameInput! | |
) { | |
clearPointsInFrame(input: $input) { | |
frameIndex | |
rleMaskList { | |
objectId | |
rleMask { | |
counts | |
size | |
} | |
} | |
} | |
} | |
`, | |
variables: { | |
input: { | |
sessionId, | |
frameIndex, | |
objectId, | |
}, | |
}, | |
onCompleted: response => { | |
tracklet.points[frameIndex] = []; | |
tracklet.isInitialized = true; | |
this._updateTrackletMasks(response.clearPointsInFrame, true); | |
resolve(); | |
}, | |
onError: error => { | |
Logger.error(error); | |
reject(error); | |
}, | |
}); | |
}); | |
} | |
public clearPointsInVideo(): Promise<void> { | |
const sessionId = this._session.id; | |
if (sessionId === null) { | |
return Promise.reject('No active session'); | |
} | |
// Mark session needing propagation when point is set | |
this._updateStreamingState('none'); | |
return new Promise(resolve => { | |
commitMutation<SAM2ModelClearPointsInVideoMutation>(this._environment, { | |
mutation: graphql` | |
mutation SAM2ModelClearPointsInVideoMutation( | |
$input: ClearPointsInVideoInput! | |
) { | |
clearPointsInVideo(input: $input) { | |
success | |
} | |
} | |
`, | |
variables: { | |
input: { | |
sessionId, | |
}, | |
}, | |
onCompleted: response => { | |
const {success} = response.clearPointsInVideo; | |
if (!success) { | |
this._sendResponse<ClearPointsInVideoResponse>( | |
'clearPointsInVideo', | |
{isSuccessful: false}, | |
); | |
return; | |
} | |
// Reset points and masks for each tracklet | |
this._clearTracklets(); | |
// Notify the main thread | |
this._context.goToFrame(this._context.frameIndex); | |
this._updateTracklets(); | |
this._sendResponse<ClearPointsInVideoResponse>('clearPointsInVideo', { | |
isSuccessful: true, | |
}); | |
resolve(); | |
}, | |
onError: error => { | |
this._sendResponse<ClearPointsInVideoResponse>('clearPointsInVideo', { | |
isSuccessful: false, | |
}); | |
Logger.error(error); | |
}, | |
}); | |
}); | |
} | |
public async streamMasks(frameIndex: number): Promise<void> { | |
const sessionId = this._session.id; | |
if (sessionId === null) { | |
return Promise.reject('No active session'); | |
} | |
try { | |
this._sendResponse<StreamingStartedResponse>('streamingStarted'); | |
// 1. Clear previous masks | |
this._context.clearMasks(); | |
this._clearTrackletMasks(); | |
// 2. Create abort controller and async generator | |
const controller = new AbortController(); | |
this.abortController = controller; | |
this._updateStreamingState('requesting'); | |
const generator = this._streamMasksForSession( | |
controller, | |
sessionId, | |
frameIndex, | |
); | |
// 3. parse stream response and update masks in session objects | |
let isAborted = false; | |
for await (const result of generator) { | |
if ('aborted' in result) { | |
this._updateStreamingState('aborting'); | |
await this._abortRequest(); | |
this._updateStreamingState('aborted'); | |
isAborted = true; | |
} else { | |
await this._updateTrackletMasks(result, false); | |
this._updateStreamingState('partial'); | |
} | |
} | |
if (!isAborted) { | |
// Mark session needing propagation when point is set | |
this._updateStreamingState('full'); | |
} | |
} catch (error) { | |
Logger.error(error); | |
throw error; | |
} | |
this._sendResponse<StreamingCompletedResponse>('streamingCompleted'); | |
} | |
public abortStreamMasks() { | |
this.abortController?.abort(); | |
this._sendResponse<StreamingCompletedResponse>('streamingCompleted'); | |
} | |
public enableStats(): void { | |
this._stats = new Stats('ms', 'D', 1000 / 25); | |
} | |
// PRIVATE | |
private _cleanup() { | |
this._session.id = null; | |
// Clear existing tracklets | |
this._session.tracklets = []; | |
} | |
private _clearTracklets() { | |
this._session.tracklets = []; | |
this._context.clearMasks(); | |
} | |
private _updateStreamingState( | |
state: StreamingState, | |
forceUpdate: boolean = false, | |
) { | |
if (!forceUpdate && this._streamingState === state) { | |
return; | |
} | |
this._streamingState = state; | |
this._sendResponse<StreamingStateUpdateResponse>('streamingStateUpdate', { | |
state, | |
}); | |
} | |
private async _removeTrackletMasks(tracklet: Tracklet) { | |
this._context.clearTrackletMasks(tracklet); | |
delete this._session.tracklets[tracklet.id]; | |
// Notify the main thread | |
this._context.goToFrame(this._context.frameIndex); | |
this._updateTracklets(); | |
} | |
private async _updateTrackletMasks( | |
data: SAM2ModelAddNewPointsMutation$data['addPoints'], | |
updateThumbnails: boolean, | |
shouldGoToFrame: boolean = true, | |
) { | |
const {frameIndex, rleMaskList} = data; | |
// 1. parse and decode masks for all objects | |
for (const {objectId, rleMask} of rleMaskList) { | |
const track = this._session.tracklets[objectId]; | |
const {size, counts} = rleMask; | |
const rleObject: RLEObject = { | |
size: [size[0], size[1]], | |
counts: counts, | |
}; | |
const isEmpty = counts === this._emptyMask?.counts; | |
this._stats?.begin(); | |
const decodedMask = decode([rleObject]); | |
const bbox = toBbox([rleObject]); | |
const mask: Mask = { | |
data: rleObject as RLEObject, | |
shape: [...decodedMask.shape], | |
bounds: [ | |
[bbox[0], bbox[1]], | |
[bbox[0] + bbox[2], bbox[1] + bbox[3]], | |
], | |
isEmpty, | |
} as const; | |
track.masks[frameIndex] = mask; | |
if (updateThumbnails && !isEmpty) { | |
const {ctx} = await this._compressMaskForCanvas(decodedMask); | |
const frame = this._context.currentFrame as VideoFrame; | |
await generateThumbnail(track, frameIndex, mask, frame, ctx); | |
} | |
} | |
this._context.updateTracklets( | |
frameIndex, | |
Object.values(this._session.tracklets), | |
shouldGoToFrame, | |
); | |
// Notify the main thread | |
this._updateTracklets(); | |
} | |
private _updateTracklets() { | |
const tracklets: BaseTracklet[] = Object.values( | |
this._session.tracklets, | |
).map(tracklet => { | |
// Notify the main thread | |
const { | |
id, | |
color, | |
isInitialized, | |
points: trackletPoints, | |
thumbnail, | |
masks, | |
} = tracklet; | |
return { | |
id, | |
color, | |
isInitialized, | |
points: trackletPoints, | |
thumbnail, | |
masks: masks.map(mask => ({ | |
shape: mask.shape, | |
bounds: mask.bounds, | |
isEmpty: mask.isEmpty, | |
})), | |
}; | |
}); | |
this._sendResponse<TrackletsUpdatedResponse>('trackletsUpdated', { | |
tracklets, | |
}); | |
} | |
private _clearTrackletMasks() { | |
const keys = Object.keys(this._session.tracklets); | |
for (const key of keys) { | |
const trackletId = Number(key); | |
const tracklet = {...this._session.tracklets[trackletId], masks: []}; | |
this._session.tracklets[trackletId] = tracklet; | |
} | |
this._updateTracklets(); | |
} | |
private async _compressMaskForCanvas( | |
decodedMask: DataArray, | |
): Promise<{compressedData: Blob; ctx: OffscreenCanvasRenderingContext2D}> { | |
const data = convertMaskToRGBA(decodedMask.data as Uint8Array); | |
this._maskCanvas.width = decodedMask.shape[0]; | |
this._maskCanvas.height = decodedMask.shape[1]; | |
const imageData = new ImageData( | |
data, | |
decodedMask.shape[0], | |
decodedMask.shape[1], | |
); | |
this._maskCtx.putImageData(imageData, 0, 0); | |
const canvas = new OffscreenCanvas( | |
decodedMask.shape[1], | |
decodedMask.shape[0], | |
); | |
const ctx = canvas.getContext('2d'); | |
invariant(ctx != null, 'context cannot be null'); | |
ctx.save(); | |
ctx.rotate(Math.PI / 2); | |
// Since the image was previously rotated 90° clockwise, after the image is rotated, | |
// we scale the canvas's width using scaleY and height using scaleX. | |
ctx.scale(1, -1); | |
ctx.drawImage(this._maskCanvas, 0, 0); | |
ctx.restore(); | |
const compressedData = await canvas.convertToBlob({type: 'image/png'}); | |
return {compressedData, ctx}; | |
} | |
private async *_streamMasksForSession( | |
abortController: AbortController, | |
sessionId: string, | |
startFrameIndex: undefined | number = 0, | |
): AsyncGenerator<StreamMasksResult | StreamMasksAbortResult, undefined> { | |
const url = `${this._endpoint}/propagate_in_video`; | |
const requestBody = { | |
session_id: sessionId, | |
start_frame_index: startFrameIndex, | |
}; | |
const headers: {[name: string]: string} = Object.assign({ | |
'Content-Type': 'application/json', | |
}); | |
const response = await fetch(url, { | |
method: 'POST', | |
body: JSON.stringify(requestBody), | |
headers, | |
}); | |
const contentType = response.headers.get('Content-Type'); | |
if ( | |
contentType == null || | |
!contentType.startsWith('multipart/x-savi-stream;') | |
) { | |
throw new Error( | |
'endpoint needs to support Content-Type "multipart/x-savi-stream"', | |
); | |
} | |
const responseBody = response.body; | |
if (responseBody == null) { | |
throw new Error('response body is null'); | |
} | |
const reader = multipartStream(contentType, responseBody).getReader(); | |
const textDecoder = new TextDecoder(); | |
while (true) { | |
if (abortController.signal.aborted) { | |
reader.releaseLock(); | |
yield {aborted: true}; | |
return; | |
} | |
const {done, value} = await reader.read(); | |
if (done) { | |
return; | |
} | |
const {headers, body} = value; | |
const contentType = headers.get('Content-Type') as string; | |
if (contentType.startsWith('application/json')) { | |
const jsonResponse = JSON.parse(textDecoder.decode(body)); | |
const maskResults = jsonResponse.results; | |
const rleMaskList = maskResults.map( | |
(mask: {object_id: number; mask: RLEObject}) => { | |
return { | |
objectId: mask.object_id, | |
rleMask: mask.mask, | |
}; | |
}, | |
); | |
yield { | |
frameIndex: jsonResponse.frame_index, | |
rleMaskList, | |
}; | |
} | |
} | |
} | |
private async _abortRequest(): Promise<void> { | |
const sessionId = this._session.id; | |
invariant(sessionId != null, 'session id cannot be empty'); | |
return new Promise((resolve, reject) => { | |
try { | |
commitMutation<SAM2ModelCancelPropagateInVideoMutation>( | |
this._environment, | |
{ | |
mutation: graphql` | |
mutation SAM2ModelCancelPropagateInVideoMutation( | |
$input: CancelPropagateInVideoInput! | |
) { | |
cancelPropagateInVideo(input: $input) { | |
success | |
} | |
} | |
`, | |
variables: { | |
input: { | |
sessionId, | |
}, | |
}, | |
onCompleted: response => { | |
const {success} = response.cancelPropagateInVideo; | |
if (!success) { | |
reject(`could not abort session ${sessionId}`); | |
return; | |
} | |
resolve(); | |
}, | |
onError: error => { | |
Logger.error(error); | |
reject(error); | |
}, | |
}, | |
); | |
} catch (error) { | |
Logger.error(error); | |
reject(error); | |
} | |
}); | |
} | |
} | |