ruslanmv commited on
Commit
bb645d8
·
1 Parent(s): c58ec5c

First commit

Browse files
Files changed (1) hide show
  1. app.py +139 -0
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from huggingface_hub import InferenceClient
3
+ import gradio as gr
4
+
5
+ # Initialize the inference client with the new LLM
6
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
7
+
8
+ # Define the system prompt for enhancing user prompts
9
+ SYSTEM_PROMPT = (
10
+ "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
11
+ "without changing the essence, only write the enhanced prompt and nothing else."
12
+ )
13
+
14
+ def format_prompt(message):
15
+ """
16
+ Format the input message using the system prompt and a timestamp to ensure uniqueness.
17
+ """
18
+ timestamp = time.time()
19
+ formatted = (
20
+ f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]"
21
+ f"[INST] {message} {timestamp} [/INST]"
22
+ )
23
+ return formatted
24
+
25
+ def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
26
+ """
27
+ Generate an enhanced prompt using the new LLM.
28
+ This function yields intermediate results as they are generated.
29
+ """
30
+ temperature = float(temperature)
31
+ if temperature < 1e-2:
32
+ temperature = 1e-2
33
+ top_p = float(top_p)
34
+ generate_kwargs = {
35
+ "temperature": temperature,
36
+ "max_new_tokens": int(max_new_tokens),
37
+ "top_p": top_p,
38
+ "repetition_penalty": float(repetition_penalty),
39
+ "do_sample": True,
40
+ }
41
+ formatted_prompt = format_prompt(message)
42
+ stream = client.text_generation(
43
+ formatted_prompt,
44
+ **generate_kwargs,
45
+ stream=True,
46
+ details=True,
47
+ return_full_text=False,
48
+ )
49
+ output = ""
50
+ for response in stream:
51
+ token_text = response.token.text
52
+ output += token_text
53
+ yield output.strip('</s>')
54
+ return output.strip('</s>')
55
+
56
+ # Markdown texts for credits and best practices
57
+ CREDITS_MARKDOWN = """
58
+ # Prompt Enhancer
59
+ Credits: Instructions and design inspired by [ruslanmv.com](https://ruslanmv.com).
60
+ """
61
+
62
+ BEST_PRACTICES = """
63
+ **Best Practices**
64
+ - Be specific and clear in your input prompt
65
+ - Use temperature 0.0 for consistent, focused results
66
+ - Increase temperature up to 1.0 for more creative variations
67
+ - Review and iterate on engineered prompts for optimal results
68
+ """
69
+
70
+ # Build the Gradio interface with the Ocean theme
71
+ with gr.Blocks(theme=gr.themes.Ocean(), css=".gradio-container { max-width: 800px; margin: auto; }") as demo:
72
+ # Credits at the top
73
+ gr.Markdown(CREDITS_MARKDOWN)
74
+
75
+ gr.Markdown(
76
+ "Enhance your prompt to under 100 words while preserving its essence. "
77
+ "Adjust the generation parameters as needed."
78
+ )
79
+
80
+ with gr.Row():
81
+ with gr.Column(scale=1):
82
+ input_prompt = gr.Textbox(
83
+ label="Input Prompt",
84
+ placeholder="Enter your prompt here...",
85
+ lines=4,
86
+ )
87
+ max_tokens_slider = gr.Slider(
88
+ label="Max New Tokens",
89
+ minimum=50,
90
+ maximum=512,
91
+ step=1,
92
+ value=256,
93
+ )
94
+ temperature_slider = gr.Slider(
95
+ label="Temperature",
96
+ minimum=0.1,
97
+ maximum=2.0,
98
+ step=0.1,
99
+ value=0.9,
100
+ )
101
+ top_p_slider = gr.Slider(
102
+ label="Top-p (nucleus sampling)",
103
+ minimum=0.1,
104
+ maximum=1.0,
105
+ step=0.05,
106
+ value=0.95,
107
+ )
108
+ repetition_penalty_slider = gr.Slider(
109
+ label="Repetition Penalty",
110
+ minimum=1.0,
111
+ maximum=2.0,
112
+ step=0.05,
113
+ value=1.0,
114
+ )
115
+ generate_button = gr.Button("Enhance Prompt")
116
+ with gr.Column(scale=1):
117
+ output_prompt = gr.Textbox(
118
+ label="Enhanced Prompt",
119
+ lines=10,
120
+ interactive=True,
121
+ )
122
+
123
+ # Best practices message at the bottom
124
+ gr.Markdown(BEST_PRACTICES)
125
+
126
+ # Wire the button click to the generate function (streaming functionality is handled internally)
127
+ generate_button.click(
128
+ fn=generate,
129
+ inputs=[
130
+ input_prompt,
131
+ max_tokens_slider,
132
+ temperature_slider,
133
+ top_p_slider,
134
+ repetition_penalty_slider,
135
+ ],
136
+ outputs=output_prompt,
137
+ )
138
+
139
+ demo.launch()