Spaces:
Running
Running
File size: 5,204 Bytes
9475ff0 0728e3f 9475ff0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import base64
import io
import logging
from typing import List, Optional
import torch
import torchaudio
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from generator import load_csm_1b, Segment
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="CSM 1B API",
description="API for Sesame's Conversational Speech Model",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
generator = None
class SegmentRequest(BaseModel):
speaker: int
text: str
audio_base64: Optional[str] = None
class GenerateAudioRequest(BaseModel):
text: str
speaker: int
context: List[SegmentRequest] = []
max_audio_length_ms: float = 10000
temperature: float = 0.9
topk: int = 50
class AudioResponse(BaseModel):
audio_base64: str
sample_rate: int
@app.on_event("startup")
async def startup_event():
global generator
logger.info("Loading CSM 1B model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
logger.warning("GPU not available. Using CPU, performance may be slow!")
try:
# Try to load the model with more information for debugging
logger.info("Attempting to load CSM 1B model...")
# Import extra tools that might be needed
from huggingface_hub import hf_hub_download
import json
import os
# Try to use an alternative loading method if the direct method fails
try:
# First attempt with default loading
generator = load_csm_1b(device=device)
except TypeError as e:
if "missing 1 required positional argument: 'config'" in str(e):
logger.info("Model requires config. Attempting to load with configuration...")
# Try to load the configuration first
try:
# The model_path can be model_id or path
model_id = "sesame/csm-1b"
# Try to download and load the config
config_file = hf_hub_download(repo_id=model_id, filename="config.json")
with open(config_file, 'r') as f:
config = json.load(f)
# Now try loading with config
generator = load_csm_1b(device=device, config=config)
except Exception as config_error:
logger.error(f"Failed to load configuration: {str(config_error)}")
raise
else:
raise
logger.info(f"Model loaded successfully on device: {device}")
except Exception as e:
logger.error(f"Could not load model: {str(e)}")
raise e
@app.post("/generate-audio", response_model=AudioResponse)
async def generate_audio(request: GenerateAudioRequest):
global generator
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded. Please try again later.")
try:
context_segments = []
for segment in request.context:
if segment.audio_base64:
audio_bytes = base64.b64decode(segment.audio_base64)
audio_buffer = io.BytesIO(audio_bytes)
audio_tensor, sample_rate = torchaudio.load(audio_buffer)
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0),
orig_freq=sample_rate,
new_freq=generator.sample_rate
)
else:
audio_tensor = torch.zeros(0, dtype=torch.float32)
context_segments.append(
Segment(text=segment.text, speaker=segment.speaker, audio=audio_tensor)
)
audio = generator.generate(
text=request.text,
speaker=request.speaker,
context=context_segments,
max_audio_length_ms=request.max_audio_length_ms,
temperature=request.temperature,
topk=request.topk,
)
buffer = io.BytesIO()
torchaudio.save(buffer, audio.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return AudioResponse(
audio_base64=audio_base64,
sample_rate=generator.sample_rate
)
except Exception as e:
logger.error(f"error when building audio: {str(e)}")
raise HTTPException(status_code=500, detail=f"error when building audio: {str(e)}")
@app.get("/health")
async def health_check():
if generator is None:
return {"status": "not_ready", "message": "Model is loading"}
return {"status": "ready", "message": "API is ready to serve"}
|