Spaces:
Sleeping
Sleeping
""" | |
Grok API ํด๋ผ์ด์ธํธ ๋ชจ๋ | |
""" | |
import os | |
import json | |
import logging | |
import traceback | |
from typing import List, Dict, Any, Optional, Union | |
from dotenv import load_dotenv | |
import requests | |
# ํ๊ฒฝ ๋ณ์ ๋ก๋ | |
load_dotenv() | |
# ๋ก๊ฑฐ ์ค์ | |
logger = logging.getLogger("GrokLLM") | |
if not logger.hasHandlers(): | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
class GrokLLM: | |
"""Grok API ๋ํผ ํด๋์ค""" | |
def __init__(self): | |
"""Grok LLM ํด๋์ค ์ด๊ธฐํ""" | |
self.api_key = os.getenv("GROK_API_KEY") | |
self.api_base = os.getenv("GROK_API_BASE", "https://api.x.ai/v1") | |
self.model = os.getenv("GROK_MODEL", "grok-3-latest") | |
if not self.api_key: | |
logger.warning("Grok API ํค๊ฐ .env ํ์ผ์ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
logger.warning("GROK_API_KEY๋ฅผ ํ์ธํ์ธ์.") | |
else: | |
logger.info("Grok API ํค ๋ก๋ ์๋ฃ.") | |
logger.debug(f"API ๊ธฐ๋ณธ URL: {self.api_base}") | |
logger.debug(f"๊ธฐ๋ณธ ๋ชจ๋ธ: {self.model}") | |
def chat_completion( | |
self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.7, | |
max_tokens: int = 1000, | |
stream: bool = False, | |
**kwargs | |
) -> Dict[str, Any]: | |
""" | |
Grok ์ฑํ ์์ฑ API ํธ์ถ | |
Args: | |
messages: ์ฑํ ๋ฉ์์ง ๋ชฉ๋ก | |
temperature: ์์ฑ ์จ๋ (๋ฎ์์๋ก ๊ฒฐ์ ์ ) | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
stream: ์คํธ๋ฆฌ๋ฐ ์๋ต ํ์ฑํ ์ฌ๋ถ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
API ์๋ต (๋์ ๋๋ฆฌ) | |
""" | |
if not self.api_key: | |
logger.error("API ํค๊ฐ ์ค์ ๋์ง ์์ Grok API๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค.") | |
raise ValueError("Grok API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
try: | |
logger.info(f"Grok API ์์ฒญ ์ ์ก ์ค (๋ชจ๋ธ: {self.model})") | |
# API ์์ฒญ ํค๋ ๋ฐ ๋ฐ์ดํฐ ์ค๋น | |
headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": self.model, | |
"messages": messages, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"stream": stream | |
} | |
# ์ถ๊ฐ ๋งค๊ฐ๋ณ์ ๋ณํฉ | |
for key, value in kwargs.items(): | |
if key not in data: | |
data[key] = value | |
# API ์์ฒญ ๋ณด๋ด๊ธฐ | |
endpoint = f"{self.api_base}/chat/completions" | |
# ๋๋ฒ๊น : API ์์ฒญ ๋ฐ์ดํฐ ๋ก๊น (๋ฏผ๊ฐ ์ ๋ณด ์ ์ธ) | |
debug_data = data.copy() | |
debug_data["messages"] = f"[{len(data['messages'])}๊ฐ ๋ฉ์์ง]" | |
logger.debug(f"Grok API ์์ฒญ ๋ฐ์ดํฐ: {json.dumps(debug_data)}") | |
response = requests.post( | |
endpoint, | |
headers=headers, | |
json=data, | |
timeout=30 # 30์ด ํ์์์ ์ค์ | |
) | |
# ์๋ต ์ํ ์ฝ๋ ํ์ธ | |
if not response.ok: | |
logger.error(f"Grok API ์ค๋ฅ: ์ํ ์ฝ๋ {response.status_code}") | |
logger.error(f"์๋ต ๋ด์ฉ: {response.text}") | |
return {"error": f"API ์ค๋ฅ: ์ํ ์ฝ๋ {response.status_code}", "detail": response.text} | |
# ์๋ต ํ์ฑ | |
try: | |
result = response.json() | |
logger.debug(f"API ์๋ต ๊ตฌ์กฐ: {list(result.keys())}") | |
return result | |
except json.JSONDecodeError as e: | |
logger.error(f"Grok API JSON ํ์ฑ ์คํจ: {e}") | |
logger.error(f"์๋ณธ ์๋ต: {response.text[:500]}...") | |
return {"error": "API ์๋ต์ ํ์ฑํ ์ ์์ต๋๋ค", "detail": str(e)} | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Grok API ์์ฒญ ์คํจ: {e}") | |
return {"error": f"API ์์ฒญ ์คํจ: {str(e)}"} | |
except Exception as e: | |
logger.error(f"Grok API ํธ์ถ ์ค ์์์น ๋ชปํ ์ค๋ฅ: {e}") | |
logger.error(traceback.format_exc()) | |
return {"error": f"์์์น ๋ชปํ ์ค๋ฅ: {str(e)}"} | |
def generate( | |
self, | |
prompt: str, | |
system_prompt: Optional[str] = None, | |
temperature: float = 0.7, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> str: | |
""" | |
๊ฐ๋จํ ํ ์คํธ ์์ฑ ์ธํฐํ์ด์ค | |
Args: | |
prompt: ์ฌ์ฉ์ ํ๋กฌํํธ | |
system_prompt: ์์คํ ํ๋กฌํํธ (์ ํ ์ฌํญ) | |
temperature: ์์ฑ ์จ๋ | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
์์ฑ๋ ํ ์คํธ | |
""" | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
try: | |
response = self.chat_completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
# ์ค๋ฅ ์๋ต ํ์ธ | |
if "error" in response: | |
logger.error(f"ํ ์คํธ ์์ฑ ์ค API ์ค๋ฅ: {response['error']}") | |
error_detail = response.get("detail", "") | |
return f"API ์ค๋ฅ: {response['error']} {error_detail}" | |
# ์๋ต ํ์ ๊ฒ์ฆ | |
if 'choices' not in response or not response['choices']: | |
logger.error(f"API ์๋ต์ 'choices' ํ๋๊ฐ ์์ต๋๋ค: {response}") | |
return "์๋ต ํ์ ์ค๋ฅ: ์์ฑ๋ ํ ์คํธ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
# ๋ฉ์์ง ์ปจํ ์ธ ํ์ธ | |
choice = response['choices'][0] | |
if 'message' not in choice or 'content' not in choice['message']: | |
logger.error(f"API ์๋ต์ ์์ ํ๋๊ฐ ์์ต๋๋ค: {choice}") | |
return "์๋ต ํ์ ์ค๋ฅ: ๋ฉ์์ง ๋ด์ฉ์ ์ฐพ์ ์ ์์ต๋๋ค." | |
generated_text = choice['message']['content'].strip() | |
logger.info(f"ํ ์คํธ ์์ฑ ์๋ฃ (๊ธธ์ด: {len(generated_text)})") | |
return generated_text | |
except Exception as e: | |
logger.error(f"ํ ์คํธ ์์ฑ ์ค ์์ธ ๋ฐ์: {e}") | |
logger.error(traceback.format_exc()) | |
return f"์ค๋ฅ ๋ฐ์: {str(e)}" | |
def rag_generate( | |
self, | |
query: str, | |
context: List[str], | |
system_prompt: Optional[str] = None, | |
temperature: float = 0.3, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> str: | |
""" | |
RAG ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํ์ฉํ ํ ์คํธ ์์ฑ | |
Args: | |
query: ์ฌ์ฉ์ ์ง์ | |
context: ๊ฒ์๋ ๋ฌธ๋งฅ ๋ชฉ๋ก | |
system_prompt: ์์คํ ํ๋กฌํํธ (์ ํ ์ฌํญ) | |
temperature: ์์ฑ ์จ๋ | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
์์ฑ๋ ํ ์คํธ | |
""" | |
if not system_prompt: | |
system_prompt = """๋น์ ์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ๋ต๋ณํ๋ ๋์ฐ๋ฏธ์ ๋๋ค. | |
- ๊ฒ์ ๊ฒฐ๊ณผ๋ <context> ํ๊ทธ ์์ ์ ๊ณต๋ฉ๋๋ค. | |
- ๊ฒ์ ๊ฒฐ๊ณผ์ ๋ต๋ณ์ด ์์ผ๋ฉด ํด๋น ์ ๋ณด๋ฅผ ์ฌ์ฉํ์ฌ ๋ช ํํ๊ฒ ๋ต๋ณํ์ธ์. | |
- ๊ฒ์ ๊ฒฐ๊ณผ์ ๋ต๋ณ์ด ์์ผ๋ฉด "๊ฒ์ ๊ฒฐ๊ณผ์ ๊ด๋ จ ์ ๋ณด๊ฐ ์์ต๋๋ค"๋ผ๊ณ ๋งํ์ธ์. | |
- ๊ฒ์ ๋ด์ฉ์ ๊ทธ๋๋ก ๋ณต์ฌํ์ง ๋ง๊ณ , ์์ฐ์ค๋ฌ์ด ํ๊ตญ์ด๋ก ๋ต๋ณ์ ์์ฑํ์ธ์. | |
- ๋ต๋ณ์ ๊ฐ๊ฒฐํ๊ณ ์ ํํ๊ฒ ์ ๊ณตํ์ธ์.""" | |
try: | |
# ์ค์: ์ปจํ ์คํธ ๊ธธ์ด ์ ํ | |
max_context = 10 | |
if len(context) > max_context: | |
logger.warning(f"์ปจํ ์คํธ๊ฐ ๋๋ฌด ๊ธธ์ด ์ฒ์ {max_context}๊ฐ๋ง ์ฌ์ฉํฉ๋๋ค.") | |
context = context[:max_context] | |
# ๊ฐ ์ปจํ ์คํธ ์ก์ธ์ค | |
limited_context = [] | |
for i, doc in enumerate(context): | |
# ๊ฐ ๋ฌธ์๋ฅผ 1000์๋ก ์ ํ | |
if len(doc) > 1000: | |
logger.warning(f"๋ฌธ์ {i+1}์ ๊ธธ์ด๊ฐ ์ ํ๋์์ต๋๋ค ({len(doc)} -> 1000)") | |
doc = doc[:1000] + "...(์๋ต)" | |
limited_context.append(doc) | |
context_text = "\n\n".join([f"๋ฌธ์ {i+1}: {doc}" for i, doc in enumerate(limited_context)]) | |
prompt = f"""์ง๋ฌธ: {query} | |
<context> | |
{context_text} | |
</context> | |
์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ง๋ฌธ์ ๋ต๋ณํด ์ฃผ์ธ์.""" | |
logger.info(f"RAG ํ๋กฌํํธ ์์ฑ ์๋ฃ (๊ธธ์ด: {len(prompt)})") | |
result = self.generate( | |
prompt=prompt, | |
system_prompt=system_prompt, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
# ๊ฒฐ๊ณผ๊ฐ ์ค๋ฅ ๋ฉ์์ง์ธ์ง ํ์ธ | |
if result.startswith("์ค๋ฅ") or result.startswith("API ์ค๋ฅ") or result.startswith("์๋ต ํ์ ์ค๋ฅ"): | |
logger.error(f"RAG ์์ฑ ๊ฒฐ๊ณผ๊ฐ ์ค๋ฅ๋ฅผ ํฌํจํฉ๋๋ค: {result}") | |
# ์ข ๋ ์ฌ์ฉ์ ์นํ์ ์ธ ์ค๋ฅ ๋ฉ์์ง ๋ฐํ | |
return "์ฃ์กํฉ๋๋ค. ํ์ฌ ์๋ต์ ์์ฑํ๋๋ฐ ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ต๋๋ค. ์ ์ ํ ๋ค์ ์๋ํด์ฃผ์ธ์." | |
return result | |
except Exception as e: | |
logger.error(f"RAG ํ ์คํธ ์์ฑ ์ค ์์ธ ๋ฐ์: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return "์ฃ์กํฉ๋๋ค. ์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค. ์ ์ ํ ๋ค์ ์๋ํด์ฃผ์ธ์." | |