Prathamesh1420 commited on
Commit
eb7c94b
·
verified ·
1 Parent(s): d944b9e

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +11 -7
chatbot.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import pickle
3
  import torch
4
- import pickle
5
  import matplotlib.pyplot as plt
6
  from langchain_community.document_loaders import TextLoader
7
  from datasets import load_dataset
@@ -9,6 +8,8 @@ from sentence_transformers import SentenceTransformer, util
9
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
10
  from transformers import BertModel, BertTokenizer
11
  from langchain_core.prompts import PromptTemplate
 
 
12
 
13
  os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
14
 
@@ -45,6 +46,8 @@ class Chatbot:
45
  self.gpt2_model_name = "gpt2"
46
  self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
47
  self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
 
 
48
 
49
  def load_embeddings(self):
50
  if os.path.exists("embeddings_cache.pkl"):
@@ -102,11 +105,12 @@ class Chatbot:
102
  plt.axis('off')
103
  plt.show()
104
 
105
- @staticmethod
106
- def cos_sim(a, b):
107
- a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
108
- b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
109
- return torch.mm(a_norm.T, b_norm) # Reshape a_norm to (768, 1)
 
110
 
111
  def generate_response(self, query):
112
  # Process the user query and generate a response
@@ -119,4 +123,4 @@ class Chatbot:
119
  self.display_text_and_images(results_text)
120
 
121
  # Return both chatbot response and recommended products
122
- return chatbot_response,results_text
 
1
  import os
2
  import pickle
3
  import torch
 
4
  import matplotlib.pyplot as plt
5
  from langchain_community.document_loaders import TextLoader
6
  from datasets import load_dataset
 
8
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
  from transformers import BertModel, BertTokenizer
10
  from langchain_core.prompts import PromptTemplate
11
+ from transformers import BlipProcessor, BlipForConditionalGeneration
12
+ from PIL import Image
13
 
14
  os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
15
 
 
46
  self.gpt2_model_name = "gpt2"
47
  self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
48
  self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
49
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
50
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
51
 
52
  def load_embeddings(self):
53
  if os.path.exists("embeddings_cache.pkl"):
 
105
  plt.axis('off')
106
  plt.show()
107
 
108
+ def generate_image_caption(self, image_path):
109
+ raw_image = Image.open(image_path).convert('RGB')
110
+ inputs = self.blip_processor(raw_image, return_tensors="pt")
111
+ out = self.blip_model.generate(**inputs)
112
+ caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
113
+ return caption
114
 
115
  def generate_response(self, query):
116
  # Process the user query and generate a response
 
123
  self.display_text_and_images(results_text)
124
 
125
  # Return both chatbot response and recommended products
126
+ return chatbot_response, results_text