File size: 2,318 Bytes
5e2e02c
1a75c78
5e2e02c
9f29f51
 
5e2e02c
9f29f51
 
1a75c78
 
 
 
 
 
9f29f51
5e2e02c
9f29f51
 
5e2e02c
9f29f51
 
5e2e02c
9f29f51
 
267005e
9f29f51
 
 
 
5e2e02c
 
 
 
267005e
9f29f51
 
5e2e02c
267005e
5e2e02c
 
267005e
5e2e02c
 
 
9f29f51
 
 
5e2e02c
267005e
5e2e02c
 
 
9f29f51
 
5e2e02c
9f29f51
 
5e2e02c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration

# Function to generate response using RAG (Retrieval-Augmented Generation)
def generate_response_with_rag(txt):
    try:
        # Initialize the RAG model and tokenizer
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-token-nq",
            index_name="exact",
            use_dummy_dataset=True,
            trust_remote_code=True  # Allows loading the required dataset script
        )
        model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")

        # Tokenize the input text
        inputs = tokenizer(txt, return_tensors="pt")

        # Retrieve relevant documents using the retriever
        retrieved_docs = retriever.retrieve(inputs["input_ids"])

        # Generate the output using RAG
        generated = model.generate(input_ids=inputs["input_ids"], context_input_ids=retrieved_docs['context_input_ids'])

        # Decode the generated text
        summary = tokenizer.decode(generated[0], skip_special_tokens=True)

        return summary
    except Exception as e:
        st.error(f"An error occurred during summarization: {str(e)}")
        return None

# Page title and layout
st.set_page_config(page_title='πŸ¦œπŸ”— RAG Text Summarization App')
st.title('πŸ¦œπŸ”— RAG Text Summarization App')

# Text input area for user to input text
txt_input = st.text_area('Enter your text', '', height=200)

# Form to accept the user's text input for summarization
response = None
with st.form('summarize_form', clear_on_submit=True):
    submitted = st.form_submit_button('Submit')
    if submitted and txt_input:
        with st.spinner('Summarizing with RAG...'):
            response = generate_response_with_rag(txt_input)

# Display the response if available
if response:
    st.info(response)

# Instructions for getting started with Hugging Face models
st.subheader("Hugging Face RAG Summarization")
st.write("""
This app uses Hugging Face's RAG model (Retrieval-Augmented Generation) to generate summaries with relevant external context.
RAG retrieves information from a set of documents and combines that with a generative model to produce more accurate summaries.
""")