Spaces:
Running
Running
import nest_asyncio | |
import pytest | |
import os | |
from starfish.common.env_loader import load_env_file | |
from starfish import data_gen_template | |
nest_asyncio.apply() | |
load_env_file() | |
async def test_list(): | |
"""Test with input data and broadcast variables | |
- Input: List of dicts with city names | |
- Broadcast: num_records_per_city | |
- Expected: All cities processed successfully | |
""" | |
result = data_gen_template.list() | |
assert len(result) != 0 | |
async def test_list_detail(): | |
"""Test with input data and broadcast variables | |
- Input: List of dicts with city names | |
- Broadcast: num_records_per_city | |
- Expected: All cities processed successfully | |
""" | |
result = data_gen_template.list(is_detail=True) | |
assert len(result) != 0 | |
async def test_get_generate_by_topic_Success(): | |
data_gen_template.list() | |
topic_generator_temp = data_gen_template.get("starfish/generate_by_topic") | |
num_records = 20 | |
input_data = { | |
"user_instruction": "Generate Q&A pairs about machine learning concepts", | |
"num_records": num_records, | |
"records_per_topic": 5, | |
"topics": [ | |
"supervised learning", | |
"unsupervised learning", | |
{"reinforcement learning": 3}, # This means generate 3 records for this topic | |
"neural networks", | |
], | |
"topic_model_name": "openai/gpt-4", | |
"topic_model_kwargs": {"temperature": 0.7}, | |
"generation_model_name": "openai/gpt-4", | |
"generation_model_kwargs": {"temperature": 0.8, "max_tokens": 200}, | |
"output_schema": [ | |
{"name": "question", "type": "str"}, | |
{"name": "answer", "type": "str"}, | |
{"name": "difficulty", "type": "str"}, # Added an additional field | |
], | |
"data_factory_config": {"max_concurrency": 4, "task_runner_timeout": 60 * 2}, | |
} | |
# results = topic_generator_temp.run(input_data.model_dump()) | |
results = await topic_generator_temp.run(input_data) | |
assert len(results) == num_records | |
async def test_get_generate_func_call_dataset(): | |
data_gen_template.list() | |
generate_func_call_dataset = data_gen_template.get("starfish/generate_func_call_dataset") | |
input_data = { | |
"num_records": 4, | |
"api_contract": { | |
"name": "weather_api.get_current_weather", | |
"description": "Retrieves the current weather conditions for a specified location .", | |
"parameters": { | |
"location": {"type": "string", "description": "The name of the city or geographic location .", "required": True}, | |
"units": {"type": "string", "description": "The units for temperature measurement( e.g., 'Celsius', 'Fahrenheit') .", "required": False}, | |
}, | |
}, | |
"topic_model_name": "openai/gpt-4", | |
"topic_model_kwargs": {"temperature": 0.7}, | |
"generation_model_name": "openai/gpt-4o-mini", | |
"generation_model_kwargs": {"temperature": 0.8, "max_tokens": 200}, | |
"data_factory_config": {"max_concurrency": 24, "task_runner_timeout": 60 * 2}, | |
} | |
results = await generate_func_call_dataset.run(input_data) | |
assert len(results) >= input_data["num_records"] | |