Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -108,31 +108,88 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm | |
| 108 | 
             
                    plt.tight_layout()  # Ensure the entire plot fits into the figure area
         | 
| 109 | 
             
                    #ax.set_facecolor('#D3D3D3')
         | 
| 110 | 
             
                elif output_extension.lower() in vid_extensions:
         | 
| 111 | 
            -
                    output_video = output_path  # Load the video file here
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 112 | 
             
                    output_image = None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 113 | 
             
                    plt.style.use("ggplot")
         | 
| 114 | 
             
                    fig, ax = plt.subplots(figsize=(10, 6))
         | 
| 115 | 
            -
                     | 
| 116 | 
            -
             | 
| 117 | 
            -
                    sns.lineplot(ax | 
| 118 | 
            -
             | 
| 119 | 
             
                    ax.set_title('Number of Objects over Seconds', fontsize=20, pad=20)  # Increase padding for the title
         | 
| 120 | 
             
                    ax.set_xlabel('Second', fontsize=16)  # Increase font size
         | 
| 121 | 
             
                    ax.set_ylabel('Object Count', fontsize=16)  # Increase font size
         | 
| 122 | 
             
                    ax.tick_params(axis='x', labelsize=12)  # Increase label size for x-axis
         | 
| 123 | 
             
                    ax.tick_params(axis='y', labelsize=12)  # Increase label size for y-axis
         | 
| 124 | 
            -
             | 
| 125 | 
             
                    # Add grid but make it lighter and put it behind bars
         | 
| 126 | 
             
                    ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6)
         | 
| 127 | 
             
                    ax.set_axisbelow(True)
         | 
| 128 | 
            -
             | 
| 129 | 
             
                    # Change the background color to a lighter shade
         | 
| 130 | 
             
                    ax.set_facecolor('#F0F0F0')
         | 
| 131 | 
            -
             | 
| 132 | 
             
                    # Add a legend with a smaller font size
         | 
| 133 | 
             
                    ax.legend(fontsize=10)
         | 
|  | |
|  | |
| 134 |  | 
| 135 | 
            -
                    plt.tight_layout()  # Ensure the entire
         | 
| 136 | 
             
                return output_image, output_video, fig
         | 
| 137 |  | 
| 138 |  | 
|  | |
| 108 | 
             
                    plt.tight_layout()  # Ensure the entire plot fits into the figure area
         | 
| 109 | 
             
                    #ax.set_facecolor('#D3D3D3')
         | 
| 110 | 
             
                elif output_extension.lower() in vid_extensions:
         | 
| 111 | 
            +
                    # output_video = output_path  # Load the video file here
         | 
| 112 | 
            +
                    # output_image = None
         | 
| 113 | 
            +
                    # plt.style.use("ggplot")
         | 
| 114 | 
            +
                    # fig, ax = plt.subplots(figsize=(10, 6))
         | 
| 115 | 
            +
                    # #for label in labels:
         | 
| 116 | 
            +
                    #     #df_label = frame_counts_df[frame_counts_df['label'] == label]
         | 
| 117 | 
            +
                    # sns.lineplot(ax = ax, data = frame_counts_df,  x = 'frame', y = 'count', hue = 'label', palette=palette,linewidth=2.5)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # ax.set_title('Number of Objects over Seconds', fontsize=20, pad=20)  # Increase padding for the title
         | 
| 120 | 
            +
                    # ax.set_xlabel('Second', fontsize=16)  # Increase font size
         | 
| 121 | 
            +
                    # ax.set_ylabel('Object Count', fontsize=16)  # Increase font size
         | 
| 122 | 
            +
                    # ax.tick_params(axis='x', labelsize=12)  # Increase label size for x-axis
         | 
| 123 | 
            +
                    # ax.tick_params(axis='y', labelsize=12)  # Increase label size for y-axis
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # # Add grid but make it lighter and put it behind bars
         | 
| 126 | 
            +
                    # ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6)
         | 
| 127 | 
            +
                    # ax.set_axisbelow(True)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # # Change the background color to a lighter shade
         | 
| 130 | 
            +
                    # ax.set_facecolor('#F0F0F0')
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # # Add a legend with a smaller font size
         | 
| 133 | 
            +
                    # ax.legend(fontsize=10)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # plt.tight_layout()  # Ensure the entire
         | 
