import torch import math import cv2 import json import time from PIL import Image import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer,AutoModelForCausalLM import clip import numpy as np from tqdm import tqdm import os from dotenv import load_dotenv from IPython.display import Audio import re from groq import Groq from moviepy.editor import VideoFileClip, AudioFileClip,CompositeAudioClip from pydub import AudioSegment import shutil import gradio as gr from huggingface_hub import hf_hub_download from TTS.api import TTS groq_key = os.environ["GROQ_API_KEY"] class TemporalTransformerEncoder(nn.Module): def __init__(self, embed_dim, num_heads, num_layers, num_frames, dropout=0.1): super().__init__() self.embed_dim = embed_dim self.num_frames = num_frames self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.cls_token, std=0.02) self.position_embed = nn.Parameter(torch.zeros(1, num_frames + 1, embed_dim)) nn.init.trunc_normal_(self.position_embed, std=0.02) encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=4 * embed_dim, dropout=dropout, activation='gelu', batch_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) def forward(self, x): B = x.size(0) cls_token = self.cls_token.expand(B, 1, -1) x = torch.cat([cls_token, x], dim=1) x = x + self.position_embed[:, :x.size(1)] x = self.transformer(x) return { "cls": x[:, 0], "tokens": x[:, 1:] } class CricketCommentator(nn.Module): def __init__(self, train_mode=False, num_frames=16, train_layers=2): super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.num_frames = num_frames import clip self.clip, self.preprocess = clip.load("ViT-B/32", device=self.device) self.clip = self.clip.float() if train_mode: for param in self.clip.parameters(): param.requires_grad = False self.temporal_encoder = TemporalTransformerEncoder( embed_dim=512, num_heads=8, num_layers=3, num_frames=num_frames, dropout=0.1 ).to(self.device).float() # Updated projection for DeepSeek (2048-dim) self.projection = nn.Sequential( nn.Linear(512, 2048), nn.GELU(), nn.LayerNorm(2048), nn.Dropout(0.1), nn.Linear(2048, 2048), nn.Tanh() ).to(self.device).float() # DeepSeek model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct") self.model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct").to(self.device).float() self.tokenizer.pad_token = self.tokenizer.eos_token # Freeze all parameters initially for param in self.model.parameters(): param.requires_grad = False # Unfreeze last N layers if training if train_mode and train_layers > 0: # Unfreeze last transformer blocks for block in self.model.model.layers[-train_layers:]: for param in block.parameters(): param.requires_grad = True # Unfreeze final norm and head for param in self.model.model.norm.parameters(): param.requires_grad = True for param in self.model.lm_head.parameters(): param.requires_grad = True def forward(self, frames): batch_size = frames.shape[0] frames = frames.view(-1, 3, 224, 224) with torch.no_grad(): frame_features = self.clip.encode_image(frames.to(self.device)) frame_features = frame_features.view(batch_size, self.num_frames, -1).float() frame_features = F.normalize(frame_features, p=2, dim=-1) temporal_out = self.temporal_encoder(frame_features) visual_embeds = self.projection(temporal_out["cls"]) return F.normalize(visual_embeds, p=2, dim=-1).unsqueeze(1) def extract_frames(self, video_path): cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) stride = max(1, total_frames // self.num_frames) frames = [] for i in range(0, total_frames, stride): cap.set(cv2.CAP_PROP_POS_FRAMES, i) ret, frame = cap.read() if ret: h, w, _ = frame.shape crop_size = min(h, w) // 2 y, x = (h - crop_size) // 2, (w - crop_size) // 2 cropped = cv2.cvtColor(frame[y:y+crop_size, x:x+crop_size], cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(cropped) frames.append(self.preprocess(pil_image)) if len(frames) >= self.num_frames: break else: break cap.release() if len(frames) < self.num_frames: frames.extend([torch.zeros(3, 224, 224)] * (self.num_frames - len(frames))) return torch.stack(frames) def generate_commentary(self, video_path): frames = self.extract_frames(video_path).unsqueeze(0).to(self.device) visual_embeds = self.forward(frames) # Shape: [1, 1, 2560] # Prepare text prompt prompt = ("USER: