|
import argparse |
|
import uvicorn |
|
import threading |
|
import os |
|
import config |
|
from app import app as gradio_app |
|
from api import app as api_app |
|
|
|
def run_api(): |
|
"""Run the FastAPI server""" |
|
uvicorn.run( |
|
api_app, |
|
host=config.API_HOST, |
|
port=config.API_PORT |
|
) |
|
|
|
def run_gradio(): |
|
"""Run the Gradio interface""" |
|
gradio_app.launch( |
|
server_name=config.GRADIO_HOST, |
|
server_port=config.GRADIO_PORT, |
|
share=False |
|
) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Run Diffusion Models App") |
|
parser.add_argument( |
|
"--mode", |
|
type=str, |
|
default="all", |
|
choices=["all", "api", "ui"], |
|
help="Which component to run: 'all' (default), 'api', or 'ui'" |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not config.HF_TOKEN: |
|
print("Warning: HF_TOKEN environment variable is not set. Please set it for API access.") |
|
print("You can create a .env file with HF_TOKEN=your_token or set it in your environment.") |
|
|
|
if args.mode == "all": |
|
|
|
api_thread = threading.Thread(target=run_api) |
|
api_thread.daemon = True |
|
api_thread.start() |
|
|
|
print(f"API server running at http://{config.API_HOST}:{config.API_PORT}") |
|
print(f"Starting Gradio UI at http://{config.GRADIO_HOST}:{config.GRADIO_PORT}") |
|
|
|
|
|
run_gradio() |
|
|
|
elif args.mode == "api": |
|
print(f"Starting API server at http://{config.API_HOST}:{config.API_PORT}") |
|
run_api() |
|
|
|
elif args.mode == "ui": |
|
print(f"Starting Gradio UI at http://{config.GRADIO_HOST}:{config.GRADIO_PORT}") |
|
run_gradio() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|