File size: 2,167 Bytes
27e80e8
 
43f0a07
27e80e8
43f0a07
27e80e8
43f0a07
27e80e8
 
 
 
f148635
 
 
27e80e8
43f0a07
 
 
 
27e80e8
 
43f0a07
 
27e80e8
43f0a07
 
27e80e8
 
 
 
 
43f0a07
27e80e8
 
 
 
43f0a07
27e80e8
f67c9fa
27e80e8
 
43f0a07
f148635
 
43f0a07
f67c9fa
43f0a07
 
 
f148635
43f0a07
 
 
f148635
43f0a07
f148635
27e80e8
43f0a07
f148635
27e80e8
f148635
27e80e8
 
 
f148635
f67c9fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import json
import time
import boto3
import torchaudio
from dotenv import load_dotenv
from generator import load_csm_1b

# Load environment variables
load_dotenv()

REGION = os.getenv("AWS_DEFAULT_REGION", "us-east-2")
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT_NAME", "Llama-3-2-3B-Instruct-streaming-endpoint")
HF_TOKEN = os.getenv("HF_TOKEN")

# Load local Sesame model
generator = load_csm_1b(device="cpu")  # Change to "cuda" if ZeroGPU is enabled

# Function to invoke LLaMA from SageMaker
def invoke_llama(prompt: str) -> str:
    try:
        # Add prompt formatting for better LLaMA output
        formatted_prompt = f"### Question:\n{prompt.strip()}\n\n### Answer:"
        client_boto = boto3.client("sagemaker-runtime", region_name=REGION)
        payload = {"inputs": formatted_prompt}

        response = client_boto.invoke_endpoint(
            EndpointName=LLAMA_ENDPOINT,
            ContentType="application/json",
            Body=json.dumps(payload)
        )

        result = json.loads(response["Body"].read().decode("utf-8"))
        if isinstance(result, list):
            return result[0].get("generated_text", "")
        return result.get("generated_text", str(result))

    except Exception as e:
        print(f"Error calling LLaMA: {e}")
        return "I'm having trouble processing your request."

# Function to generate speech using the local Sesame model
def speak_with_sesame(text: str) -> str:
    try:
        audio = generator.generate(
            text=text,
            speaker=0,
            context=[],
            max_audio_length_ms=30000
        )
        output_path = f"csm_output_{int(time.time())}.wav"
        torchaudio.save(output_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
        return output_path
    except Exception as e:
        print(f"Error generating Sesame voice: {e}")
        return None

# Main logic to run full conversation pipeline
def conversation_with_voice(user_input: str) -> dict:
    text_response = invoke_llama(user_input)
    audio_path = speak_with_sesame(text_response)
    return {
        "text_response": text_response,
        "audio_path": audio_path
    }