hugging2021 commited on
Commit
757d2c1
·
verified ·
1 Parent(s): 1d98a7e

Update vector_store.py

Browse files
Files changed (1) hide show
  1. vector_store.py +2 -2
vector_store.py CHANGED
@@ -44,7 +44,7 @@ def apply_corrections(text):
44
 
45
  # --------------------------------
46
  # Load the embedding model
47
- def get_embeddings(model_name="intfloat/multilingual-e5-large-instruct", device="cuda"):
48
  return HuggingFaceEmbeddings(
49
  model_name=model_name,
50
  model_kwargs={'device': device},
@@ -111,7 +111,7 @@ if __name__ == "__main__":
111
  parser.add_argument("--folder", type=str, default="dataset", help="Path to the folder containing the documents")
112
  parser.add_argument("--save_path", type=str, default="vector_db", help="Path to save the vector store")
113
  parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
114
- parser.add_argument("--model_name", type=str, default="intfloat/multilingual-e5-large-instruct", help="Name of the embedding model")
115
  parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use ('cuda' or 'cpu')")
116
 
117
  args = parser.parse_args()
 
44
 
45
  # --------------------------------
46
  # Load the embedding model
47
+ def get_embeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", device="cuda"):
48
  return HuggingFaceEmbeddings(
49
  model_name=model_name,
50
  model_kwargs={'device': device},
 
111
  parser.add_argument("--folder", type=str, default="dataset", help="Path to the folder containing the documents")
112
  parser.add_argument("--save_path", type=str, default="vector_db", help="Path to save the vector store")
113
  parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
114
+ parser.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2", help="Name of the embedding model")
115
  parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use ('cuda' or 'cpu')")
116
 
117
  args = parser.parse_args()