File size: 4,397 Bytes
f2a2588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
from abc import ABC, abstractmethod
from google import genai
from google.genai import types
from pydantic import BaseModel
class LLMClient(ABC):
"""
Abstract base class for calling LLM APIs.
"""
def __init__(self, config: dict = None):
"""
Initializes the LLMClient with a configuration dictionary.
Args:
config (dict): Configuration settings for the LLM client.
"""
self.config = config or {}
@abstractmethod
def call_api(self, prompt: str) -> str:
"""
Call the underlying LLM API with the given prompt.
Args:
prompt (str): The prompt or input text for the LLM.
Returns:
str: The response from the LLM.
"""
pass
class GeminiLLMClient(LLMClient):
"""
Concrete implementation of LLMClient for the Gemini API.
"""
def __init__(self, config: dict):
"""
Initializes the GeminiLLMClient with an API key, model name, and optional generation settings.
Args:
config (dict): Configuration containing:
- 'api_key': (optional) API key for Gemini (falls back to GEMINI_API_KEY env var)
- 'model_name': (optional) the model to use (default 'gemini-2.0-flash')
- 'generation_config': (optional) dict of GenerateContentConfig parameters
"""
api_key = config.get("api_key") or os.environ.get("GEMINI_API_KEY")
if not api_key:
raise ValueError(
"API key for Gemini must be provided in config['api_key'] or GEMINI_API_KEY env var."
)
self.client = genai.Client(api_key=api_key)
self.model_name = config.get("model_name", "gemini-2.0-flash")
# allow custom generation settings, fallback to sensible defaults
gen_conf = config.get("generation_config", {})
self.generate_config = types.GenerateContentConfig(
response_mime_type=gen_conf.get("response_mime_type", "text/plain"),
temperature=gen_conf.get("temperature"),
max_output_tokens=gen_conf.get("max_output_tokens"),
top_p=gen_conf.get("top_p"),
top_k=gen_conf.get("top_k"),
# add any other fields you want to expose
)
def call_api(self, prompt: str) -> str:
"""
Call the Gemini API with the given prompt (non-streaming).
Args:
prompt (str): The input text for the API.
Returns:
str: The generated text from the Gemini API.
"""
contents = [
types.Content(
role="user",
parts=[types.Part.from_text(text=prompt)],
)
]
# Non-streaming call returns a full response object
response = self.client.models.generate_content(
model=self.model_name,
contents=contents,
config=self.generate_config,
)
# Combine all output parts into a single string
return response.text
class AIExtractor:
def __init__(self, llm_client: LLMClient, prompt_template: str):
"""
Initializes the AIExtractor with a specific LLM client and configuration.
Args:
llm_client (LLMClient): An instance of a class that implements the LLMClient interface.
prompt_template (str): The template to use for generating prompts for the LLM.
should contain placeholders for dynamic content.
e.g., "Extract the following information: {content} based on schema: {schema}"
"""
self.llm_client = llm_client
self.prompt_template = prompt_template
def extract(self, content: str, schema: BaseModel) -> str:
"""
Extracts structured information from the given content based on the provided schema.
Args:
content (str): The raw content to extract information from.
schema (BaseModel): A Pydantic model defining the structure of the expected output.
Returns:
str: The structured JSON object as a string.
"""
prompt = self.prompt_template.format(content=content, schema=schema.model_json_schema())
# print(f"Generated prompt: {prompt}")
response = self.llm_client.call_api(prompt)
return response
|