|
import streamlit as st |
|
import os |
|
from llama_index.core import ( |
|
VectorStoreIndex, |
|
SimpleDirectoryReader, |
|
Settings, |
|
PromptTemplate, |
|
QueryBundle, |
|
) |
|
from llama_index.llms.gemini import Gemini |
|
from llama_index.embeddings.gemini import GeminiEmbedding |
|
from llama_index.core import get_response_synthesizer |
|
from llama_index.core.node_parser import SemanticSplitterNodeParser |
|
from llama_index.core.retrievers import VectorIndexRetriever |
|
from llama_index.core.query_engine import RetrieverQueryEngine |
|
from llama_index.core.indices.query.query_transform import HyDEQueryTransform |
|
from llama_index.core.postprocessor import SentenceTransformerRerank |
|
from llama_index.core import load_index_from_storage |
|
from llama_index.core import StorageContext |
|
from llama_index.core.retrievers import QueryFusionRetriever |
|
from dotenv import load_dotenv |
|
import logging |
|
import google.generativeai as genai |
|
from pathlib import Path |
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
|
|
model_gemini_pro_vision = "gemini-pro-vision" |
|
model_gemini_pro = "gemini-pro" |
|
|
|
|
|
|
|
Settings.llm = Gemini(models=model_gemini_pro, api_key=os.getenv("GOOGLE_API_KEY")) |
|
Settings.embed_model = GeminiEmbedding( |
|
model_name="models/embedding-001", |
|
api_key=os.getenv("GOOGLE_API_KEY") |
|
) |
|
|
|
|
|
|
|
def create_semantic_splitter_node_parser(): |
|
"""Creates a semantic splitter.""" |
|
return SemanticSplitterNodeParser( |
|
buffer_size=1, breakpoint_percentile_threshold=95, embed_model=Settings.embed_model |
|
) |
|
|
|
|
|
def load_and_index_pdf(pdf_path): |
|
"""Loads and index the pdf. |
|
|
|
Args : |
|
pdf_path (str) : The path to the pdf file |
|
|
|
Returns : |
|
index (llama_index.core.VectorStoreIndex): The vector index |
|
""" |
|
try: |
|
logging.info(f"Loading PDF document from: {pdf_path}") |
|
documents = SimpleDirectoryReader(input_files=[pdf_path]).load_data() |
|
if documents: |
|
logging.info("Creating vector store index") |
|
index = VectorStoreIndex.from_documents(documents, node_parser=create_semantic_splitter_node_parser()) |
|
return index |
|
else: |
|
logging.warning("No documents found in the PDF") |
|
return None |
|
except Exception as e: |
|
logging.error(f"Error loading and indexing PDF: {e}") |
|
return None |
|
|
|
|
|
def create_rag_pipeline(index): |
|
"""Creates a RAG pipeline for translation. |
|
|
|
Args : |
|
index (llama_index.core.VectorStoreIndex): The vector index. |
|
|
|
Returns : |
|
query_engine(llama_index.core.query_engine.RetrieverQueryEngine): The query engine |
|
""" |
|
|
|
logging.info("Initializing RAG Pipeline components") |
|
|
|
|
|
retriever = index.as_retriever(similarity_top_k=5) |
|
|
|
|
|
|
|
hyde_query_transform = HyDEQueryTransform(llm=Settings.llm) |
|
|
|
|
|
reranker = SentenceTransformerRerank(top_n=3, model="BAAI/bge-reranker-base") |
|
|
|
|
|
response_synthesizer = get_response_synthesizer( |
|
response_mode="refine", |
|
) |
|
|
|
|
|
query_engine = RetrieverQueryEngine( |
|
retriever=retriever, |
|
response_synthesizer=response_synthesizer, |
|
node_postprocessors=[reranker], |
|
query_transform= hyde_query_transform |
|
) |
|
|
|
logging.info("RAG Pipeline is configured.") |
|
return query_engine |
|
|
|
def translate_text(french_text, query_engine): |
|
"""Translates french text to Yipunu using a highly optimized RAG. |
|
|
|
Args : |
|
french_text (str): The french text to translate. |
|
query_engine (llama_index.core.query_engine.RetrieverQueryEngine): The query engine. |
|
|
|
Returns: |
|
(str): The yipunu translation or an error message. |
|
""" |
|
|
|
try: |
|
logging.info(f"Initiating translation of: {french_text}") |
|
|
|
template = ( |
|
"Tu es un excellent traducteur du français vers le yipunu. Tu traduis le texte sans donner d'explication. " |
|
"Texte: {french_text} " |
|
"Traduction:" |
|
) |
|
|
|
prompt_template = PromptTemplate(template) |
|
query_bundle = QueryBundle(french_text, custom_prompt=prompt_template) |
|
response = query_engine.query(query_bundle) |
|
logging.info(f"Translation Result: {response.response}") |
|
return response.response |
|
except Exception as e: |
|
logging.error(f"Error during translation: {e}") |
|
return f"Error during translation: {str(e)}" |
|
|
|
|
|
|
|
def main(): |
|
"""Main function for streamlit app.""" |
|
|
|
st.title("French to Yipunu Translation App") |
|
|
|
|
|
default_pdf_path = Path("data/parlons_yipunu.pdf") |
|
|
|
|
|
if default_pdf_path.exists(): |
|
index = load_and_index_pdf(str(default_pdf_path)) |
|
if index: |
|
query_engine = create_rag_pipeline(index) |
|
french_text = st.text_area("Enter French Text:", "Ni vosi yipunu") |
|
if st.button("Translate"): |
|
translation = translate_text(french_text, query_engine) |
|
st.success(f"Yipunu Translation: {translation}") |
|
else: |
|
|
|
uploaded_file = st.file_uploader("Upload a PDF file containing the Punu grammar:", type="pdf") |
|
if uploaded_file is not None: |
|
|
|
temp_file_path = Path("temp_file.pdf") |
|
with open(temp_file_path, "wb") as f: |
|
f.write(uploaded_file.read()) |
|
|
|
index = load_and_index_pdf(str(temp_file_path)) |
|
if index: |
|
query_engine = create_rag_pipeline(index) |
|
french_text = st.text_area("Enter French Text:", "Ni vosi yipunu") |
|
if st.button("Translate"): |
|
translation = translate_text(french_text, query_engine) |
|
st.success(f"Yipunu Translation: {translation}") |
|
|
|
|
|
os.remove(temp_file_path) |
|
else: |
|
st.info("Please upload a pdf containing the punu grammar.") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |