Spaces:
Sleeping
Sleeping
File size: 5,739 Bytes
1ffdd41 3945853 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 1ffdd41 2c46588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import gradio as gr
from langchain_core.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain_google_genai import ChatGoogleGenerativeAI
import google.generativeai as genai
from langchain.chains.question_answering import load_qa_chain
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Load Mistral model
model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
mistral_tokenizer = AutoTokenizer.from_pretrained(model_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16
# Improved model loading with error handling
try:
mistral_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
device_map=device
)
print(f"Mistral model loaded successfully on {device}")
except Exception as e:
print(f"Error loading Mistral model: {str(e)}")
mistral_model = None
def initialize(file_path, question):
try:
# Check if API key is set
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
return "Error: GOOGLE_API_KEY environment variable is not set."
model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
not contained in the context, say "answer not available in context" \n\n
Context: \n {context}?\n
Question: \n {question} \n
Answer:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
if os.path.exists(file_path):
# Load and process PDF
pdf_loader = PyPDFLoader(file_path)
pages = pdf_loader.load_and_split()
if not pages:
return "Error: The PDF file appears to be empty or could not be processed."
context = "\n".join(str(page.page_content) for page in pages[:30])
# Generate Gemini answer
stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
stuff_answer = stuff_chain(
{"input_documents": pages, "question": question, "context": context},
return_only_outputs=True
)
gemini_answer = stuff_answer['output_text']
# Use Mistral model for additional text generation
if mistral_model is not None:
mistral_prompt = f"Based on this answer: {gemini_answer}\nGenerate a follow-up question:"
mistral_inputs = mistral_tokenizer.encode(mistral_prompt, return_tensors='pt').to(device)
with torch.no_grad():
mistral_outputs = mistral_model.generate(
mistral_inputs,
max_length=200, # Increased max length
min_length=20, # Set min length
do_sample=True, # Enable sampling
top_p=0.95, # Top-p sampling
temperature=0.7 # Temperature for creativity
)
mistral_output = mistral_tokenizer.decode(mistral_outputs[0], skip_special_tokens=True)
# Clean up the output to get just the follow-up question
if "Generate a follow-up question:" in mistral_output:
mistral_output = mistral_output.split("Generate a follow-up question:")[1].strip()
combined_output = f"Gemini Answer: {gemini_answer}\n\nMistral Follow-up: {mistral_output}"
else:
combined_output = f"Gemini Answer: {gemini_answer}\n\n(Mistral model unavailable)"
return combined_output
else:
return f"Error: File not found at path '{file_path}'. Please ensure the PDF file is valid."
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"An error occurred: {str(e)}\n\nDetails: {error_details}"
# Define Gradio Interface with improved error handling
def pdf_qa(file, question):
if file is None:
return "Please upload a PDF file first."
if not question or question.strip() == "":
return "Please enter a question about the document."
try:
return initialize(file.name, question)
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"Error processing request: {str(e)}\n\nDetails: {error_details}"
# Create Gradio Interface with additional options
demo = gr.Interface(
fn=pdf_qa,
inputs=[
gr.File(label="Upload PDF File", file_types=[".pdf"]),
gr.Textbox(label="Ask about the document", placeholder="What is the main topic of this document?")
],
outputs=gr.Textbox(label="Answer - Combined Gemini and Mistral"),
title="RAG Knowledge Retrieval using Gemini API and Mistral Model",
description="Upload a PDF file and ask questions about the content. The system uses Gemini for answering and Mistral for generating follow-up questions.",
examples=[
[None, "What are the main findings in this document?"],
[None, "Summarize the key points discussed in this paper."]
],
allow_flagging="never"
)
# Launch the app with additional parameters
if __name__ == "__main__":
demo.launch(share=True, debug=True) |