File size: 11,328 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import asyncio
import math
from typing import Any, Dict, List, Optional, Union

from starfish import StructuredLLM


async def generate_topics(
    user_instruction: str,
    num_topics: int,
    model_name: str = "openai/gpt-4o-mini",
    model_kwargs: Optional[Dict[str, Any]] = None,
    existing_topics: Optional[List[str]] = None,
) -> List[str]:
    """Generate unique topics based on user instructions using a StructuredLLM model."""
    if model_kwargs is None:
        model_kwargs = {}
    if "temperature" not in model_kwargs:
        model_kwargs["temperature"] = 1
    existing_topics = existing_topics or []

    if num_topics <= 0:
        return []

    # Calculate batches needed (5 topics per batch)
    llm_batch_size = 5
    num_batches = math.ceil(num_topics / llm_batch_size)
    generated_topics = []

    for _ in range(num_batches):
        topic_generator = StructuredLLM(
            model_name=model_name,
            prompt="""Can you generate a list of topics about {{user_instruction}}
                  {% if existing_topics_str %}
                  Please do not generate topics that are already in the list: {{existing_topics_str}}
                  Make sure the topics are unique and vary from each other
                  {% endif %}
                """,
            output_schema=[{"name": "topic", "type": "str"}],
            model_kwargs=model_kwargs,
        )

        all_existing = existing_topics + generated_topics
        input_params = {"user_instruction": user_instruction, "num_records": min(llm_batch_size, num_topics - len(generated_topics))}

        if all_existing:
            input_params["existing_topics_str"] = ",".join(all_existing)

        topic_response = await topic_generator.run(**input_params)
        topic_data = [item.get("topic") for item in topic_response.data]
        generated_topics.extend(topic_data)

        if len(generated_topics) >= num_topics:
            break

    return generated_topics


