ajnjdnjpek / app.py
Adieee5's picture
Update app.py
ca8feaa verified
raw
history blame
2.39 kB
from dotenv import load_dotenv
load_dotenv()
import os
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import PromptTemplate
from langchain.chains import RetrievalQA
# === Step 1: Load API Key ===
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY not found in environment variables.")
# === Step 2: Initialize LLM (Gemini) ===
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash-lite",
google_api_key=GOOGLE_API_KEY,
convert_system_message_to_human=True
)
# === Step 3: Load Chroma Vector Store ===
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
vectordb = Chroma(
persist_directory="./chroma_store",
embedding_function=embedding_model,
collection_name="pdf_search_chroma"
)
retriever = vectordb.as_retriever(search_kwargs={"k": 6})
# === Step 4: Custom Prompt Template ===
prompt_template = PromptTemplate.from_template("""
You are an intelligent assistant for students asking about their university.
If answer is not defined or not clearly understood, ask for clarification.
Answer clearly and helpfully based on the retrieved context. Do not make up information or suggestions.
Context:
{context}
Question:
{question}
Answer:
""")
# === Step 5: Create Retrieval-QA Chain ===
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt_template}
)
# === Step 6: Flask Setup ===
app = Flask(__name__, static_folder="static", template_folder="templates")
# === Step 7: Serve Frontend ===
@app.route("/")
def index():
return render_template("index.html") # Make sure chat.html exists
@app.route('/get', methods=['POST'])
def chat():
data = request.json
query = data.get('message', '').strip()
if not query:
return jsonify({"error": "No message provided."}), 400
try:
response = qa_chain.run(query)
return jsonify({"response": response})
except Exception as e:
return jsonify({"error": str(e)}), 500
# === Step 9: Run the App ===
if __name__ == '__main__':
app.run(debug=False)