File size: 2,912 Bytes
125c0a4
 
60ddc02
cae33ea
125c0a4
 
 
 
 
60ddc02
125c0a4
60ddc02
 
562a35a
60ddc02
125c0a4
 
 
 
6107cf6
125c0a4
 
 
 
f6491de
125c0a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562a35a
125c0a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cae33ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import httpx
import logging
import os
from pipecat.frames.frames import TextFrame, LLMResponseFrame
from pipecat.processors.frame_processor import FrameProcessor, FrameDirection

logger = logging.getLogger(__name__)

class AzureOpenAILLMService(FrameProcessor):
    def __init__(self, preprompt: str = "", endpoint: str = "https://designcntrl-azure-openai-server.openai.azure.com/openai/deployments/ottopilot/chat/completions?api-version=2023-03-15-preview"):
        super().__init__()
        self.api_key = os.environ.get("azure_openai")
        if not self.api_key:
            logger.error("Missing Azure OpenAI API key: azure_openai")
            raise ValueError("Azure OpenAI API key not found in environment variable 'azure_openai'")
        self.preprompt = preprompt
        self.endpoint = endpoint
        self.client = httpx.AsyncClient()

    async def process_frame(self, frame, direction: FrameDirection):
        if isinstance(frame, TextFrame) and direction == FrameDirection.UPSTREAM:
            try:
                messages = []
                if self.preprompt:
                    messages.append({"role": "system", "content": self.preprompt})
                messages.append({"role": "user", "content": frame.text})

                headers = {
                    "Content-Type": "application/json",
                    "api-key": self.api_key
                }
                data = {
                    "messages": messages,
                    "temperature": 0.5,
                    "max_tokens": 4000,
                    "top_p": 1,
                    "frequency_penalty": 0,
                    "presence_penalty": 0
                }

                response = await self.client.post(self.endpoint, headers=headers, json=data, timeout=30)
                response.raise_for_status()
                result = response.json()

                if "choices" in result and len(result["choices"]) > 0:
                    content = result["choices"][0]["message"]["content"]
                    continue_flag = len(content) >= 4000
                    await self.push_frame(LLMResponseFrame(content=content, continue_flag=continue_flag))
                else:
                    logger.error("No valid content in API response")
                    await self.push_frame(TextFrame("Error: No valid response from LLM"))

            except httpx.HTTPStatusError as e:
                logger.error(f"API error: {e}")
                await self.push_frame(TextFrame(f"Error: API request failed - {str(e)}"))
            except Exception as e:
                logger.error(f"Unexpected error: {e}", exc_info=True)
                await self.push_frame(TextFrame(f"Error: Unexpected error - {str(e)}"))
        else:
            await self.push_frame(frame, direction)

    async def stop(self):
        await self.client.aclose()
        logger.info("AzureOpenAILLMService stopped")