File size: 5,739 Bytes
1ffdd41
 
 
 
 
 
 
 
 
 
 
3945853
1ffdd41
 
 
 
 
 
2c46588
 
 
 
 
 
 
 
 
 
 
 
1ffdd41
 
 
2c46588
 
 
 
 
1ffdd41
 
 
 
 
 
 
 
 
 
2c46588
1ffdd41
 
2c46588
 
 
 
1ffdd41
2c46588
 
1ffdd41
2c46588
 
 
 
1ffdd41
 
 
2c46588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ffdd41
 
2c46588
1ffdd41
2c46588
 
 
1ffdd41
2c46588
1ffdd41
 
 
2c46588
 
 
 
 
 
 
 
 
 
1ffdd41
2c46588
 
1ffdd41
2c46588
 
 
 
 
1ffdd41
2c46588
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import gradio as gr
from langchain_core.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain_google_genai import ChatGoogleGenerativeAI
import google.generativeai as genai
from langchain.chains.question_answering import load_qa_chain
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

# Load Mistral model
model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
mistral_tokenizer = AutoTokenizer.from_pretrained(model_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16

# Improved model loading with error handling
try:
    mistral_model = AutoModelForCausalLM.from_pretrained(
        model_path, 
        torch_dtype=dtype, 
        device_map=device
    )
    print(f"Mistral model loaded successfully on {device}")
except Exception as e:
    print(f"Error loading Mistral model: {str(e)}")
    mistral_model = None

def initialize(file_path, question):
    try:
        # Check if API key is set
        api_key = os.getenv("GOOGLE_API_KEY")
        if not api_key:
            return "Error: GOOGLE_API_KEY environment variable is not set."
            
        model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
        prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
                              not contained in the context, say "answer not available in context" \n\n
                              Context: \n {context}?\n
                              Question: \n {question} \n
                              Answer:
                            """
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        
        if os.path.exists(file_path):
            # Load and process PDF
            pdf_loader = PyPDFLoader(file_path)
            pages = pdf_loader.load_and_split()
            
            if not pages:
                return "Error: The PDF file appears to be empty or could not be processed."
                
            context = "\n".join(str(page.page_content) for page in pages[:30])
            
            # Generate Gemini answer
            stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
            stuff_answer = stuff_chain(
                {"input_documents": pages, "question": question, "context": context}, 
                return_only_outputs=True
            )
            gemini_answer = stuff_answer['output_text']
            
            # Use Mistral model for additional text generation
            if mistral_model is not None:
                mistral_prompt = f"Based on this answer: {gemini_answer}\nGenerate a follow-up question:"
                mistral_inputs = mistral_tokenizer.encode(mistral_prompt, return_tensors='pt').to(device)
                
                with torch.no_grad():
                    mistral_outputs = mistral_model.generate(
                        mistral_inputs, 
                        max_length=200,  # Increased max length
                        min_length=20,   # Set min length
                        do_sample=True,  # Enable sampling
                        top_p=0.95,      # Top-p sampling
                        temperature=0.7  # Temperature for creativity
                    )
                mistral_output = mistral_tokenizer.decode(mistral_outputs[0], skip_special_tokens=True)
                # Clean up the output to get just the follow-up question
                if "Generate a follow-up question:" in mistral_output:
                    mistral_output = mistral_output.split("Generate a follow-up question:")[1].strip()
                
                combined_output = f"Gemini Answer: {gemini_answer}\n\nMistral Follow-up: {mistral_output}"
            else:
                combined_output = f"Gemini Answer: {gemini_answer}\n\n(Mistral model unavailable)"
                
            return combined_output
        else:
            return f"Error: File not found at path '{file_path}'. Please ensure the PDF file is valid."
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        return f"An error occurred: {str(e)}\n\nDetails: {error_details}"

# Define Gradio Interface with improved error handling
def pdf_qa(file, question):
    if file is None:
        return "Please upload a PDF file first."
    
    if not question or question.strip() == "":
        return "Please enter a question about the document."
    
    try:
        return initialize(file.name, question)
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        return f"Error processing request: {str(e)}\n\nDetails: {error_details}"

# Create Gradio Interface with additional options
demo = gr.Interface(
    fn=pdf_qa,
    inputs=[
        gr.File(label="Upload PDF File", file_types=[".pdf"]),
        gr.Textbox(label="Ask about the document", placeholder="What is the main topic of this document?")
    ],
    outputs=gr.Textbox(label="Answer - Combined Gemini and Mistral"),
    title="RAG Knowledge Retrieval using Gemini API and Mistral Model",
    description="Upload a PDF file and ask questions about the content. The system uses Gemini for answering and Mistral for generating follow-up questions.",
    examples=[
        [None, "What are the main findings in this document?"],
        [None, "Summarize the key points discussed in this paper."]
    ],
    allow_flagging="never"
)

# Launch the app with additional parameters
if __name__ == "__main__":
    demo.launch(share=True, debug=True)