|
import spaces |
|
import torch |
|
import argparse |
|
import os |
|
import sys |
|
import pickle |
|
import gc |
|
import tempfile |
|
import subprocess |
|
import logging |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import List, Dict, Any, Optional, Tuple |
|
from dataclasses import dataclass |
|
from tenacity import retry, stop_after_attempt, wait_exponential |
|
import numpy as np |
|
from tqdm.auto import tqdm |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('app.log'), |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class ModelConfig: |
|
name: str = 'iic/mPLUG-Owl3-7B-240728' |
|
cache_dir: str = '/data/models' |
|
max_num_frames: int = 32 |
|
device: str = 'cuda' |
|
torch_dtype: torch.dtype = torch.half |
|
attn_implementation: str = 'sdpa' |
|
|
|
@dataclass |
|
class AzureConfig: |
|
container_name: str = "logs" |
|
connection_string: str = "BlobEndpoint=https://assentian.blob.core.windows.net/;QueueEndpoint=https://assentian.queue.core.windows.net/;FileEndpoint=https://assentian.file.core.windows.net/;TableEndpoint=https://assentian.table.core.windows.net/;SharedAccessSignature=sv=2024-11-04&ss=bfqt&srt=sco&sp=rwdlacupiytfx&se=2025-04-30T17:16:18Z&st=2025-04-22T09:16:18Z&spr=https&sig=AkJb79C%2FJ0G1HqfotIYuSfm%2Fb%2BQ2E%2FjvxV3ZG7ejVQo%3D" |
|
|
|
class ResourceManager: |
|
def __init__(self, device: str = 'cuda'): |
|
self.device = device |
|
|
|
def setup_gpu(self): |
|
if self.device == 'cuda': |
|
|
|
torch.cuda.set_per_process_memory_fraction(1.) |
|
self.initial_mem = torch.cuda.memory_allocated() |
|
logger.info(f"Initial GPU memory: {self.initial_mem / 1024**2:.2f} MB") |
|
|
|
def cleanup(self): |
|
if self.device == 'cuda': |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
final_mem = torch.cuda.memory_allocated() |
|
logger.info(f"Final GPU memory: {final_mem / 1024**2:.2f} MB") |
|
|
|
def get_batch_size(self) -> int: |
|
if self.device == 'cuda': |
|
available_memory = torch.cuda.get_device_properties(0).total_memory |
|
return min(32, int(available_memory / (1024**3) * 4)) |
|
return 32 |
|
|
|
from azure.storage.blob import BlobServiceClient |
|
from azure.core.exceptions import AzureError |
|
from typing import List, Optional |
|
|
|
class AzureStorageManager: |
|
def __init__(self, config: AzureConfig): |
|
self.config = config |
|
self._init_blob_client() |
|
|
|
def _init_blob_client(self) -> None: |
|
"""Initialize the blob service client with proper error handling.""" |
|
try: |
|
self.blob_service_client = BlobServiceClient.from_connection_string( |
|
conn_str=self.config.connection_string |
|
) |
|
|
|
self.blob_service_client.get_service_properties() |
|
logger.info("Successfully initialized Azure Blob Storage client") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize Azure Blob Storage client: {str(e)}") |
|
raise |
|
|
|
def list_blobs(self) -> List[str]: |
|
"""List video blobs in the specified container with proper error handling.""" |
|
try: |
|
|
|
container_client = self.blob_service_client.get_container_client(self.config.container_name) |
|
|
|
|
|
video_extensions = ['.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'] |
|
blob_list = [] |
|
|
|
try: |
|
|
|
blobs = container_client.list_blobs() |
|
blob_list = [ |
|
blob.name for blob in blobs |
|
if any(blob.name.lower().endswith(ext) for ext in video_extensions) |
|
] |
|
except AzureError as ae: |
|
logger.error(f"Azure error while listing blobs: {str(ae)}") |
|
raise |
|
|
|
logger.info(f"Successfully found {len(blob_list)} video blobs") |
|
return blob_list |
|
|
|
except Exception as e: |
|
logger.error(f"Error listing blobs: {str(e)}") |
|
return [] |
|
|
|
def download_blob(self, blob_name: str) -> Optional[str]: |
|
"""Download a blob to a temporary file.""" |
|
try: |
|
|
|
blob_client = self.blob_service_client.get_container_client( |
|
self.config.container_name |
|
).get_blob_client(blob_name) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(blob_name)[1]) as temp_file: |
|
temp_path = temp_file.name |
|
|
|
blob_data = blob_client.download_blob() |
|
blob_data.readinto(temp_file) |
|
|
|
logger.info(f"Successfully downloaded blob {blob_name} to {temp_path}") |
|
return temp_path |
|
|
|
except Exception as e: |
|
logger.error(f"Error downloading blob {blob_name}: {str(e)}") |
|
return None |
|
|
|
class ModelManager: |
|
def __init__(self, config: ModelConfig): |
|
self.config = config |
|
self.resource_manager = ResourceManager(config.device) |
|
self.setup_model() |
|
|
|
def setup_model(self): |
|
"""Initialize model, tokenizer, and processor with proper error handling.""" |
|
try: |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
self.resource_manager.setup_gpu() |
|
os.makedirs(self.config.cache_dir, exist_ok=True) |
|
|
|
|
|
try: |
|
from modelscope.hub.snapshot_download import snapshot_download |
|
model_path = snapshot_download(self.config.name, cache_dir=self.config.cache_dir) |
|
except Exception as e: |
|
logger.warning(f"Error downloading model via snapshot_download: {str(e)}") |
|
model_path = os.path.join(self.config.cache_dir, self.config.name) |
|
|
|
|
|
self.model = AutoModel.from_pretrained( |
|
model_path, |
|
attn_implementation=self.config.attn_implementation, |
|
trust_remote_code=True, |
|
torch_dtype=self.config.torch_dtype, |
|
device_map='auto', |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
self.processor = self.model.init_processor(self.tokenizer) |
|
|
|
self.model.eval() |
|
logger.info("Successfully initialized model, tokenizer, and processor") |
|
|
|
except Exception as e: |
|
logger.error(f"Error setting up model: {str(e)}") |
|
raise |
|
|
|
def cleanup(self): |
|
"""Clean up resources.""" |
|
self.resource_manager.cleanup() |
|
|
|
from PIL import Image |
|
from decord import VideoReader, cpu |
|
import cv2 |
|
from ultralytics import YOLO |
|
from dataclasses import dataclass |
|
from typing import List, Dict, Tuple, Optional, Union |
|
import numpy as np |
|
|
|
@dataclass |
|
class VideoConfig: |
|
image_extensions: set = frozenset({'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}) |
|
video_extensions: set = frozenset({'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}) |
|
yolo_model_path: str = './best_yolov11.pt' |
|
sample_fps: int = 1 |
|
|
|
class MediaProcessor: |
|
def __init__(self, config: VideoConfig): |
|
self.config = config |
|
self.yolo_model = self._load_yolo_model() |
|
logger.info("Initialized MediaProcessor") |
|
|
|
def _load_yolo_model(self) -> YOLO: |
|
"""Load YOLO model with error handling.""" |
|
try: |
|
model = YOLO(self.config.yolo_model_path) |
|
logger.info("Successfully loaded YOLO model") |
|
return model |
|
except Exception as e: |
|
logger.error(f"Error loading YOLO model: {str(e)}") |
|
raise |
|
|
|
@staticmethod |
|
def get_file_extension(filename: str) -> str: |
|
return os.path.splitext(filename)[1].lower() |
|
|
|
def is_image(self, filename: str) -> bool: |
|
return self.get_file_extension(filename) in self.config.image_extensions |
|
|
|
def is_video(self, filename: str) -> bool: |
|
return self.get_file_extension(filename) in self.config.video_extensions |
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) |
|
def encode_video_in_chunks(self, video_path: str, max_frames: int = 32) -> List[Tuple[int, List[Image.Image]]]: |
|
"""Extract frames from a video in chunks with retry mechanism.""" |
|
try: |
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
sample_fps = round(vr.get_avg_fps() / self.config.sample_fps) |
|
frame_idx = [i for i in range(0, len(vr), sample_fps)] |
|
chunks = [frame_idx[i:i + max_frames] for i in range(0, len(frame_idx), max_frames)] |
|
|
|
processed_chunks = [] |
|
for chunk_idx, chunk in enumerate(chunks): |
|
try: |
|
frames = vr.get_batch(chunk).asnumpy() |
|
frames = [Image.fromarray(v.astype('uint8')) for v in frames] |
|
processed_chunks.append((chunk_idx, frames)) |
|
logger.debug(f"Processed chunk {chunk_idx} with {len(frames)} frames") |
|
except Exception as e: |
|
logger.error(f"Error processing chunk {chunk_idx}: {str(e)}") |
|
continue |
|
|
|
return processed_chunks |
|
except Exception as e: |
|
logger.error(f"Error encoding video: {str(e)}") |
|
raise |
|
|
|
def process_yolo_results(self, results) -> Tuple[int, int, Dict[str, int]]: |
|
"""Process YOLO detection results and count people and machinery.""" |
|
try: |
|
people_count = 0 |
|
machine_types = { |
|
"Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, |
|
"Bulldozer": 0, "Excavator": 0, "Dump Truck": 0, |
|
"Concrete Mixer": 0, "Loader": 0, "Pump Truck": 0, |
|
"Pile Driver": 0, "Grader": 0, "Other Vehicle": 0 |
|
} |
|
|
|
for r in results: |
|
boxes = r.boxes |
|
for box in boxes: |
|
cls = int(box.cls[0]) |
|
conf = float(box.conf[0]) |
|
if conf < 0.5: |
|
continue |
|
|
|
class_name = self.yolo_model.names[cls] |
|
if class_name.lower() == 'worker': |
|
people_count += 1 |
|
|
|
machinery_mapping = { |
|
'tower_crane': "Tower Crane", |
|
'mobile_crane': "Mobile Crane", |
|
'compactor': "Compactor/Roller", |
|
'roller': "Compactor/Roller", |
|
'bulldozer': "Bulldozer", |
|
'dozer': "Bulldozer", |
|
'excavator': "Excavator", |
|
'dump_truck': "Dump Truck", |
|
'truck': "Dump Truck", |
|
'concrete_mixer_truck': "Concrete Mixer", |
|
'loader': "Loader", |
|
'pump_truck': "Pump Truck", |
|
'pile_driver': "Pile Driver", |
|
'grader': "Grader", |
|
'other_vehicle': "Other Vehicle" |
|
} |
|
|
|
class_lower = class_name.lower() |
|
for key, value in machinery_mapping.items(): |
|
if key in class_lower: |
|
machine_types[value] += 1 |
|
break |
|
|
|
total_machinery = sum(machine_types.values()) |
|
return people_count, total_machinery, machine_types |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing YOLO results: {str(e)}") |
|
return 0, 0, {} |
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) |
|
def detect_people_and_machinery(self, media_path: Union[str, Image.Image]) -> Tuple[int, int, Dict[str, int]]: |
|
"""Detect people and machinery using YOLOv11 for both images and videos.""" |
|
try: |
|
max_people_count = 0 |
|
max_machine_types = { |
|
"Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, |
|
"Bulldozer": 0, "Excavator": 0, "Dump Truck": 0, |
|
"Concrete Mixer": 0, "Loader": 0, "Pump Truck": 0, |
|
"Pile Driver": 0, "Grader": 0, "Other Vehicle": 0 |
|
} |
|
|
|
if isinstance(media_path, str) and self.is_video(media_path): |
|
cap = cv2.VideoCapture(media_path) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
sample_rate = max(1, int(fps)) |
|
frame_count = 0 |
|
|
|
with tqdm(desc="Processing video frames", unit="frame") as pbar: |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame_count % sample_rate == 0: |
|
results = self.yolo_model(frame) |
|
people, _, machine_types = self.process_yolo_results(results) |
|
max_people_count = max(max_people_count, people) |
|
for k, v in machine_types.items(): |
|
max_machine_types[k] = max(max_machine_types[k], v) |
|
|
|
frame_count += 1 |
|
pbar.update(1) |
|
|
|
cap.release() |
|
|
|
else: |
|
if isinstance(media_path, str): |
|
img = cv2.imread(media_path) |
|
else: |
|
img = cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR) |
|
|
|
results = self.yolo_model(img) |
|
max_people_count, _, max_machine_types = self.process_yolo_results(results) |
|
|
|
|
|
max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0} |
|
total_machinery_count = sum(max_machine_types.values()) |
|
|
|
logger.info(f"Detection complete - People: {max_people_count}, Machinery: {total_machinery_count}") |
|
return max_people_count, total_machinery_count, max_machine_types |
|
|
|
except Exception as e: |
|
logger.error(f"Error in detection: {str(e)}") |
|
return 0, 0, {} |
|
|
|
def annotate_video_with_bboxes(self, video_path: str) -> str: |
|
"""Annotate video with bounding boxes and detection summaries.""" |
|
try: |
|
cap = cv2.VideoCapture(video_path) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) |
|
annotated_video_path = out_file.name |
|
out_file.close() |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
writer = cv2.VideoWriter(annotated_video_path, fourcc, fps, (w, h)) |
|
|
|
with tqdm(desc="Annotating video", unit="frame") as pbar: |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
results = self.yolo_model(frame) |
|
frame_counts = {} |
|
|
|
for r in results: |
|
boxes = r.boxes |
|
for box in boxes: |
|
cls_id = int(box.cls[0]) |
|
conf = float(box.conf[0]) |
|
if conf < 0.5: |
|
continue |
|
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
class_name = self.yolo_model.names[cls_id] |
|
|
|
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
|
|
|
|
label_text = f"{class_name} {conf:.2f}" |
|
cv2.putText(frame, label_text, (x1, y1 - 6), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1) |
|
|
|
frame_counts[class_name] = frame_counts.get(class_name, 0) + 1 |
|
|
|
|
|
summary_str = ", ".join(f"{cls_name}: {count}" |
|
for cls_name, count in frame_counts.items()) |
|
cv2.putText(frame, summary_str, (15, 30), |
|
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2) |
|
|
|
writer.write(frame) |
|
pbar.update(1) |
|
|
|
cap.release() |
|
writer.release() |
|
|
|
logger.info(f"Video annotation complete: {annotated_video_path}") |
|
return annotated_video_path |
|
|
|
except Exception as e: |
|
logger.error(f"Error annotating video: {str(e)}") |
|
raise |
|
|
|
import gradio as gr |
|
|
|
|
|
class InferenceEngine: |
|
def __init__(self, model_manager: ModelManager, media_processor: MediaProcessor): |
|
self.model_manager = model_manager |
|
self.media_processor = media_processor |
|
|
|
def process_image(self, image_path: str, prompt: str) -> str: |
|
try: |
|
image = Image.open(image_path) |
|
messages = [{ |
|
"role": "user", |
|
"content": prompt, |
|
"images": [image] |
|
}] |
|
model_messages = [] |
|
images = [] |
|
for msg in messages: |
|
content_str = msg["content"] |
|
if "images" in msg and msg["images"]: |
|
content_str += "<|image|>" |
|
images.extend(msg["images"]) |
|
model_messages.append({"role": msg["role"], "content": content_str}) |
|
model_messages.append({"role": "assistant", "content": ""}) |
|
inputs = self.model_manager.processor(model_messages, images=images, videos=None) |
|
inputs.to(self.model_manager.config.device) |
|
inputs.update({ |
|
'tokenizer': self.model_manager.tokenizer, |
|
'max_new_tokens': 100, |
|
'decode_text': True, |
|
}) |
|
response = self.model_manager.model.generate(**inputs) |
|
del inputs |
|
return response[0] |
|
except Exception as e: |
|
logger.error(f"Error processing image: {str(e)}") |
|
return "Error processing image" |
|
|
|
def analyze_image_activities(self, image_path: str) -> str: |
|
prompt = ( |
|
"Analyze this construction site image and describe the activities happening. " |
|
"Focus on construction activities, machinery usage, and worker actions." |
|
) |
|
return self.process_image(image_path, prompt) |
|
|
|
def process_video_chunk(self, video_frames, prompt: str) -> str: |
|
messages = [{ |
|
"role": "user", |
|
"content": prompt, |
|
"video_frames": video_frames |
|
}] |
|
model_messages = [] |
|
videos = [] |
|
for msg in messages: |
|
content_str = msg["content"] |
|
if "video_frames" in msg and msg["video_frames"]: |
|
content_str += "<|video|>" |
|
videos.append(msg["video_frames"]) |
|
model_messages.append({"role": msg["role"], "content": content_str}) |
|
model_messages.append({"role": "assistant", "content": ""}) |
|
inputs = self.model_manager.processor( |
|
model_messages, |
|
images=None, |
|
videos=videos if videos else None |
|
) |
|
inputs.to(self.model_manager.config.device) |
|
inputs.update({ |
|
'tokenizer': self.model_manager.tokenizer, |
|
'max_new_tokens': 100, |
|
'decode_text': True, |
|
}) |
|
response = self.model_manager.model.generate(**inputs) |
|
del inputs |
|
return response[0] |
|
|
|
def analyze_video_activities(self, video_path: str) -> str: |
|
all_responses = [] |
|
chunk_generator = self.media_processor.encode_video_in_chunks( |
|
video_path, max_frames=self.model_manager.config.max_num_frames |
|
) |
|
for chunk_idx, video_frames in chunk_generator: |
|
prompt = ( |
|
"Analyze this construction site video chunk and describe the activities happening. " |
|
"Focus on construction activities, machinery usage, and worker actions." |
|
) |
|
response = self.process_video_chunk(video_frames, prompt) |
|
all_responses.append(f"Time period {chunk_idx + 1}:\n{response}") |
|
return "\n\n".join(all_responses) |
|
|
|
|
|
def build_gradio_interface( |
|
model_manager: ModelManager, |
|
azure_manager: AzureStorageManager, |
|
media_processor: MediaProcessor, |
|
inference_engine: InferenceEngine |
|
): |
|
blob_names = azure_manager.list_blobs() |
|
print("Blob names:", blob_names) |
|
@spaces.GPU |
|
def process_diary( |
|
day, date, total_people, total_machinery, machinery_types, activities, |
|
media_source, local_file, azure_blob |
|
): |
|
media_path = None |
|
try: |
|
if media_source == "Local File": |
|
if local_file is None: |
|
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None] |
|
media_path = local_file |
|
else: |
|
if not azure_blob: |
|
return [day, date, "No blob selected", "No blob selected", "No blob selected", "No blob selected", None] |
|
media_path = azure_manager.download_blob(azure_blob) |
|
|
|
file_ext = media_processor.get_file_extension(media_path) |
|
if not (media_processor.is_image(media_path) or media_processor.is_video(media_path)): |
|
raise ValueError(f"Unsupported file type: {file_ext}") |
|
|
|
detected_people, detected_machinery, detected_machinery_types = media_processor.detect_people_and_machinery(media_path) |
|
logger.info(f"Detected people: {detected_people}, Detected machinery: {detected_machinery}, Machinery types: {detected_machinery_types}") |
|
|
|
annotated_video_path = None |
|
if media_processor.is_image(media_path): |
|
detected_activities = inference_engine.analyze_image_activities(media_path) |
|
else: |
|
detected_activities = inference_engine.analyze_video_activities(media_path) |
|
annotated_video_path = media_processor.annotate_video_with_bboxes(media_path) |
|
|
|
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()]) |
|
return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, annotated_video_path] |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing media: {str(e)}") |
|
return [day, date, "Error processing media", "Error processing media", "Error processing media", "Error processing media", None] |
|
finally: |
|
if media_source == "Azure Blob" and media_path and os.path.exists(media_path): |
|
try: |
|
os.remove(media_path) |
|
logger.info(f"Removed temporary file: {media_path}") |
|
except Exception as e: |
|
logger.error(f"Error removing temporary file: {str(e)}") |
|
|
|
with gr.Blocks(title="Digital Site Diary") as demo: |
|
gr.Markdown("# 📝 Digital Site Diary") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### User Input") |
|
day = gr.Textbox(label="Day", value='9') |
|
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d")) |
|
total_people = gr.Number(label="Total Number of People", precision=0, value=10) |
|
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3) |
|
machinery_types = gr.Textbox(label="Number of Machinery Per Type", placeholder="e.g., Excavator: 2, Roller: 1", value="Excavator: 2, Roller: 1") |
|
activities = gr.Textbox(label="Activity", placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting", value="9 AM: Excavation, 10 AM: Concreting", lines=3) |
|
media_source = gr.Radio(["Local File", "Azure Blob"], label="Media Source", value="Local File") |
|
local_file = gr.File(label="Upload Image/Video", file_types=["image", "video"], visible=True) |
|
azure_blob = gr.Dropdown(label="Select Video from Azure", choices=blob_names, visible=False) |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
with gr.Column(): |
|
gr.Markdown("### Model Detection") |
|
model_day = gr.Textbox(label="Day") |
|
model_date = gr.Textbox(label="Date") |
|
model_people = gr.Textbox(label="Total Number of People") |
|
model_machinery = gr.Textbox(label="Total Executable") |
|
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type") |
|
model_activities = gr.Textbox(label="Activity", lines=5) |
|
model_annotated_video = gr.Video(label="Annotated Video") |
|
|
|
|
|
def update_blob_dropdown(source): |
|
if source == "Azure Blob": |
|
|
|
new_blob_names = azure_manager.list_blobs() |
|
return gr.update(visible=False), gr.update(visible=True, choices=new_blob_names) |
|
else: |
|
return gr.update(visible=True), gr.update(visible=False) |
|
|
|
media_source.change( |
|
fn=update_blob_dropdown, |
|
inputs=media_source, |
|
outputs=[local_file, azure_blob] |
|
) |
|
|
|
submit_btn.click( |
|
fn=process_diary, |
|
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media_source, local_file, azure_blob], |
|
outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, model_annotated_video] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
model_config = ModelConfig() |
|
azure_config = AzureConfig() |
|
video_config = VideoConfig() |
|
|
|
model_manager = ModelManager(model_config) |
|
azure_manager = AzureStorageManager(azure_config) |
|
media_processor = MediaProcessor(video_config) |
|
inference_engine = InferenceEngine(model_manager, media_processor) |
|
|
|
demo = build_gradio_interface(model_manager, azure_manager, media_processor, inference_engine) |
|
parser = argparse.ArgumentParser(description='Digital Site Diary') |
|
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps') |
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
parser.add_argument("--port", type=int, default=7860) |
|
args = parser.parse_args() |
|
demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host) |