Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import uuid
|
2 |
import chromadb
|
|
|
3 |
from langchain.vectorstores import Chroma
|
4 |
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
from langchain.retrievers import ContextualCompressionRetriever
|
@@ -7,8 +8,15 @@ from langchain.retrievers.document_compressors import CrossEncoderReranker
|
|
7 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
8 |
import gradio as gr
|
9 |
|
|
|
|
|
|
|
|
|
10 |
# Initialize embedding model
|
11 |
-
embedding_model = HuggingFaceEmbeddings(
|
|
|
|
|
|
|
12 |
|
13 |
# Initialize ChromaDB client and collection
|
14 |
chroma_client = chromadb.PersistentClient(path="./chroma_db")
|
@@ -19,7 +27,10 @@ vectorstore = Chroma(
|
|
19 |
)
|
20 |
|
21 |
# Initialize reranker
|
22 |
-
reranker = HuggingFaceCrossEncoder(
|
|
|
|
|
|
|
23 |
compressor = CrossEncoderReranker(model=reranker, top_n=5)
|
24 |
retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) # Retrieve 2k initially
|
25 |
compression_retriever = ContextualCompressionRetriever(
|
|
|
1 |
import uuid
|
2 |
import chromadb
|
3 |
+
import torch
|
4 |
from langchain.vectorstores import Chroma
|
5 |
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
|
8 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
9 |
import gradio as gr
|
10 |
|
11 |
+
# Set device to GPU if available, else CPU
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
print(f"Using device: {device}")
|
14 |
+
|
15 |
# Initialize embedding model
|
16 |
+
embedding_model = HuggingFaceEmbeddings(
|
17 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
18 |
+
model_kwargs={"device": device}
|
19 |
+
)
|
20 |
|
21 |
# Initialize ChromaDB client and collection
|
22 |
chroma_client = chromadb.PersistentClient(path="./chroma_db")
|
|
|
27 |
)
|
28 |
|
29 |
# Initialize reranker
|
30 |
+
reranker = HuggingFaceCrossEncoder(
|
31 |
+
model_name="BAAI/bge-reranker-base",
|
32 |
+
model_kwargs={"device": device}
|
33 |
+
)
|
34 |
compressor = CrossEncoderReranker(model=reranker, top_n=5)
|
35 |
retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) # Retrieve 2k initially
|
36 |
compression_retriever = ContextualCompressionRetriever(
|