Spaces:
Sleeping
Sleeping
""" | |
OpenAI API ํด๋ผ์ด์ธํธ ๋ชจ๋ | |
""" | |
import os | |
import json | |
import logging | |
from typing import List, Dict, Any, Optional, Union | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
# ํ๊ฒฝ ๋ณ์ ๋ก๋ | |
load_dotenv() | |
# ๋ก๊ฑฐ ์ค์ | |
logger = logging.getLogger("OpenAILLM") | |
if not logger.hasHandlers(): | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
class OpenAILLM: | |
"""OpenAI API ๋ํผ ํด๋์ค""" | |
def __init__(self): | |
"""OpenAI LLM ํด๋์ค ์ด๊ธฐํ""" | |
self.api_key = os.getenv("OPENAI_API_KEY") | |
self.model = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo") | |
if not self.api_key: | |
logger.warning("OpenAI API ํค๊ฐ .env ํ์ผ์ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
logger.warning("OPENAI_API_KEY๋ฅผ ํ์ธํ์ธ์.") | |
else: | |
# OpenAI ํด๋ผ์ด์ธํธ ์ด๊ธฐํ | |
self.client = OpenAI(api_key=self.api_key) | |
logger.info("OpenAI API ํค ๋ก๋ ์๋ฃ.") | |
def chat_completion( | |
self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.7, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> Dict[str, Any]: | |
""" | |
OpenAI ์ฑํ ์์ฑ API ํธ์ถ | |
Args: | |
messages: ์ฑํ ๋ฉ์์ง ๋ชฉ๋ก | |
temperature: ์์ฑ ์จ๋ (๋ฎ์์๋ก ๊ฒฐ์ ์ ) | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
API ์๋ต (๋์ ๋๋ฆฌ) | |
""" | |
if not self.api_key: | |
logger.error("API ํค๊ฐ ์ค์ ๋์ง ์์ OpenAI API๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค.") | |
raise ValueError("OpenAI API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
try: | |
logger.info(f"OpenAI API ์์ฒญ ์ ์ก ์ค (๋ชจ๋ธ: {self.model})") | |
# ์๋ก์ด OpenAI SDK๋ฅผ ์ฌ์ฉํ์ฌ API ํธ์ถ | |
response = self.client.chat.completions.create( | |
model=self.model, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
return response | |
except Exception as e: | |
logger.error(f"OpenAI API ์์ฒญ ์คํจ: {e}") | |
raise Exception(f"OpenAI API ์์ฒญ ์คํจ: {e}") | |
def generate( | |
self, | |
prompt: str, | |
system_prompt: Optional[str] = None, | |
temperature: float = 0.7, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> str: | |
""" | |
๊ฐ๋จํ ํ ์คํธ ์์ฑ ์ธํฐํ์ด์ค | |
Args: | |
prompt: ์ฌ์ฉ์ ํ๋กฌํํธ | |
system_prompt: ์์คํ ํ๋กฌํํธ (์ ํ ์ฌํญ) | |
temperature: ์์ฑ ์จ๋ | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
์์ฑ๋ ํ ์คํธ | |
""" | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
try: | |
response = self.chat_completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
# ์๋ก์ด OpenAI SDK ์๋ต ๊ตฌ์กฐ์ ๋ง๊ฒ ์ฒ๋ฆฌ | |
if not response or not hasattr(response, 'choices') or not response.choices: | |
logger.error("OpenAI API ์๋ต์์ ์์ฑ๋ ํ ์คํธ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
return "" | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(f"ํ ์คํธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return f"์ค๋ฅ: {str(e)}" | |
def rag_generate( | |
self, | |
query: str, | |
context: List[str], | |
system_prompt: Optional[str] = None, | |
temperature: float = 0.3, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> str: | |
""" | |
RAG ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํ์ฉํ ํ ์คํธ ์์ฑ | |
Args: | |
query: ์ฌ์ฉ์ ์ง์ | |
context: ๊ฒ์๋ ๋ฌธ๋งฅ ๋ชฉ๋ก | |
system_prompt: ์์คํ ํ๋กฌํํธ (์ ํ ์ฌํญ) | |
temperature: ์์ฑ ์จ๋ | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
์์ฑ๋ ํ ์คํธ | |
""" | |
if not system_prompt: | |
system_prompt = """๋น์ ์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ๋ต๋ณํ๋ ๋์ฐ๋ฏธ์ ๋๋ค. | |
- ๊ฒ์ ๊ฒฐ๊ณผ๋ <context> ํ๊ทธ ์์ ์ ๊ณต๋ฉ๋๋ค. | |
- ๊ฒ์ ๊ฒฐ๊ณผ์ ๋ต๋ณ์ด ์์ผ๋ฉด ํด๋น ์ ๋ณด๋ฅผ ์ฌ์ฉํ์ฌ ๋ช ํํ๊ฒ ๋ต๋ณํ์ธ์. | |
- ๊ฒ์ ๊ฒฐ๊ณผ์ ๋ต๋ณ์ด ์์ผ๋ฉด "๊ฒ์ ๊ฒฐ๊ณผ์ ๊ด๋ จ ์ ๋ณด๊ฐ ์์ต๋๋ค"๋ผ๊ณ ๋งํ์ธ์. | |
- ๊ฒ์ ๋ด์ฉ์ ๊ทธ๋๋ก ๋ณต์ฌํ์ง ๋ง๊ณ , ์์ฐ์ค๋ฌ์ด ํ๊ตญ์ด๋ก ๋ต๋ณ์ ์์ฑํ์ธ์. | |
- ๋ต๋ณ์ ๊ฐ๊ฒฐํ๊ณ ์ ํํ๊ฒ ์ ๊ณตํ์ธ์.""" | |
# ์ค์: ์ปจํ ์คํธ ๊ธธ์ด ์ ํ | |
# gpt-4o-mini์ ๋ง๊ฒ ์ ํ ์ํ | |
max_context = 10 # 3๊ฐ์์ 10๊ฐ๋ก ์ฆ๊ฐ | |
if len(context) > max_context: | |
logger.warning(f"์ปจํ ์คํธ๊ฐ ๋๋ฌด ๊ธธ์ด ์ฒ์ {max_context}๊ฐ๋ง ์ฌ์ฉํฉ๋๋ค.") | |
context = context[:max_context] | |
# ๊ฐ ์ปจํ ์คํธ ์ก์ธ์ค | |
limited_context = [] | |
for i, doc in enumerate(context): | |
# ๊ฐ ๋ฌธ์๋ฅผ 1000์๋ก ์ ํ (์ด์ 500์์์ ์ ๊ทธ๋ ์ด๋) | |
if len(doc) > 1000: | |
logger.warning(f"๋ฌธ์ {i+1}์ ๊ธธ์ด๊ฐ ์ ํ๋์์ต๋๋ค ({len(doc)} -> 1000)") | |
doc = doc[:1000] + "...(์๋ต)" | |
limited_context.append(doc) | |
context_text = "\n\n".join([f"๋ฌธ์ {i+1}: {doc}" for i, doc in enumerate(limited_context)]) | |
prompt = f"""์ง๋ฌธ: {query} | |
<context> | |
{context_text} | |
</context> | |
์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ง๋ฌธ์ ๋ต๋ณํด ์ฃผ์ธ์.""" | |
try: | |
return self.generate( | |
prompt=prompt, | |
system_prompt=system_prompt, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
except Exception as e: | |
logger.error(f"RAG ํ ์คํธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return f"์ค๋ฅ: {str(e)}" | |