File size: 3,801 Bytes
fa529c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41f63c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import re
import yaml
from dataclasses import dataclass, field
from typing import Dict, List, Set, Optional

@dataclass
class PromptTemplate:
    """
    A template class for managing and validating LLM prompts.

    This class handles:
    - Storing system and user prompts
    - Validating required template variables
    - Formatting prompts with provided variables

    Attributes:
        system_prompt (str): The system-level instructions for the LLM
        user_template (str): Template string with variables in {variable} format
    """
    system_prompt: str
    user_template: str

    def __post_init__(self):
        """Initialize the set of required variables from the template."""
        self.required_variables: Set[str] = self._get_required_variables()

    def _get_required_variables(self) -> set:
        """
        Extract required variables from the template using regex.

        Returns:
            set: Set of variable names found in the template

        Example:
            Template "Write about {topic} in {style}" returns {'topic', 'style'}
        """
        return set(re.findall(r'\{(\w+)\}', self.user_template))

    def _validate_variables(self, provided_vars: Dict):
        """
        Ensure all required template variables are provided.

        Args:
            provided_vars: Dictionary of variable names and values

        Raises:
            ValueError: If any required variables are missing
        """
        provided_keys = set(provided_vars.keys())
        missing_vars = self.required_variables - provided_keys
        if missing_vars:
            error_msg = (
                f"\nPrompt Template Error:\n"
                f"Missing required variables: {', '.join(missing_vars)}\n"
                f"Template requires: {', '.join(self.required_variables)}\n"
                f"You provided: {', '.join(provided_keys)}\n"
                f"Template string: '{self.user_template}'"
            )
            raise ValueError(error_msg)

    def format(self, **kwargs) -> List[Dict[str, str]]:
        """
        Format the prompt template with provided variables.

        Args:
            **kwargs: Key-value pairs for template variables

        Returns:
            List[Dict[str, str]]: Formatted messages ready for LLM API

        Example:
            template.format(topic="AI", style="academic")
        """
        self._validate_variables(kwargs)

        try:
            formatted_user_message = self.user_template.format(**kwargs)
        except Exception as e:
            raise ValueError(f"Error formatting template: {str(e)}")

        return [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": formatted_user_message}
        ]


def load_prompt(yaml_path: str, version: str = None) -> tuple[PromptTemplate, dict]:
    """
    Load prompt configuration from YAML file.

    Args:
        yaml_path: Path to YAML configuration file
        version: Specific version to load (defaults to 'current_version')

    Returns:
        tuple: (PromptTemplate instance, generation parameters dictionary)

    Example:
        prompt, params = load_prompt('prompts.yaml', version='v2')
    """
    with open(yaml_path, 'r') as f:
        data = yaml.safe_load(f)

    # Use specified version or fall back to current_version
    version_to_use = version or data.get('current_version')
    if version_to_use not in data:
        raise KeyError(f"Version '{version_to_use}' not found in {yaml_path}")

    version_data = data[version_to_use]

    prompt = PromptTemplate(
        system_prompt=version_data['system_prompt'],
        user_template=version_data['user_template']
    )

    generation_params = version_data.get('generation_params', {})

    return prompt, generation_params