Researcher / app.py
mgbam's picture
Update app.py
3dbb4eb verified
raw
history blame
6.21 kB
import streamlit as st
from transformers import pipeline
import networkx as nx
from pyvis.network import Network
import tempfile
import openai
# ---------------------------
# Model Loading & Caching
# ---------------------------
@st.cache_resource(show_spinner=False)
def load_summarizer():
# Load a summarization pipeline from Hugging Face (using facebook/bart-large-cnn)
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
return summarizer
@st.cache_resource(show_spinner=False)
def load_text_generator():
# For a quick demo, we use a smaller text generation model (e.g., GPT-2)
generator = pipeline("text-generation", model="gpt2")
return generator
summarizer = load_summarizer()
generator = load_text_generator()
# ---------------------------
# OpenAI Based Idea Generation (Streaming)
# ---------------------------
def generate_ideas_with_openai(prompt, api_key):
openai.api_key = api_key
output_text = ""
# Create a chat completion request for streaming output
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."},
{"role": "user", "content": prompt}
],
stream=True,
)
st_text = st.empty() # Placeholder for streaming output
for chunk in response:
if 'choices' in chunk and len(chunk['choices']) > 0:
delta = chunk['choices'][0]['delta']
if 'content' in delta:
text_piece = delta['content']
output_text += text_piece
st_text.text(output_text)
return output_text
def generate_ideas_with_hf(prompt):
# Use a Hugging Face text-generation pipeline for demo purposes.
# (This may be less creative compared to GPT-3.5)
results = generator(prompt, max_length=150, num_return_sequences=1)
idea_text = results[0]['generated_text']
return idea_text
# ---------------------------
# Streamlit App Layout
# ---------------------------
st.title("Graph of AI Ideas Application")
st.sidebar.header("Configuration")
generation_mode = st.sidebar.selectbox("Select Idea Generation Mode",
["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"])
openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password")
# --- Section 1: Research Paper Input and Idea Generation ---
st.header("Research Paper Input")
paper_abstract = st.text_area("Enter the research paper abstract:", height=200)
if st.button("Generate Ideas"):
if paper_abstract.strip():
st.subheader("Summarized Abstract")
# Summarize the paper abstract to capture essential points
summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False)
summary_text = summary[0]['summary_text']
st.write(summary_text)
st.subheader("Generated Research Ideas")
# Build a prompt that combines the abstract and its summary
prompt = (
f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n"
f"Paper Abstract:\n{paper_abstract}\n\n"
f"Summary:\n{summary_text}\n\n"
f"Research Ideas:"
)
if generation_mode == "OpenAI GPT-3.5 (Streaming)":
if not openai_api_key.strip():
st.error("Please provide your OpenAI API Key in the sidebar.")
else:
with st.spinner("Generating ideas using OpenAI GPT-3.5..."):
ideas = generate_ideas_with_openai(prompt, openai_api_key)
st.write(ideas)
else:
with st.spinner("Generating ideas using Hugging Face open source model..."):
ideas = generate_ideas_with_hf(prompt)
st.write(ideas)
else:
st.error("Please enter a research paper abstract.")
# --- Section 2: Knowledge Graph Visualization ---
st.header("Knowledge Graph Visualization")
st.markdown(
"Simulate a knowledge graph by entering paper details and their citation relationships. "
"Enter details in CSV format: **PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';'). "
"Example:\n\n`1,Paper A,2;3`\n`2,Paper B,`\n`3,Paper C,2`"
)
papers_csv = st.text_area("Enter paper details in CSV format:", height=150)
if st.button("Generate Knowledge Graph"):
if papers_csv.strip():
import pandas as pd
from io import StringIO
# Process the CSV text input
data = []
for line in papers_csv.splitlines():
parts = line.split(',')
if len(parts) >= 3:
paper_id = parts[0].strip()
title = parts[1].strip()
cited = parts[2].strip()
cited_list = [c.strip() for c in cited.split(';') if c.strip()]
data.append({"paper_id": paper_id, "title": title, "cited": cited_list})
if data:
# Build a directed graph
G = nx.DiGraph()
for paper in data:
G.add_node(paper["paper_id"], title=paper["title"])
for cited in paper["cited"]:
G.add_edge(paper["paper_id"], cited)
st.subheader("Knowledge Graph")
# Create an interactive visualization using Pyvis
net = Network(height="500px", width="100%", directed=True)
for node, node_data in G.nodes(data=True):
net.add_node(node, label=node_data["title"])
for source, target in G.edges():
net.add_edge(source, target)
# Write and display the network as HTML in Streamlit
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
net.write_html(temp_file.name)
with open(temp_file.name, 'r', encoding='utf-8') as f:
html_content = f.read()
st.components.v1.html(html_content, height=500)
else:
st.error("Please enter paper details for the knowledge graph.")