Spaces:
Running
Running
File size: 4,356 Bytes
bb645d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import time
from huggingface_hub import InferenceClient
import gradio as gr
# Initialize the inference client with the new LLM
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
# Define the system prompt for enhancing user prompts
SYSTEM_PROMPT = (
"You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
"without changing the essence, only write the enhanced prompt and nothing else."
)
def format_prompt(message):
"""
Format the input message using the system prompt and a timestamp to ensure uniqueness.
"""
timestamp = time.time()
formatted = (
f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]"
f"[INST] {message} {timestamp} [/INST]"
)
return formatted
def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
"""
Generate an enhanced prompt using the new LLM.
This function yields intermediate results as they are generated.
"""
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = {
"temperature": temperature,
"max_new_tokens": int(max_new_tokens),
"top_p": top_p,
"repetition_penalty": float(repetition_penalty),
"do_sample": True,
}
formatted_prompt = format_prompt(message)
stream = client.text_generation(
formatted_prompt,
**generate_kwargs,
stream=True,
details=True,
return_full_text=False,
)
output = ""
for response in stream:
token_text = response.token.text
output += token_text
yield output.strip('</s>')
return output.strip('</s>')
# Markdown texts for credits and best practices
CREDITS_MARKDOWN = """
# Prompt Enhancer
Credits: Instructions and design inspired by [ruslanmv.com](https://ruslanmv.com).
"""
BEST_PRACTICES = """
**Best Practices**
- Be specific and clear in your input prompt
- Use temperature 0.0 for consistent, focused results
- Increase temperature up to 1.0 for more creative variations
- Review and iterate on engineered prompts for optimal results
"""
# Build the Gradio interface with the Ocean theme
with gr.Blocks(theme=gr.themes.Ocean(), css=".gradio-container { max-width: 800px; margin: auto; }") as demo:
# Credits at the top
gr.Markdown(CREDITS_MARKDOWN)
gr.Markdown(
"Enhance your prompt to under 100 words while preserving its essence. "
"Adjust the generation parameters as needed."
)
with gr.Row():
with gr.Column(scale=1):
input_prompt = gr.Textbox(
label="Input Prompt",
placeholder="Enter your prompt here...",
lines=4,
)
max_tokens_slider = gr.Slider(
label="Max New Tokens",
minimum=50,
maximum=512,
step=1,
value=256,
)
temperature_slider = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.9,
)
top_p_slider = gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.1,
maximum=1.0,
step=0.05,
value=0.95,
)
repetition_penalty_slider = gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
)
generate_button = gr.Button("Enhance Prompt")
with gr.Column(scale=1):
output_prompt = gr.Textbox(
label="Enhanced Prompt",
lines=10,
interactive=True,
)
# Best practices message at the bottom
gr.Markdown(BEST_PRACTICES)
# Wire the button click to the generate function (streaming functionality is handled internally)
generate_button.click(
fn=generate,
inputs=[
input_prompt,
max_tokens_slider,
temperature_slider,
top_p_slider,
repetition_penalty_slider,
],
outputs=output_prompt,
)
demo.launch()
|