Spaces:
Sleeping
Sleeping
from huggingface_hub import InferenceClient | |
import gradio as gr | |
client = InferenceClient("grammarly/coedit-large") | |
def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k=50, repetition_penalty=1.0): | |
print(f" TEMP: {temperature} \n\t TYPE: {type(temperature)}") | |
print(f" TOP-P: {top_p} \n\t TYPE: {type(top_p)}") | |
print(f" TOP-K: {top_k} \n\t TYPE: {type(top_k)}") | |
print(f" MAX_TOK: {max_new_tokens} \n\t TYPE: {type(max_new_tokens)}") | |
#temperature = float(temperature) | |
temperature = float(temperature[0]) if isinstance(temperature, list) else float(temperature) | |
if temperature < 1e-2: temperature = 1e-2 | |
top_p = float(top_p) | |
top_k = int(top_k) # Ensure top_k is an integer, as it was being treated like a float | |
generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) # seed=42,) | |
formatted_prompt = "Fix grammatical errors in this sentence: " + prompt | |
print("\nPROMPT: \n\t" + formatted_prompt) | |
# Generate text from the HF inference | |
output = client.text_generation(formatted_prompt, **generate_kwargs, details=True, return_full_text=True) | |
#output = "" | |
#for response in stream: | |
# output += response.token.text | |
# yield output | |
return output | |
additional_inputs=[ | |
gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), | |
gr.Slider( label="Max new tokens", value=150, minimum=0, maximum=250, step=64, interactive=True, info="The maximum numbers of new tokens", ), | |
gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), | |
gr.Slider( label="Top-k", value=50, minimum=0, maximum=100, step=1, interactive=True, info="Limits the number of top-k tokens considered at each step"), | |
] | |
gr.ChatInterface( | |
fn=generate, | |
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"), | |
additional_inputs=additional_inputs, | |
title="My Grammarly Space", | |
concurrency_limit=20, | |
).launch(show_api=False) | |