Update chatbot.py
Browse files- chatbot.py +11 -7
chatbot.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os
|
2 |
import pickle
|
3 |
import torch
|
4 |
-
import pickle
|
5 |
import matplotlib.pyplot as plt
|
6 |
from langchain_community.document_loaders import TextLoader
|
7 |
from datasets import load_dataset
|
@@ -9,6 +8,8 @@ from sentence_transformers import SentenceTransformer, util
|
|
9 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
10 |
from transformers import BertModel, BertTokenizer
|
11 |
from langchain_core.prompts import PromptTemplate
|
|
|
|
|
12 |
|
13 |
os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
|
14 |
|
@@ -45,6 +46,8 @@ class Chatbot:
|
|
45 |
self.gpt2_model_name = "gpt2"
|
46 |
self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
|
47 |
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
|
|
|
|
|
48 |
|
49 |
def load_embeddings(self):
|
50 |
if os.path.exists("embeddings_cache.pkl"):
|
@@ -102,11 +105,12 @@ class Chatbot:
|
|
102 |
plt.axis('off')
|
103 |
plt.show()
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
110 |
|
111 |
def generate_response(self, query):
|
112 |
# Process the user query and generate a response
|
@@ -119,4 +123,4 @@ class Chatbot:
|
|
119 |
self.display_text_and_images(results_text)
|
120 |
|
121 |
# Return both chatbot response and recommended products
|
122 |
-
return chatbot_response,results_text
|
|
|
1 |
import os
|
2 |
import pickle
|
3 |
import torch
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
from langchain_community.document_loaders import TextLoader
|
6 |
from datasets import load_dataset
|
|
|
8 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
9 |
from transformers import BertModel, BertTokenizer
|
10 |
from langchain_core.prompts import PromptTemplate
|
11 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
12 |
+
from PIL import Image
|
13 |
|
14 |
os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
|
15 |
|
|
|
46 |
self.gpt2_model_name = "gpt2"
|
47 |
self.gpt2_model = GPT2LMHeadModel.from_pretrained(self.gpt2_model_name)
|
48 |
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(self.gpt2_model_name)
|
49 |
+
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
50 |
+
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
|
51 |
|
52 |
def load_embeddings(self):
|
53 |
if os.path.exists("embeddings_cache.pkl"):
|
|
|
105 |
plt.axis('off')
|
106 |
plt.show()
|
107 |
|
108 |
+
def generate_image_caption(self, image_path):
|
109 |
+
raw_image = Image.open(image_path).convert('RGB')
|
110 |
+
inputs = self.blip_processor(raw_image, return_tensors="pt")
|
111 |
+
out = self.blip_model.generate(**inputs)
|
112 |
+
caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
|
113 |
+
return caption
|
114 |
|
115 |
def generate_response(self, query):
|
116 |
# Process the user query and generate a response
|
|
|
123 |
self.display_text_and_images(results_text)
|
124 |
|
125 |
# Return both chatbot response and recommended products
|
126 |
+
return chatbot_response, results_text
|