File size: 2,083 Bytes
ac47d9a
 
 
 
b5e1894
ac47d9a
 
b5e1894
ac47d9a
b5e1894
 
 
ac47d9a
 
b5e1894
 
ac47d9a
0ab7ce5
ac47d9a
 
 
 
 
 
b5e1894
ac47d9a
b5e1894
ac47d9a
b5e1894
 
 
ac47d9a
b5e1894
ac47d9a
 
 
b5e1894
ac47d9a
b5e1894
ac47d9a
b5e1894
ac47d9a
 
b5e1894
 
 
 
 
 
 
ac47d9a
0ab7ce5
ac47d9a
 
b5e1894
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone

device = 'cpu' 

# Initialize Pinecone instance
pc = Pinecone(api_key='89eeb534-da10-4068-92f7-12eddeabe1e5')

# Check if the index exists; if not, create it
index_name = 'abstractive-question-answering'
index = pc.Index(index_name)

def load_models():
    print("Loading models...")
    
    retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
    tokenizer = T5Tokenizer.from_pretrained('t5-small')
    generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)

    return retriever, generator, tokenizer

retriever, generator, tokenizer = load_models()

def process_query(query):    
    # Query Pinecone
    xq = retriever.encode([query]).tolist()
    xc = index.query(vector=xq, top_k=1, include_metadata=True)
    
    # Print the response to check the structure
    print("Pinecone response:", xc)

    # Check if 'matches' exists and is a list
    if 'matches' in xc and isinstance(xc['matches'], list):
        context = [m['metadata']['Output'] for m in xc['matches']]
        context_str = " ".join(context)
        formatted_query = f"answer the question: {query} context: {context_str}"
    else:
        # Handle the case where 'matches' isn't found or isn't in the expected format
        context_str = ""
        formatted_query = f"answer the question: {query} context: {context_str}"

    # Generate answer using T5 model
    output_text = context_str
    if len(output_text.splitlines()) > 5:
        return output_text

    if output_text.lower() == "none":
        return "The topic is not covered in the student manual."

    inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
    ids = generator.generate(inputs, num_beams=2, min_length=10, max_length=60, repetition_penalty=1.2)
    answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)

    return answer