Spaces:
Running
Running
import asyncio | |
import math | |
from typing import Any, Dict, List, Optional, Union | |
from starfish import StructuredLLM | |
async def generate_topics( | |
user_instruction: str, | |
num_topics: int, | |
model_name: str = "openai/gpt-4o-mini", | |
model_kwargs: Optional[Dict[str, Any]] = None, | |
existing_topics: Optional[List[str]] = None, | |
) -> List[str]: | |
"""Generate unique topics based on user instructions using a StructuredLLM model.""" | |
if model_kwargs is None: | |
model_kwargs = {} | |
if "temperature" not in model_kwargs: | |
model_kwargs["temperature"] = 1 | |
existing_topics = existing_topics or [] | |
if num_topics <= 0: | |
return [] | |
# Calculate batches needed (5 topics per batch) | |
llm_batch_size = 5 | |
num_batches = math.ceil(num_topics / llm_batch_size) | |
generated_topics = [] | |
for _ in range(num_batches): | |
topic_generator = StructuredLLM( | |
model_name=model_name, | |
prompt="""Can you generate a list of topics about {{user_instruction}} | |
{% if existing_topics_str %} | |
Please do not generate topics that are already in the list: {{existing_topics_str}} | |
Make sure the topics are unique and vary from each other | |
{% endif %} | |
""", | |
output_schema=[{"name": "topic", "type": "str"}], | |
model_kwargs=model_kwargs, | |
) | |
all_existing = existing_topics + generated_topics | |
input_params = {"user_instruction": user_instruction, "num_records": min(llm_batch_size, num_topics - len(generated_topics))} | |
if all_existing: | |
input_params["existing_topics_str"] = ",".join(all_existing) | |
topic_response = await topic_generator.run(**input_params) | |
topic_data = [item.get("topic") for item in topic_response.data] | |
generated_topics.extend(topic_data) | |
if len(generated_topics) >= num_topics: | |
break | |
return generated_topics | |
async def prepare_topic( | |
topics: Optional[List[Union[str, Dict[str, int]]]] = None, | |
num_records: Optional[int] = None, | |
records_per_topic: int = 20, | |
user_instruction: Optional[str] = None, | |
model_name: str = "openai/gpt-4o-mini", | |
model_kwargs: Optional[Dict[str, Any]] = None, | |
) -> List[Dict[str, str]]: | |
"""Split records into topics, generating topics if none are provided or if needed. | |
Supported input formats: | |
1. String list: ['topic1', 'topic2'] - Topics with equal or calculated distribution | |
2. Dict list: [{'topic1': 20}, {'topic2': 30}] - Topics with specific counts | |
3. Mixed: ['topic1', {'topic2': 30}] - Combination of both formats | |
4. None: No topics provided, will generate based on user_instruction | |
Args: | |
topics: Optional list of topics, either strings or {topic: count} dicts | |
num_records: Total number of records to split (required for dict topics or None topics) | |
records_per_topic: Number of records per topic (default: 20) | |
user_instruction: Topic generation instructions (required if topics is None) | |
model_name: Model name for topic generation | |
model_kwargs: Model kwargs for topic generation | |
Returns: | |
List of {'topic': topic_name} dictionaries, with one entry per record | |
""" | |
if model_kwargs is None: | |
model_kwargs = {} | |
if "temperature" not in model_kwargs: | |
model_kwargs["temperature"] = 1 | |
# --- STEP 1: Input validation and normalization --- | |
if topics is None: | |
# Must have num_records and user_instruction if no topics provided | |
if not num_records or num_records <= 0: | |
raise ValueError("num_records must be positive when topics are not provided") | |
if not user_instruction: | |
raise ValueError("user_instruction required when topics are not provided") | |
topic_assignments = [] | |
else: | |
# Validate topics is a non-empty list | |
if not isinstance(topics, list) or not topics: | |
raise ValueError("topics must be a non-empty list") | |
# Convert all topic inputs to a standardized [(topic_name, count)] list | |
# For string topics: count will be None (to be calculated later) | |
# For dict topics: use the specified count | |
topic_assignments = [] | |
seen_topics = set() | |
for topic in topics: | |
if isinstance(topic, str): | |
if topic not in seen_topics: | |
topic_assignments.append((topic, None)) | |
seen_topics.add(topic) | |
elif isinstance(topic, dict) and len(topic) == 1: | |
topic_name = next(iter(topic)) | |
count = topic[topic_name] | |
if not isinstance(count, int) or count < 0: | |
raise ValueError(f"Topic '{topic_name}' has invalid count {count}") | |
if topic_name not in seen_topics: | |
topic_assignments.append((topic_name, count)) | |
seen_topics.add(topic_name) | |
else: | |
raise ValueError("Topics must be strings or single-key dictionaries") | |
# --- STEP 2: Calculate or validate counts for provided topics --- | |
result = [] | |
assigned_count = 0 | |
topic_names = [] # Track all assigned topic names | |
if topic_assignments: | |
# Handle string topics with no count (None) - assign counts based on input | |
string_topics = [(name, count) for name, count in topic_assignments if count is None] | |
dict_topics = [(name, count) for name, count in topic_assignments if count is not None] | |
# Case: String topics with no num_records - assign records_per_topic to each | |
if string_topics and num_records is None: | |
for name, _ in string_topics: | |
result.append({name: records_per_topic}) | |
topic_names.append(name) | |
assigned_count += records_per_topic | |
# Case: String topics with num_records - distribute evenly | |
elif string_topics and num_records is not None: | |
remaining = num_records - sum(count for _, count in dict_topics if count is not None) | |
if remaining < 0: | |
raise ValueError("Dict topic counts exceed num_records") | |
# Distribute remaining records among string topics | |
if string_topics and remaining > 0: | |
base = remaining // len(string_topics) | |
extra = remaining % len(string_topics) | |
for i, (name, _) in enumerate(string_topics): | |
count = base + (1 if i < extra else 0) | |
if count > 0: | |
result.append({name: count}) | |
topic_names.append(name) | |
assigned_count += count | |
# Add dictionary topics with predefined counts | |
for name, count in dict_topics: | |
if count > 0: | |
result.append({name: count}) | |
topic_names.append(name) | |
assigned_count += count | |
# Validate total count for dictionary topics | |
if dict_topics and num_records is None: | |
raise ValueError("num_records required when using dictionary topics") | |
if num_records is not None and assigned_count > num_records: | |
raise ValueError(f"Total assigned count ({assigned_count}) exceeds num_records ({num_records})") | |
# --- STEP 3: Generate topics for remaining records if needed --- | |
remaining_records = 0 if num_records is None else num_records - assigned_count | |
if remaining_records > 0: | |
if records_per_topic <= 0: | |
raise ValueError("records_per_topic must be positive when generating topics") | |
# Generate topics with LLM if instructions provided | |
if user_instruction: | |
topics_needed = math.ceil(remaining_records / records_per_topic) | |
generated = await generate_topics( | |
user_instruction=user_instruction, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names | |
) | |
# Assign counts to generated topics | |
for topic in generated: | |
if topic in topic_names: # Skip if duplicate (shouldn't happen with proper LLM) | |
print(f"Skipping duplicate generated topic: {topic}") | |
continue | |
count = min(records_per_topic, remaining_records) | |
if count <= 0: | |
break | |
result.append({topic: count}) | |
topic_names.append(topic) | |
remaining_records -= count | |
assigned_count += count | |
# Generate auto-topics for any still-remaining records | |
auto_index = 1 | |
while remaining_records > 0: | |
# Find next available auto_topic name | |
auto_name = f"auto_topic{auto_index}" | |
while auto_name in topic_names: | |
auto_index += 1 | |
auto_name = f"auto_topic{auto_index}" | |
count = min(records_per_topic, remaining_records) | |
result.append({auto_name: count}) | |
topic_names.append(auto_name) | |
remaining_records -= count | |
assigned_count += count | |
auto_index += 1 | |
# Final validation | |
if num_records is not None and assigned_count != num_records: | |
print(f"Warning: Assigned {assigned_count} records, expected {num_records}") | |
flatten_topic_list = [] | |
for item in result: | |
for key, count in item.items(): | |
flatten_topic_list.extend([{"topic": key}] * count) | |
return flatten_topic_list | |
if __name__ == "__main__": | |
print("--- Running Examples ---") | |
# Example 1: Dictionary topics with additional generation | |
print("\nExample 1: Dictionary topics + generation") | |
topics1 = [{"topic1": 20}, {"topic2": 30}] | |
result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instruction="some context")) | |
print(f"Result: {result1}") | |
print(f"Total: {len(result1)}") | |
# Example 2: String topics with even distribution | |
print("\nExample 2: String topics with distribution") | |
topics2 = ["topicA", "topicB", "topicC"] | |
result2 = asyncio.run(prepare_topic(topics=topics2, num_records=10)) | |
print(f"Result: {result2}") | |
print(f"Total: {len(result2)}") | |
# Example 3: Mixed string and dict topics | |
print("\nExample 3: Mixed string/dict topics") | |
topics3 = ["topicX", {"topicY": 10}] | |
result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instruction="mixed topics")) | |
print(f"Result: {result3}") | |
print(f"Total: {len(result3)}") | |
# Example 4: String topics with fixed count | |
print("\nExample 4: String topics with fixed count") | |
topics4 = ["apple", "banana", "cherry"] | |
result4 = asyncio.run(prepare_topic(topics=topics4, records_per_topic=15)) | |
print(f"Result: {result4}") | |
print(f"Total: {len(result4)}") | |
# Example 5: No topics, generate all | |
print("\nExample 5: No topics, generate all") | |
async def run_example5(): | |
result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instruction="cloud computing") | |
print(f"Result: {result}") | |
print(f"Total: {len(result)}") | |
asyncio.run(run_example5()) | |
print("\n--- Examples Finished ---") | |