Spaces:
Build error
Build error
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -374,52 +374,107 @@ async def start_questions(request: Request):
|
|
| 374 |
# -----------------------------
|
| 375 |
# 1. Define Custom BERT Embedding Model
|
| 376 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
class BERTEmbeddings(Embeddings):
|
| 378 |
-
def __init__(self, model_name='bert-base-uncased'):
|
|
|
|
| 379 |
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 380 |
self.model = BertModel.from_pretrained(model_name)
|
| 381 |
self.model.eval() # Set model to eval mode
|
|
|
|
|
|
|
| 382 |
|
| 383 |
def embed_documents(self, texts):
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
| 385 |
with torch.no_grad():
|
| 386 |
outputs = self.model(**inputs)
|
|
|
|
|
|
|
| 387 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
return embeddings.cpu().numpy()
|
| 390 |
|
| 391 |
def embed_query(self, text):
|
|
|
|
| 392 |
return self.embed_documents([text])[0]
|
| 393 |
|
| 394 |
|
| 395 |
# -----------------------------
|
| 396 |
# 2. Initialize Embedding Model
|
| 397 |
# -----------------------------
|
| 398 |
-
embedding_model = BERTEmbeddings()
|
| 399 |
-
|
| 400 |
|
| 401 |
# -----------------------------
|
| 402 |
-
#
|
| 403 |
# -----------------------------
|
| 404 |
-
docs = [
|
| 405 |
-
Document(page_content="Mercedes Sosa released many albums between 2000 and 2009.", metadata={"id": 1}),
|
| 406 |
-
Document(page_content="She was a prominent Argentine folk singer.", metadata={"id": 2}),
|
| 407 |
-
Document(page_content="Her album 'Al Despertar' was released in 1998.", metadata={"id": 3}),
|
| 408 |
-
Document(page_content="She continued releasing music well into the 2000s.", metadata={"id": 4}),
|
| 409 |
-
]
|
| 410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
# -----------------------------
|
| 413 |
-
#
|
| 414 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
|
| 419 |
# -----------------------------
|
| 420 |
# 6. Create LangChain Retriever Tool
|
| 421 |
# -----------------------------
|
| 422 |
-
|
|
|
|
| 423 |
|
| 424 |
question_retriever_tool = create_retriever_tool(
|
| 425 |
retriever=retriever,
|
|
@@ -1052,6 +1107,8 @@ def process_all_tasks(tasks: list):
|
|
| 1052 |
## Langgraph
|
| 1053 |
|
| 1054 |
# Build graph function
|
|
|
|
|
|
|
| 1055 |
provider = "huggingface"
|
| 1056 |
|
| 1057 |
model_config = {
|
|
|
|
| 374 |
# -----------------------------
|
| 375 |
# 1. Define Custom BERT Embedding Model
|
| 376 |
# -----------------------------
|
| 377 |
+
import torch
|
| 378 |
+
import torch.nn.functional as F
|
| 379 |
+
from transformers import BertTokenizer, BertModel
|
| 380 |
+
from langchain.embeddings import Embeddings
|
| 381 |
+
|
| 382 |
class BERTEmbeddings(Embeddings):
|
| 383 |
+
def __init__(self, model_name='bert-base-uncased', device='cpu'):
|
| 384 |
+
# Initialize the tokenizer and model
|
| 385 |
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 386 |
self.model = BertModel.from_pretrained(model_name)
|
| 387 |
self.model.eval() # Set model to eval mode
|
| 388 |
+
self.device = device
|
| 389 |
+
self.model.to(self.device) # Move model to the specified device (CPU or GPU)
|
| 390 |
|
| 391 |
def embed_documents(self, texts):
|
| 392 |
+
# Tokenize the input texts
|
| 393 |
+
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
| 394 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()} # Move inputs to the specified device
|
| 395 |
+
|
| 396 |
with torch.no_grad():
|
| 397 |
outputs = self.model(**inputs)
|
| 398 |
+
|
| 399 |
+
# Get the embeddings by averaging the last hidden state across tokens
|
| 400 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 401 |
+
|
| 402 |
+
# Normalize embeddings for cosine similarity
|
| 403 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 404 |
+
|
| 405 |
+
# Return the embeddings as numpy array
|
| 406 |
return embeddings.cpu().numpy()
|
| 407 |
|
| 408 |
def embed_query(self, text):
|
| 409 |
+
# Embed a single query (text)
|
| 410 |
return self.embed_documents([text])[0]
|
| 411 |
|
| 412 |
|
| 413 |
# -----------------------------
|
| 414 |
# 2. Initialize Embedding Model
|
| 415 |
# -----------------------------
|
|
|
|
|
|
|
| 416 |
|
| 417 |
# -----------------------------
|
| 418 |
+
# Create FAISS Vector Store
|
| 419 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
+
class MyVectorStore:
|
| 422 |
+
def __init__(self, index: faiss.Index):
|
| 423 |
+
self.index = index
|
| 424 |
+
|
| 425 |
+
def save_local(self, path: str):
|
| 426 |
+
# Save the FAISS index to the specified file
|
| 427 |
+
faiss.write_index(self.index, "/home/wendy/Downloads")
|
| 428 |
+
print(f"Index saved to {path}")
|
| 429 |
+
|
| 430 |
+
@classmethod
|
| 431 |
+
def load_local(cls, path: str):
|
| 432 |
+
# Load the FAISS index from the specified file
|
| 433 |
+
index = faiss.read_index(path)
|
| 434 |
+
return cls(index)
|
| 435 |
|
| 436 |
# -----------------------------
|
| 437 |
+
# 3. Prepare Documents
|
| 438 |
# -----------------------------
|
| 439 |
+
# Define the URL where the JSON file is hosted
|
| 440 |
+
url = "https://agents-course-unit4-scoring.hf.space/questions"
|
| 441 |
+
|
| 442 |
+
# Download the JSON file from the URL
|
| 443 |
+
response = requests.get(url)
|
| 444 |
+
response.raise_for_status() # Ensure that the request was successful
|
| 445 |
+
|
| 446 |
+
# Parse the JSON data
|
| 447 |
+
docs = json.loads(response.text)
|
| 448 |
+
|
| 449 |
+
# Assuming the JSON structure has a 'text' field for each document
|
| 450 |
+
texts = [doc['text'] for doc in docs] # Extract text from JSON
|
| 451 |
+
|
| 452 |
+
# Initialize the embedding model
|
| 453 |
+
embedding_model = BERTEmbeddings()
|
| 454 |
+
|
| 455 |
+
# Generate embeddings for each document
|
| 456 |
+
embeddings = [embedding_model.encode(text) for text in texts]
|
| 457 |
+
|
| 458 |
+
# Create the FAISS index
|
| 459 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
| 460 |
+
|
| 461 |
+
# Save the FAISS index
|
| 462 |
+
vector_store = MyVectorStore(index)
|
| 463 |
+
vector_store.save_local("/home/wt/Downloads/faiss_index.index")
|
| 464 |
+
|
| 465 |
+
# Load the FAISS index later
|
| 466 |
+
loaded_vector_store = MyVectorStore.load_local("faiss_index.index")
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
|
| 471 |
|
| 472 |
|
| 473 |
# -----------------------------
|
| 474 |
# 6. Create LangChain Retriever Tool
|
| 475 |
# -----------------------------
|
| 476 |
+
|
| 477 |
+
retriever = FAISS.load_local("faiss_index.index", embedding_model).as_retriever()
|
| 478 |
|
| 479 |
question_retriever_tool = create_retriever_tool(
|
| 480 |
retriever=retriever,
|
|
|
|
| 1107 |
## Langgraph
|
| 1108 |
|
| 1109 |
# Build graph function
|
| 1110 |
+
vector_store = vector_store.save_local("faiss_index")
|
| 1111 |
+
|
| 1112 |
provider = "huggingface"
|
| 1113 |
|
| 1114 |
model_config = {
|