Y-Mangoes commited on
Commit
ac90524
·
verified ·
1 Parent(s): 4e06373

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import uuid
2
  import chromadb
 
3
  from langchain.vectorstores import Chroma
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
  from langchain.retrievers import ContextualCompressionRetriever
@@ -7,8 +8,15 @@ from langchain.retrievers.document_compressors import CrossEncoderReranker
7
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
8
  import gradio as gr
9
 
 
 
 
 
10
  # Initialize embedding model
11
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
12
 
13
  # Initialize ChromaDB client and collection
14
  chroma_client = chromadb.PersistentClient(path="./chroma_db")
@@ -19,7 +27,10 @@ vectorstore = Chroma(
19
  )
20
 
21
  # Initialize reranker
22
- reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
 
 
 
23
  compressor = CrossEncoderReranker(model=reranker, top_n=5)
24
  retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) # Retrieve 2k initially
25
  compression_retriever = ContextualCompressionRetriever(
 
1
  import uuid
2
  import chromadb
3
+ import torch
4
  from langchain.vectorstores import Chroma
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
  from langchain.retrievers import ContextualCompressionRetriever
 
8
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
9
  import gradio as gr
10
 
11
+ # Set device to GPU if available, else CPU
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {device}")
14
+
15
  # Initialize embedding model
16
+ embedding_model = HuggingFaceEmbeddings(
17
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
18
+ model_kwargs={"device": device}
19
+ )
20
 
21
  # Initialize ChromaDB client and collection
22
  chroma_client = chromadb.PersistentClient(path="./chroma_db")
 
27
  )
28
 
29
  # Initialize reranker
30
+ reranker = HuggingFaceCrossEncoder(
31
+ model_name="BAAI/bge-reranker-base",
32
+ model_kwargs={"device": device}
33
+ )
34
  compressor = CrossEncoderReranker(model=reranker, top_n=5)
35
  retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) # Retrieve 2k initially
36
  compression_retriever = ContextualCompressionRetriever(