File size: 3,894 Bytes
9912c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import warnings
import tensorflow as tf
import requests  # For API requests
from flask import Flask, render_template, request, jsonify
from langchain_core.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import pickle
from transformers import T5Tokenizer, T5ForConditionalGeneration
from langchain_community.llms import HuggingFacePipeline
import traceback 
import re 
from src.helper import download_hugging_face_embeddings, load_faiss_with_metadata

app = Flask(__name__)

# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)


# Initialize components
embeddings = download_hugging_face_embeddings()
vectorstore = load_faiss_with_metadata(embeddings)
retriever = vectorstore.as_retriever(search_kwargs={'k': 5})

# Load fine-tuned model
model_path = "flan-t5-finetuned-new"
tokenizer = T5Tokenizer.from_pretrained(model_path, use_fast=True, add_prefix_space=True)
model = T5ForConditionalGeneration.from_pretrained(model_path)

pipeline = HuggingFacePipeline.from_model_id(
    model_id=model_path,
    task="text2text-generation",
    pipeline_kwargs={
        "max_length": 50,
        "min_length": 15,
        "top_k": 25,
        "temperature": 0.4,
        "top_p": 0.85,
        "repetition_penalty": 1.7, 
        "do_sample" : True , 
        "num_beams": 3 , # Reduced beams for speed
    }
)

PROMPT = PromptTemplate(
    template=(
        "You are a helpful and expert Health-care assistant. Analyze the context and provide a detailed, accurate, and structured response to the question.\n\n"
        "If you don't know the answer just say I don't know."
        "Context: {context}\n\n"
        "Question: {question}\n"
        "Answer with explanations where applicable:"
    ),
    input_variables=["context", "question"]
)


qa = RetrievalQA.from_chain_type(
    llm=pipeline,
    chain_type="stuff",
    retriever=retriever,
    return_source_documents=False,
    chain_type_kwargs={"prompt": PROMPT}
)

# Clean up the context
def clean_context(context):
    return context.replace('\n', ' ').strip()

# Post-process the model output
def post_process_response(response):
    response = re.sub(r'\b(\w+\s+){0,2}(\w+)(\s+\2)+\b', r'\1\2', response)  # Remove repetitive words
    response = re.sub(r'(state of being )+', 'state of being ', response, flags=re.IGNORECASE)
    sentences = [sentence.strip().capitalize() for sentence in response.split('.') if sentence.strip()]
    return ". ".join(sentences[:3]) + "." if sentences else "I don't know."

# Main function for getting output
def get_output(input_text):
    # Reduced retrieval step size
    retrieved_documents = retriever.invoke(input_text)
    context = " ".join(clean_context(doc.page_content) for doc in retrieved_documents)
    query_input = f"Context: {context}\n\nQuestion: {input_text}\n"
    result = qa({"query": query_input})
    return post_process_response(result['result'])

@app.route("/")
def index():
    return render_template('chat.html')

@app.route("/get", methods=['GET', 'POST'])
def chat():
    msg = request.form["msg"]
    print(f"User input: {msg}")
    input_text = msg
    try:
        if "and" in input_text:
            sub_queries = input_text.lower().split("and")
            responses = [get_output(sub_query.strip()) for sub_query in sub_queries]
            return str("\n".join(responses)) 
        else: 
            return str(get_output(input_text)) 
        

    except Exception as e:
        print(traceback.format_exc())
        return f"Error processing the request: {e}"

if __name__ == '__main__':
    app.run(debug=True)