gpu / app.py
vitorcalvi's picture
1
5f09150
raw
history blame
1.37 kB
import gradio as gr
import torch
from tabs.FACS_analysis import create_facs_analysis_tab
from ui_components import CUSTOM_CSS, HEADER_HTML, DISCLAIMER_HTML
import spaces # Importing spaces to utilize GPU
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define the tab structure
TAB_STRUCTURE = [
("Visual Analysis", [
("FACS for Stress, Anxiety, Depression", create_facs_analysis_tab),
])
]
# Decorate GPU-dependent function with GPU
@spaces.GPU(duration=300) # Increased duration if necessary
def create_demo():
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# Ensure that any models loaded within create_facs_analysis_tab use the correct device
with gr.Blocks(css=CUSTOM_CSS) as demo:
gr.Markdown(HEADER_HTML)
with gr.Tabs(elem_classes=["main-tab"]):
for main_tab, sub_tabs in TAB_STRUCTURE:
with gr.Tab(main_tab):
with gr.Tabs():
for sub_tab, create_fn in sub_tabs:
with gr.Tab(sub_tab):
create_fn(device=device) # Pass device if needed
gr.HTML(DISCLAIMER_HTML)
return demo
# Create the demo instance
demo = create_demo()
if __name__ == "__main__":
demo.launch()