traopia commited on
Commit
743eef3
·
1 Parent(s): 997d13b
Files changed (1) hide show
  1. src/use_llm.py +3 -8
src/use_llm.py CHANGED
@@ -31,17 +31,12 @@ def main_generate(prompt, model=DEFAULT_MODEL, system_prompt="You are a helpful
31
 
32
 
33
 
34
- HF_TOKEN = os.getenv("HF_TOKEN")
35
- MODEL_ID = "thenlper/gte-large" # embedding model
36
 
37
- client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
38
 
39
  def get_embeddings(texts):
40
  if isinstance(texts, str):
41
  texts = [texts]
42
- embeddings = []
43
- for text in texts:
44
- response = client.text_to_vector(text)
45
- # response is usually a list of floats (the embedding vector)
46
- embeddings.append(response)
47
  return embeddings
 
31
 
32
 
33
 
34
+ from sentence_transformers import SentenceTransformer
 
35
 
36
+ model = SentenceTransformer("thenlper/gte-large") # downloaded from Hugging Face
37
 
38
  def get_embeddings(texts):
39
  if isinstance(texts, str):
40
  texts = [texts]
41
+ embeddings = model.encode(texts, convert_to_numpy=True)
 
 
 
 
42
  return embeddings