import gradio as gr import fitz # PyMuPDF for reading PDFs import numpy as np from bokeh.plotting import figure, output_file, save from bokeh.models import HoverTool, ColumnDataSource import umap import pandas as pd from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances from sentence_transformers import SentenceTransformer import tempfile # Initialize the model globally model = SentenceTransformer('all-MiniLM-L6-v2') def process_pdf(pdf_path): # Open the PDF doc = fitz.open(pdf_path) texts = [page.get_text() for page in doc] return " ".join(texts) def create_embeddings(text): sentences = text.split(". ") # A simple split; consider a more robust sentence splitter embeddings = model.encode(sentences) return embeddings, sentences def generate_plot(query, pdf_file): # Generate embeddings for the query query_embedding = model.encode([query])[0] # Process the PDF and create embeddings text = process_pdf(pdf_file.name) embeddings, sentences = create_embeddings(text) # Prepare the data for UMAP and visualization all_embeddings = np.vstack([embeddings, query_embedding]) all_sentences = sentences + [query] # UMAP transformation umap_transform = umap.UMAP(n_neighbors=15, min_dist=0.0, n_components=2, random_state=42) umap_embeddings = umap_transform.fit_transform(all_embeddings) # Find the closest sentences to the query distances = cosine_similarity([query_embedding], embeddings)[0] closest_indices = distances.argsort()[-5:][::-1] # Adjust the number as needed # Prepare data for plotting data = { 'x': umap_embeddings[:-1, 0], # Exclude the query point itself 'y': umap_embeddings[:-1, 1], # Exclude the query point itself 'content': all_sentences[:-1], # Exclude the query sentence itself 'color': ['red' if i in closest_indices else 'blue' for i in range(len(sentences))], } source = ColumnDataSource(data) # Create the Bokeh plot p = figure(title="UMAP Projection of Sentences", width=700, height=700) p.scatter('x', 'y', color='color', source=source) hover = HoverTool(tooltips=[("Content", "@content")]) p.add_tools(hover) # Save the plot to an HTML file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") output_file(temp_file.name) save(p) return temp_file.name def gradio_interface(pdf_file, query): plot_path = generate_plot(query, pdf_file) with open(plot_path, "r") as f: html_content = f.read() return html_content iface = gr.Interface( fn=gradio_interface, inputs=[gr.File(label="Upload PDF"), gr.Textbox(label="Query")], outputs=gr.HTML(label="Visualization"), title="PDF Content Visualizer", description="Upload a PDF and enter a query to visualize the content." ) if __name__ == "__main__": iface.launch()