import os import uuid import logging import torch import numpy as np from fastapi import FastAPI, HTTPException from pydantic import BaseModel from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler from diffusers.utils import export_to_video from huggingface_hub import hf_hub_download from safetensors.torch import load_file from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Cho phép tất cả domain gọi API (hoặc liệt kê domain cụ thể ở đây) allow_credentials=True, allow_methods=["*"], # Cho phép tất cả phương thức (POST, GET, etc.) allow_headers=["*"], # Cho phép tất cả headers ) # Thiết lập logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Thiết lập thư mục cache cho Hugging Face ngay đầu file os.environ["HF_HOME"] = "/tmp/huggingface_cache" os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface_cache" os.makedirs(os.environ["HF_HOME"], exist_ok=True) logger.info(f"HF_HOME set to {os.environ['HF_HOME']}") app = FastAPI() # Tạo thư mục lưu video trong /tmp output_dir = "/tmp/outputs" os.makedirs(output_dir, exist_ok=True) # Constants bases = { "Cartoon": "frankjoshua/toonyou_beta6", "Realistic": "emilianJR/epiCRealism", "3d": "Lykon/DreamShaper", "Anime": "Yntec/mistoonAnime2" } step_loaded = None base_loaded = "Realistic" motion_loaded = None # Thiết lập thiết bị CPU và kiểu dữ liệu device = "cpu" dtype = torch.float32 # Khởi tạo pipeline pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device) pipe.scheduler = EulerDiscreteScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear" ) pipe.safety_checker = None # Mô hình dữ liệu cho request class VideoRequest(BaseModel): prompt: str base: str = "Realistic" motion: str = "" step: int = 1 # Endpoint tạo video @app.post("/generate_video") async def generate_video(request: VideoRequest): global step_loaded, base_loaded, motion_loaded prompt = request.prompt base = request.base motion = request.motion step = request.step logger.info(f"Tạo video với prompt: {prompt}, base: {base}, motion: {motion}, steps: {step}") try: # Kiểm tra base hợp lệ if base not in bases: raise HTTPException(status_code=400, detail="Base model không hợp lệ") # Tải AnimateDiff Lightning checkpoint if step_loaded != step: repo = "ByteDance/AnimateDiff-Lightning" ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False) step_loaded = step # Tải mô hình cơ sở nếu thay đổi if base_loaded != base: pipe.unet.load_state_dict( torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False ) base_loaded = base # Tải motion LoRA nếu có if motion_loaded != motion: try: pipe.unload_lora_weights() if motion: pipe.load_lora_weights(motion, adapter_name="motion") pipe.set_adapters(["motion"], [0.7]) motion_loaded = motion except Exception as e: logger.warning(f"Không thể tải motion LoRA: {e}") motion_loaded = "" # Suy luận with torch.no_grad(): output = pipe( prompt=prompt, guidance_scale=1.2, num_inference_steps=step, num_frames=32, width=256, height=256 ) # Chuẩn hóa khung hình cho 8 giây frames = output.frames[0] fps = 24 target_frames = fps * 8 if len(frames) < target_frames: frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames] else: frames = frames[:target_frames] # Tạo video name = str(uuid.uuid4()).replace("-", "") video_path = os.path.join(output_dir, f"{name}.mp4") export_to_video(frames, video_path, fps=fps) if not os.path.exists(video_path): raise FileNotFoundError("Video không được tạo") logger.info(f"Video sẵn sàng tại {video_path}") # Trả về file video return FileResponse(video_path, media_type="video/mp4", filename=f"{name}.mp4") except Exception as e: logger.error(f"Lỗi khi tạo video: {e}") raise HTTPException(status_code=500, detail=str(e)) # Endpoint kiểm tra trạng thái @app.get("/") async def root(): return {"message": "FastAPI AnimateDiff-Lightning API on Hugging Face Spaces"} if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) # Nếu PORT không có thì mặc định 7860 uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)