|
import shlex |
|
import subprocess |
|
import os |
|
import sys |
|
import logging |
|
from pathlib import Path |
|
from huggingface_hub import snapshot_download |
|
from huggingface_hub.utils import RepositoryNotFoundError, HfHubError |
|
import torch |
|
import fire |
|
import gradio as gr |
|
from gradio_app.gradio_3dgen import create_ui as create_3d_ui |
|
from gradio_app.all_models import model_zoo |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def setup_dependencies(): |
|
"""Install required packages with error handling""" |
|
try: |
|
logger.info("Installing dependencies...") |
|
subprocess.run(shlex.split("pip install pip==24.0"), check=True) |
|
|
|
|
|
packages = [ |
|
"package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl", |
|
"package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl" |
|
] |
|
|
|
|
|
for package in packages: |
|
if not Path(package).exists(): |
|
raise FileNotFoundError(f"Package file not found: {package}") |
|
|
|
logger.info(f"Installing {package}") |
|
subprocess.run( |
|
shlex.split(f"pip install {package} --force-reinstall --no-deps"), |
|
check=True |
|
) |
|
|
|
logger.info("Dependencies installed successfully") |
|
except subprocess.CalledProcessError as e: |
|
logger.error(f"Failed to install dependencies: {str(e)}") |
|
raise |
|
except FileNotFoundError as e: |
|
logger.error(str(e)) |
|
raise |
|
|
|
def setup_model(): |
|
"""Download and set up model with error handling""" |
|
try: |
|
logger.info("Downloading model checkpoints...") |
|
|
|
|
|
ckpt_dir = Path("./ckpt") |
|
ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
max_retries = 3 |
|
for attempt in range(max_retries): |
|
try: |
|
snapshot_download( |
|
"public-data/Unique3D", |
|
repo_type="model", |
|
local_dir=str(ckpt_dir), |
|
token=os.getenv("HF_TOKEN") |
|
) |
|
break |
|
except RepositoryNotFoundError as e: |
|
logger.error(f"Repository not found: {str(e)}") |
|
raise |
|
except Exception as e: |
|
if attempt == max_retries - 1: |
|
logger.error(f"Failed to download model after {max_retries} attempts: {str(e)}") |
|
raise |
|
logger.warning(f"Download attempt {attempt + 1} failed, retrying...") |
|
continue |
|
|
|
logger.info("Model checkpoints downloaded successfully") |
|
|
|
|
|
torch.set_float32_matmul_precision('medium') |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.set_grad_enabled(False) |
|
|
|
logger.info("PyTorch configured successfully") |
|
except Exception as e: |
|
logger.error(f"Error during model setup: {str(e)}") |
|
raise |
|
|
|
|
|
_TITLE = 'Text to 3D' |
|
|
|
def launch(): |
|
"""Launch the Gradio interface""" |
|
try: |
|
logger.info("Initializing models...") |
|
model_zoo.init_models() |
|
|
|
logger.info("Creating Gradio interface...") |
|
with gr.Blocks(title=_TITLE) as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown('# ' + _TITLE) |
|
create_3d_ui("wkl") |
|
|
|
demo.queue().launch(share=True) |
|
except Exception as e: |
|
logger.error(f"Error launching application: {str(e)}") |
|
raise |
|
|
|
if __name__ == '__main__': |
|
try: |
|
logger.info("Starting application setup...") |
|
setup_dependencies() |
|
sys.path.append(os.curdir) |
|
setup_model() |
|
fire.Fire(launch) |
|
except Exception as e: |
|
logger.error(f"Application startup failed: {str(e)}") |
|
sys.exit(1) |