File size: 1,055 Bytes
f745223
e8fb838
1ffd977
 
13a089e
e8fb838
 
f745223
cbcb343
f745223
 
 
d8a82cd
52c453e
f745223
13a089e
 
1ffd977
3988351
13a089e
 
 
 
 
 
 
 
 
 
1ffd977
 
f57923a
 
1ffd977
 
13a089e
 
1ffd977
13a089e
1ffd977
 
 
2334dc1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import torch
import gradio as gr

from strings import TITLE, ABSTRACT 
from gen import get_pretrained_models, get_output, setup_model_parallel

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "50505"

local_rank, world_size = setup_model_parallel()
generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)

history = []

def chat(user_input):
    bot_response = get_output(generator, user_input)[0]

    history.append({
        "role": "user",
        "content": user_input
    })
    history.append({
        "role": "system",
        "content": bot_response
    })    
    
    response = ""
    for word in bot_response.split(" "):
        response += word + " "
        yield [(user_input, response)]

with gr.Blocks() as demo:
    gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
    
    chatbot = gr.Chatbot()
    textbox = gr.Textbox(placeholder="Enter a prompt")

    textbox.submit(chat, textbox, chatbot)

demo.queue(api_open=False).launch()