RemoteAgent / utils /grok_client.py
jeongsoo's picture
add grok model
641f51b
"""
Grok API ํด๋ผ์ด์–ธํŠธ ๋ชจ๋“ˆ
"""
import os
import json
import logging
import traceback
from typing import List, Dict, Any, Optional, Union
from dotenv import load_dotenv
import requests
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๋กœ๋“œ
load_dotenv()
# ๋กœ๊ฑฐ ์„ค์ •
logger = logging.getLogger("GrokLLM")
if not logger.hasHandlers():
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
class GrokLLM:
"""Grok API ๋ž˜ํผ ํด๋ž˜์Šค"""
def __init__(self):
"""Grok LLM ํด๋ž˜์Šค ์ดˆ๊ธฐํ™”"""
self.api_key = os.getenv("GROK_API_KEY")
self.api_base = os.getenv("GROK_API_BASE", "https://api.x.ai/v1")
self.model = os.getenv("GROK_MODEL", "grok-3-latest")
if not self.api_key:
logger.warning("Grok API ํ‚ค๊ฐ€ .env ํŒŒ์ผ์— ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
logger.warning("GROK_API_KEY๋ฅผ ํ™•์ธํ•˜์„ธ์š”.")
else:
logger.info("Grok API ํ‚ค ๋กœ๋“œ ์™„๋ฃŒ.")
logger.debug(f"API ๊ธฐ๋ณธ URL: {self.api_base}")
logger.debug(f"๊ธฐ๋ณธ ๋ชจ๋ธ: {self.model}")
def chat_completion(
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 1000,
stream: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
Grok ์ฑ„ํŒ… ์™„์„ฑ API ํ˜ธ์ถœ
Args:
messages: ์ฑ„ํŒ… ๋ฉ”์‹œ์ง€ ๋ชฉ๋ก
temperature: ์ƒ์„ฑ ์˜จ๋„ (๋‚ฎ์„์ˆ˜๋ก ๊ฒฐ์ •์ )
max_tokens: ์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜
stream: ์ŠคํŠธ๋ฆฌ๋ฐ ์‘๋‹ต ํ™œ์„ฑํ™” ์—ฌ๋ถ€
**kwargs: ์ถ”๊ฐ€ API ๋งค๊ฐœ๋ณ€์ˆ˜
Returns:
API ์‘๋‹ต (๋”•์…”๋„ˆ๋ฆฌ)
"""
if not self.api_key:
logger.error("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•„ Grok API๋ฅผ ํ˜ธ์ถœํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
raise ValueError("Grok API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
try:
logger.info(f"Grok API ์š”์ฒญ ์ „์†ก ์ค‘ (๋ชจ๋ธ: {self.model})")
# API ์š”์ฒญ ํ—ค๋” ๋ฐ ๋ฐ์ดํ„ฐ ์ค€๋น„
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream
}
# ์ถ”๊ฐ€ ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ณ‘ํ•ฉ
for key, value in kwargs.items():
if key not in data:
data[key] = value
# API ์š”์ฒญ ๋ณด๋‚ด๊ธฐ
endpoint = f"{self.api_base}/chat/completions"
# ๋””๋ฒ„๊น…: API ์š”์ฒญ ๋ฐ์ดํ„ฐ ๋กœ๊น… (๋ฏผ๊ฐ ์ •๋ณด ์ œ์™ธ)
debug_data = data.copy()
debug_data["messages"] = f"[{len(data['messages'])}๊ฐœ ๋ฉ”์‹œ์ง€]"
logger.debug(f"Grok API ์š”์ฒญ ๋ฐ์ดํ„ฐ: {json.dumps(debug_data)}")
response = requests.post(
endpoint,
headers=headers,
json=data,
timeout=30 # 30์ดˆ ํƒ€์ž„์•„์›ƒ ์„ค์ •
)
# ์‘๋‹ต ์ƒํƒœ ์ฝ”๋“œ ํ™•์ธ
if not response.ok:
logger.error(f"Grok API ์˜ค๋ฅ˜: ์ƒํƒœ ์ฝ”๋“œ {response.status_code}")
logger.error(f"์‘๋‹ต ๋‚ด์šฉ: {response.text}")
return {"error": f"API ์˜ค๋ฅ˜: ์ƒํƒœ ์ฝ”๋“œ {response.status_code}", "detail": response.text}
# ์‘๋‹ต ํŒŒ์‹ฑ
try:
result = response.json()
logger.debug(f"API ์‘๋‹ต ๊ตฌ์กฐ: {list(result.keys())}")
return result
except json.JSONDecodeError as e:
logger.error(f"Grok API JSON ํŒŒ์‹ฑ ์‹คํŒจ: {e}")
logger.error(f"์›๋ณธ ์‘๋‹ต: {response.text[:500]}...")
return {"error": "API ์‘๋‹ต์„ ํŒŒ์‹ฑํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค", "detail": str(e)}
except requests.exceptions.RequestException as e:
logger.error(f"Grok API ์š”์ฒญ ์‹คํŒจ: {e}")
return {"error": f"API ์š”์ฒญ ์‹คํŒจ: {str(e)}"}
except Exception as e:
logger.error(f"Grok API ํ˜ธ์ถœ ์ค‘ ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜: {e}")
logger.error(traceback.format_exc())
return {"error": f"์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜: {str(e)}"}
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 1000,
**kwargs
) -> str:
"""
๊ฐ„๋‹จํ•œ ํ…์ŠคํŠธ ์ƒ์„ฑ ์ธํ„ฐํŽ˜์ด์Šค
Args:
prompt: ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ
system_prompt: ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ (์„ ํƒ ์‚ฌํ•ญ)
temperature: ์ƒ์„ฑ ์˜จ๋„
max_tokens: ์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜
**kwargs: ์ถ”๊ฐ€ API ๋งค๊ฐœ๋ณ€์ˆ˜
Returns:
์ƒ์„ฑ๋œ ํ…์ŠคํŠธ
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
response = self.chat_completion(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
# ์˜ค๋ฅ˜ ์‘๋‹ต ํ™•์ธ
if "error" in response:
logger.error(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ API ์˜ค๋ฅ˜: {response['error']}")
error_detail = response.get("detail", "")
return f"API ์˜ค๋ฅ˜: {response['error']} {error_detail}"
# ์‘๋‹ต ํ˜•์‹ ๊ฒ€์ฆ
if 'choices' not in response or not response['choices']:
logger.error(f"API ์‘๋‹ต์— 'choices' ํ•„๋“œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค: {response}")
return "์‘๋‹ต ํ˜•์‹ ์˜ค๋ฅ˜: ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๋ฉ”์‹œ์ง€ ์ปจํ…์ธ  ํ™•์ธ
choice = response['choices'][0]
if 'message' not in choice or 'content' not in choice['message']:
logger.error(f"API ์‘๋‹ต์— ์˜ˆ์ƒ ํ•„๋“œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค: {choice}")
return "์‘๋‹ต ํ˜•์‹ ์˜ค๋ฅ˜: ๋ฉ”์‹œ์ง€ ๋‚ด์šฉ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
generated_text = choice['message']['content'].strip()
logger.info(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์™„๋ฃŒ (๊ธธ์ด: {len(generated_text)})")
return generated_text
except Exception as e:
logger.error(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ ์˜ˆ์™ธ ๋ฐœ์ƒ: {e}")
logger.error(traceback.format_exc())
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
def rag_generate(
self,
query: str,
context: List[str],
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_tokens: int = 1000,
**kwargs
) -> str:
"""
RAG ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ํ™œ์šฉํ•œ ํ…์ŠคํŠธ ์ƒ์„ฑ
Args:
query: ์‚ฌ์šฉ์ž ์งˆ์˜
context: ๊ฒ€์ƒ‰๋œ ๋ฌธ๋งฅ ๋ชฉ๋ก
system_prompt: ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ (์„ ํƒ ์‚ฌํ•ญ)
temperature: ์ƒ์„ฑ ์˜จ๋„
max_tokens: ์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜
**kwargs: ์ถ”๊ฐ€ API ๋งค๊ฐœ๋ณ€์ˆ˜
Returns:
์ƒ์„ฑ๋œ ํ…์ŠคํŠธ
"""
if not system_prompt:
system_prompt = """๋‹น์‹ ์€ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋Š” ๋„์šฐ๋ฏธ์ž…๋‹ˆ๋‹ค.
- ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋Š” <context> ํƒœ๊ทธ ์•ˆ์— ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค.
- ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์— ๋‹ต๋ณ€์ด ์žˆ์œผ๋ฉด ํ•ด๋‹น ์ •๋ณด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ช…ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•˜์„ธ์š”.
- ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์— ๋‹ต๋ณ€์ด ์—†์œผ๋ฉด "๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์— ๊ด€๋ จ ์ •๋ณด๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค"๋ผ๊ณ  ๋งํ•˜์„ธ์š”.
- ๊ฒ€์ƒ‰ ๋‚ด์šฉ์„ ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌํ•˜์ง€ ๋ง๊ณ , ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด๋กœ ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.
- ๋‹ต๋ณ€์€ ๊ฐ„๊ฒฐํ•˜๊ณ  ์ •ํ™•ํ•˜๊ฒŒ ์ œ๊ณตํ•˜์„ธ์š”."""
try:
# ์ค‘์š”: ์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ œํ•œ
max_context = 10
if len(context) > max_context:
logger.warning(f"์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด ์ฒ˜์Œ {max_context}๊ฐœ๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
context = context[:max_context]
# ๊ฐ ์ปจํ…์ŠคํŠธ ์•ก์„ธ์Šค
limited_context = []
for i, doc in enumerate(context):
# ๊ฐ ๋ฌธ์„œ๋ฅผ 1000์ž๋กœ ์ œํ•œ
if len(doc) > 1000:
logger.warning(f"๋ฌธ์„œ {i+1}์˜ ๊ธธ์ด๊ฐ€ ์ œํ•œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค ({len(doc)} -> 1000)")
doc = doc[:1000] + "...(์ƒ๋žต)"
limited_context.append(doc)
context_text = "\n\n".join([f"๋ฌธ์„œ {i+1}: {doc}" for i, doc in enumerate(limited_context)])
prompt = f"""์งˆ๋ฌธ: {query}
<context>
{context_text}
</context>
์œ„ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•ด ์ฃผ์„ธ์š”."""
logger.info(f"RAG ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ ์™„๋ฃŒ (๊ธธ์ด: {len(prompt)})")
result = self.generate(
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
# ๊ฒฐ๊ณผ๊ฐ€ ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€์ธ์ง€ ํ™•์ธ
if result.startswith("์˜ค๋ฅ˜") or result.startswith("API ์˜ค๋ฅ˜") or result.startswith("์‘๋‹ต ํ˜•์‹ ์˜ค๋ฅ˜"):
logger.error(f"RAG ์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ ์˜ค๋ฅ˜๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค: {result}")
# ์ข€ ๋” ์‚ฌ์šฉ์ž ์นœํ™”์ ์ธ ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
return "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ํ˜„์žฌ ์‘๋‹ต์„ ์ƒ์„ฑํ•˜๋Š”๋ฐ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ์ž ์‹œ ํ›„ ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
return result
except Exception as e:
logger.error(f"RAG ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ ์˜ˆ์™ธ ๋ฐœ์ƒ: {str(e)}")
logger.error(traceback.format_exc())
return "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ์ž ์‹œ ํ›„ ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."