Spaces:
Sleeping
Sleeping
Update rag/RAG.py
Browse files- rag/RAG.py +19 -9
rag/RAG.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from google import genai
|
2 |
from google.genai import types
|
3 |
import numpy as np
|
|
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
import os
|
6 |
from dotenv import load_dotenv
|
@@ -26,23 +27,32 @@ class RAG:
|
|
26 |
except Exception as e:
|
27 |
raise ValueError(f"an error occured: {e}")
|
28 |
|
29 |
-
def generate_embedding(self,text,task_type=None):
|
30 |
try:
|
31 |
-
if
|
32 |
task_type = self.TASK_TYPE
|
33 |
-
|
34 |
chunks = self.split_text(text)
|
35 |
-
for i in range(0,len(chunks),self.MAX_BATCH_SIZE)
|
|
|
|
|
36 |
response = client.models.embed_content(
|
37 |
model=self.MODEL,
|
38 |
-
contents=
|
39 |
config=types.EmbedContentConfig(task_type=task_type)
|
40 |
)
|
41 |
-
for
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
return {"embeddings": embeddings, "chunks": chunks}, 200
|
44 |
-
|
45 |
-
|
|
|
46 |
|
47 |
|
48 |
|
|
|
1 |
from google import genai
|
2 |
from google.genai import types
|
3 |
import numpy as np
|
4 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
import os
|
7 |
from dotenv import load_dotenv
|
|
|
27 |
except Exception as e:
|
28 |
raise ValueError(f"an error occured: {e}")
|
29 |
|
30 |
+
def generate_embedding(self, text, task_type=None):
|
31 |
try:
|
32 |
+
if not task_type:
|
33 |
task_type = self.TASK_TYPE
|
34 |
+
|
35 |
chunks = self.split_text(text)
|
36 |
+
batches = [chunks[i:i + self.MAX_BATCH_SIZE] for i in range(0, len(chunks), self.MAX_BATCH_SIZE)]
|
37 |
+
|
38 |
+
def embed_batch(batch):
|
39 |
response = client.models.embed_content(
|
40 |
model=self.MODEL,
|
41 |
+
contents=batch,
|
42 |
config=types.EmbedContentConfig(task_type=task_type)
|
43 |
)
|
44 |
+
return [embedding.values for embedding in response.embeddings]
|
45 |
+
|
46 |
+
embeddings = []
|
47 |
+
with ThreadPoolExecutor(max_workers=100) as executor:
|
48 |
+
futures = [executor.submit(embed_batch, batch) for batch in batches]
|
49 |
+
for future in as_completed(futures):
|
50 |
+
embeddings.extend(future.result())
|
51 |
+
|
52 |
return {"embeddings": embeddings, "chunks": chunks}, 200
|
53 |
+
|
54 |
+
except Exception as e:
|
55 |
+
return {"an error occurred": str(e)}, 500
|
56 |
|
57 |
|
58 |
|