rahideer's picture
Create app.py
b77a775 verified
raw
history blame
1.8 kB
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import load_dataset
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification
# Load the XNLI dataset (Multilingual NLI dataset) for demonstration
dataset = load_dataset("xnli", split="validation")
# Initialize tokenizer and retriever for multilingual support (using XLM-Roberta)
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_multilingual_dataset")
# Initialize the RAG model
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
# Define Streamlit app
st.title('Multilingual RAG Translator/Answer Bot')
st.markdown("This app uses a multilingual RAG model to answer your questions in the language of the query. Ask questions in languages like Urdu, Hindi, or French!")
# User input for query
user_query = st.text_input("Ask a question in Urdu, Hindi, or French:")
if user_query:
# Tokenize the input question
inputs = tokenizer(user_query, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs['input_ids']
# Use the retriever to get relevant context
retrieved_docs = retriever.retrieve(input_ids)
# Generate an answer using the context
generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs)
answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Display the answer
st.write(f"Answer: {answer}")
# Display the most relevant documents
st.subheader("Relevant Documents:")
for doc in retrieved_docs:
st.write(doc['text'][:300] + '...') # Display first 300 characters of each doc