John-Jiang's picture
init commit
5301c48
import asyncio
import math
from typing import Any, Dict, List, Optional, Union
from starfish import StructuredLLM
async def generate_topics(
user_instruction: str,
num_topics: int,
model_name: str = "openai/gpt-4o-mini",
model_kwargs: Optional[Dict[str, Any]] = None,
existing_topics: Optional[List[str]] = None,
) -> List[str]:
"""Generate unique topics based on user instructions using a StructuredLLM model."""
if model_kwargs is None:
model_kwargs = {}
if "temperature" not in model_kwargs:
model_kwargs["temperature"] = 1
existing_topics = existing_topics or []
if num_topics <= 0:
return []
# Calculate batches needed (5 topics per batch)
llm_batch_size = 5
num_batches = math.ceil(num_topics / llm_batch_size)
generated_topics = []
for _ in range(num_batches):
topic_generator = StructuredLLM(
model_name=model_name,
prompt="""Can you generate a list of topics about {{user_instruction}}
{% if existing_topics_str %}
Please do not generate topics that are already in the list: {{existing_topics_str}}
Make sure the topics are unique and vary from each other
{% endif %}
""",
output_schema=[{"name": "topic", "type": "str"}],
model_kwargs=model_kwargs,
)
all_existing = existing_topics + generated_topics
input_params = {"user_instruction": user_instruction, "num_records": min(llm_batch_size, num_topics - len(generated_topics))}
if all_existing:
input_params["existing_topics_str"] = ",".join(all_existing)
topic_response = await topic_generator.run(**input_params)
topic_data = [item.get("topic") for item in topic_response.data]
generated_topics.extend(topic_data)
if len(generated_topics) >= num_topics:
break
return generated_topics
async def prepare_topic(
topics: Optional[List[Union[str, Dict[str, int]]]] = None,
num_records: Optional[int] = None,
records_per_topic: int = 20,
user_instruction: Optional[str] = None,
model_name: str = "openai/gpt-4o-mini",
model_kwargs: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, str]]:
"""Split records into topics, generating topics if none are provided or if needed.
Supported input formats:
1. String list: ['topic1', 'topic2'] - Topics with equal or calculated distribution
2. Dict list: [{'topic1': 20}, {'topic2': 30}] - Topics with specific counts
3. Mixed: ['topic1', {'topic2': 30}] - Combination of both formats
4. None: No topics provided, will generate based on user_instruction
Args:
topics: Optional list of topics, either strings or {topic: count} dicts
num_records: Total number of records to split (required for dict topics or None topics)
records_per_topic: Number of records per topic (default: 20)
user_instruction: Topic generation instructions (required if topics is None)
model_name: Model name for topic generation
model_kwargs: Model kwargs for topic generation
Returns:
List of {'topic': topic_name} dictionaries, with one entry per record
"""
if model_kwargs is None:
model_kwargs = {}
if "temperature" not in model_kwargs:
model_kwargs["temperature"] = 1
# --- STEP 1: Input validation and normalization ---
if topics is None:
# Must have num_records and user_instruction if no topics provided
if not num_records or num_records <= 0:
raise ValueError("num_records must be positive when topics are not provided")
if not user_instruction:
raise ValueError("user_instruction required when topics are not provided")
topic_assignments = []
else:
# Validate topics is a non-empty list
if not isinstance(topics, list) or not topics:
raise ValueError("topics must be a non-empty list")
# Convert all topic inputs to a standardized [(topic_name, count)] list
# For string topics: count will be None (to be calculated later)
# For dict topics: use the specified count
topic_assignments = []
seen_topics = set()
for topic in topics:
if isinstance(topic, str):
if topic not in seen_topics:
topic_assignments.append((topic, None))
seen_topics.add(topic)
elif isinstance(topic, dict) and len(topic) == 1:
topic_name = next(iter(topic))
count = topic[topic_name]
if not isinstance(count, int) or count < 0:
raise ValueError(f"Topic '{topic_name}' has invalid count {count}")
if topic_name not in seen_topics:
topic_assignments.append((topic_name, count))
seen_topics.add(topic_name)
else:
raise ValueError("Topics must be strings or single-key dictionaries")
# --- STEP 2: Calculate or validate counts for provided topics ---
result = []
assigned_count = 0
topic_names = [] # Track all assigned topic names
if topic_assignments:
# Handle string topics with no count (None) - assign counts based on input
string_topics = [(name, count) for name, count in topic_assignments if count is None]
dict_topics = [(name, count) for name, count in topic_assignments if count is not None]
# Case: String topics with no num_records - assign records_per_topic to each
if string_topics and num_records is None:
for name, _ in string_topics:
result.append({name: records_per_topic})
topic_names.append(name)
assigned_count += records_per_topic
# Case: String topics with num_records - distribute evenly
elif string_topics and num_records is not None:
remaining = num_records - sum(count for _, count in dict_topics if count is not None)
if remaining < 0:
raise ValueError("Dict topic counts exceed num_records")
# Distribute remaining records among string topics
if string_topics and remaining > 0:
base = remaining // len(string_topics)
extra = remaining % len(string_topics)
for i, (name, _) in enumerate(string_topics):
count = base + (1 if i < extra else 0)
if count > 0:
result.append({name: count})
topic_names.append(name)
assigned_count += count
# Add dictionary topics with predefined counts
for name, count in dict_topics:
if count > 0:
result.append({name: count})
topic_names.append(name)
assigned_count += count
# Validate total count for dictionary topics
if dict_topics and num_records is None:
raise ValueError("num_records required when using dictionary topics")
if num_records is not None and assigned_count > num_records:
raise ValueError(f"Total assigned count ({assigned_count}) exceeds num_records ({num_records})")
# --- STEP 3: Generate topics for remaining records if needed ---
remaining_records = 0 if num_records is None else num_records - assigned_count
if remaining_records > 0:
if records_per_topic <= 0:
raise ValueError("records_per_topic must be positive when generating topics")
# Generate topics with LLM if instructions provided
if user_instruction:
topics_needed = math.ceil(remaining_records / records_per_topic)
generated = await generate_topics(
user_instruction=user_instruction, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names
)
# Assign counts to generated topics
for topic in generated:
if topic in topic_names: # Skip if duplicate (shouldn't happen with proper LLM)
print(f"Skipping duplicate generated topic: {topic}")
continue
count = min(records_per_topic, remaining_records)
if count <= 0:
break
result.append({topic: count})
topic_names.append(topic)
remaining_records -= count
assigned_count += count
# Generate auto-topics for any still-remaining records
auto_index = 1
while remaining_records > 0:
# Find next available auto_topic name
auto_name = f"auto_topic{auto_index}"
while auto_name in topic_names:
auto_index += 1
auto_name = f"auto_topic{auto_index}"
count = min(records_per_topic, remaining_records)
result.append({auto_name: count})
topic_names.append(auto_name)
remaining_records -= count
assigned_count += count
auto_index += 1
# Final validation
if num_records is not None and assigned_count != num_records:
print(f"Warning: Assigned {assigned_count} records, expected {num_records}")
flatten_topic_list = []
for item in result:
for key, count in item.items():
flatten_topic_list.extend([{"topic": key}] * count)
return flatten_topic_list
if __name__ == "__main__":
print("--- Running Examples ---")
# Example 1: Dictionary topics with additional generation
print("\nExample 1: Dictionary topics + generation")
topics1 = [{"topic1": 20}, {"topic2": 30}]
result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instruction="some context"))
print(f"Result: {result1}")
print(f"Total: {len(result1)}")
# Example 2: String topics with even distribution
print("\nExample 2: String topics with distribution")
topics2 = ["topicA", "topicB", "topicC"]
result2 = asyncio.run(prepare_topic(topics=topics2, num_records=10))
print(f"Result: {result2}")
print(f"Total: {len(result2)}")
# Example 3: Mixed string and dict topics
print("\nExample 3: Mixed string/dict topics")
topics3 = ["topicX", {"topicY": 10}]
result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instruction="mixed topics"))
print(f"Result: {result3}")
print(f"Total: {len(result3)}")
# Example 4: String topics with fixed count
print("\nExample 4: String topics with fixed count")
topics4 = ["apple", "banana", "cherry"]
result4 = asyncio.run(prepare_topic(topics=topics4, records_per_topic=15))
print(f"Result: {result4}")
print(f"Total: {len(result4)}")
# Example 5: No topics, generate all
print("\nExample 5: No topics, generate all")
async def run_example5():
result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instruction="cloud computing")
print(f"Result: {result}")
print(f"Total: {len(result)}")
asyncio.run(run_example5())
print("\n--- Examples Finished ---")