Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import networkx as nx | |
from pyvis.network import Network | |
import tempfile | |
import openai | |
# --------------------------- | |
# Model Loading & Caching | |
# --------------------------- | |
def load_summarizer(): | |
# Load a summarization pipeline from Hugging Face (using facebook/bart-large-cnn) | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
return summarizer | |
def load_text_generator(): | |
# For a quick demo, we use a smaller text generation model (e.g., GPT-2) | |
generator = pipeline("text-generation", model="gpt2") | |
return generator | |
summarizer = load_summarizer() | |
generator = load_text_generator() | |
# --------------------------- | |
# OpenAI Based Idea Generation (Streaming) | |
# --------------------------- | |
def generate_ideas_with_openai(prompt, api_key): | |
openai.api_key = api_key | |
output_text = "" | |
# Create a chat completion request for streaming output | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."}, | |
{"role": "user", "content": prompt} | |
], | |
stream=True, | |
) | |
st_text = st.empty() # Placeholder for streaming output | |
for chunk in response: | |
if 'choices' in chunk and len(chunk['choices']) > 0: | |
delta = chunk['choices'][0]['delta'] | |
if 'content' in delta: | |
text_piece = delta['content'] | |
output_text += text_piece | |
st_text.text(output_text) | |
return output_text | |
def generate_ideas_with_hf(prompt): | |
# Use a Hugging Face text-generation pipeline for demo purposes. | |
# (This may be less creative compared to GPT-3.5) | |
results = generator(prompt, max_length=150, num_return_sequences=1) | |
idea_text = results[0]['generated_text'] | |
return idea_text | |
# --------------------------- | |
# Streamlit App Layout | |
# --------------------------- | |
st.title("Graph of AI Ideas Application") | |
st.sidebar.header("Configuration") | |
generation_mode = st.sidebar.selectbox("Select Idea Generation Mode", | |
["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"]) | |
openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password") | |
# --- Section 1: Research Paper Input and Idea Generation --- | |
st.header("Research Paper Input") | |
paper_abstract = st.text_area("Enter the research paper abstract:", height=200) | |
if st.button("Generate Ideas"): | |
if paper_abstract.strip(): | |
st.subheader("Summarized Abstract") | |
# Summarize the paper abstract to capture essential points | |
summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False) | |
summary_text = summary[0]['summary_text'] | |
st.write(summary_text) | |
st.subheader("Generated Research Ideas") | |
# Build a prompt that combines the abstract and its summary | |
prompt = ( | |
f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n" | |
f"Paper Abstract:\n{paper_abstract}\n\n" | |
f"Summary:\n{summary_text}\n\n" | |
f"Research Ideas:" | |
) | |
if generation_mode == "OpenAI GPT-3.5 (Streaming)": | |
if not openai_api_key.strip(): | |
st.error("Please provide your OpenAI API Key in the sidebar.") | |
else: | |
with st.spinner("Generating ideas using OpenAI GPT-3.5..."): | |
ideas = generate_ideas_with_openai(prompt, openai_api_key) | |
st.write(ideas) | |
else: | |
with st.spinner("Generating ideas using Hugging Face open source model..."): | |
ideas = generate_ideas_with_hf(prompt) | |
st.write(ideas) | |
else: | |
st.error("Please enter a research paper abstract.") | |
# --- Section 2: Knowledge Graph Visualization --- | |
st.header("Knowledge Graph Visualization") | |
st.markdown( | |
"Simulate a knowledge graph by entering paper details and their citation relationships. " | |
"Enter details in CSV format: **PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';'). " | |
"Example:\n\n`1,Paper A,2;3`\n`2,Paper B,`\n`3,Paper C,2`" | |
) | |
papers_csv = st.text_area("Enter paper details in CSV format:", height=150) | |
if st.button("Generate Knowledge Graph"): | |
if papers_csv.strip(): | |
import pandas as pd | |
from io import StringIO | |
# Process the CSV text input | |
data = [] | |
for line in papers_csv.splitlines(): | |
parts = line.split(',') | |
if len(parts) >= 3: | |
paper_id = parts[0].strip() | |
title = parts[1].strip() | |
cited = parts[2].strip() | |
cited_list = [c.strip() for c in cited.split(';') if c.strip()] | |
data.append({"paper_id": paper_id, "title": title, "cited": cited_list}) | |
if data: | |
# Build a directed graph | |
G = nx.DiGraph() | |
for paper in data: | |
G.add_node(paper["paper_id"], title=paper["title"]) | |
for cited in paper["cited"]: | |
G.add_edge(paper["paper_id"], cited) | |
st.subheader("Knowledge Graph") | |
# Create an interactive visualization using Pyvis | |
net = Network(height="500px", width="100%", directed=True) | |
for node, node_data in G.nodes(data=True): | |
net.add_node(node, label=node_data["title"]) | |
for source, target in G.edges(): | |
net.add_edge(source, target) | |
# Write and display the network as HTML in Streamlit | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") | |
net.write_html(temp_file.name) | |
with open(temp_file.name, 'r', encoding='utf-8') as f: | |
html_content = f.read() | |
st.components.v1.html(html_content, height=500) | |
else: | |
st.error("Please enter paper details for the knowledge graph.") | |