MixtureOfInputs / server.py
yzhuang's picture
fix
3d9b062
raw
history blame
3.16 kB
# app.py ── launch vLLM inside a Hugging Face Space (with clean shutdown)
import os, signal, sys, atexit, time, socket, subprocess
import spaces # only needed for the GPU decorator
# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def _wait_for_port(host: str, port: int, timeout: int = 240):
"""Block until (host, port) accepts TCP connections or timeout."""
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket() as sock:
sock.settimeout(2)
if sock.connect_ex((host, port)) == 0:
return
time.sleep(1)
raise RuntimeError(f"vLLM server on {host}:{port} never came up")
def _kill_proc_tree(proc: subprocess.Popen):
"""SIGTERM the whole process-group started by `proc` (if still alive)."""
if proc and proc.poll() is None: # still running
pgid = os.getpgid(proc.pid)
os.killpg(pgid, signal.SIGTERM) # graceful
try:
proc.wait(15)
except subprocess.TimeoutExpired:
os.killpg(pgid, signal.SIGKILL) # force
# ----------------------------------------------------------------------
# Setup – runs on *CPU* only; fast.
# ----------------------------------------------------------------------
def setup_mixinputs():
subprocess.run(["mixinputs", "setup"], check=True)
# ----------------------------------------------------------------------
# Serve – runs on the GPU; heavy, so we mark it.
# ----------------------------------------------------------------------
def launch_vllm_server(beta: float = 1.0, port: int = 8000) -> subprocess.Popen:
env = os.environ.copy()
env["MIXINPUTS_BETA"] = str(beta)
env["VLLM_USE_V1"] = "1"
cmd = [
"vllm", "serve",
"Qwen/Qwen3-4B",
"--tensor-parallel-size", "1",
"--enforce-eager",
"--max-model-len", "2048",
"--max-seq-len-to-capture", "2048",
"--max-num-seqs", "1",
"--port", str(port)
]
# new session ⇒ its own process-group
proc = subprocess.Popen(cmd, env=env, start_new_session=True)
_wait_for_port("localhost", port) # block until ready
return proc
# ----------------------------------------------------------------------
# MAIN
# ----------------------------------------------------------------------
if __name__ == "__main__":
setup_mixinputs() # fast
server_proc = launch_vllm_server() # heavy
# Ensures the GPU process dies when the Space stops / reloads
atexit.register(_kill_proc_tree, server_proc)
# ---- your Gradio / FastAPI app goes below ----
# e.g. import gradio as gr
# with gr.Blocks(teardown=lambda: _kill_proc_tree(server_proc)) as demo:
# ...
# demo.launch(server_name="0.0.0.0", server_port=7860)
#
# For this snippet we’ll just block forever so the container
# doesn’t exit immediately.
try:
server_proc.wait()
except KeyboardInterrupt:
pass