Spaces:
Sleeping
Sleeping
""" | |
DeepSeek LLM API ํด๋ผ์ด์ธํธ ๋ชจ๋ | |
""" | |
import os | |
import json | |
import logging | |
import requests | |
from typing import List, Dict, Any, Optional, Union | |
from dotenv import load_dotenv | |
# ํ๊ฒฝ ๋ณ์ ๋ก๋ | |
load_dotenv() | |
# ๋ก๊ฑฐ ์ค์ | |
logger = logging.getLogger("DeepSeekLLM") | |
if not logger.hasHandlers(): | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
class DeepSeekLLM: | |
"""DeepSeek LLM API ๋ํผ ํด๋์ค""" | |
def __init__(self): | |
"""DeepSeek LLM ํด๋์ค ์ด๊ธฐํ""" | |
self.api_key = os.getenv("DEEPSEEK_API_KEY") | |
self.endpoint = os.getenv("DEEPSEEK_ENDPOINT", "https://api.deepseek.com/v1/chat/completions") | |
self.model = os.getenv("DEEPSEEK_MODEL", "deepseek-chat") | |
if not self.api_key: | |
logger.warning("DeepSeek API ํค๊ฐ .env ํ์ผ์ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
logger.warning("DEEPSEEK_API_KEY๋ฅผ ํ์ธํ์ธ์.") | |
else: | |
logger.info("DeepSeek LLM API ํค ๋ก๋ ์๋ฃ.") | |
def chat_completion( | |
self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.7, | |
max_tokens: int = 1000, | |
stream: bool = False, | |
**kwargs | |
) -> Dict[str, Any]: | |
""" | |
DeepSeek ์ฑํ ์์ฑ API ํธ์ถ | |
Args: | |
messages: ์ฑํ ๋ฉ์์ง ๋ชฉ๋ก | |
temperature: ์์ฑ ์จ๋ (๋ฎ์์๋ก ๊ฒฐ์ ์ ) | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
stream: ์คํธ๋ฆฌ๋ฐ ์๋ต ํ์ฑํ ์ฌ๋ถ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
API ์๋ต (๋์ ๋๋ฆฌ) | |
""" | |
if not self.api_key: | |
logger.error("API ํค๊ฐ ์ค์ ๋์ง ์์ DeepSeek API๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค.") | |
raise ValueError("DeepSeek API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"model": self.model, | |
"messages": messages, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"stream": stream, | |
**kwargs | |
} | |
try: | |
logger.info(f"DeepSeek API ์์ฒญ ์ ์ก ์ค: {self.endpoint}") | |
response = requests.post( | |
self.endpoint, | |
headers=headers, | |
json=payload, | |
timeout=60 # ํ์์์ ์ค์ | |
) | |
response.raise_for_status() | |
if stream: | |
return response # ์คํธ๋ฆฌ๋ฐ ์๋ต์ ์์ ์๋ต ๊ฐ์ฒด ๋ฐํ | |
else: | |
return response.json() | |
except requests.exceptions.Timeout: | |
logger.error("DeepSeek API ์์ฒญ ์๊ฐ ์ด๊ณผ") | |
raise TimeoutError("DeepSeek API ์์ฒญ ์๊ฐ ์ด๊ณผ") | |
except requests.exceptions.RequestException as e: | |
logger.error(f"DeepSeek API ์์ฒญ ์คํจ: {e}") | |
if hasattr(e, 'response') and e.response is not None: | |
logger.error(f"์๋ต ์ฝ๋: {e.response.status_code}, ๋ด์ฉ: {e.response.text}") | |
raise ConnectionError(f"DeepSeek API ์์ฒญ ์คํจ: {e}") | |
def generate( | |
self, | |
prompt: str, | |
system_prompt: Optional[str] = None, | |
temperature: float = 0.7, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> str: | |
""" | |
๊ฐ๋จํ ํ ์คํธ ์์ฑ ์ธํฐํ์ด์ค | |
Args: | |
prompt: ์ฌ์ฉ์ ํ๋กฌํํธ | |
system_prompt: ์์คํ ํ๋กฌํํธ (์ ํ ์ฌํญ) | |
temperature: ์์ฑ ์จ๋ | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
์์ฑ๋ ํ ์คํธ | |
""" | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
try: | |
response = self.chat_completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
if not response or "choices" not in response or not response["choices"]: | |
logger.error("DeepSeek API ์๋ต์์ ์์ฑ๋ ํ ์คํธ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
return "" | |
return response["choices"][0]["message"]["content"].strip() | |
except Exception as e: | |
logger.error(f"ํ ์คํธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return f"์ค๋ฅ: {str(e)}" | |
def rag_generate( | |
self, | |
query: str, | |
context: List[str], | |
system_prompt: Optional[str] = None, | |
temperature: float = 0.3, | |
max_tokens: int = 1000, | |
**kwargs | |
) -> str: | |
""" | |
RAG ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํ์ฉํ ํ ์คํธ ์์ฑ | |
Args: | |
query: ์ฌ์ฉ์ ์ง์ | |
context: ๊ฒ์๋ ๋ฌธ๋งฅ ๋ชฉ๋ก | |
system_prompt: ์์คํ ํ๋กฌํํธ (์ ํ ์ฌํญ) | |
temperature: ์์ฑ ์จ๋ | |
max_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์ | |
**kwargs: ์ถ๊ฐ API ๋งค๊ฐ๋ณ์ | |
Returns: | |
์์ฑ๋ ํ ์คํธ | |
""" | |
if not system_prompt: | |
system_prompt = """๋น์ ์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ๋ต๋ณํ๋ ๋์ฐ๋ฏธ์ ๋๋ค. | |
- ๊ฒ์ ๊ฒฐ๊ณผ๋ <context> ํ๊ทธ ์์ ์ ๊ณต๋ฉ๋๋ค. | |
- ๊ฒ์ ๊ฒฐ๊ณผ์ ๋ต๋ณ์ด ์์ผ๋ฉด ํด๋น ์ ๋ณด๋ฅผ ์ฌ์ฉํ์ฌ ๋ช ํํ๊ฒ ๋ต๋ณํ์ธ์. | |
- ๊ฒ์ ๊ฒฐ๊ณผ์ ๋ต๋ณ์ด ์์ผ๋ฉด "๊ฒ์ ๊ฒฐ๊ณผ์ ๊ด๋ จ ์ ๋ณด๊ฐ ์์ต๋๋ค"๋ผ๊ณ ๋งํ์ธ์. | |
- ๊ฒ์ ๋ด์ฉ์ ๊ทธ๋๋ก ๋ณต์ฌํ์ง ๋ง๊ณ , ์์ฐ์ค๋ฌ์ด ํ๊ตญ์ด๋ก ๋ต๋ณ์ ์์ฑํ์ธ์. | |
- ๋ต๋ณ์ ๊ฐ๊ฒฐํ๊ณ ์ ํํ๊ฒ ์ ๊ณตํ์ธ์.""" | |
context_text = "\n\n".join([f"๋ฌธ์ {i+1}: {doc}" for i, doc in enumerate(context)]) | |
prompt = f"""์ง๋ฌธ: {query} | |
<context> | |
{context_text} | |
</context> | |
์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ง๋ฌธ์ ๋ต๋ณํด ์ฃผ์ธ์.""" | |
try: | |
return self.generate( | |
prompt=prompt, | |
system_prompt=system_prompt, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
**kwargs | |
) | |
except Exception as e: | |
logger.error(f"RAG ํ ์คํธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return f"์ค๋ฅ: {str(e)}" | |