live-train / app.py
pjxcharya's picture
Update app.py
6ad4f49 verified
raw
history blame
24.1 kB
import gradio as gr
import cv2
import numpy as np
import mediapipe as mp
import time
import traceback
# Import your exercise classes
from exercises.hammer_curl import HammerCurl
from exercises.push_up import PushUp
from exercises.squat import Squat
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
# --- State Variables ---
exercise_trackers = {
"Hammer Curl": HammerCurl(),
"Push Up": PushUp(),
"Squat": Squat()
}
current_exercise_tracker = None # Will be set when an exercise is selected and workout starts
selected_exercise_name = "Hammer Curl" # Default exercise
target_reps = 10
target_sets = 3
current_set_count = 1
workout_complete_message = ""
workout_active = False # New state variable to control active workout session
def update_targets_and_exercise_display(exercise_name_choice, reps_in, sets_in):
"""
Handles changes from exercise selection, target reps, or target sets.
This function is primarily for updating the system state when these controls change,
and preparing for a new workout session if one isn't active.
"""
global selected_exercise_name, target_reps, target_sets
global current_set_count, workout_complete_message, current_exercise_tracker, workout_active
# Update selected exercise
if selected_exercise_name != exercise_name_choice:
selected_exercise_name = exercise_name_choice
# Reset tracker and progress as exercise changed
if selected_exercise_name in exercise_trackers:
if selected_exercise_name == "Hammer Curl": exercise_trackers[selected_exercise_name] = HammerCurl()
elif selected_exercise_name == "Push Up": exercise_trackers[selected_exercise_name] = PushUp()
elif selected_exercise_name == "Squat": exercise_trackers[selected_exercise_name] = Squat()
current_exercise_tracker = exercise_trackers.get(selected_exercise_name) # Get the new tracker
current_set_count = 1
workout_complete_message = ""
workout_active = False # Changing exercise stops active workout
print(f"Exercise changed to: {selected_exercise_name}. Workout stopped and progress reset.")
# Update target reps
try:
new_reps = int(reps_in)
if new_reps > 0 and target_reps != new_reps:
target_reps = new_reps
current_set_count = 1 # Reset progress if targets change
workout_complete_message = ""
if current_exercise_tracker: current_exercise_tracker.reset_reps()
workout_active = False # Changing targets stops active workout
print(f"Target reps updated to: {target_reps}. Workout stopped and progress reset.")
except ValueError: pass
# Update target sets
try:
new_sets = int(sets_in)
if new_sets > 0 and target_sets != new_sets:
target_sets = new_sets
current_set_count = 1 # Reset progress if targets change
workout_complete_message = ""
if current_exercise_tracker: current_exercise_tracker.reset_reps()
workout_active = False # Changing targets stops active workout
print(f"Target sets updated to: {target_sets}. Workout stopped and progress reset.")
except ValueError: pass
# Determine initial display values
current_reps_val = 0
if current_exercise_tracker and hasattr(current_exercise_tracker, 'counter'): # For Pushup/Squat
current_reps_val = current_exercise_tracker.counter
elif current_exercise_tracker and hasattr(current_exercise_tracker, 'counter_right'): # For Hammer Curl
current_reps_val = current_exercise_tracker.counter_right
reps_disp = f"{current_reps_val}/{target_reps}"
if selected_exercise_name == "Hammer Curl":
r_c = current_exercise_tracker.counter_right if current_exercise_tracker else 0
l_c = current_exercise_tracker.counter_left if current_exercise_tracker else 0
reps_disp = f"R: {r_c}, L: {l_c} (Target: {target_reps} for R)"
return (selected_exercise_name,
reps_disp,
f"{current_set_count}/{target_sets}",
"N/A", # Angle
"Select exercise, set targets, then press 'Start Workout'.", # Feedback
workout_complete_message if workout_complete_message else ("Workout Not Active" if not workout_active else "")) # Workout Status
def trigger_start_workout():
global current_set_count, workout_complete_message, workout_active, selected_exercise_name, current_exercise_tracker, target_reps, target_sets
print("Start Workout button clicked.")
workout_active = True
current_set_count = 1
workout_complete_message = ""
current_exercise_tracker = exercise_trackers.get(selected_exercise_name) # Ensure it's the current one
if current_exercise_tracker:
current_exercise_tracker.reset_reps()
print(f"Tracker for {selected_exercise_name} reset.")
else:
# This case should ideally not happen if selected_exercise_name is always valid
print(f"Error: No tracker found for {selected_exercise_name} on start.")
# Initialize a new one just in case
if selected_exercise_name == "Hammer Curl": exercise_trackers[selected_exercise_name] = HammerCurl()
elif selected_exercise_name == "Push Up": exercise_trackers[selected_exercise_name] = PushUp()
elif selected_exercise_name == "Squat": exercise_trackers[selected_exercise_name] = Squat()
current_exercise_tracker = exercise_trackers.get(selected_exercise_name)
if current_exercise_tracker: current_exercise_tracker.reset_reps()
reps_disp = f"0/{target_reps}"
if selected_exercise_name == "Hammer Curl":
reps_disp = f"R: 0, L: 0 (Target: {target_reps} for R)"
return (selected_exercise_name,
reps_disp,
f"1/{target_sets}",
"N/A",
f"Workout Started: {selected_exercise_name}. Go!",
"Workout Active")
def trigger_stop_workout():
global workout_active
print("Stop Workout button clicked.")
workout_active = False
# Values to update UI components to reflect stopped state
# Reps/sets can remain as they were or be explicitly cleared for display
# For simplicity, let's just change the feedback and status
return ("Workout Stopped. Press Start to resume or change settings.",
"Workout Stopped")
def process_frame(video_frame_np): # Removed other inputs as they are handled by global state now
global current_exercise_tracker, selected_exercise_name, target_reps, target_sets
global current_set_count, workout_complete_message, workout_active
default_h, default_w = 480, 640
if video_frame_np is not None:
default_h, default_w, _ = video_frame_np.shape
annotated_image = video_frame_np.copy()
else:
blank_frame = np.zeros((default_h, default_w, 3), dtype=np.uint8)
cv2.putText(blank_frame, "No Camera Input", (50, default_h // 2), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2)
# Must return all 6 expected values for process_frame outputs
return blank_frame, f"0/{target_reps}", f"{current_set_count}/{target_sets}", "No frame", "No camera", "Error"
# Initialize display values
reps_display = f"0/{target_reps}"
if selected_exercise_name == "Hammer Curl" and current_exercise_tracker:
reps_display = f"R: {current_exercise_tracker.counter_right}, L: {current_exercise_tracker.counter_left} (Target: {target_reps} for R)"
elif current_exercise_tracker and hasattr(current_exercise_tracker, 'counter'):
reps_display = f"{current_exercise_tracker.counter}/{target_reps}"
sets_display = f"{current_set_count}/{target_sets}"
angle_display = "N/A"
feedback_display = "Waiting for workout to start..."
current_workout_status = "Workout Not Active"
if workout_active and not workout_complete_message and current_exercise_tracker:
feedback_display = "Processing..." # Default if active
current_workout_status = "Workout Active"
try:
image_rgb = cv2.cvtColor(video_frame_np, cv2.COLOR_BGR2RGB)
image_rgb.flags.writeable = False
results = pose.process(image_rgb)
image_rgb.flags.writeable = True
if results.pose_landmarks:
landmarks_mp = results.pose_landmarks.landmark
frame_height, frame_width, _ = annotated_image.shape
actual_reps_this_set = 0
try:
if selected_exercise_name == "Hammer Curl":
r_count, r_angle, l_count, l_angle, warn_r, warn_l, _, _, r_stage, l_stage = current_exercise_tracker.track_hammer_curl(landmarks_mp, annotated_image)
actual_reps_this_set = r_count
reps_display = f"R: {r_count}, L: {l_count} (Target: {target_reps} for R)"
angle_display = f"R Ang: {int(r_angle)}, L Ang: {int(l_angle)}"
feedback_list = []
if warn_r: feedback_list.append(f"R: {warn_r}")
if warn_l: feedback_list.append(f"L: {warn_l}")
feedback_display = " | ".join(feedback_list) if feedback_list else "Good form!"
elif selected_exercise_name == "Push Up":
exercise_data = current_exercise_tracker.track_push_up(landmarks_mp, frame_width, frame_height)
actual_reps_this_set = exercise_data.get("counter", 0)
angle_display = f"L: {int(exercise_data.get('angle_left',0))}, R: {int(exercise_data.get('angle_right',0))}"
feedback_display = str(exercise_data.get("feedback", "No feedback"))
if 'get_drawing_annotations' in dir(current_exercise_tracker):
annotations_to_draw = current_exercise_tracker.get_drawing_annotations(landmarks_mp, frame_width, frame_height, exercise_data)
for ann in annotations_to_draw:
if ann["type"] == "line": cv2.line(annotated_image, tuple(ann["start_point"]), tuple(ann["end_point"]), ann["color_bgr"], ann["thickness"])
elif ann["type"] == "circle": cv2.circle(annotated_image, tuple(ann["center_point"]), ann["radius"], ann["color_bgr"], -1 if ann.get("filled", False) else ann["thickness"])
elif ann["type"] == "text": cv2.putText(annotated_image, ann["text_content"], tuple(ann["position"]), cv2.FONT_HERSHEY_SIMPLEX, ann["font_scale"], ann["color_bgr"], ann["thickness"])
elif selected_exercise_name == "Squat":
exercise_data = current_exercise_tracker.track_squat(landmarks_mp, frame_width, frame_height)
actual_reps_this_set = exercise_data.get("counter", 0)
angle_display = f"L: {int(exercise_data.get('angle_left',0))}, R: {int(exercise_data.get('angle_right',0))}"
feedback_display = str(exercise_data.get("feedback", "No feedback"))
if 'get_drawing_annotations' in dir(current_exercise_tracker):
annotations_to_draw = current_exercise_tracker.get_drawing_annotations(landmarks_mp, frame_width, frame_height, exercise_data)
for ann in annotations_to_draw:
if ann["type"] == "line": cv2.line(annotated_image, tuple(ann["start_point"]), tuple(ann["end_point"]), ann["color_bgr"], ann["thickness"])
elif ann["type"] == "circle": cv2.circle(annotated_image, tuple(ann["center_point"]), ann["radius"], ann["color_bgr"], -1 if ann.get("filled", False) else ann["thickness"])
elif ann["type"] == "text": cv2.putText(annotated_image, ann["text_content"], tuple(ann["position"]), cv2.FONT_HERSHEY_SIMPLEX, ann["font_scale"], ann["color_bgr"], ann["thickness"])
if selected_exercise_name != "Hammer Curl":
reps_display = f"{actual_reps_this_set}/{target_reps}"
if actual_reps_this_set >= target_reps:
if current_set_count < target_sets:
current_set_count += 1
current_exercise_tracker.reset_reps()
feedback_display = f"Set {current_set_count-1} complete! Starting set {current_set_count}."
if selected_exercise_name == "Hammer Curl": reps_display = f"R: 0, L: 0 (Target: {target_reps} for R)"
else: reps_display = f"0/{target_reps}"
elif current_set_count >= target_sets:
feedback_display = "Workout Complete!"
workout_complete_message = "Workout Complete!" # No more auto-restart message
workout_active = False # Stop workout automatically
if selected_exercise_name == "Hammer Curl": reps_display = f"R: {target_reps}, L: {target_reps} (Target: {target_reps} for R)"
else: reps_display = f"{target_reps}/{target_reps}"
current_workout_status = workout_complete_message if workout_complete_message else "Workout Active"
except Exception as e_exercise:
print(f"PROCESS_FRAME: Error during exercise '{selected_exercise_name}' logic: {e_exercise}")
print(traceback.format_exc())
cv2.putText(annotated_image, f"Error in {selected_exercise_name}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
feedback_display = f"Error in {selected_exercise_name} processing."
else: # No landmarks detected
feedback_display = "No person detected. Adjust position."
# Keep drawing generic if no landmarks and workout is active
# mp_drawing.draw_landmarks(annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style())
elif workout_complete_message : # Workout is complete
feedback_display = workout_complete_message
current_workout_status = workout_complete_message
reps_display = f"{target_reps}/{target_reps}" if selected_exercise_name != "Hammer Curl" else f"R: {target_reps}, L: {target_reps} (Target: {target_reps} for R)"
sets_display = f"{target_sets}/{target_sets}"
if results and results.pose_landmarks: mp_drawing.draw_landmarks(annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style())
elif not workout_active:
feedback_display = "Workout stopped or not started. Press 'Start Workout'."
current_workout_status = "Workout Stopped / Not Started"
# Draw generic landmarks if pose was processed
if results and results.pose_landmarks :
mp_drawing.draw_landmarks(annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style())
if not isinstance(annotated_image, np.ndarray) or annotated_image.ndim != 3 or annotated_image.shape[2] != 3:
annotated_image = np.zeros((default_h, default_w, 3), dtype=np.uint8)
cv2.putText(annotated_image, "Display Error", (50, default_h // 2), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255),2)
return annotated_image, reps_display, sets_display, angle_display, feedback_display, current_workout_status
except Exception as e_main:
print(f"PROCESS_FRAME: CRITICAL error in process_frame: {e_main}")
print(traceback.format_exc())
error_frame = np.zeros((default_h, default_w, 3), dtype=np.uint8)
cv2.putText(error_frame, "Critical Error", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
return error_frame, "Error", "Error", "Error", "Critical Error", "Error"
# --- Custom CSS ---
custom_css = """
body, .gradio-container {
background: linear-gradient(to bottom right, #2A002A, #5D3FD3) !important; /* Darker Violet to a brighter Violet */
color: #E0E0E0 !important;
}
.gradio-container { font-family: 'Exo 2', sans-serif !important; }
label, .gr-checkbox-label span {
color: #D0D0D0 !important;
font-weight: bold !important;
}
h1, h3 { /* Targeting h1 and h3 for titles */
color: #FFFFFF !important;
text-align: center !important;
font-family: 'Exo 2', sans-serif !important;
}
.prose { color: #E8E8E8 !important; text-align: center !important; }
.gr-button { /* General button styling */
font-family: 'Exo 2', sans-serif !important;
border-radius: 8px !important;
font-weight: bold !important;
}
/* Specific styling for control panel sections - you might need to inspect element for exact classes */
.controls-section .gr-panel { /* Assuming you wrap sections in gr.Panel or gr.Group */
background-color: rgba(0,0,0,0.2) !important;
border-radius: 10px !important;
padding: 15px !important;
margin-bottom: 15px !important;
}
.status-text { /* Class for status text boxes if needed */
font-weight: bold !important;
}
"""
# --- Gradio Theme ---
theme = gr.themes.Base(
font=[gr.themes.GoogleFont("Exo 2"), "ui-sans-serif", "system-ui", "sans-serif"],
primary_hue=gr.themes.colors.purple, # Main accent color
secondary_hue=gr.themes.colors.pink,
neutral_hue=gr.themes.colors.slate
).set(
body_text_color="#E0E0E0",
input_background_fill="rgba(255,255,255,0.05)", # Slightly transparent white
input_border_color="rgba(255,255,255,0.2)",
input_text_color="#FFFFFF",
button_primary_background_fill=gr.themes.colors.purple[600], # Main button color
button_primary_background_fill_hover=gr.themes.colors.purple[500],
button_primary_text_color="#FFFFFF",
button_secondary_background_fill=gr.themes.colors.pink[600],
button_secondary_background_fill_hover=gr.themes.colors.pink[500],
button_secondary_text_color="#FFFFFF",
block_title_text_color = "#FFFFFF", # For titles of blocks/groups
block_label_text_color = "#E0E0E0", # For labels of blocks/groups
border_color_accent = gr.themes.colors.purple[400],
background_fill_primary = "#1E001E", # Very dark purple for main background if gradient doesn't fully take
background_fill_secondary = "#2A0A2A", # Slightly lighter for other areas
)
# --- Gradio Interface ---
exercise_choices_list = ["Hammer Curl", "Push Up", "Squat"]
with gr.Blocks(theme=theme, css=custom_css) as iface:
gr.Markdown("# LIVE TRAINING SESSION")
gr.Markdown("AI-powered exercise tracking and feedback")
# Hidden state for selected exercise name (updated by buttons)
# This allows process_frame to know the selection without direct input from dropdown
# However, for simplicity now, we will rely on the global selected_exercise_name
# exercise_name_state = gr.State(value="Hammer Curl")
with gr.Row(equal_height=False):
with gr.Column(scale=2): # Video feed
webcam_input = gr.Image(sources=["webcam"], streaming=True, type="numpy", label="Live Workout Feed")
with gr.Column(scale=1): # Controls and Status
with gr.Group(): # Using Group for card-like effect, can be styled with CSS if needed
gr.Markdown("### Select Exercise")
with gr.Row():
hc_btn = gr.Button("Hammer Curl")
pu_btn = gr.Button("Push Up")
sq_btn = gr.Button("Squat")
with gr.Group():
gr.Markdown("### Configure Workout")
with gr.Row():
target_sets_number = gr.Number(value=target_sets, label="Sets", precision=0, minimum=1, scale=1)
target_reps_number = gr.Number(value=target_reps, label="Reps", precision=0, minimum=1, scale=1)
with gr.Row():
start_button = gr.Button("Start Workout", variant="primary", scale=1) # Make it stand out
stop_button = gr.Button("Stop Workout", variant="stop", scale=1) # 'stop' variant for red-ish
with gr.Group():
gr.Markdown("### Current Status")
current_exercise_display = gr.Textbox(label="Exercise", value=selected_exercise_name, interactive=False)
sets_output = gr.Textbox(label="Set", interactive=False)
reps_output = gr.Textbox(label="Repetitions", interactive=False)
# angle_output = gr.Textbox(label="Angle Details", interactive=False) # Removed from UI as per image
feedback_output = gr.Textbox(label="Feedback", lines=3, max_lines=5, interactive=False)
workout_status_output = gr.Textbox(label="Workout Status", interactive=False)
# --- Define component interactions ---
# Outputs that are updated by multiple actions
shared_outputs = [current_exercise_display, reps_output, sets_output, feedback_output, workout_status_output] # Removed angle_output from display
# Outputs from process_frame (includes image + text outputs for status)
# Note: angle_display is calculated in process_frame but not shown in this UI version
process_frame_outputs = [webcam_input, reps_output, sets_output, angle_output, feedback_output, workout_status_output]
# Handler for changing targets or initial setup
# This function now just returns the values for shared_outputs
def handle_config_change_and_select(exercise_name, reps, sets):
sel_ex, r_disp, s_disp, _, f_disp, w_stat = update_targets_and_exercise_display(exercise_name, reps, sets)
return sel_ex, r_disp, s_disp, f_disp, w_stat # Matches shared_outputs
# Exercise selection buttons
hc_btn.click(lambda r=target_reps_number, s=target_sets_number: handle_config_change_and_select("Hammer Curl", r,s), inputs=[target_reps_number, target_sets_number], outputs=shared_outputs)
pu_btn.click(lambda r=target_reps_number, s=target_sets_number: handle_config_change_and_select("Push Up",r,s), inputs=[target_reps_number, target_sets_number], outputs=shared_outputs)
sq_btn.click(lambda r=target_reps_number, s=target_sets_number: handle_config_change_and_select("Squat",r,s), inputs=[target_reps_number, target_sets_number], outputs=shared_outputs)
# Target number changes
target_reps_number.change(lambda ex=selected_exercise_name, r=target_reps_number, s=target_sets_number: handle_config_change_and_select(ex, r, s), inputs=[selected_exercise_name, target_reps_number, target_sets_number], outputs=shared_outputs)
target_sets_number.change(lambda ex=selected_exercise_name, r=target_reps_number, s=target_sets_number: handle_config_change_and_select(ex, r, s), inputs=[selected_exercise_name, target_reps_number, target_sets_number], outputs=shared_outputs)
# Start and Stop buttons
start_button.click(trigger_start_workout, inputs=None, outputs=shared_outputs) # Updates text fields
# Stop button only needs to update feedback and status_output
stop_button.click(trigger_stop_workout, inputs=None, outputs=[feedback_output, workout_status_output])
# Video stream processing
# process_frame only takes webcam_input directly. Other states are global.
# Its outputs now align with process_frame_outputs defined earlier.
webcam_input.stream(fn=process_frame, inputs=[webcam_input], outputs=process_frame_outputs)
if __name__ == "__main__":
iface.launch(debug=False, share=False)