File size: 4,397 Bytes
f2a2588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from abc import ABC, abstractmethod
from google import genai
from google.genai import types
from pydantic import BaseModel

class LLMClient(ABC):
    """
    Abstract base class for calling LLM APIs.
    """
    def __init__(self, config: dict = None):
        """
        Initializes the LLMClient with a configuration dictionary.
        
        Args:
            config (dict): Configuration settings for the LLM client.
        """
        self.config = config or {}

    @abstractmethod
    def call_api(self, prompt: str) -> str:
        """
        Call the underlying LLM API with the given prompt.
        
        Args:
            prompt (str): The prompt or input text for the LLM.

        Returns:
            str: The response from the LLM.
        """
        pass


class GeminiLLMClient(LLMClient):
    """
    Concrete implementation of LLMClient for the Gemini API.
    """

    def __init__(self, config: dict):
        """
        Initializes the GeminiLLMClient with an API key, model name, and optional generation settings.

        Args:
            config (dict): Configuration containing:
                - 'api_key': (optional) API key for Gemini (falls back to GEMINI_API_KEY env var)
                - 'model_name': (optional) the model to use (default 'gemini-2.0-flash')
                - 'generation_config': (optional) dict of GenerateContentConfig parameters
        """
        api_key = config.get("api_key") or os.environ.get("GEMINI_API_KEY")
        if not api_key:
            raise ValueError(
                "API key for Gemini must be provided in config['api_key'] or GEMINI_API_KEY env var."
            )
        self.client = genai.Client(api_key=api_key)
        self.model_name = config.get("model_name", "gemini-2.0-flash")
        # allow custom generation settings, fallback to sensible defaults
        gen_conf = config.get("generation_config", {})
        self.generate_config = types.GenerateContentConfig(
            response_mime_type=gen_conf.get("response_mime_type", "text/plain"),
            temperature=gen_conf.get("temperature"),
            max_output_tokens=gen_conf.get("max_output_tokens"),
            top_p=gen_conf.get("top_p"),
            top_k=gen_conf.get("top_k"),
            # add any other fields you want to expose
        )

    def call_api(self, prompt: str) -> str:
        """
        Call the Gemini API with the given prompt (non-streaming).

        Args:
            prompt (str): The input text for the API.

        Returns:
            str: The generated text from the Gemini API.
        """
        contents = [
            types.Content(
                role="user",
                parts=[types.Part.from_text(text=prompt)],
            )
        ]

        # Non-streaming call returns a full response object
        response = self.client.models.generate_content(
            model=self.model_name,
            contents=contents,
            config=self.generate_config,
        )

        # Combine all output parts into a single string
        return response.text

        

class AIExtractor:
    def __init__(self, llm_client: LLMClient, prompt_template: str):
        """
        Initializes the AIExtractor with a specific LLM client and configuration.

        Args:
            llm_client (LLMClient): An instance of a class that implements the LLMClient interface.
            prompt_template (str): The template to use for generating prompts for the LLM.
            should contain placeholders for dynamic content. 
            e.g., "Extract the following information: {content} based on schema: {schema}"
        """
        self.llm_client = llm_client
        self.prompt_template = prompt_template

    def extract(self, content: str, schema: BaseModel) -> str:
        """
        Extracts structured information from the given content based on the provided schema.

        Args:
            content (str): The raw content to extract information from.
            schema (BaseModel): A Pydantic model defining the structure of the expected output.

        Returns:
            str: The structured JSON object as a string.
        """
        prompt = self.prompt_template.format(content=content, schema=schema.model_json_schema())
        # print(f"Generated prompt: {prompt}")
        response = self.llm_client.call_api(prompt)
        return response