async def prepare_topic(
    topics: Optional[List[Union[str, Dict[str, int]]]] = None,
    num_records: Optional[int] = None,
    records_per_topic: int = 20,
    user_instruction: Optional[str] = None,
    model_name: str = "openai/gpt-4o-mini",
    model_kwargs: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, str]]:
    """Split records into topics, generating topics if none are provided or if needed.

    Supported input formats:
    1. String list: ['topic1', 'topic2'] - Topics with equal or calculated distribution
    2. Dict list: [{'topic1': 20}, {'topic2': 30}] - Topics with specific counts
    3. Mixed: ['topic1', {'topic2': 30}] - Combination of both formats
    4. None: No topics provided, will generate based on user_instruction

    Args:
        topics: Optional list of topics, either strings or {topic: count} dicts
        num_records: Total number of records to split (required for dict topics or None topics)
        records_per_topic: Number of records per topic (default: 20)
        user_instruction: Topic generation instructions (required if topics is None)
        model_name: Model name for topic generation
        model_kwargs: Model kwargs for topic generation

    Returns:
        List of {'topic': topic_name} dictionaries, with one entry per record
    """
    if model_kwargs is None:
        model_kwargs = {}
    if "temperature" not in model_kwargs:
        model_kwargs["temperature"] = 1
    # --- STEP 1: Input validation and normalization ---
    if topics is None:
        # Must have num_records and user_instruction if no topics provided
        if not num_records or num_records <= 0:
            raise ValueError("num_records must be positive when topics are not provided")
        if not user_instruction:
            raise ValueError("user_instruction required when topics are not provided")
        topic_assignments = []
    else:
        # Validate topics is a non-empty list
        if not isinstance(topics, list) or not topics:
            raise ValueError("topics must be a non-empty list")

        # Convert all topic inputs to a standardized [(topic_name, count)] list
        # For string topics: count will be None (to be calculated later)
        # For dict topics: use the specified count
        topic_assignments = []
        seen_topics = set()

        for topic in topics:
            if isinstance(topic, str):
                if topic not in seen_topics:
                    topic_assignments.append((topic, None))
                    seen_topics.add(topic)
            elif isinstance(topic, dict) and len(topic) == 1:
                topic_name = next(iter(topic))
                count = topic[topic_name]

                if not isinstance(count, int) or count < 0:
                    raise ValueError(f"Topic '{topic_name}' has invalid count {count}")

                if topic_name not in seen_topics:
                    topic_assignments.append((topic_name, count))
                    seen_topics.add(topic_name)
            else:
                raise ValueError("Topics must be strings or single-key dictionaries")

    # --- STEP 2: Calculate or validate counts for provided topics ---
    result = []
    assigned_count = 0
    topic_names = []  # Track all assigned topic names

    if topic_assignments:
        # Handle string topics with no count (None) - assign counts based on input
        string_topics = [(name, count) for name, count in topic_assignments if count is None]
        dict_topics = [(name, count) for name, count in topic_assignments if count is not None]

        # Case: String topics with no num_records - assign records_per_topic to each
        if string_topics and num_records is None:
            for name, _ in string_topics:
                result.append({name: records_per_topic})
                topic_names.append(name)
                assigned_count += records_per_topic

        # Case: String topics with num_records - distribute evenly
        elif string_topics and num_records is not None:
            remaining = num_records - sum(count for _, count in dict_topics if count is not None)
            if remaining < 0:
                raise ValueError("Dict topic counts exceed num_records")

            # Distribute remaining records among string topics
            if string_topics and remaining > 0:
                base = remaining // len(string_topics)
                extra = remaining % len(string_topics)

                for i, (name, _) in enumerate(string_topics):
                    count = base + (1 if i < extra else 0)
                    if count > 0:
                        result.append({name: count})
                        topic_names.append(name)
                        assigned_count += count

        # Add dictionary topics with predefined counts
        for name, count in dict_topics:
            if count > 0:
                result.append({name: count})
                topic_names.append(name)
                assigned_count += count

        # Validate total count for dictionary topics
        if dict_topics and num_records is None:
            raise ValueError("num_records required when using dictionary topics")

        if num_records is not None and assigned_count > num_records:
            raise ValueError(f"Total assigned count ({assigned_count}) exceeds num_records ({num_records})")

    # --- STEP 3: Generate topics for remaining records if needed ---
    remaining_records = 0 if num_records is None else num_records - assigned_count

    if remaining_records > 0:
        if records_per_topic <= 0:
            raise ValueError("records_per_topic must be positive when generating topics")

        # Generate topics with LLM if instructions provided
        if user_instruction:
            topics_needed = math.ceil(remaining_records / records_per_topic)

            generated = await generate_topics(
                user_instruction=user_instruction, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names
            )

            # Assign counts to generated topics
            for topic in generated:
                if topic in topic_names:  # Skip if duplicate (shouldn't happen with proper LLM)
                    print(f"Skipping duplicate generated topic: {topic}")
                    continue

                count = min(records_per_topic, remaining_records)
                if count <= 0:
                    break

                result.append({topic: count})
                topic_names.append(topic)
                remaining_records -= count
                assigned_count += count

        # Generate auto-topics for any still-remaining records
        auto_index = 1
        while remaining_records > 0:
            # Find next available auto_topic name
            auto_name = f"auto_topic{auto_index}"
            while auto_name in topic_names:
                auto_index += 1
                auto_name = f"auto_topic{auto_index}"

            count = min(records_per_topic, remaining_records)
            result.append({auto_name: count})
            topic_names.append(auto_name)
            remaining_records -= count
            assigned_count += count
            auto_index += 1

    # Final validation
    if num_records is not None and assigned_count != num_records:
        print(f"Warning: Assigned {assigned_count} records, expected {num_records}")

    flatten_topic_list = []
    for item in result:
        for key, count in item.items():
            flatten_topic_list.extend([{"topic": key}] * count)

    return flatten_topic_list


if __name__ == "__main__":
    print("--- Running Examples ---")

    # Example 1: Dictionary topics with additional generation
    print("\nExample 1: Dictionary topics + generation")
    topics1 = [{"topic1": 20}, {"topic2": 30}]
    result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instruction="some context"))
    print(f"Result: {result1}")
    print(f"Total: {len(result1)}")

    # Example 2: String topics with even distribution
    print("\nExample 2: String topics with distribution")
    topics2 = ["topicA", "topicB", "topicC"]
    result2 = asyncio.run(prepare_topic(topics=topics2, num_records=10))
    print(f"Result: {result2}")
    print(f"Total: {len(result2)}")

    # Example 3: Mixed string and dict topics
    print("\nExample 3: Mixed string/dict topics")
    topics3 = ["topicX", {"topicY": 10}]
    result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instruction="mixed topics"))
    print(f"Result: {result3}")
    print(f"Total: {len(result3)}")

    # Example 4: String topics with fixed count
    print("\nExample 4: String topics with fixed count")
    topics4 = ["apple", "banana", "cherry"]
    result4 = asyncio.run(prepare_topic(topics=topics4, records_per_topic=15))
    print(f"Result: {result4}")
    print(f"Total: {len(result4)}")

    # Example 5: No topics, generate all
    print("\nExample 5: No topics, generate all")

    async def run_example5():
        result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instruction="cloud computing")
        print(f"Result: {result}")
        print(f"Total: {len(result)}")

    asyncio.run(run_example5())

    print("\n--- Examples Finished ---")