wt002 commited on
Commit
9ac015d
Β·
verified Β·
1 Parent(s): dae11a5

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +72 -15
agent.py CHANGED
@@ -374,52 +374,107 @@ async def start_questions(request: Request):
374
  # -----------------------------
375
  # 1. Define Custom BERT Embedding Model
376
  # -----------------------------
 
 
 
 
 
377
  class BERTEmbeddings(Embeddings):
378
- def __init__(self, model_name='bert-base-uncased'):
 
379
  self.tokenizer = BertTokenizer.from_pretrained(model_name)
380
  self.model = BertModel.from_pretrained(model_name)
381
  self.model.eval() # Set model to eval mode
 
 
382
 
383
  def embed_documents(self, texts):
384
- inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
 
 
 
385
  with torch.no_grad():
386
  outputs = self.model(**inputs)
 
 
387
  embeddings = outputs.last_hidden_state.mean(dim=1)
388
- embeddings = F.normalize(embeddings, p=2, dim=1) # Normalize for cosine similarity
 
 
 
 
389
  return embeddings.cpu().numpy()
390
 
391
  def embed_query(self, text):
 
392
  return self.embed_documents([text])[0]
393
 
394
 
395
  # -----------------------------
396
  # 2. Initialize Embedding Model
397
  # -----------------------------
398
- embedding_model = BERTEmbeddings()
399
-
400
 
401
  # -----------------------------
402
- # 3. Prepare Documents
403
  # -----------------------------
404
- docs = [
405
- Document(page_content="Mercedes Sosa released many albums between 2000 and 2009.", metadata={"id": 1}),
406
- Document(page_content="She was a prominent Argentine folk singer.", metadata={"id": 2}),
407
- Document(page_content="Her album 'Al Despertar' was released in 1998.", metadata={"id": 3}),
408
- Document(page_content="She continued releasing music well into the 2000s.", metadata={"id": 4}),
409
- ]
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
  # -----------------------------
413
- # 4. Create FAISS Vector Store
414
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  vector_store = FAISS.from_documents(docs, embedding_model)
416
- vector_store.save_local("faiss_index")
 
 
 
 
 
 
 
 
 
 
417
 
418
 
419
  # -----------------------------
420
  # 6. Create LangChain Retriever Tool
421
  # -----------------------------
422
- retriever = vector_store.as_retriever()
 
423
 
424
  question_retriever_tool = create_retriever_tool(
425
  retriever=retriever,
@@ -1052,6 +1107,8 @@ def process_all_tasks(tasks: list):
1052
  ## Langgraph
1053
 
1054
  # Build graph function
 
 
1055
  provider = "huggingface"
1056
 
1057
  model_config = {
 
374
  # -----------------------------
375
  # 1. Define Custom BERT Embedding Model
376
  # -----------------------------
377
+ import torch
378
+ import torch.nn.functional as F
379
+ from transformers import BertTokenizer, BertModel
380
+ from langchain.embeddings import Embeddings
381
+
382
  class BERTEmbeddings(Embeddings):
383
+ def __init__(self, model_name='bert-base-uncased', device='cpu'):
384
+ # Initialize the tokenizer and model
385
  self.tokenizer = BertTokenizer.from_pretrained(model_name)
386
  self.model = BertModel.from_pretrained(model_name)
387
  self.model.eval() # Set model to eval mode
388
+ self.device = device
389
+ self.model.to(self.device) # Move model to the specified device (CPU or GPU)
390
 
391
  def embed_documents(self, texts):
392
+ # Tokenize the input texts
393
+ inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
394
+ inputs = {key: value.to(self.device) for key, value in inputs.items()} # Move inputs to the specified device
395
+
396
  with torch.no_grad():
397
  outputs = self.model(**inputs)
398
+
399
+ # Get the embeddings by averaging the last hidden state across tokens
400
  embeddings = outputs.last_hidden_state.mean(dim=1)
401
+
402
+ # Normalize embeddings for cosine similarity
403
+ embeddings = F.normalize(embeddings, p=2, dim=1)
404
+
405
+ # Return the embeddings as numpy array
406
  return embeddings.cpu().numpy()
407
 
408
  def embed_query(self, text):
409
+ # Embed a single query (text)
410
  return self.embed_documents([text])[0]
411
 
412
 
413
  # -----------------------------
414
  # 2. Initialize Embedding Model
415
  # -----------------------------
 
 
416
 
417
  # -----------------------------
418
+ # Create FAISS Vector Store
419
  # -----------------------------
 
 
 
 
 
 
420
 
421
+ class MyVectorStore:
422
+ def __init__(self, index: faiss.Index):
423
+ self.index = index
424
+
425
+ def save_local(self, path: str):
426
+ # Save the FAISS index to the specified file
427
+ faiss.write_index(self.index, "/home/wendy/Downloads")
428
+ print(f"Index saved to {path}")
429
+
430
+ @classmethod
431
+ def load_local(cls, path: str):
432
+ # Load the FAISS index from the specified file
433
+ index = faiss.read_index(path)
434
+ return cls(index)
435
 
436
  # -----------------------------
437
+ # 3. Prepare Documents
438
  # -----------------------------
439
+ # Define the URL where the JSON file is hosted
440
+ url = "https://agents-course-unit4-scoring.hf.space/questions"
441
+
442
+ # Download the JSON file from the URL
443
+ response = requests.get(url)
444
+ response.raise_for_status() # Ensure that the request was successful
445
+
446
+ # Parse the JSON data
447
+ docs = json.loads(response.text)
448
+
449
+ # Assuming the JSON structure has a 'text' field for each document
450
+ texts = [doc['text'] for doc in docs] # Extract text from JSON
451
+
452
+ # Initialize the embedding model
453
+ embedding_model = BERTEmbeddings()
454
+
455
+ # Generate embeddings for each document
456
+ embeddings = [embedding_model.encode(text) for text in texts]
457
+
458
+ # Create the FAISS index
459
  vector_store = FAISS.from_documents(docs, embedding_model)
460
+
461
+ # Save the FAISS index
462
+ vector_store = MyVectorStore(index)
463
+ vector_store.save_local("/home/wt/Downloads/faiss_index.index")
464
+
465
+ # Load the FAISS index later
466
+ loaded_vector_store = MyVectorStore.load_local("faiss_index.index")
467
+
468
+
469
+
470
+
471
 
472
 
473
  # -----------------------------
474
  # 6. Create LangChain Retriever Tool
475
  # -----------------------------
476
+
477
+ retriever = FAISS.load_local("faiss_index.index", embedding_model).as_retriever()
478
 
479
  question_retriever_tool = create_retriever_tool(
480
  retriever=retriever,
 
1107
  ## Langgraph
1108
 
1109
  # Build graph function
1110
+ vector_store = vector_store.save_local("faiss_index")
1111
+
1112
  provider = "huggingface"
1113
 
1114
  model_config = {