Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torchvision.transforms._transforms_video import (CenterCropVideo,NormalizeVideo) | |
from torchvision.transforms import (Compose,Lambda,RandomCrop,RandomHorizontalFlip,Resize) | |
from pytorchvideo.transforms import (ApplyTransformToKey,Normalize,RandomShortSideScale,UniformTemporalSubsample,Permute) | |
import numpy as np | |
import gradio as gr | |
import spaces | |
# torch.save(model.state_dict(), 'model.pth') | |
video_transform = Compose([ | |
ApplyTransformToKey(key = 'video', | |
transform = Compose([ | |
UniformTemporalSubsample(20), | |
Lambda(lambda x:x/255), | |
Normalize((0.45,0.45,0.45),(0.225,0.225,0.225)), | |
RandomShortSideScale(min_size = 248, max_size = 256), | |
CenterCropVideo(224), | |
RandomHorizontalFlip(p=0.5), | |
]), | |
), | |
]) | |
#============================================================ | |
class mymodel_test(nn.Module): | |
def __init__(self): | |
super(mymodel_test,self).__init__() | |
self.video_model = torch.hub.load('facebookresearch/pytorchvideo','efficient_x3d_xs', pretrained=False) | |
self.relu = nn.ReLU() | |
self.Linear = nn.Linear(400,1) | |
def forward(self,x): | |
x = self.relu(self.video_model(x)) | |
x = self.Linear(x) | |
return x | |
#============================================================ | |
model_test = mymodel_test() | |
model_test.load_state_dict(torch.load('model.pth')) | |
from pytorchvideo.data.encoded_video import EncodedVideo | |
def interface_video(video_path): | |
video = EncodedVideo.from_path(video_path) | |
video_data = video.get_clip(0,2) | |
video_data = video_transform(video_data) | |
video_data['video'].shape | |
model = model_test | |
inputs = video_data['video'] | |
inputs = torch.unsqueeze(inputs, 0 ) | |
inputs.shape | |
preds = model(inputs) | |
preds = preds.detach().cpu().numpy() | |
preds = np.where(preds>0.5,1,0) | |
if(preds==0): | |
return 'non violence' | |
else: | |
return 'violence' | |
# return preds | |
demo = gr.Blocks() | |
with demo: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_video = gr.Video(label='Input Video', height=360) | |
# input_video = load_video(input_video) | |
with gr.Row(): | |
submit_video_button = gr.Button('Submit') | |
with gr.Column(): | |
label_video = gr.Label() | |
submit_video_button.click(fn=iinterface_video, inputs=input_video, outputs=label_video) | |
demo.launch() |