Spaces:
Paused
Paused
themissingCRAM
commited on
Commit
·
071e96e
1
Parent(s):
2f46a72
chroma db
Browse files
app.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
-
from smolagents import Tool, CodeAgent, HfApiModel,
|
| 4 |
import spaces
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
import datasets
|
| 7 |
from langchain.docstore.document import Document
|
| 8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 9 |
-
from langchain_community.retrievers import BM25Retriever
|
| 10 |
import chromadb
|
| 11 |
from chromadb.utils import embedding_functions
|
| 12 |
|
|
@@ -43,8 +42,8 @@ class RetrieverTool(Tool):
|
|
| 43 |
"""
|
| 44 |
inputs = {
|
| 45 |
"query": {
|
| 46 |
-
"type": "string",
|
| 47 |
-
"description": "The
|
| 48 |
}
|
| 49 |
}
|
| 50 |
output_type = "string"
|
|
@@ -65,14 +64,14 @@ class RetrieverTool(Tool):
|
|
| 65 |
embedding_function=embedding_func,
|
| 66 |
metadata={"hnsw:space": "cosine"},
|
| 67 |
)
|
| 68 |
-
collection.
|
| 69 |
documents=[doc.page_content for doc in docs],
|
| 70 |
ids=[f"id{i}" for i in range(len(docs))],
|
| 71 |
)
|
| 72 |
self.collection = collection
|
| 73 |
|
| 74 |
-
def forward(self, query: str) -> str:
|
| 75 |
-
assert isinstance(query, str), "Your search query must be a string"
|
| 76 |
docs = self.collection.query(query, n_results=5)
|
| 77 |
retrieved_text = "\nRetrieved documents:\n" + "".join(
|
| 78 |
[
|
|
@@ -127,4 +126,17 @@ if __name__ == "__main__":
|
|
| 127 |
verbosity_level=10,
|
| 128 |
)
|
| 129 |
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
+
from smolagents import Tool, CodeAgent, HfApiModel, stream_to_gradio
|
| 4 |
import spaces
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
import datasets
|
| 7 |
from langchain.docstore.document import Document
|
| 8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
| 9 |
import chromadb
|
| 10 |
from chromadb.utils import embedding_functions
|
| 11 |
|
|
|
|
| 42 |
"""
|
| 43 |
inputs = {
|
| 44 |
"query": {
|
| 45 |
+
"type": "List[string]",
|
| 46 |
+
"description": "The python list of queries to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
|
| 47 |
}
|
| 48 |
}
|
| 49 |
output_type = "string"
|
|
|
|
| 64 |
embedding_function=embedding_func,
|
| 65 |
metadata={"hnsw:space": "cosine"},
|
| 66 |
)
|
| 67 |
+
collection.upsert(
|
| 68 |
documents=[doc.page_content for doc in docs],
|
| 69 |
ids=[f"id{i}" for i in range(len(docs))],
|
| 70 |
)
|
| 71 |
self.collection = collection
|
| 72 |
|
| 73 |
+
def forward(self, query: list[str]) -> str:
|
| 74 |
+
assert isinstance(query, list[str]), "Your search query must be a string"
|
| 75 |
docs = self.collection.query(query, n_results=5)
|
| 76 |
retrieved_text = "\nRetrieved documents:\n" + "".join(
|
| 77 |
[
|
|
|
|
| 126 |
verbosity_level=10,
|
| 127 |
)
|
| 128 |
|
| 129 |
+
def enter_message(new_message, conversation_history):
|
| 130 |
+
conversation_history.append(gr.ChatMessage(role="user", content=new_message))
|
| 131 |
+
yield "", conversation_history
|
| 132 |
+
for msg in stream_to_gradio(agent, new_message):
|
| 133 |
+
conversation_history.append(msg)
|
| 134 |
+
yield "", conversation_history
|
| 135 |
+
|
| 136 |
+
with gr.Blocks() as b:
|
| 137 |
+
chatbot = gr.Chatbot(type="messages", height=1000)
|
| 138 |
+
textbox = gr.Textbox(lines=3, label="")
|
| 139 |
+
button = gr.Button("reply")
|
| 140 |
+
button.click(enter_message, [textbox, chatbot], [textbox, chatbot])
|
| 141 |
+
b.launch()
|
| 142 |
+
# GradioUI(agent).launch()
|