File size: 2,927 Bytes
c35b520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Callable, Optional

import gradio as gr
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Zilliz
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.llms import OpenAI

chain: Optional[Callable] = None


def web_loader(url_list, openai_key, zilliz_uri, user, password):
    if not url_list:
        return "please enter url list"
    loader = WebBaseLoader(url_list.split())
    docs = loader.load()

    text_splitter = CharacterTextSplitter(chunk_size=1024, chunk_overlap=0)
    docs = text_splitter.split_documents(docs)
    embeddings = OpenAIEmbeddings(model="ada", openai_api_key=openai_key)

    docsearch = Zilliz.from_documents(
        docs,
        embedding=embeddings,
        connection_args={
            "uri": zilliz_uri,
            "user": user,
            "password": password,
            "secure": True,
        },
    )

    global chain
    chain = RetrievalQAWithSourcesChain.from_chain_type(
        OpenAI(temperature=0, openai_api_key=openai_key),
        chain_type="map_reduce",
        retriever=docsearch.as_retriever(),
    )
    return "success to load data"


def query(question):
    global chain
    # "What is milvus?"
    if not chain:
        return "please load the data first"
    return chain(inputs={"question": question}, return_only_outputs=True).get(
        "answer", "fail to get answer"
    )


if __name__ == "__main__":
    block = gr.Blocks()
    with block as demo:
        gr.Markdown("<h1><center>Langchain And Zilliz Cloud Demo</center></h1>")
        url_list_text = gr.Textbox(
            label="Url list",
            lines=3,
        )
        openai_key_text = gr.Textbox(label="openai api key")
        with gr.Row():
            zilliz_uri_text = gr.Textbox(label="zilliz cloud uri")
            user_text = gr.Textbox(label="user")
            password_text = gr.Textbox(label="password", type="password")
        loader_output = gr.Textbox(label="Load Status")
        loader_btn = gr.Button("WebLoader")
        loader_btn.click(
            fn=web_loader,
            inputs=[
                url_list_text,
                openai_key_text,
                zilliz_uri_text,
                user_text,
                password_text,
            ],
            outputs=loader_output,
            api_name="web_load",
        )

        question_text = gr.Textbox(
            label="question",
            lines=3,
        )
        query_output = gr.Textbox(label="question answer", lines=3)
        query_btn = gr.Button("Generate")
        query_btn.click(
            fn=query,
            inputs=[question_text],
            outputs=query_output,
            api_name="generate_answer",
        )

        demo.queue().launch(server_name="0.0.0.0", share=False)