| 136 | 
            +
                    output_video = output_path
         | 
| 137 | 
             
                    output_image = None
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                    # Interpolation preprocessing
         | 
| 140 | 
            +
                    interpolated_data = []
         | 
| 141 | 
            +
                    
         | 
| 142 | 
            +
                    labels = frame_counts_df['label'].unique()
         | 
| 143 | 
            +
                    for label in labels:
         | 
| 144 | 
            +
                        df_label = frame_counts_df[frame_counts_df['label'] == label]
         | 
| 145 | 
            +
                        
         | 
| 146 | 
            +
                        # Sort data by frame to ensure smooth interpolation
         | 
| 147 | 
            +
                        df_label = df_label.sort_values('frame')
         | 
| 148 | 
            +
                    
         | 
| 149 | 
            +
                        # Original data points
         | 
| 150 | 
            +
                        x = df_label['frame']
         | 
| 151 | 
            +
                        y = df_label['count']
         | 
| 152 | 
            +
                    
         | 
| 153 | 
            +
                        # Check if we have enough points for interpolation
         | 
| 154 | 
            +
                        if len(x) > 1:
         | 
| 155 | 
            +
                            # Create spline interpolation
         | 
| 156 | 
            +
                            x_smooth = np.linspace(x.min(), x.max(), 500)
         | 
| 157 | 
            +
                            spline = make_interp_spline(x, y, k=3)  # Cubic spline interpolation
         | 
| 158 | 
            +
                            y_smooth = spline(x_smooth)
         | 
| 159 | 
            +
                    
         | 
| 160 | 
            +
                            # Append the smoothed data to the list
         | 
| 161 | 
            +
                            interpolated_data.append(pd.DataFrame({'frame': x_smooth, 'count': y_smooth, 'label': label}))
         | 
| 162 | 
            +
                    
         | 
| 163 | 
            +
                    # Concatenate all interpolated data into a single DataFrame
         | 
| 164 | 
            +
                    if interpolated_data:
         | 
| 165 | 
            +
                        interpolated_df = pd.concat(interpolated_data)
         | 
| 166 | 
            +
                    else:
         | 
| 167 | 
            +
                        interpolated_df = pd.DataFrame(columns=['frame', 'count', 'label'])
         | 
| 168 | 
            +
                    
         | 
| 169 | 
             
                    plt.style.use("ggplot")
         | 
| 170 | 
             
                    fig, ax = plt.subplots(figsize=(10, 6))
         | 
| 171 | 
            +
                    
         | 
| 172 | 
            +
                    # Plot using Seaborn
         | 
| 173 | 
            +
                    sns.lineplot(ax=ax, data=interpolated_df, x='frame', y='count', hue='label', palette=palette, linewidth=2.5)
         | 
| 174 | 
            +
                    
         | 
| 175 | 
             
                    ax.set_title('Number of Objects over Seconds', fontsize=20, pad=20)  # Increase padding for the title
         | 
| 176 | 
             
                    ax.set_xlabel('Second', fontsize=16)  # Increase font size
         | 
| 177 | 
             
                    ax.set_ylabel('Object Count', fontsize=16)  # Increase font size
         | 
| 178 | 
             
                    ax.tick_params(axis='x', labelsize=12)  # Increase label size for x-axis
         | 
| 179 | 
             
                    ax.tick_params(axis='y', labelsize=12)  # Increase label size for y-axis
         | 
| 180 | 
            +
                    
         | 
| 181 | 
             
                    # Add grid but make it lighter and put it behind bars
         | 
| 182 | 
             
                    ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6)
         | 
| 183 | 
             
                    ax.set_axisbelow(True)
         | 
| 184 | 
            +
                    
         | 
| 185 | 
             
                    # Change the background color to a lighter shade
         | 
| 186 | 
             
                    ax.set_facecolor('#F0F0F0')
         | 
| 187 | 
            +
                    
         | 
| 188 | 
             
                    # Add a legend with a smaller font size
         | 
| 189 | 
             
                    ax.legend(fontsize=10)
         | 
| 190 | 
            +
                    
         | 
| 191 | 
            +
                    plt.tight_layout()  # Ensure the entire plot is visible
         | 
| 192 |  | 
|  | |
| 193 | 
             
                return output_image, output_video, fig
         | 
| 194 |  | 
| 195 |  | 
