Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import random | |
| import traceback # Keep traceback for detailed error logging | |
| # Helper function to handle empty values | |
| def safe_value(value, default): | |
| """Return default if value is empty or None""" | |
| if value is None or value == "": | |
| return default | |
| return value | |
| # Get Hugging Face token from environment variable (as fallback) | |
| DEFAULT_HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None) | |
| # Create global variables for model and tokenizer | |
| global_model = None | |
| global_tokenizer = None | |
| model_loaded = False | |
| loaded_model_name = "None" # Keep track of which model was loaded | |
| def load_model(hf_token): | |
| """Load the model with the provided token""" | |
| global global_model, global_tokenizer, model_loaded, loaded_model_name | |
| if not hf_token: | |
| model_loaded = False | |
| loaded_model_name = "None" | |
| return "β οΈ Please enter your Hugging Face token to use the model.", gr.Tabs.update(visible=False) | |
| try: | |
| # Try different model versions from smallest to largest | |
| # Prioritize instruction-tuned models | |
| model_options = [ | |
| "google/gemma-2b-it", | |
| "google/gemma-7b-it", | |
| "google/gemma-2b", | |
| "google/gemma-7b", | |
| # Add a smaller, potentially public model as a last resort | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| ] | |
| print(f"Attempting to load models with token starting with: {hf_token[:5]}...") | |
| loaded_successfully = False | |
| for model_name in model_options: | |
| try: | |
| print(f"\n--- Attempting to load model: {model_name} ---") | |
| is_gemma = "gemma" in model_name.lower() | |
| is_fallback = "tinyllama" in model_name.lower() | |
| current_token = hf_token if is_gemma else None # Only use token for Gemma models | |
| # Load tokenizer | |
| print("Loading tokenizer...") | |
| global_tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| token=current_token | |
| ) | |
| print("Tokenizer loaded successfully.") | |
| # Load model | |
| print(f"Loading model {model_name}...") | |
| global_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, # Use bfloat16 for better performance/compatibility if available | |
| device_map="auto", # Let HF decide device placement | |
| token=current_token | |
| ) | |
| print(f"Model {model_name} loaded successfully!") | |
| model_loaded = True | |
| loaded_model_name = model_name | |
| loaded_successfully = True | |
| if is_fallback: | |
| return f"β Fallback model '{model_name}' loaded successfully! Limited capabilities compared to Gemma.", gr.Tabs.update(visible=True) | |
| else: | |
| return f"β Model '{model_name}' loaded successfully!", gr.Tabs.update(visible=True) | |
| except Exception as specific_e: | |
| print(f"Failed to load {model_name}: {specific_e}") | |
| # traceback.print_exc() # Keep for debugging if needed, but can be verbose | |
| if "401 Client Error" in str(specific_e) and is_gemma: | |
| print("Authentication error likely. Check token and license agreement.") | |
| # Don't immediately fail, try next model | |
| elif "requires you to be logged in" in str(specific_e) and is_gemma: | |
| print("Authentication error likely. Check token and license agreement.") | |
| # Don't immediately fail, try next model | |
| # Continue to the next model option | |
| continue | |
| # If loop finishes without loading | |
| if not loaded_successfully: | |
| model_loaded = False | |
| loaded_model_name = "None" | |
| print("Could not load any model version.") | |
| return "β Could not load any model. Please check your token (ensure it has read permissions and you've accepted Gemma's license on Hugging Face) and network connection.", gr.Tabs.update(visible=False) | |
| except Exception as e: | |
| model_loaded = False | |
| loaded_model_name = "None" | |
| error_msg = str(e) | |
| print(f"Error in load_model: {error_msg}") | |
| traceback.print_exc() | |
| if "401 Client Error" in error_msg or "requires you to be logged in" in error_msg : | |
| return "β Authentication failed. Please check your Hugging Face token and ensure you have accepted the Gemma license agreement on the Hugging Face model page.", gr.Tabs.update(visible=False) | |
| else: | |
| return f"β An unexpected error occurred during model loading: {error_msg}", gr.Tabs.update(visible=False) | |
| def generate_prompt(task_type, **kwargs): | |
| """Generate appropriate prompts based on task type and parameters""" | |
| # Using a dictionary-based approach for cleaner prompt generation | |
| prompts = { | |
| "creative": "Write a {style} about {topic}. Be creative and engaging.", | |
| "informational": "Write an {format_type} about {topic}. Be clear, factual, and informative.", | |
| "summarize": "Summarize the following text concisely:\n\n{text}", | |
| "translate": "Translate the following text to {target_lang}:\n\n{text}", | |
| "qa": "Based on the following text:\n\n{text}\n\nAnswer this question: {question}", | |
| "code_generate": "Write {language} code to {task}. Include comments explaining the code.", | |
| "code_explain": "Explain the following {language} code in simple terms:\n\n```\n{code}\n```", | |
| "code_debug": "Identify and fix the potential bug(s) in the following {language} code. Explain the fix:\n\n```\n{code}\n```", | |
| "brainstorm": "Brainstorm {category} ideas about {topic}. Provide a diverse list.", | |
| "content_creation": "Create a {content_type} about {topic} targeting {audience}. Make it engaging.", | |
| "email_draft": "Draft a professional {email_type} email regarding the following:\n\n{context}", | |
| "document_edit": "Improve the following text for {edit_type}:\n\n{text}", | |
| "explain": "Explain {topic} clearly for a {level} audience.", | |
| "classify": "Classify the following text into one of these categories: {categories}\n\nText: {text}\n\nCategory:", | |
| "data_extract": "Extract the following data points ({data_points}) from the text below:\n\nText: {text}\n\nExtracted Data:", | |
| } | |
| prompt_template = prompts.get(task_type) | |
| if prompt_template: | |
| try: | |
| # Ensure all required keys are present with defaults if necessary | |
| # This prevents KeyError if a kwarg is missing | |
| required_keys = [k[1:-1] for k in prompt_template.replace('{',' ').replace('}',' ').split() if '{' in k and '}' in k] | |
| final_kwargs = {key: kwargs.get(key, f"[{key}]") for key in required_keys} | |
| # Add remaining kwargs that might not be in the template explicitly | |
| final_kwargs.update(kwargs) | |
| return prompt_template.format(**final_kwargs) | |
| except KeyError as e: | |
| print(f"Warning: Missing key for prompt template '{task_type}': {e}") | |
| return kwargs.get("prompt", f"Generate text based on: {kwargs}") # Fallback | |
| else: | |
| # Fallback for custom or undefined task types | |
| return kwargs.get("prompt", "Generate text based on the input.") | |
| def generate_text(prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9): | |
| """Generate text using the loaded model""" | |
| global global_model, global_tokenizer, model_loaded, loaded_model_name | |
| print(f"\n--- Generating Text ---") | |
| print(f"Model: {loaded_model_name}") | |
| print(f"Params: max_new_tokens={max_new_tokens}, temp={temperature}, top_p={top_p}") | |
| print(f"Prompt (start): {prompt[:150]}...") | |
| if not model_loaded or global_model is None or global_tokenizer is None: | |
| print("Model not loaded error.") | |
| return "β οΈ Model not loaded. Please authenticate first." | |
| if not prompt: | |
| return "β οΈ Please enter a prompt or configure a task." | |
| try: | |
| # Add role/turn indicators if using an instruction-tuned model | |
| if "it" in loaded_model_name.lower() or "chat" in loaded_model_name.lower(): | |
| # Simple chat structure assumed by many instruction models | |
| chat_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" | |
| else: | |
| # Base models might not need specific turn indicators | |
| chat_prompt = prompt | |
| inputs = global_tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=True).to(global_model.device) | |
| input_length = inputs.input_ids.shape[1] | |
| print(f"Input token length: {input_length}") | |
| # Adjust max_length based on input, prevent it from being too small | |
| # max_length = max(input_length + 64, input_length + max_new_tokens) # Ensure at least some generation | |
| # Use max_new_tokens directly as it's clearer for users | |
| # Ensure max_new_tokens isn't excessively large for the model context | |
| # Gemma 2B/7B context is often 8192, TinyLlama 2048 | |
| # Let's cap generation length for stability | |
| effective_max_new_tokens = min(max_new_tokens, 2048) # Cap generation length | |
| generation_args = { | |
| "input_ids": inputs.input_ids, | |
| "attention_mask": inputs.attention_mask, # Include attention mask | |
| "max_new_tokens": effective_max_new_tokens, | |
| "do_sample": True, | |
| "temperature": float(temperature), # Ensure float | |
| "top_p": float(top_p), # Ensure float | |
| "pad_token_id": global_tokenizer.eos_token_id # Use EOS token for padding | |
| } | |
| print(f"Generation args: {generation_args}") | |
| # Generate text | |
| with torch.no_grad(): # Disable gradient calculation for inference | |
| outputs = global_model.generate(**generation_args) | |
| # Decode response, skipping special tokens and the prompt | |
| # Decode only the newly generated tokens | |
| generated_ids = outputs[0, input_length:] | |
| generated_text = global_tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| print(f"Generated text length: {len(generated_text)}") | |
| print(f"Generated text (start): {generated_text[:150]}...") | |
| return generated_text.strip() # Remove leading/trailing whitespace | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"Generation error: {error_msg}") | |
| print(f"Error type: {type(e)}") | |
| traceback.print_exc() | |
| return f"β Error during text generation: {error_msg}\n\nPlease check the logs or try adjusting parameters (e.g., reduce Max Tokens)." | |
| # Create parameters UI component (reusable function) | |
| def create_parameter_ui(): | |
| with gr.Accordion("β¨ Generation Parameters", open=False): | |
| with gr.Row(): | |
| # Renamed max_length to max_new_tokens for clarity with HF generate API | |
| max_new_tokens = gr.Slider( | |
| minimum=64, | |
| maximum=2048, # Set a reasonable max limit | |
| value=512, # Default to a moderate length | |
| step=64, | |
| label="Max New Tokens", | |
| info="Max number of tokens to generate.", | |
| elem_id="max_new_tokens_slider" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, # Avoid 0 which disables sampling | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness. Lower is more focused, higher is more diverse.", | |
| elem_id="temperature_slider" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, # Can be 1.0 | |
| value=0.9, | |
| step=0.05, | |
| label="Top-P (Nucleus Sampling)", | |
| info="Considers tokens with cumulative probability >= top_p.", | |
| elem_id="top_p_slider" | |
| ) | |
| return [max_new_tokens, temperature, top_p] | |
| # --- Gradio Interface --- | |
| # Use the soft theme for a clean look, allow light/dark switching | |
| with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo: | |
| # Header | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1 style="font-size: 2.5em; font-weight: bold;"> | |
| <span style="font-size: 1.5em;">π€</span> Gemma Capabilities Demo | |
| </h1> | |
| <p style="font-size: 1.1em; color: #555;"> | |
| Explore the text generation capabilities of Google's Gemma models (or a fallback). | |
| </p> | |
| <p style="font-size: 0.9em; color: #777;"> | |
| Requires a Hugging Face token with access to Gemma models. | |
| <a href="https://huggingface.co/google/gemma-7b-it" target="_blank">[Accept Gemma License Here]</a> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| # --- Authentication Section --- | |
| with gr.Group(variant="panel"): # Use panel variant for visual grouping | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| hf_token = gr.Textbox( | |
| label="Hugging Face Token", | |
| placeholder="Paste your HF token here (hf_...)", | |
| type="password", | |
| value=DEFAULT_HF_TOKEN, | |
| elem_id="hf_token_input" | |
| ) | |
| with gr.Column(scale=1, min_width=150): | |
| # Add spacer for alignment if needed, or adjust scale | |
| # gr.Spacer(height=10) # Add space above button if needed | |
| auth_button = gr.Button("Load Model", variant="primary", elem_id="auth_button") | |
| auth_status = gr.Markdown("βΉοΈ Enter your Hugging Face token and click 'Load Model'. This might take a minute.", elem_id="auth_status") | |
| # Define authentication flow (simplified) | |
| def handle_auth(token): | |
| # Show loading message immediately | |
| yield "β³ Authenticating and loading model... Please wait.", gr.Tabs.update(visible=False) | |
| # Call the actual model loading function | |
| status_message, tabs_update = load_model(token) | |
| yield status_message, tabs_update | |
| # Link button click to the handler | |
| auth_button.click( | |
| fn=handle_auth, | |
| inputs=[hf_token], | |
| outputs=[auth_status, gr.get_component("main_tabs")], # Update status and hide/show main_tabs by element id | |
| queue=True # Run in queue for potentially long operation | |
| ) | |
| # --- Main Content Tabs (Initially Hidden) --- | |
| # Use gr.Tabs with visible=False initially | |
| with gr.Tabs(elem_id="main_tabs", visible=False) as tabs: | |
| # --- Text Generation Tab --- | |
| with gr.TabItem("π Creative & Informational", id="tab_text_gen"): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Configure Task") | |
| text_gen_type = gr.Radio( | |
| ["Creative Writing", "Informational Writing", "Custom Prompt"], | |
| label="Writing Type", | |
| value="Creative Writing", | |
| elem_id="text_gen_type" | |
| ) | |
| # --- Dynamic Options --- | |
| with gr.Group(visible=True, elem_id="creative_options") as creative_options: | |
| style = gr.Dropdown(["short story", "poem", "script", "song lyrics", "joke", "dialogue"], label="Style", value="short story", elem_id="creative_style") | |
| creative_topic = gr.Textbox(label="Topic", placeholder="e.g., a lonely astronaut on Mars", value="a robot discovering music", elem_id="creative_topic", lines=2) | |
| with gr.Group(visible=False, elem_id="info_options") as info_options: | |
| format_type = gr.Dropdown(["article", "summary", "explanation", "report", "comparison"], label="Format", value="article", elem_id="info_format") | |
| info_topic = gr.Textbox(label="Topic", placeholder="e.g., the basics of quantum physics", value="the impact of AI on healthcare", elem_id="info_topic", lines=2) | |
| with gr.Group(visible=False, elem_id="custom_prompt_group") as custom_prompt_group: | |
| custom_prompt = gr.Textbox(label="Custom Prompt", placeholder="Enter your full prompt here...", lines=5, elem_id="custom_prompt") | |
| # Show/hide logic | |
| def update_text_gen_visibility(choice): | |
| return { | |
| creative_options: gr.update(visible=choice == "Creative Writing"), | |
| info_options: gr.update(visible=choice == "Informational Writing"), | |
| custom_prompt_group: gr.update(visible=choice == "Custom Prompt") | |
| } | |
| text_gen_type.change(update_text_gen_visibility, inputs=text_gen_type, outputs=[creative_options, info_options, custom_prompt_group], queue=False) | |
| # Parameters | |
| text_gen_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| generate_text_btn = gr.Button("Generate Text", variant="primary", elem_id="generate_text_btn") | |
| # Output Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generated Output") | |
| text_output = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="text_output") | |
| # Handler | |
| def text_generation_handler(gen_type, style, creative_topic, format_type, info_topic, custom_prompt_text, max_tokens, temp, top_p_val): | |
| task_map = { | |
| "Creative Writing": ("creative", {"style": style, "topic": creative_topic}), | |
| "Informational Writing": ("informational", {"format_type": format_type, "topic": info_topic}), | |
| "Custom Prompt": ("custom", {"prompt": custom_prompt_text}) | |
| } | |
| task_type, kwargs = task_map.get(gen_type, ("custom", {"prompt": custom_prompt_text})) | |
| # Ensure safe values | |
| for k, v in kwargs.items(): | |
| kwargs[k] = safe_value(v, f"[{k}]") | |
| final_prompt = generate_prompt(task_type, **kwargs) | |
| return generate_text(final_prompt, max_tokens, temp, top_p_val) | |
| generate_text_btn.click( | |
| text_generation_handler, | |
| inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params], | |
| outputs=text_output | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Creative Writing", "poem", "the sound of rain on a tin roof", "", "", "", 512, 0.7, 0.9], | |
| ["Informational Writing", "", "", "explanation", "how photosynthesis works", "", 768, 0.6, 0.9], | |
| ["Custom Prompt", "", "", "", "", "Write a short dialogue between a cat and a dog discussing their humans.", 512, 0.8, 0.95], | |
| ], | |
| inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params], | |
| outputs=text_output, | |
| label="Try these examples...", | |
| fn=text_generation_handler # Need to provide function for examples to run | |
| ) | |
| # --- Brainstorming Tab --- | |
| with gr.TabItem("π§ Brainstorming", id="tab_brainstorm"): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Brainstorming Setup") | |
| brainstorm_category = gr.Dropdown(["project", "business", "creative", "solution", "content", "feature", "product name"], label="Idea Category", value="project", elem_id="brainstorm_category") | |
| brainstorm_topic = gr.Textbox(label="Topic or Problem", placeholder="e.g., reducing plastic waste", value="unique mobile app ideas", elem_id="brainstorm_topic", lines=3) | |
| brainstorm_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| brainstorm_btn = gr.Button("Generate Ideas", variant="primary", elem_id="brainstorm_btn") | |
| # Output Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generated Ideas") | |
| brainstorm_output = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="brainstorm_output") | |
| # Handler | |
| def brainstorm_handler(category, topic, max_tokens, temp, top_p_val): | |
| category = safe_value(category, "project") | |
| topic = safe_value(topic, "innovative concepts") | |
| prompt = generate_prompt("brainstorm", category=category, topic=topic) | |
| return generate_text(prompt, max_tokens, temp, top_p_val) | |
| brainstorm_btn.click(brainstorm_handler, inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params], outputs=brainstorm_output) | |
| gr.Examples( | |
| examples=[ | |
| ["solution", "making online learning more engaging", 768, 0.8, 0.9], | |
| ["business", "eco-friendly subscription boxes", 768, 0.75, 0.9], | |
| ["creative", "themes for a fantasy novel", 512, 0.85, 0.95], | |
| ], | |
| inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params], | |
| outputs=brainstorm_output, | |
| label="Try these examples...", | |
| fn=brainstorm_handler | |
| ) | |
| # --- Code Capabilities Tab --- | |
| with gr.TabItem("π» Code", id="tab_code"): | |
| with gr.Tabs() as code_tabs: | |
| # --- Code Generation --- | |
| with gr.TabItem("Generate Code", id="subtab_code_gen"): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Code Generation Setup") | |
| code_language_gen = gr.Dropdown(["Python", "JavaScript", "Java", "C++", "HTML", "CSS", "SQL", "Bash", "Rust"], label="Language", value="Python", elem_id="code_language_gen") | |
| code_task = gr.Textbox(label="Task Description", placeholder="e.g., function to calculate factorial", value="create a Python class for a basic calculator", lines=4, elem_id="code_task") | |
| code_gen_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| code_gen_btn = gr.Button("Generate Code", variant="primary", elem_id="code_gen_btn") | |
| # Output Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generated Code") | |
| # Use language mapping for syntax highlighting | |
| lang_map = {"Python": "python", "JavaScript": "javascript", "Java": "java", "C++": "cpp", "HTML": "html", "CSS": "css", "SQL": "sql", "Bash": "bash", "Rust": "rust"} | |
| code_output = gr.Code(label="Result", language="python", lines=25, interactive=False, elem_id="code_output") | |
| # Handler | |
| def code_gen_handler(language, task, max_tokens, temp, top_p_val): | |
| language = safe_value(language, "Python") | |
| task = safe_value(task, "hello world program") | |
| prompt = generate_prompt("code_generate", language=language, task=task) | |
| result = generate_text(prompt, max_tokens, temp, top_p_val) | |
| # Try to extract code block if markdown is used | |
| if "```" in result: | |
| code_block = result.split("```") | |
| if len(code_block) > 1: | |
| # Return content of the first code block, stripping language hint if present | |
| content = code_block[1] | |
| if content.lower().startswith(language.lower()): | |
| content = content[len(language):].lstrip() | |
| return content | |
| return result # Return full result if no block found | |
| # Update output language display based on dropdown | |
| def update_code_language_display(lang): | |
| return gr.Code(language=lang_map.get(lang, "plaintext")) # Update component property | |
| code_language_gen.change(update_code_language_display, inputs=code_language_gen, outputs=code_output, queue=False) | |
| code_gen_btn.click(code_gen_handler, inputs=[code_language_gen, code_task, *code_gen_params], outputs=code_output) | |
| gr.Examples( | |
| examples=[ | |
| ["JavaScript", "function to validate an email address using regex", 768, 0.6, 0.9], | |
| ["SQL", "query to select users older than 30 from a 'users' table", 512, 0.5, 0.8], | |
| ["HTML", "basic structure for a personal portfolio website", 1024, 0.7, 0.9], | |
| ], | |
| inputs=[code_language_gen, code_task, *code_gen_params], | |
| outputs=code_output, | |
| label="Try these examples...", | |
| fn=code_gen_handler | |
| ) | |
| # --- Code Explanation --- | |
| with gr.TabItem("Explain Code", id="subtab_code_explain"): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Code Explanation Setup") | |
| # Allow user to select language for context, even if not strictly necessary for model | |
| code_language_explain = gr.Dropdown(["Python", "JavaScript", "Java", "C++", "HTML", "CSS", "SQL", "Bash", "Rust", "Other"], label="Code Language (for context)", value="Python", elem_id="code_language_explain") | |
| code_to_explain = gr.Code(label="Paste Code Here", language="python", lines=15, elem_id="code_to_explain") | |
| explain_code_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| explain_code_btn = gr.Button("Explain Code", variant="primary", elem_id="explain_code_btn") | |
| # Output Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Explanation") | |
| code_explanation = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="code_explanation") | |
| # Update code input language display | |
| def update_explain_language_display(lang): | |
| return gr.Code(language=lang_map.get(lang, "plaintext")) | |
| code_language_explain.change(update_explain_language_display, inputs=code_language_explain, outputs=code_to_explain, queue=False) | |
| # Handler | |
| def explain_code_handler(language, code, max_tokens, temp, top_p_val): | |
| code = safe_value(code, "# Add code here") | |
| language = safe_value(language, "code") # Use selected language in prompt | |
| prompt = generate_prompt("code_explain", language=language, code=code) | |
| return generate_text(prompt, max_tokens, temp, top_p_val) | |
| explain_code_btn.click(explain_code_handler, inputs=[code_language_explain, code_to_explain, *explain_code_params], outputs=code_explanation) | |
| # --- Code Debugging --- | |
| with gr.TabItem("Debug Code", id="subtab_code_debug"): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Code Debugging Setup") | |
| code_language_debug = gr.Dropdown(["Python", "JavaScript", "Java", "C++", "SQL", "Bash", "Other"], label="Code Language (for context)", value="Python", elem_id="code_language_debug") | |
| code_to_debug = gr.Code( | |
| label="Paste Potentially Buggy Code Here", | |
| language="python", | |
| lines=15, | |
| value="def calculate_average(numbers):\n sum = 0\n for n in numbers:\n sum += n\n # Bug: potential division by zero if numbers is empty\n return sum / len(numbers)", # Example with potential bug | |
| elem_id="code_to_debug" | |
| ) | |
| debug_code_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| debug_code_btn = gr.Button("Debug Code", variant="primary", elem_id="debug_code_btn") | |
| # Output Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Debugging Analysis & Fix") | |
| debug_result = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="debug_result") | |
| # Update code input language display | |
| def update_debug_language_display(lang): | |
| return gr.Code(language=lang_map.get(lang, "plaintext")) | |
| code_language_debug.change(update_debug_language_display, inputs=code_language_debug, outputs=code_to_debug, queue=False) | |
| # Handler | |
| def debug_code_handler(language, code, max_tokens, temp, top_p_val): | |
| code = safe_value(code, "# Add potentially buggy code here") | |
| language = safe_value(language, "code") | |
| prompt = generate_prompt("code_debug", language=language, code=code) | |
| return generate_text(prompt, max_tokens, temp, top_p_val) | |
| debug_code_btn.click(debug_code_handler, inputs=[code_language_debug, code_to_debug, *debug_code_params], outputs=debug_result) | |
| # --- Text Comprehension Tab (Summarize, QA, Translate) --- | |
| with gr.TabItem("π Comprehension", id="tab_comprehension"): | |
| with gr.Tabs() as comprehension_tabs: | |
| # --- Summarization --- | |
| with gr.TabItem("Summarize", id="subtab_summarize"): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Summarization Setup") | |
| summarize_text = gr.Textbox(label="Text to Summarize", placeholder="Paste long text here...", lines=15, elem_id="summarize_text") | |
| summarize_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| summarize_btn = gr.Button("Summarize Text", variant="primary", elem_id="summarize_btn") | |
| # Output Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Summary") | |
| summary_output = gr.Textbox(label="Result", lines=15, interactive=False, elem_id="summary_output") | |
| # Handler | |
| def summarize_handler(text, max_tokens, temp, top_p_val): | |
| text = safe_value(text, "Please provide text to summarize.") | |
| # Use shorter max_tokens default for summary | |
| max_tokens = min(max_tokens, 512) | |
| prompt = generate_prompt("summarize", text=text) | |
| return generate_text(prompt, max_tokens, temp, top_p_val) | |
| summarize_btn.click(summarize_handler, inputs=[summarize_text, *summarize_params], outputs=summary_output) | |
| # --- Question Answering --- | |
| with gr.TabItem("Q & A", id="subtab_qa"): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Question Answering Setup") | |
| qa_text = gr.Textbox(label="Context Text", placeholder="Paste the text containing the answer...", lines=10, elem_id="qa_text") | |
| qa_question = gr.Textbox(label="Question", placeholder="Ask a question based on the text above...", elem_id="qa_question") | |
| qa_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| qa_btn = gr.Button("Get Answer", variant="primary", elem_id="qa_btn") | |
| # Output Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Answer") | |
| qa_output = gr.Textbox(label="Result", lines=10, interactive=False, elem_id="qa_output") | |
| # Handler | |
| def qa_handler(text, question, max_tokens, temp, top_p_val): | |
| text = safe_value(text, "Please provide context text.") | |
| question = safe_value(question, "What is the main point?") | |
| # Use shorter max_tokens default for QA | |
| max_tokens = min(max_tokens, 256) | |
| prompt = generate_prompt("qa", text=text, question=question) | |
| return generate_text(prompt, max_tokens, temp, top_p_val) | |
| qa_btn.click(qa_handler, inputs=[qa_text, qa_question, *qa_params], outputs=qa_output) | |
| # --- Translation --- | |
| with gr.TabItem("Translate", id="subtab_translate"): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Translation Setup") | |
| translate_text = gr.Textbox(label="Text to Translate", placeholder="Enter text in any language...", lines=8, elem_id="translate_text") | |
| target_lang = gr.Dropdown( | |
| ["French", "Spanish", "German", "Japanese", "Chinese", "Russian", "Arabic", "Hindi", "Portuguese", "Italian"], | |
| label="Translate To", value="French", elem_id="target_lang" | |
| ) | |
| translate_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| translate_btn = gr.Button("Translate Text", variant="primary", elem_id="translate_btn") | |
| # Output Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Translation") | |
| translation_output = gr.Textbox(label="Result", lines=8, interactive=False, elem_id="translation_output") | |
| # Handler | |
| def translate_handler(text, lang, max_tokens, temp, top_p_val): | |
| text = safe_value(text, "Please enter text to translate.") | |
| lang = safe_value(lang, "French") | |
| prompt = generate_prompt("translate", text=text, target_lang=lang) | |
| return generate_text(prompt, max_tokens, temp, top_p_val) | |
| translate_btn.click(translate_handler, inputs=[translate_text, target_lang, *translate_params], outputs=translation_output) | |
| # --- More Tasks Tab (Consolidating less common ones) --- | |
| with gr.TabItem("π οΈ More Tasks", id="tab_more"): | |
| with gr.Tabs() as more_tasks_tabs: | |
| # --- Content Creation --- | |
| with gr.TabItem("Content Creation", id="tab_content"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Content Setup") | |
| content_type = gr.Dropdown(["blog post outline", "social media post (Twitter)", "social media post (LinkedIn)", "marketing email subject line", "product description", "press release intro"], label="Content Type", value="blog post outline", elem_id="content_type") | |
| content_topic = gr.Textbox(label="Topic", placeholder="e.g., benefits of remote work", value="sustainable travel tips", elem_id="content_topic", lines=2) | |
| content_audience = gr.Textbox(label="Target Audience", placeholder="e.g., small business owners", value="eco-conscious millennials", elem_id="content_audience") | |
| content_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| content_btn = gr.Button("Generate Content", variant="primary", elem_id="content_btn") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generated Content") | |
| content_output = gr.Textbox(label="Result", lines=20, interactive=False, elem_id="content_output") | |
| def content_handler(c_type, topic, audience, max_tok, temp, top_p_val): | |
| c_type = safe_value(c_type, "text") | |
| topic = safe_value(topic, "a given subject") | |
| audience = safe_value(audience, "a general audience") | |
| prompt = generate_prompt("content_creation", content_type=c_type, topic=topic, audience=audience) | |
| return generate_text(prompt, max_tok, temp, top_p_val) | |
| content_btn.click(content_handler, inputs=[content_type, content_topic, content_audience, *content_params], outputs=content_output) | |
| # --- Email Drafting --- | |
| with gr.TabItem("Email Drafting", id="tab_email"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Email Setup") | |
| email_type = gr.Dropdown(["job inquiry", "meeting request", "follow-up", "thank you note", "customer support response", "sales outreach"], label="Email Type", value="meeting request", elem_id="email_type") | |
| email_context = gr.Textbox(label="Key Points / Context", placeholder="Provide bullet points or context...", lines=5, value="Request a brief meeting next week to discuss project X. Suggest Tuesday or Wednesday afternoon. Mention attached agenda.", elem_id="email_context") | |
| email_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| email_btn = gr.Button("Generate Email Draft", variant="primary", elem_id="email_btn") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generated Email") | |
| email_output = gr.Textbox(label="Result", lines=20, interactive=False, elem_id="email_output") | |
| def email_handler(e_type, context, max_tok, temp, top_p_val): | |
| e_type = safe_value(e_type, "professional") | |
| context = safe_value(context, "the following points") | |
| prompt = generate_prompt("email_draft", email_type=e_type, context=context) | |
| return generate_text(prompt, max_tok, temp, top_p_val) | |
| email_btn.click(email_handler, inputs=[email_type, email_context, *email_params], outputs=email_output) | |
| # --- Document Editing --- | |
| with gr.TabItem("Document Editing", id="tab_edit"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Editing Setup") | |
| edit_text = gr.Textbox(label="Text to Edit", placeholder="Paste text here...", lines=10, elem_id="edit_text") | |
| edit_type = gr.Dropdown(["improve clarity", "fix grammar & spelling", "make more concise", "make more formal", "make more casual", "simplify language"], label="Improve For", value="improve clarity", elem_id="edit_type") | |
| edit_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| edit_btn = gr.Button("Edit Text", variant="primary", elem_id="edit_btn") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Edited Text") | |
| edit_output = gr.Textbox(label="Result", lines=10, interactive=False, elem_id="edit_output") | |
| def edit_handler(text, e_type, max_tok, temp, top_p_val): | |
| text = safe_value(text, "Provide text to edit.") | |
| e_type = safe_value(e_type, "clarity and grammar") | |
| prompt = generate_prompt("document_edit", text=text, edit_type=e_type) | |
| # Generate potentially longer text as editing might expand it | |
| return generate_text(prompt, max(max_tok, len(text.split()) + 128), temp, top_p_val) | |
| edit_btn.click(edit_handler, inputs=[edit_text, edit_type, *edit_params], outputs=edit_output) | |
| # --- Classification --- | |
| with gr.TabItem("Classification", id="tab_classify"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Classification Setup") | |
| classify_text = gr.Textbox(label="Text to Classify", placeholder="Enter text...", lines=8, value="This new sci-fi movie explores themes of AI consciousness and interstellar travel.") | |
| classify_categories = gr.Textbox(label="Categories (comma-separated)", placeholder="e.g., positive, negative, neutral", value="Technology, Entertainment, Science, Politics, Sports, Health") | |
| classify_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| classify_btn = gr.Button("Classify Text", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Classification Result") | |
| classify_output = gr.Textbox(label="Predicted Category", lines=2, interactive=False) | |
| def classify_handler(text, cats, max_tok, temp, top_p_val): | |
| text = safe_value(text, "Text to classify needed.") | |
| cats = safe_value(cats, "category1, category2") | |
| # Classification usually needs short output | |
| max_tok = min(max_tok, 64) | |
| prompt = generate_prompt("classify", text=text, categories=cats) | |
| return generate_text(prompt, max_tok, temp, top_p_val) | |
| classify_btn.click(classify_handler, inputs=[classify_text, classify_categories, *classify_params], outputs=classify_output) | |
| # --- Data Extraction --- | |
| with gr.TabItem("Data Extraction", id="tab_extract"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Extraction Setup") | |
| extract_text = gr.Textbox(label="Source Text", placeholder="Paste text containing data...", lines=10, value="Order #12345 placed on 2024-03-15 by Jane Doe ([email protected]). Total amount: $99.95. Shipping to 123 Main St, Anytown, USA.") | |
| extract_data_points = gr.Textbox(label="Data to Extract (comma-separated)", placeholder="e.g., name, email, order number", value="order number, date, customer name, email, total amount, address") | |
| extract_params = create_parameter_ui() | |
| gr.Spacer(height=15) | |
| extract_btn = gr.Button("Extract Data", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Extracted Data") | |
| extract_output = gr.Textbox(label="Result (e.g., JSON or key-value pairs)", lines=10, interactive=False) | |
| def extract_handler(text, points, max_tok, temp, top_p_val): | |
| text = safe_value(text, "Provide text for extraction.") | |
| points = safe_value(points, "key information") | |
| prompt = generate_prompt("data_extract", text=text, data_points=points) | |
| return generate_text(prompt, max_tok, temp, top_p_val) | |
| extract_btn.click(extract_handler, inputs=[extract_text, extract_data_points, *extract_params], outputs=extract_output) | |
| # --- Footer --- | |
| gr.Markdown( | |
| """ | |
| --- | |
| <div style="text-align: center; font-size: 0.9em; color: #777;"> | |
| <p>Powered by Google's Gemma models via Hugging Face π€ Transformers & Gradio.</p> | |
| <p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p> | |
| <p>Model Loaded: <span id="footer-model-name">None</span></p> <!-- Placeholder for dynamic update --> | |
| </div> | |
| """ | |
| ) | |
| # Add JS to update the footer model name (optional, but nice) | |
| auth_status.change(lambda status: gr.update(value=f""" | |
| --- | |
| <div style="text-align: center; font-size: 0.9em; color: #777;"> | |
| <p>Powered by Google's Gemma models via Hugging Face π€ Transformers & Gradio.</p> | |
| <p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p> | |
| <p>Model Loaded: <strong>{loaded_model_name if model_loaded else 'None'}</strong></p> | |
| </div> | |
| """), inputs=auth_status, outputs=gr.Markdown(elem_id="footer-model-display")) # Need a dummy output or separate MD for this | |
| # --- Launch App --- | |
| # Allow built-in theme switching | |
| demo.launch(share=False, allowed_themes=["light", "dark"]) |