File size: 5,204 Bytes
9475ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0728e3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9475ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import base64
import io
import logging
from typing import List, Optional

import torch
import torchaudio
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from generator import load_csm_1b, Segment

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="CSM 1B API",
    description="API for Sesame's Conversational Speech Model",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

generator = None

class SegmentRequest(BaseModel):
    speaker: int
    text: str
    audio_base64: Optional[str] = None

class GenerateAudioRequest(BaseModel):
    text: str
    speaker: int
    context: List[SegmentRequest] = []
    max_audio_length_ms: float = 10000
    temperature: float = 0.9
    topk: int = 50

class AudioResponse(BaseModel):
    audio_base64: str
    sample_rate: int

@app.on_event("startup")
async def startup_event():
    global generator
    logger.info("Loading CSM 1B model...")
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cpu":
        logger.warning("GPU not available. Using CPU, performance may be slow!")
    
    try:
        # Try to load the model with more information for debugging
        logger.info("Attempting to load CSM 1B model...")
        
        # Import extra tools that might be needed
        from huggingface_hub import hf_hub_download
        import json
        import os
        
        # Try to use an alternative loading method if the direct method fails
        try:
            # First attempt with default loading
            generator = load_csm_1b(device=device)
        except TypeError as e:
            if "missing 1 required positional argument: 'config'" in str(e):
                logger.info("Model requires config. Attempting to load with configuration...")
                
                # Try to load the configuration first
                try:
                    # The model_path can be model_id or path
                    model_id = "sesame/csm-1b"
                    # Try to download and load the config
                    config_file = hf_hub_download(repo_id=model_id, filename="config.json")
                    with open(config_file, 'r') as f:
                        config = json.load(f)
                    
                    # Now try loading with config
                    generator = load_csm_1b(device=device, config=config)
                except Exception as config_error:
                    logger.error(f"Failed to load configuration: {str(config_error)}")
                    raise
            else:
                raise
                
        logger.info(f"Model loaded successfully on device: {device}")
    except Exception as e:
        logger.error(f"Could not load model: {str(e)}")
        raise e

@app.post("/generate-audio", response_model=AudioResponse)
async def generate_audio(request: GenerateAudioRequest):
    global generator
    
    if generator is None:
        raise HTTPException(status_code=503, detail="Model not loaded. Please try again later.")
    
    try:
        context_segments = []
        for segment in request.context:
            if segment.audio_base64:
                audio_bytes = base64.b64decode(segment.audio_base64)
                audio_buffer = io.BytesIO(audio_bytes)
                
                audio_tensor, sample_rate = torchaudio.load(audio_buffer)
                audio_tensor = torchaudio.functional.resample(
                    audio_tensor.squeeze(0), 
                    orig_freq=sample_rate, 
                    new_freq=generator.sample_rate
                )
            else:
                audio_tensor = torch.zeros(0, dtype=torch.float32)
            
            context_segments.append(
                Segment(text=segment.text, speaker=segment.speaker, audio=audio_tensor)
            )
        
        audio = generator.generate(
            text=request.text,
            speaker=request.speaker,
            context=context_segments,
            max_audio_length_ms=request.max_audio_length_ms,
            temperature=request.temperature,
            topk=request.topk,
        )
        
        buffer = io.BytesIO()
        torchaudio.save(buffer, audio.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
        torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
        buffer.seek(0)
        audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
        
        return AudioResponse(
            audio_base64=audio_base64,
            sample_rate=generator.sample_rate
        )
        
    except Exception as e:
        logger.error(f"error when building audio: {str(e)}")
        raise HTTPException(status_code=500, detail=f"error when building audio: {str(e)}")

@app.get("/health")
async def health_check():
    if generator is None:
        return {"status": "not_ready", "message": "Model is loading"}
    return {"status": "ready", "message": "API is ready to serve"}