File size: 2,083 Bytes
ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a 0ab7ce5 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a b5e1894 ac47d9a 0ab7ce5 ac47d9a b5e1894 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
device = 'cpu'
# Initialize Pinecone instance
pc = Pinecone(api_key='89eeb534-da10-4068-92f7-12eddeabe1e5')
# Check if the index exists; if not, create it
index_name = 'abstractive-question-answering'
index = pc.Index(index_name)
def load_models():
print("Loading models...")
retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
tokenizer = T5Tokenizer.from_pretrained('t5-small')
generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
return retriever, generator, tokenizer
retriever, generator, tokenizer = load_models()
def process_query(query):
# Query Pinecone
xq = retriever.encode([query]).tolist()
xc = index.query(vector=xq, top_k=1, include_metadata=True)
# Print the response to check the structure
print("Pinecone response:", xc)
# Check if 'matches' exists and is a list
if 'matches' in xc and isinstance(xc['matches'], list):
context = [m['metadata']['Output'] for m in xc['matches']]
context_str = " ".join(context)
formatted_query = f"answer the question: {query} context: {context_str}"
else:
# Handle the case where 'matches' isn't found or isn't in the expected format
context_str = ""
formatted_query = f"answer the question: {query} context: {context_str}"
# Generate answer using T5 model
output_text = context_str
if len(output_text.splitlines()) > 5:
return output_text
if output_text.lower() == "none":
return "The topic is not covered in the student manual."
inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
ids = generator.generate(inputs, num_beams=2, min_length=10, max_length=60, repetition_penalty=1.2)
answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
return answer
|