File size: 4,944 Bytes
6575706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
LLM ์ธํ„ฐํŽ˜์ด์Šค ๋ชจ๋“ˆ - ๋‹ค์–‘ํ•œ LLM์„ ํ†ตํ•ฉ ๊ด€๋ฆฌ
"""

import os
import logging
from typing import List, Dict, Any, Optional, Union
from dotenv import load_dotenv

# LLM ํด๋ผ์ด์–ธํŠธ ์ž„ํฌํŠธ
from utils.openai_client import OpenAILLM
from utils.deepseek_client import DeepSeekLLM

# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๋กœ๋“œ
load_dotenv()

# ๋กœ๊ฑฐ ์„ค์ •
logger = logging.getLogger("LLMInterface")
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)

class LLMInterface:
    """๋‹ค์–‘ํ•œ LLM API๋ฅผ ํ†ตํ•ฉ ๊ด€๋ฆฌํ•˜๋Š” ์ธํ„ฐํŽ˜์ด์Šค ํด๋ž˜์Šค"""
    
    # ์ง€์›๋˜๋Š” LLM ๋ชฉ๋ก (UI์—์„œ ํ‘œ์‹œ๋  ์ด๋ฆ„๊ณผ ๋‚ด๋ถ€ ์‹๋ณ„์ž)
    SUPPORTED_LLMS = {
        "OpenAI": "openai",
        "DeepSeek": "deepseek"
    }
    
    def __init__(self, default_llm: str = "openai"):
        """LLM ์ธํ„ฐํŽ˜์ด์Šค ์ดˆ๊ธฐํ™”
        
        Args:
            default_llm: ๊ธฐ๋ณธ LLM ์‹๋ณ„์ž ('openai' ๋˜๋Š” 'deepseek')
        """
        # LLM ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
        self.llm_clients = {
            "openai": OpenAILLM(),
            "deepseek": DeepSeekLLM()
        }
        
        # ๊ธฐ๋ณธ LLM ์„ค์ • (์œ ํšจํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ openai๋กœ ์„ค์ •)
        if default_llm not in self.llm_clients:
            logger.warning(f"์ง€์ •๋œ ๊ธฐ๋ณธ LLM '{default_llm}'๊ฐ€ ์œ ํšจํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. 'openai'๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.")
            default_llm = "openai"
        
        self.default_llm = default_llm
        self.current_llm = default_llm
        
        logger.info(f"LLM ์ธํ„ฐํŽ˜์ด์Šค ์ดˆ๊ธฐํ™” ์™„๋ฃŒ, ๊ธฐ๋ณธ LLM: {default_llm}")
    
    def set_llm(self, llm_id: str) -> bool:
        """ํ˜„์žฌ LLM์„ ์„ค์ •
        
        Args:
            llm_id: LLM ์‹๋ณ„์ž
            
        Returns:
            ์„ฑ๊ณต ์—ฌ๋ถ€
        """
        if llm_id not in self.llm_clients:
            logger.error(f"์ง€์›๋˜์ง€ ์•Š๋Š” LLM ์‹๋ณ„์ž: {llm_id}")
            return False
        
        self.current_llm = llm_id
        logger.info(f"ํ˜„์žฌ LLM์ด '{llm_id}'๋กœ ์„ค์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
        return True
    
    def get_current_llm_name(self) -> str:
        """ํ˜„์žฌ LLM์˜ ํ‘œ์‹œ ์ด๋ฆ„ ๋ฐ˜ํ™˜"""
        for name, id in self.SUPPORTED_LLMS.items():
            if id == self.current_llm:
                return name
        return "Unknown"
    
    def get_current_llm_details(self) -> Dict[str, str]:
        """ํ˜„์žฌ LLM์˜ ์„ธ๋ถ€ ์ •๋ณด ๋ฐ˜ํ™˜"""
        name = self.get_current_llm_name()
        model = ""
        
        if self.current_llm == "openai":
            model = self.llm_clients["openai"].model
        elif self.current_llm == "deepseek":
            model = self.llm_clients["deepseek"].model
        
        return {
            "name": name,
            "id": self.current_llm,
            "model": model
        }
    
    def generate(
        self, 
        prompt: str, 
        system_prompt: Optional[str] = None,
        llm_id: Optional[str] = None,
        **kwargs
    ) -> str:
        """ํ…์ŠคํŠธ ์ƒ์„ฑ
        
        Args:
            prompt: ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ
            system_prompt: ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ (์„ ํƒ ์‚ฌํ•ญ)
            llm_id: ์‚ฌ์šฉํ•  LLM ์‹๋ณ„์ž (๋ฏธ์ง€์ • ์‹œ ํ˜„์žฌ LLM ์‚ฌ์šฉ)
            **kwargs: ์ถ”๊ฐ€ ์ธ์ž (temperature, max_tokens ๋“ฑ)
            
        Returns:
            ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ
        """
        # ์‚ฌ์šฉํ•  LLM ๊ฒฐ์ •
        llm_to_use = llm_id if llm_id and llm_id in self.llm_clients else self.current_llm
        llm_client = self.llm_clients[llm_to_use]
        
        # LLM ์ •๋ณด ๋กœ๊น…
        logger.info(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์š”์ฒญ, LLM: {llm_to_use}")
        
        # ์ƒ์„ฑ ์š”์ฒญ
        return llm_client.generate(
            prompt=prompt,
            system_prompt=system_prompt,
            **kwargs
        )
    
    def rag_generate(
        self, 
        query: str, 
        context: List[str],
        llm_id: Optional[str] = None,
        **kwargs
    ) -> str:
        """RAG ๊ธฐ๋ฐ˜ ํ…์ŠคํŠธ ์ƒ์„ฑ
        
        Args:
            query: ์‚ฌ์šฉ์ž ์งˆ์˜
            context: ๊ฒ€์ƒ‰๋œ ๋ฌธ๋งฅ ๋ชฉ๋ก
            llm_id: ์‚ฌ์šฉํ•  LLM ์‹๋ณ„์ž (๋ฏธ์ง€์ • ์‹œ ํ˜„์žฌ LLM ์‚ฌ์šฉ)
            **kwargs: ์ถ”๊ฐ€ ์ธ์ž (temperature, max_tokens ๋“ฑ)
            
        Returns:
            ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ
        """
        # ์‚ฌ์šฉํ•  LLM ๊ฒฐ์ •
        llm_to_use = llm_id if llm_id and llm_id in self.llm_clients else self.current_llm
        llm_client = self.llm_clients[llm_to_use]
        
        # LLM ์ •๋ณด ๋กœ๊น…
        logger.info(f"RAG ํ…์ŠคํŠธ ์ƒ์„ฑ ์š”์ฒญ, LLM: {llm_to_use}")
        
        # ์ƒ์„ฑ ์š”์ฒญ
        return llm_client.rag_generate(
            query=query,
            context=context,
            **kwargs
        )