Create azure_openai.py
Browse files- azure_openai.py +60 -0
azure_openai.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import httpx
|
2 |
+
import logging
|
3 |
+
from pipecat.frames.frames import TextFrame, LLMResponseFrame
|
4 |
+
from pipecat.processors.frame_processor import FrameProcessor, FrameDirection
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class AzureOpenAILLMService(FrameProcessor):
|
9 |
+
def __init__(self, api_key: str, preprompt: str = "", endpoint: str = "https://designcntrl-azure-openai-server.openai.azure.com/openai/deployments/ottopilot/chat/completions?api-version=2023-03-15-preview"):
|
10 |
+
super().__init__()
|
11 |
+
self.api_key = api_key
|
12 |
+
self.preprompt = preprompt
|
13 |
+
self.endpoint = endpoint
|
14 |
+
self.client = httpx.AsyncClient()
|
15 |
+
|
16 |
+
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
17 |
+
if isinstance(frame, TextFrame) and direction == FrameDirection.UPSTREAM:
|
18 |
+
try:
|
19 |
+
messages = []
|
20 |
+
if self.preprompt:
|
21 |
+
messages.append({"role": "system", "content": self.preprompt})
|
22 |
+
messages.append({"role": "user", "content": frame.text})
|
23 |
+
|
24 |
+
headers = {
|
25 |
+
"Content-Type": "application/json",
|
26 |
+
"api-key": self.api_key
|
27 |
+
}
|
28 |
+
data = {
|
29 |
+
"messages": messages,
|
30 |
+
"temperature": 0.5,
|
31 |
+
"max_tokens": 4000,
|
32 |
+
"top_p": 1,
|
33 |
+
"frequency_penalty": 0,
|
34 |
+
"presence_penalty": 0
|
35 |
+
}
|
36 |
+
|
37 |
+
response = await self.client.post(self.endpoint, headers=headers, json=data, timeout=30)
|
38 |
+
response.raise_for_status()
|
39 |
+
result = response.json()
|
40 |
+
|
41 |
+
if "choices" in result and len(result["choices"]) > 0:
|
42 |
+
content = result["choices"][0]["message"]["content"]
|
43 |
+
continue_flag = len(content) >= 4000
|
44 |
+
await self.push_frame(LLMResponseFrame(content=content, continue_flag=continue_flag))
|
45 |
+
else:
|
46 |
+
logger.error("No valid content in API response")
|
47 |
+
await self.push_frame(TextFrame("Error: No valid response from LLM"))
|
48 |
+
|
49 |
+
except httpx.HTTPStatusError as e:
|
50 |
+
logger.error(f"API error: {e}")
|
51 |
+
await self.push_frame(TextFrame(f"Error: API request failed - {str(e)}"))
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Unexpected error: {e}", exc_info=True)
|
54 |
+
await self.push_frame(TextFrame(f"Error: Unexpected error - {str(e)}"))
|
55 |
+
else:
|
56 |
+
await self.push_frame(frame, direction)
|
57 |
+
|
58 |
+
async def stop(self):
|
59 |
+
await self.client.aclose()
|
60 |
+
logger.info("AzureOpenAILLMService stopped")
|