Spaces:
Running
Running
import os | |
import json | |
import time | |
import boto3 | |
import torchaudio | |
from dotenv import load_dotenv | |
from generator import load_csm_1b | |
# Load environment variables | |
load_dotenv() | |
REGION = os.getenv("AWS_DEFAULT_REGION", "us-east-2") | |
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT_NAME", "Llama-3-2-3B-Instruct-streaming-endpoint") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Load local Sesame model | |
generator = load_csm_1b(device="cpu") # Change to "cuda" if ZeroGPU is enabled | |
# Function to invoke LLaMA from SageMaker | |
def invoke_llama(prompt: str) -> str: | |
try: | |
# Add prompt formatting for better LLaMA output | |
formatted_prompt = f"### Question:\n{prompt.strip()}\n\n### Answer:" | |
client_boto = boto3.client("sagemaker-runtime", region_name=REGION) | |
payload = {"inputs": formatted_prompt} | |
response = client_boto.invoke_endpoint( | |
EndpointName=LLAMA_ENDPOINT, | |
ContentType="application/json", | |
Body=json.dumps(payload) | |
) | |
result = json.loads(response["Body"].read().decode("utf-8")) | |
if isinstance(result, list): | |
return result[0].get("generated_text", "") | |
return result.get("generated_text", str(result)) | |
except Exception as e: | |
print(f"Error calling LLaMA: {e}") | |
return "I'm having trouble processing your request." | |
# Function to generate speech using the local Sesame model | |
def speak_with_sesame(text: str) -> str: | |
try: | |
audio = generator.generate( | |
text=text, | |
speaker=0, | |
context=[], | |
max_audio_length_ms=30000 | |
) | |
output_path = f"csm_output_{int(time.time())}.wav" | |
torchaudio.save(output_path, audio.unsqueeze(0).cpu(), generator.sample_rate) | |
return output_path | |
except Exception as e: | |
print(f"Error generating Sesame voice: {e}") | |
return None | |
# Main logic to run full conversation pipeline | |
def conversation_with_voice(user_input: str) -> dict: | |
text_response = invoke_llama(user_input) | |
audio_path = speak_with_sesame(text_response) | |
return { | |
"text_response": text_response, | |
"audio_path": audio_path | |
} | |