File size: 4,356 Bytes
bb645d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import time
from huggingface_hub import InferenceClient
import gradio as gr

# Initialize the inference client with the new LLM
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

# Define the system prompt for enhancing user prompts
SYSTEM_PROMPT = (
    "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
    "without changing the essence, only write the enhanced prompt and nothing else."
)

def format_prompt(message):
    """
    Format the input message using the system prompt and a timestamp to ensure uniqueness.
    """
    timestamp = time.time()
    formatted = (
        f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]"
        f"[INST] {message} {timestamp} [/INST]"
    )
    return formatted

def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
    """
    Generate an enhanced prompt using the new LLM.
    This function yields intermediate results as they are generated.
    """
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    generate_kwargs = {
        "temperature": temperature,
        "max_new_tokens": int(max_new_tokens),
        "top_p": top_p,
        "repetition_penalty": float(repetition_penalty),
        "do_sample": True,
    }
    formatted_prompt = format_prompt(message)
    stream = client.text_generation(
        formatted_prompt,
        **generate_kwargs,
        stream=True,
        details=True,
        return_full_text=False,
    )
    output = ""
    for response in stream:
        token_text = response.token.text
        output += token_text
        yield output.strip('</s>')
    return output.strip('</s>')

# Markdown texts for credits and best practices
CREDITS_MARKDOWN = """
# Prompt Enhancer  
Credits: Instructions and design inspired by [ruslanmv.com](https://ruslanmv.com).
"""

BEST_PRACTICES = """
**Best Practices**  
- Be specific and clear in your input prompt  
- Use temperature 0.0 for consistent, focused results  
- Increase temperature up to 1.0 for more creative variations  
- Review and iterate on engineered prompts for optimal results  
"""

# Build the Gradio interface with the Ocean theme
with gr.Blocks(theme=gr.themes.Ocean(), css=".gradio-container { max-width: 800px; margin: auto; }") as demo:
    # Credits at the top
    gr.Markdown(CREDITS_MARKDOWN)
    
    gr.Markdown(
        "Enhance your prompt to under 100 words while preserving its essence. "
        "Adjust the generation parameters as needed."
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            input_prompt = gr.Textbox(
                label="Input Prompt",
                placeholder="Enter your prompt here...",
                lines=4,
            )
            max_tokens_slider = gr.Slider(
                label="Max New Tokens",
                minimum=50,
                maximum=512,
                step=1,
                value=256,
            )
            temperature_slider = gr.Slider(
                label="Temperature",
                minimum=0.1,
                maximum=2.0,
                step=0.1,
                value=0.9,
            )
            top_p_slider = gr.Slider(
                label="Top-p (nucleus sampling)",
                minimum=0.1,
                maximum=1.0,
                step=0.05,
                value=0.95,
            )
            repetition_penalty_slider = gr.Slider(
                label="Repetition Penalty",
                minimum=1.0,
                maximum=2.0,
                step=0.05,
                value=1.0,
            )
            generate_button = gr.Button("Enhance Prompt")
        with gr.Column(scale=1):
            output_prompt = gr.Textbox(
                label="Enhanced Prompt",
                lines=10,
                interactive=True,
            )
    
    # Best practices message at the bottom
    gr.Markdown(BEST_PRACTICES)
    
    # Wire the button click to the generate function (streaming functionality is handled internally)
    generate_button.click(
        fn=generate,
        inputs=[
            input_prompt,
            max_tokens_slider,
            temperature_slider,
            top_p_slider,
            repetition_penalty_slider,
        ],
        outputs=output_prompt,
    )

demo.launch()