| import os | |
| import time | |
| import requests | |
| import gradio as gr | |
| from huggingface_hub import get_inference_endpoint | |
| endpoint_name = os.getenv('ENDPOINT_NAME') | |
| endpoint_url = os.getenv('ENDPOINT_URL') | |
| personal_secret_token = os.getenv('PERSONAL_HF_TOKEN') | |
| turn_breaker = os.getenv('TURN_BREAKER') | |
| system_symbol = os.getenv('SYSTEM_SYMBOL') | |
| user_symbol = os.getenv('USER_SYMBOL') | |
| assistant_symbol = os.getenv('ASSISTANT_SYMBOL') | |
| headers = { | |
| "Accept" : "application/json", | |
| "Authorization": f"Bearer {personal_secret_token}", | |
| "Content-Type": "application/json" | |
| } | |
| def query(payload): | |
| response = requests.post(endpoint_url, headers=headers, json=payload) | |
| return response.json() | |
| def get_status(): | |
| endpoint = get_inference_endpoint(endpoint_name, token=personal_secret_token) | |
| return endpoint.status | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| progress=gr.Progress() | |
| ): | |
| progress(0, desc="Starting") | |
| if get_status() != "running": | |
| query({"inputs": "wake up!"}) | |
| progress(0.25, desc="Waking up model") | |
| while get_status() != "running": | |
| time.sleep(1) | |
| progress(0.5, desc="Generating") | |
| all_messages = [system_message] | |
| for val in history: | |
| if val[0]: | |
| all_messages.append(user_symbol+val[0]) | |
| if val[1]: | |
| all_messages.append(assistant_symbol+val[1]) | |
| all_messages.append(user_symbol+message) | |
| generation_kwargs = dict( | |
| max_new_tokens=max_new_tokens, | |
| do_sample=temperature > 0, | |
| top_p=top_p, | |
| temperature=temperature | |
| ) | |
| response = query({ | |
| "inputs": turn_breaker.join(all_messages), | |
| "parameters": generation_kwargs | |
| }) | |
| progress(1, desc="Generating") | |
| return response | |
| """ | |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
| """ | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="请你扮演一个开心,积极的角色,名叫贺英旭,今年26岁,工作是程序员。你需要以这个身份和朋友们进行对话。", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| show_progress="full" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |