File size: 2,791 Bytes
e763e8a
 
 
 
 
0d179e3
 
 
97e59bc
 
 
 
e763e8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d179e3
e763e8a
0d179e3
e763e8a
 
0d179e3
e763e8a
 
0d179e3
11a9727
 
 
 
 
e763e8a
 
 
 
 
 
 
 
 
 
 
0d179e3
814580b
fed1aac
814580b
3b344a7
d14928c
fed1aac
 
 
 
 
 
 
 
 
 
 
 
 
 
2b56c73
b018ed8
bd92b29
fed1aac
e763e8a
 
0d179e3
 
c9a7070
0d179e3
 
 
 
3d998a2
0d179e3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from omegaconf import OmegaConf
from query import VectaraQuery
import os

from PIL import Image
import gradio as gr
from huggingface_hub import InferenceClient

import logging
logging.basicConfig(level=logging.DEBUG)


# """
# For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
# """
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")


# def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
#     messages = [{"role": "system", "content": system_message}]

#     for val in history:
#         if val[0]:
#             messages.append({"role": "user", "content": val[0]})
#         if val[1]:
#             messages.append({"role": "assistant", "content": val[1]})

#     messages.append({"role": "user", "content": message})

#     response = ""

#     for message in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
#         token = message.choices[0].delta.content

#         response += token
#         yield response

def isTrue(x) -> bool:
    if isinstance(x, bool):
        return x
    return x.strip().lower() == 'true'

corpus_ids = str(os.environ['corpus_ids']).split(',')
cfg = OmegaConf.create({
    'customer_id': str(os.environ['customer_id']),
    'corpus_ids': corpus_ids,
    'api_key': str(os.environ['api_key']),
    'title': os.environ['title'],
    'description': os.environ['description'],
    'source_data_desc': os.environ['source_data_desc'],
    'streaming': isTrue(os.environ.get('streaming', False)),
    'prompt_name': os.environ.get('prompt_name', None)
})

cfg.description = f'''
                    <h4 style="text-align: center;">{cfg.description}</h4>
                  '''

vq = VectaraQuery(cfg.api_key, cfg.customer_id, cfg.corpus_ids, cfg.prompt_name)


def respond(message, history):
    if cfg.streaming:
        # Call stream response and stream output
        stream = vq.submit_query_streaming(message)

        outputs = ""
        for output in stream:
            outputs += output
            yield outputs
    else:
        # Call non-stream response and return message output
        response = vq.submit_query(message)
        logging.debug(f"Chat response: {response}")
        logging.debug(f"Chat response type: {type(response)}")
        yield response

def random_fun(message, history):
    return message + '!'


demo = gr.ChatInterface(respond, title = cfg.title, description = cfg.description)

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# demo = gr.ChatInterface(respond)


if __name__ == "__main__":
    demo.launch()