Spaces:
No application file
No application file
""" | |
DeepSeek API ํด๋ผ์ด์ธํธ ๋ชจ๋ | |
""" | |
import os | |
import json | |
import logging | |
from typing import List, Dict, Any, Optional, Union | |
from dotenv import load_dotenv | |
import requests | |
# ํ๊ฒฝ ๋ณ์ ๋ก๋ | |
load_dotenv() | |
# ๋ก๊ฑฐ ์ค์ | |
logger = logging.getLogger("DeepSeekLLM") | |
if not logger.hasHandlers(): | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
class DeepSeekLLM: | |
"""DeepSeek API ๋ํผ ํด๋์ค""" | |
def __init__(self): | |
"""DeepSeek LLM ํด๋์ค ์ด๊ธฐํ""" | |
self.api_key = os.getenv("DEEPSEEK_API_KEY") | |
self.api_base = os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com") | |
self.model = os.getenv("DEEPSEEK_MODEL", "deepseek-chat") | |
if not self.api_key: | |
logger.warning("DeepSeek API ํค๊ฐ .env ํ์ผ์ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
logger.warning("DEEPSEEK_API_KEY๋ฅผ ํ์ธํ์ธ์.") | |
else: | |
logger.info("DeepSeek API ํค ๋ก๋ ์๋ฃ.") | |
def chat_completion( | |
self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.7, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> Dict[str, Any]: | |
""" | |
DeepSeek ์ฑํ ์์ฑ API ํธ์ถ | |
Args: | |
messages: ์ฑํ ๋ฉ์์ง ๋ชฉ๋ก | |
temperature: ์์ฑ ์จ๋ (๋ฎ์์๋ก ๊ฒฐ์ ์ ) | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
API ์๋ต (๋์ ๋๋ฆฌ) | |
""" | |
if not self.api_key: | |
logger.error("API ํค๊ฐ ์ค์ ๋์ง ์์ DeepSeek API๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค.") | |
raise ValueError("DeepSeek API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
try: | |
logger.info(f"DeepSeek API ์์ฒญ ์ ์ก ์ค (๋ชจ๋ธ: {self.model})") | |
# API ์์ฒญ ํค๋ ๋ฐ ๋ฐ์ดํฐ ์ค๋น | |
headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": self.model, | |
"messages": messages, | |
"temperature": temperature, | |
"max_tokens": max_tokens | |
} | |
# ์ถ๊ฐ ๋งค๊ฐ๋ณ์ ๋ณํฉ | |
for key, value in kwargs.items(): | |
if key not in data: | |
data[key] = value | |
# API ์์ฒญ ๋ณด๋ด๊ธฐ | |
endpoint = f"{self.api_base}/v1/chat/completions" | |
response = requests.post( | |
endpoint, | |
headers=headers, | |
json=data | |
) | |
# ์๋ต ๊ฒ์ฆ | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
logger.error(f"DeepSeek API ์์ฒญ ์คํจ: {e}") | |
raise Exception(f"DeepSeek API ์์ฒญ ์คํจ: {e}") | |
except json.JSONDecodeError as e: | |
logger.error(f"DeepSeek API ์๋ต ํ์ฑ ์คํจ: {e}") | |
raise Exception(f"DeepSeek API ์๋ต ํ์ฑ ์คํจ: {e}") | |
def generate( | |
self, | |
prompt: str, | |
system_prompt: Optional[str] = None, | |
temperature: float = 0.7, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> str: | |
""" | |
๊ฐ๋จํ ํ ์คํธ ์์ฑ ์ธํฐํ์ด์ค | |
Args: | |
prompt: ์ฌ์ฉ์ ํ๋กฌํํธ | |
system_prompt: ์์คํ ํ๋กฌํํธ (์ ํ ์ฌํญ) | |
temperature: ์์ฑ ์จ๋ | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
์์ฑ๋ ํ ์คํธ | |
""" | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
try: | |
response = self.chat_completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
# ์๋ต ๊ฒ์ฆ ๋ฐ ์ฒ๋ฆฌ | |
if not response or 'choices' not in response or not response['choices']: | |
logger.error("DeepSeek API ์๋ต์์ ์์ฑ๋ ํ ์คํธ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
return "" | |
return response['choices'][0]['message']['content'].strip() | |
except Exception as e: | |
logger.error(f"ํ ์คํธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return f"์ค๋ฅ: {str(e)}" | |
def rag_generate( | |
self, | |
query: str, | |
context: List[str], | |
system_prompt: Optional[str] = None, | |
temperature: float = 0.3, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> str: | |
""" | |
RAG ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํ์ฉํ ํ ์คํธ ์์ฑ | |
Args: | |
query: ์ฌ์ฉ์ ์ง์ | |
context: ๊ฒ์๋ ๋ฌธ๋งฅ ๋ชฉ๋ก | |
system_prompt: ์์คํ ํ๋กฌํํธ (์ ํ ์ฌํญ) | |
temperature: ์์ฑ ์จ๋ | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
์์ฑ๋ ํ ์คํธ | |
""" | |
if not system_prompt: | |
system_prompt = """๋น์ ์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ๋ต๋ณํ๋ ๋์ฐ๋ฏธ์ ๋๋ค. | |
- ๊ฒ์ ๊ฒฐ๊ณผ๋ <context> ํ๊ทธ ์์ ์ ๊ณต๋ฉ๋๋ค. | |
- ๊ฒ์ ๊ฒฐ๊ณผ์ ๋ต๋ณ์ด ์์ผ๋ฉด ํด๋น ์ ๋ณด๋ฅผ ์ฌ์ฉํ์ฌ ๋ช ํํ๊ฒ ๋ต๋ณํ์ธ์. | |
- ๊ฒ์ ๊ฒฐ๊ณผ์ ๋ต๋ณ์ด ์์ผ๋ฉด "๊ฒ์ ๊ฒฐ๊ณผ์ ๊ด๋ จ ์ ๋ณด๊ฐ ์์ต๋๋ค"๋ผ๊ณ ๋งํ์ธ์. | |
- ๊ฒ์ ๋ด์ฉ์ ๊ทธ๋๋ก ๋ณต์ฌํ์ง ๋ง๊ณ , ์์ฐ์ค๋ฌ์ด ํ๊ตญ์ด๋ก ๋ต๋ณ์ ์์ฑํ์ธ์. | |
- ๋ต๋ณ์ ๊ฐ๊ฒฐํ๊ณ ์ ํํ๊ฒ ์ ๊ณตํ์ธ์.""" | |
# ์ค์: ์ปจํ ์คํธ ๊ธธ์ด ์ ํ | |
max_context = 10 | |
if len(context) > max_context: | |
logger.warning(f"์ปจํ ์คํธ๊ฐ ๋๋ฌด ๊ธธ์ด ์ฒ์ {max_context}๊ฐ๋ง ์ฌ์ฉํฉ๋๋ค.") | |
context = context[:max_context] | |
# ๊ฐ ์ปจํ ์คํธ ์ก์ธ์ค | |
limited_context = [] | |
for i, doc in enumerate(context): | |
# ๊ฐ ๋ฌธ์๋ฅผ 1000์๋ก ์ ํ | |
if len(doc) > 1000: | |
logger.warning(f"๋ฌธ์ {i+1}์ ๊ธธ์ด๊ฐ ์ ํ๋์์ต๋๋ค ({len(doc)} -> 1000)") | |
doc = doc[:1000] + "...(์๋ต)" | |
limited_context.append(doc) | |
context_text = "\n\n".join([f"๋ฌธ์ {i+1}: {doc}" for i, doc in enumerate(limited_context)]) | |
prompt = f"""์ง๋ฌธ: {query} | |
<context> | |
{context_text} | |
</context> | |
์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ง๋ฌธ์ ๋ต๋ณํด ์ฃผ์ธ์.""" | |
try: | |
return self.generate( | |
prompt=prompt, | |
system_prompt=system_prompt, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
except Exception as e: | |
logger.error(f"RAG ํ ์คํธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return f"์ค๋ฅ: {str(e)}" | |