Spaces:
Running
Running
| import gradio as gr | |
| import json | |
| import tempfile | |
| import os | |
| import re # For parsing conversation | |
| from typing import Union, Optional, Dict, Tuple # Import Dict and Tuple | |
| # Import the actual functions from synthgen | |
| from synthgen import ( | |
| generate_synthetic_text, | |
| generate_prompts, | |
| generate_synthetic_conversation, | |
| generate_corpus_content # Import the new function | |
| ) | |
| # We no longer need to import api_key here or check it directly in app.py | |
| # --- Helper Functions for JSON Generation --- | |
| # Use Union for Python < 3.10 compatibility | |
| def create_json_file(data: object, base_filename: str) -> Union[str, None]: | |
| """Creates a temporary JSON file and returns its path.""" | |
| try: | |
| # Create a temporary file with a .json extension | |
| with tempfile.NamedTemporaryFile(mode='w', suffix=".json", delete=False, encoding='utf-8') as temp_file: | |
| json.dump(data, temp_file, indent=4, ensure_ascii=False) | |
| return temp_file.name # Return the path to the temporary file | |
| except Exception as e: | |
| print(f"Error creating JSON file {base_filename}: {e}") | |
| return None | |
| # Add the missing function definition | |
| def create_text_file(data: str, base_filename: str) -> Union[str, None]: | |
| """Creates a temporary text file and returns its path.""" | |
| try: | |
| # Ensure filename ends with .txt | |
| if not base_filename.lower().endswith(".txt"): | |
| base_filename += ".txt" # Append if missing for clarity, though suffix handles it | |
| # Create a temporary file with a .txt extension | |
| with tempfile.NamedTemporaryFile(mode='w', suffix=".txt", delete=False, encoding='utf-8') as temp_file: | |
| temp_file.write(data) | |
| return temp_file.name # Return the path to the temporary file | |
| except Exception as e: | |
| print(f"Error creating text file {base_filename}: {e}") | |
| return None | |
| def parse_conversation_string(text: str) -> list[dict]: | |
| """Parses a multi-line conversation string into a list of message dictionaries.""" | |
| messages = [] | |
| # Regex to capture "User:" or "Assistant:" at the start of a line, followed by content | |
| pattern = re.compile(r"^(User|Assistant):\s*(.*)$", re.IGNORECASE | re.MULTILINE) | |
| matches = pattern.finditer(text) | |
| for match in matches: | |
| role = match.group(1).lower() | |
| content = match.group(2).strip() | |
| messages.append({"role": role, "content": content}) | |
| # If parsing fails or format is unexpected, return raw text in a single message? | |
| # Or return empty list? Let's return what we found. | |
| if not messages and text: # If regex found nothing but text exists | |
| print(f"Warning: Could not parse conversation structure for: '{text[:100]}...'") | |
| # Fallback: return the whole text as a single assistant message? Or user? | |
| # Let's return a generic system message indicating the raw content | |
| # return [{"role": "system", "content": f"Unparsed conversation text: {text}"}] | |
| # Or maybe just return empty, TBD based on preference | |
| pass # Return empty list if parsing fails for now | |
| return messages | |
| # Wrapper for text generation (remains largely the same, but error handling is improved in synthgen) | |
| def run_generation(prompt: str, model: str, num_samples: int) -> str: | |
| """ | |
| Wrapper function for Gradio interface to generate multiple text samples. | |
| Relies on generate_synthetic_text for API calls and error handling. | |
| """ | |
| if not prompt: | |
| return "Error: Please enter a prompt." | |
| if num_samples <= 0: | |
| return "Error: Number of samples must be positive." | |
| output = f"Generating {num_samples} samples using model '{model}'...\n" | |
| output += "="*20 + "\n\n" | |
| # generate_synthetic_text now handles API errors internally | |
| for i in range(num_samples): | |
| # The function returns the text or an error string starting with "Error:" | |
| generated_text = generate_synthetic_text(prompt, model) | |
| output += f"--- Sample {i+1} ---\n" | |
| output += generated_text + "\n\n" # Append result directly | |
| output += "="*20 + "\nGeneration complete (check results above for errors)." | |
| return output | |
| # Removed the placeholder backend functions (generate_prompts_backend, generate_single_conversation) | |
| # Modified function to handle multiple conversation prompts using the real backend | |
| def run_conversation_generation(system_prompts_text: str, model: str, num_turns: int) -> str: | |
| """ | |
| Wrapper function for Gradio interface to generate multiple conversations | |
| based on a list of prompts, calling generate_synthetic_conversation. | |
| """ | |
| if not system_prompts_text: | |
| return "Error: Please enter or generate at least one system prompt/topic." | |
| if num_turns <= 0: | |
| return "Error: Number of turns must be positive." | |
| prompts = [p.strip() for p in system_prompts_text.strip().split('\n') if p.strip()] | |
| if not prompts: | |
| return "Error: No valid prompts found in the input." | |
| output = f"Generating {len(prompts)} conversations ({num_turns} turns each) using model '{model}'...\n" | |
| output += "="*40 + "\n\n" | |
| for i, prompt in enumerate(prompts): | |
| # Call the actual function from synthgen.py | |
| # It handles API calls and returns the conversation or an error string. | |
| conversation_text = generate_synthetic_conversation(prompt, model, num_turns) | |
| # We don't need a try-except here because the function itself returns error strings | |
| # The title is now included within the returned string from the function | |
| output += f"--- Conversation {i+1}/{len(prompts)} ---\n" | |
| output += conversation_text + "\n\n" # Append result directly | |
| output += "="*40 + "\nGeneration complete (check results above for errors)." | |
| return output | |
| # Helper function for the Gradio UI to generate prompts using the real backend | |
| def generate_prompts_ui( | |
| num_prompts: int, | |
| model: str, | |
| temperature: float, # Add settings | |
| top_p: float, | |
| max_tokens: int | |
| ) -> str: | |
| """UI Wrapper to call the generate_prompts backend and format for Textbox.""" | |
| # Handle optional settings | |
| temp_val = temperature if temperature > 0 else None | |
| top_p_val = top_p if 0 < top_p <= 1 else None | |
| # Use a specific max_tokens for prompt generation or pass from UI? Let's pass from UI | |
| max_tokens_val = max_tokens if max_tokens > 0 else 200 # Set a default if UI value is 0 | |
| if not model: | |
| return "Error: Please select a model for prompt generation." | |
| if num_prompts <= 0: | |
| return "Error: Number of prompts to generate must be positive." | |
| if num_prompts > 50: | |
| return "Error: Cannot generate more than 50 prompts at a time." | |
| print(f"Generating prompts with settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val}") # Debug print | |
| try: | |
| # Call the actual function from synthgen.py, passing settings | |
| prompts_list = generate_prompts( | |
| num_prompts, | |
| model, | |
| temperature=temp_val, | |
| top_p=top_p_val, | |
| max_tokens=max_tokens_val | |
| ) | |
| return "\n".join(prompts_list) | |
| except ValueError as e: | |
| # Catch errors raised by generate_prompts (e.g., API errors, parsing errors) | |
| return f"Error generating prompts: {e}" | |
| except Exception as e: | |
| # Catch any other unexpected errors | |
| print(f"Unexpected error in generate_prompts_ui: {e}") | |
| return f"An unexpected error occurred: {e}" | |
| # --- Modified Generation Wrappers --- | |
| # Wrapper for text generation + JSON preparation - RETURNS TUPLE | |
| def run_generation_and_prepare_json( | |
| prompt: str, | |
| model: str, | |
| num_samples: int, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int | |
| ) -> Tuple[gr.update, gr.update]: # Return type hint (optional) | |
| """Generates text samples and prepares a JSON file for download.""" | |
| # Handle optional settings | |
| temp_val = temperature if temperature > 0 else None | |
| top_p_val = top_p if 0 < top_p <= 1 else None | |
| max_tokens_val = max_tokens if max_tokens > 0 else None | |
| # Handle errors by returning updates for both outputs in a tuple | |
| if not prompt: | |
| return (gr.update(value="Error: Please enter a prompt."), gr.update(value=None)) | |
| if num_samples <= 0: | |
| return (gr.update(value="Error: Number of samples must be positive."), gr.update(value=None)) | |
| output_str = f"Generating {num_samples} samples using model '{model}'...\n" | |
| output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n" | |
| output_str += "="*20 + "\n\n" | |
| results_list = [] | |
| for i in range(num_samples): | |
| generated_text = generate_synthetic_text( | |
| prompt, model, temperature=temp_val, top_p=top_p_val, max_tokens=max_tokens_val | |
| ) | |
| output_str += f"--- Sample {i+1} ---\n" | |
| output_str += generated_text + "\n\n" | |
| if not generated_text.startswith("Error:"): | |
| results_list.append(generated_text) | |
| output_str += "="*20 + "\nGeneration complete (check results above for errors)." | |
| json_filepath = create_json_file(results_list, "text_samples.json") | |
| # Return tuple of updates in the order of outputs list | |
| return (gr.update(value=output_str), gr.update(value=json_filepath)) | |
| # Wrapper for conversation generation + JSON preparation - RETURNS TUPLE | |
| def run_conversation_generation_and_prepare_json( | |
| system_prompts_text: str, | |
| model: str, | |
| num_turns: int, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int | |
| ) -> Tuple[gr.update, gr.update]: # Return type hint (optional) | |
| """Generates conversations and prepares a JSON file for download.""" | |
| temp_val = temperature if temperature > 0 else None | |
| top_p_val = top_p if 0 < top_p <= 1 else None | |
| max_tokens_val = max_tokens if max_tokens > 0 else None | |
| # Handle errors by returning updates for both outputs in a tuple | |
| if not system_prompts_text: | |
| return (gr.update(value="Error: Please enter or generate at least one system prompt/topic."), gr.update(value=None)) | |
| if num_turns <= 0: | |
| return (gr.update(value="Error: Number of turns must be positive."), gr.update(value=None)) | |
| prompts = [p.strip() for p in system_prompts_text.strip().split('\n') if p.strip()] | |
| if not prompts: | |
| return (gr.update(value="Error: No valid prompts found in the input."), gr.update(value=None)) | |
| output_str = f"Generating {len(prompts)} conversations ({num_turns} turns each) using model '{model}'...\n" | |
| output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n" | |
| output_str += "="*40 + "\n\n" | |
| results_list_structured = [] | |
| for i, prompt in enumerate(prompts): | |
| conversation_text = generate_synthetic_conversation( | |
| prompt, model, num_turns, temperature=temp_val, top_p=top_p_val, max_tokens=max_tokens_val | |
| ) | |
| output_str += f"--- Conversation {i+1}/{len(prompts)} ---\n" | |
| output_str += conversation_text + "\n\n" | |
| # --- Parsing Logic --- | |
| core_conversation_text = conversation_text | |
| if conversation_text.startswith("Error:"): core_conversation_text = None | |
| elif "\n\n" in conversation_text: | |
| parts = conversation_text.split("\n\n", 1) | |
| core_conversation_text = parts[1] if len(parts) > 1 else conversation_text | |
| if core_conversation_text: | |
| messages = parse_conversation_string(core_conversation_text) | |
| if messages: results_list_structured.append({"prompt": prompt, "messages": messages}) | |
| else: results_list_structured.append({"prompt": prompt, "error": "Failed to parse structure.", "raw_text": core_conversation_text}) | |
| elif conversation_text.startswith("Error:"): results_list_structured.append({"prompt": prompt, "error": conversation_text}) | |
| else: results_list_structured.append({"prompt": prompt, "error": "Could not extract content.", "raw_text": conversation_text}) | |
| # --- End Parsing Logic --- | |
| output_str += "="*40 + "\nGeneration complete (check results above for errors)." | |
| json_filepath = create_json_file(results_list_structured, "conversations.json") | |
| # Return tuple of updates in the order of outputs list | |
| return (gr.update(value=output_str), gr.update(value=json_filepath)) | |
| # Define content_type_labels globally for use in UI and wrapper functions | |
| content_type_labels = { | |
| "Corpus Snippets": "# Snippets", | |
| "Short Story": "Approx Words", | |
| "Article": "Approx Words" | |
| } | |
| content_type_defaults = { | |
| "Corpus Snippets": 5, | |
| "Short Story": 1000, # Match new backend default | |
| "Article": 1500 # Match new backend default | |
| } | |
| # Wrapper for Corpus/Content Generation | |
| def run_corpus_generation_and_prepare_file( | |
| topic: str, | |
| content_type: str, | |
| length_param: int, | |
| model: str, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int | |
| ) -> Tuple[gr.update, gr.update]: | |
| """Generates corpus/story/article content and prepares a file for download.""" | |
| temp_val = temperature if temperature > 0 else None | |
| top_p_val = top_p if 0 < top_p <= 1 else None | |
| max_tokens_val = max_tokens if max_tokens > 0 else None | |
| # Use the global dictionary for error messages | |
| label_for_error = content_type_labels.get(content_type, 'Length Param') | |
| if not topic: return (gr.update(value="Error: Please enter a topic."), gr.update(value=None)) | |
| if not content_type: return (gr.update(value="Error: Please select a content type."), gr.update(value=None)) | |
| if length_param <= 0: return (gr.update(value=f"Error: Please enter a positive value for '{label_for_error}'."), gr.update(value=None)) | |
| print(f"Generating {content_type} about '{topic}'...") | |
| output_str = f"Generating {content_type} about '{topic}' using model '{model}'...\n" | |
| output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n" + "="*40 + "\n\n" | |
| generated_content = generate_corpus_content( | |
| topic=topic, content_type=content_type, length_param=length_param, model=model, | |
| temperature=temp_val, top_p=top_p_val, max_tokens=max_tokens_val | |
| ) | |
| output_str += generated_content | |
| file_path = None | |
| if not generated_content.startswith("Error:"): | |
| core_content = generated_content | |
| if "\n\n" in generated_content: parts = generated_content.split("\n\n", 1); core_content = parts[1] if len(parts) > 1 else generated_content | |
| if content_type == "Corpus Snippets": | |
| snippets = [s.strip() for s in core_content.split('---') if s.strip()] | |
| if not snippets: snippets = [s.strip() for s in core_content.split('\n\n') if s.strip()] | |
| corpus_data = {"topic": topic, "snippets": snippets} | |
| file_path = create_json_file(corpus_data, f"{topic}_corpus.json") | |
| else: | |
| file_path = create_text_file(core_content, f"{topic}_{content_type.replace(' ','_')}.txt") | |
| return (gr.update(value=output_str), gr.update(value=file_path)) | |
| # NEW function to update the length parameter label and default value | |
| def update_length_param_ui(content_type: str) -> gr.update: | |
| """Updates the label and default value of the length parameter input.""" | |
| new_label = content_type_labels.get(content_type, "Length Param") | |
| new_value = content_type_defaults.get(content_type, 5) # Default to 5 if type unknown | |
| return gr.update(label=new_label, value=new_value) | |
| # --- Generation Wrappers --- | |
| # ... (generate_prompts_ui, run_generation_and_prepare_json, run_conversation_generation_and_prepare_json remain the same) ... | |
| # NEW UI Wrapper for generating TOPICS | |
| def generate_topics_ui( | |
| num_topics: int, | |
| model: str, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int | |
| ) -> str: | |
| """UI Wrapper to generate diverse topics using the AI.""" | |
| temp_val = temperature if temperature > 0 else None | |
| top_p_val = top_p if 0 < top_p <= 1 else None | |
| max_tokens_val = max_tokens if max_tokens > 0 else 150 # Limit token for topic list | |
| if not model: | |
| return "Error: Please select a model for topic generation." | |
| if num_topics <= 0: | |
| return "Error: Number of topics to generate must be positive." | |
| if num_topics > 50: # Keep limit reasonable | |
| return "Error: Cannot generate more than 50 topics at a time." | |
| print(f"Generating {num_topics} topics with settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val}") | |
| # Instruction focused on generating topics | |
| instruction = ( | |
| f"Generate exactly {num_topics} diverse and interesting topics suitable for generating content like articles, stories, or corpus snippets. " | |
| f"Each topic should be concise (a few words to a short phrase). " | |
| f"Present each topic on a new line, with no other introductory or concluding text or numbering." | |
| f"\n\nExamples:\n" | |
| f"The future of renewable energy\n" | |
| f"The history of the Silk Road\n" | |
| f"The impact of social media on mental health" | |
| ) | |
| system_msg = "You are an expert topic generator. Follow the user's instructions precisely." | |
| try: | |
| # Use the core text generation function | |
| generated_text = generate_synthetic_text( | |
| instruction, | |
| model, | |
| system_message=system_msg, | |
| temperature=temp_val, | |
| top_p=top_p_val, | |
| max_tokens=max_tokens_val | |
| ) | |
| if generated_text.startswith("Error:"): | |
| raise ValueError(generated_text) # Propagate error | |
| # Split into lines and clean up | |
| topics_list = [t.strip() for t in generated_text.strip().split('\n') if t.strip()] | |
| if not topics_list: | |
| print(f"Warning: Failed to parse topics from generated text. Raw text:\n{generated_text}") | |
| raise ValueError("AI failed to generate topics in the expected format.") | |
| # Return newline-separated string for the Textbox | |
| return "\n".join(topics_list[:num_topics]) # Truncate if needed | |
| except ValueError as e: | |
| return f"Error generating topics: {e}" | |
| except Exception as e: | |
| print(f"Unexpected error in generate_topics_ui: {e}") | |
| return f"An unexpected error occurred: {e}" | |
| # Modified Wrapper for Bulk Corpus/Content Generation | |
| def run_bulk_content_generation_and_prepare_json( | |
| topics_text: str, # Renamed from topic | |
| content_type: str, | |
| length_param: int, | |
| model: str, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int | |
| ) -> Tuple[gr.update, gr.update]: | |
| """Generates content for multiple topics and prepares a JSON file.""" | |
| temp_val = temperature if temperature > 0 else None | |
| top_p_val = top_p if 0 < top_p <= 1 else None | |
| max_tokens_val = max_tokens if max_tokens > 0 else None | |
| # --- Input Validation --- | |
| if not topics_text: | |
| return (gr.update(value="Error: Please enter or generate at least one topic."), gr.update(value=None)) | |
| if not content_type: | |
| return (gr.update(value="Error: Please select a content type."), gr.update(value=None)) | |
| topics = [t.strip() for t in topics_text.strip().split('\n') if t.strip()] | |
| if not topics: | |
| return (gr.update(value="Error: No valid topics found in the input."), gr.update(value=None)) | |
| label_for_error = content_type_labels.get(content_type, 'Length Param') | |
| if length_param <= 0: | |
| return (gr.update(value=f"Error: Please enter a positive value for '{label_for_error}'."), gr.update(value=None)) | |
| # --- End Validation --- | |
| output_str = f"Generating {content_type} for {len(topics)} topics using model '{model}'...\n" | |
| output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n" + "="*40 + "\n\n" | |
| bulk_results = [] # Store results for JSON | |
| # --- Loop through topics --- | |
| for i, topic in enumerate(topics): | |
| print(f"Generating {content_type} for topic {i+1}/{len(topics)}: '{topic}'...") | |
| output_str += f"--- Topic {i+1}/{len(topics)}: '{topic}' ---\n" | |
| generated_content_full = generate_corpus_content( # Returns string including title/error | |
| topic=topic, content_type=content_type, length_param=length_param, model=model, | |
| temperature=temp_val, top_p=top_p_val, max_tokens=max_tokens_val | |
| ) | |
| output_str += generated_content_full + "\n\n" # Add full result to textbox | |
| # --- Prepare structured result for JSON --- | |
| result_entry = {"topic": topic, "content_type": content_type} | |
| if generated_content_full.startswith("Error:"): | |
| result_entry["status"] = "error" | |
| result_entry["error_message"] = generated_content_full | |
| result_entry["content"] = None | |
| else: | |
| result_entry["status"] = "success" | |
| result_entry["error_message"] = None | |
| # Extract core content (remove potential title added by backend) | |
| core_content = generated_content_full | |
| if "\n\n" in generated_content_full: | |
| parts = generated_content_full.split("\n\n", 1) | |
| core_content = parts[1] if len(parts) > 1 else generated_content_full | |
| if content_type == "Corpus Snippets": | |
| snippets = [s.strip() for s in core_content.split('---') if s.strip()] | |
| if not snippets: snippets = [s.strip() for s in core_content.split('\n\n') if s.strip()] | |
| result_entry["content"] = snippets # Store list for corpus | |
| else: | |
| result_entry["content"] = core_content # Store string for story/article | |
| bulk_results.append(result_entry) | |
| # --- End JSON preparation --- | |
| # --- Finalize --- | |
| output_str += "="*40 + f"\nBulk generation complete for {len(topics)} topics." | |
| json_filepath = create_json_file(bulk_results, f"{content_type.replace(' ','_')}_bulk_results.json") | |
| return (gr.update(value=output_str), gr.update(value=json_filepath)) | |
| # --- Gradio Interface Definition --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Synthetic Data Generator using OpenRouter") | |
| gr.Markdown( | |
| "Generate synthetic text samples, conversations, or other content using various models" | |
| ) | |
| # Removed the api_key_loaded check and warning Markdown | |
| # Define model choices (can be shared or specific per tab) | |
| # Consider fetching these dynamically from OpenRouter if possible in the future | |
| model_choices = [ | |
| "deepseek/deepseek-chat-v3-0324:free", # Example free model | |
| "meta-llama/llama-3.3-70b-instruct:free", | |
| "deepseek/deepseek-r1:free", | |
| "google/gemini-2.5-pro-exp-03-25:free", | |
| "qwen/qwen-2.5-72b-instruct:free", | |
| "featherless/qwerky-72b:free", | |
| "google/gemma-3-27b-it:free", | |
| "mistralai/mistral-small-24b-instruct-2501:free", | |
| "deepseek/deepseek-r1-distill-llama-70b:free", | |
| "sophosympatheia/rogue-rose-103b-v0.2:free", | |
| "nvidia/llama-3.1-nemotron-70b-instruct:free", | |
| "microsoft/phi-3-medium-128k-instruct:free", | |
| "undi95/toppy-m-7b:free", | |
| "huggingfaceh4/zephyr-7b-beta:free", | |
| "openrouter/quasar-alpha" | |
| # Add more model IDs as needed | |
| ] | |
| default_model = model_choices[0] if model_choices else None | |
| # --- Shared Model Settings --- | |
| # Use an Accordion for less clutter | |
| with gr.Accordion("Model Settings (Optional)", open=False): | |
| # Set reasonable ranges and defaults. Use 0 for Max Tokens/Top-P to signify 'None'/API default. | |
| temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Controls randomness. Higher values are more creative, lower are more deterministic. 0 means use API default.") | |
| top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="Top-P (Nucleus Sampling)", info="Considers only tokens with cumulative probability mass >= top_p. 0 means use API default.") | |
| max_tokens_slider = gr.Number(value=0, minimum=0, maximum=8192, step=64, label="Max Tokens", info="Maximum number of tokens to generate in the completion. 0 means use API default.") | |
| with gr.Tabs(): | |
| with gr.TabItem("Text Generation"): | |
| with gr.Row(): | |
| prompt_input_text = gr.Textbox(label="Prompt", placeholder="Enter your prompt here (e.g., Generate a short product description for a sci-fi gadget)", lines=3) | |
| with gr.Row(): | |
| model_input_text = gr.Dropdown( | |
| label="OpenRouter Model ID", | |
| choices=model_choices, | |
| value=default_model | |
| ) | |
| num_samples_input_text = gr.Number(label="Number of Samples", value=3, minimum=1, maximum=20, step=1) | |
| generate_button_text = gr.Button("Generate Text Samples") | |
| output_text = gr.Textbox(label="Generated Samples", lines=15, show_copy_button=True) | |
| # Add File component for download | |
| download_file_text = gr.File(label="Download Samples as JSON") | |
| generate_button_text.click( | |
| fn=run_generation_and_prepare_json, | |
| inputs=[ | |
| prompt_input_text, model_input_text, num_samples_input_text, | |
| temperature_slider, top_p_slider, max_tokens_slider # Add settings inputs | |
| ], | |
| outputs=[output_text, download_file_text] | |
| ) | |
| with gr.TabItem("Conversation Generation"): | |
| gr.Markdown("Enter one system prompt/topic per line below, or use the 'Generate Prompts' button.") | |
| with gr.Row(): | |
| # Textbox for multiple prompts | |
| prompt_input_conv = gr.Textbox( | |
| label="Prompts (one per line)", | |
| lines=5, # Make it multi-line | |
| placeholder="Enter prompts here, one per line...\ne.g., Act as a pirate discussing treasure maps.\nDiscuss the future of space travel." | |
| ) | |
| with gr.Row(): | |
| # Input for number of prompts to generate | |
| num_prompts_input_conv = gr.Number(label="Number of Prompts to Generate", value=5, minimum=1, maximum=20, step=1) # Keep max reasonable | |
| # Button to trigger AI prompt generation | |
| generate_prompts_button = gr.Button("Generate Prompts using AI") | |
| with gr.Row(): | |
| # Model selection for conversation generation AND prompt generation | |
| model_input_conv = gr.Dropdown( | |
| label="OpenRouter Model ID (for generation)", | |
| choices=model_choices, | |
| value=default_model | |
| ) | |
| with gr.Row(): | |
| # Input for number of turns per conversation | |
| num_turns_input_conv = gr.Number(label="Number of Turns per Conversation (approx)", value=5, minimum=1, maximum=20, step=1) # Keep max reasonable | |
| # Button to generate the conversations based on the prompts in the Textbox | |
| generate_conversations_button = gr.Button("Generate Conversations") | |
| output_conv = gr.Textbox(label="Generated Conversations", lines=15, show_copy_button=True) | |
| # Add File component for download | |
| download_file_conv = gr.File(label="Download Conversations as JSON") | |
| # Connect the "Generate Prompts" button to the UI wrapper | |
| generate_prompts_button.click( | |
| fn=generate_prompts_ui, # Use the wrapper that calls the real function | |
| inputs=[ | |
| num_prompts_input_conv, model_input_conv, | |
| temperature_slider, top_p_slider, max_tokens_slider # Add settings inputs | |
| ], | |
| outputs=prompt_input_conv | |
| ) | |
| # Connect the "Generate Conversations" button to the real function wrapper | |
| generate_conversations_button.click( | |
| fn=run_conversation_generation_and_prepare_json, # Use the wrapper that calls the real function | |
| inputs=[ | |
| prompt_input_conv, model_input_conv, num_turns_input_conv, | |
| temperature_slider, top_p_slider, max_tokens_slider # Add settings inputs | |
| ], | |
| outputs=[output_conv, download_file_conv] # Output to both Textbox and File | |
| ) | |
| # --- Content Generation Tab (Modified for Bulk) --- | |
| with gr.TabItem("Bulk Content Generation"): | |
| output_content = gr.Textbox(label="Generated Content (Log)", lines=15, show_copy_button=True) | |
| # Output is now always JSON | |
| download_file_content = gr.File(label="Download Results as JSON") | |
| gr.Markdown("Enter one topic per line below, or use the 'Generate Topics' button.") | |
| with gr.Row(): | |
| # Changed to multi-line Textbox | |
| topic_input_content = gr.Textbox( | |
| label="Topics (one per line)", | |
| lines=5, | |
| placeholder="Enter topics here, one per line...\ne.g., The future of renewable energy\nThe history of the Silk Road" | |
| ) | |
| # --- Topic Generation --- | |
| with gr.Accordion("Topic Generation Options", open=False): | |
| with gr.Row(): | |
| num_topics_input = gr.Number(label="# Topics to Generate", value=5, minimum=1, maximum=50, step=1) | |
| # Use shared model selector below and settings | |
| generate_topics_button = gr.Button("Generate Topics using AI") | |
| # --- Generation Settings --- | |
| with gr.Row(): | |
| content_type_choices = list(content_type_labels.keys()) | |
| content_type_input = gr.Dropdown( | |
| label="Content Type", choices=content_type_choices, value=content_type_choices[0] | |
| ) | |
| default_length_label = content_type_labels[content_type_choices[0]] | |
| default_length_value = content_type_defaults[content_type_choices[0]] | |
| length_param_input = gr.Number( | |
| label=default_length_label, value=default_length_value, minimum=1, step=1 | |
| ) | |
| with gr.Row(): | |
| model_input_content = gr.Dropdown(label="Model", choices=model_choices, value=default_model) | |
| # Button to trigger bulk generation | |
| generate_content_button = gr.Button("Generate Bulk Content") | |
| # --- Event Listeners --- | |
| # Listener to update length param UI | |
| content_type_input.change( | |
| fn=update_length_param_ui, | |
| inputs=content_type_input, | |
| outputs=length_param_input | |
| ) | |
| # Listener for topic generation button | |
| generate_topics_button.click( | |
| fn=generate_topics_ui, | |
| inputs=[ # Pass necessary inputs for topic generation | |
| num_topics_input, model_input_content, # Use this tab's model selector | |
| temperature_slider, top_p_slider, max_tokens_slider | |
| ], | |
| outputs=topic_input_content # Output generated topics to the textbox | |
| ) | |
| # Listener for main generation button | |
| generate_content_button.click( | |
| fn=run_bulk_content_generation_and_prepare_json, # Use the new bulk wrapper | |
| inputs=[ | |
| topic_input_content, content_type_input, length_param_input, | |
| model_input_content, | |
| temperature_slider, top_p_slider, max_tokens_slider | |
| ], | |
| outputs=[output_content, download_file_content] | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| print("Launching Gradio App...") | |
| print("Make sure the OPENROUTER_API_KEY environment variable is set.") | |
| # Use share=True for temporary public link if running locally and need to test | |
| demo.launch() # share=True | |