Sesame-AI-POC / main.py
karagmercola's picture
Update main.py
43f0a07 verified
import os
import json
import time
import boto3
import torchaudio
from dotenv import load_dotenv
from generator import load_csm_1b
# Load environment variables
load_dotenv()
REGION = os.getenv("AWS_DEFAULT_REGION", "us-east-2")
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT_NAME", "Llama-3-2-3B-Instruct-streaming-endpoint")
HF_TOKEN = os.getenv("HF_TOKEN")
# Load local Sesame model
generator = load_csm_1b(device="cpu") # Change to "cuda" if ZeroGPU is enabled
# Function to invoke LLaMA from SageMaker
def invoke_llama(prompt: str) -> str:
try:
# Add prompt formatting for better LLaMA output
formatted_prompt = f"### Question:\n{prompt.strip()}\n\n### Answer:"
client_boto = boto3.client("sagemaker-runtime", region_name=REGION)
payload = {"inputs": formatted_prompt}
response = client_boto.invoke_endpoint(
EndpointName=LLAMA_ENDPOINT,
ContentType="application/json",
Body=json.dumps(payload)
)
result = json.loads(response["Body"].read().decode("utf-8"))
if isinstance(result, list):
return result[0].get("generated_text", "")
return result.get("generated_text", str(result))
except Exception as e:
print(f"Error calling LLaMA: {e}")
return "I'm having trouble processing your request."
# Function to generate speech using the local Sesame model
def speak_with_sesame(text: str) -> str:
try:
audio = generator.generate(
text=text,
speaker=0,
context=[],
max_audio_length_ms=30000
)
output_path = f"csm_output_{int(time.time())}.wav"
torchaudio.save(output_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
return output_path
except Exception as e:
print(f"Error generating Sesame voice: {e}")
return None
# Main logic to run full conversation pipeline
def conversation_with_voice(user_input: str) -> dict:
text_response = invoke_llama(user_input)
audio_path = speak_with_sesame(text_response)
return {
"text_response": text_response,
"audio_path": audio_path
}