File size: 3,512 Bytes
a8d4e3d
 
 
 
8387173
f310b8b
 
 
106ef8f
f310b8b
 
2ad51f4
a8d4e3d
f310b8b
 
adf54e4
f310b8b
adf54e4
 
 
f310b8b
 
a8d4e3d
f310b8b
 
 
 
 
 
 
 
 
a8d4e3d
f310b8b
 
 
 
 
 
a8d4e3d
 
f310b8b
 
a8d4e3d
106ef8f
a8d4e3d
f310b8b
a8d4e3d
f310b8b
a8d4e3d
106ef8f
f310b8b
 
1034007
f310b8b
1034007
f310b8b
 
106ef8f
 
 
 
a8d4e3d
106ef8f
 
 
 
 
f310b8b
 
a8d4e3d
 
 
5727aa4
da318a3
 
 
f310b8b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
import pandas as pd
import numpy as np

from src.Surveyor import Surveyor
from streamlit_tags import st_tags_sidebar


@st.experimental_singleton(show_spinner=True, suppress_st_warning=True)
def get_surveyor_instance(_print_fn, _survey_print_fn):
     with st.spinner('Loading The-Surveyor ...'):
        return Surveyor(print_fn=_print_fn, survey_print_fn=_survey_print_fn, high_gpu=True)


def run_survey(surveyor, download_placeholder, research_keywords=None, arxiv_ids=None, max_search=None, num_papers=None):
    zip_file_name, survey_file_name = surveyor.survey(research_keywords, 
                                                      arxiv_ids,
                                                      max_search=max_search, 
                                                      num_papers=num_papers
                                                     )
    show_survey_download(zip_file_name, survey_file_name, download_placeholder)


def show_survey_download(zip_file_name, survey_file_name, download_placeholder):
    download_placeholder.empty()
    with download_placeholder.container():
        with open(str(zip_file_name), "rb") as file:
            btn = st.download_button(
                label="Download extracted topic-clustered-highlights, images and tables as zip",
                data=file,
                file_name=str(zip_file_name)
            )

        with open(str(survey_file_name), "rb") as file:
            btn = st.download_button(
                label="Download detailed generated survey file",
                data=file,
                file_name=str(survey_file_name)
            )


def survey_space(surveyor, download_placeholder):

    form = st.sidebar.form(key='survey_form')
    research_keywords = form.text_input("What would you like to research in today?", key='research_keywords' value='')
    max_search = form.number_input("num_papers_to_search", help="maximium number of papers to glance through - defaults to 20", 
                             min_value=1, max_value=50, value=10, step=1, key='max_search')
    num_papers = form.number_input("num_papers_to_select", help="maximium number of papers to select and analyse - defaults to 8",
                             min_value=1, max_value=8, value=2, step=1, key='num_papers')

    form.write('or')

    arxiv_ids = st_tags_sidebar(
                label='# Enter arxiv ids for your curated set of papers:',
                value=[],
                text='Press enter to add more (e.g. 1605.08386v1, ...)',
                maxtags = 6,
                key='arxiv_ids')
                
    submit = form.form_submit_button('Submit')
    
    run_kwargs = {'surveyor':surveyor, 'download_placeholder':download_placeholder}
    if submit:
        if research_keywords != '':
            run_kwargs.update({'research_keywords':research_keywords, 'max_search':max_search, 'num_papers':num_papers})
        elif len(arxiv_ids):
            run_kwargs.update({'arxiv_ids':arxiv_ids})
        run_survey(**run_kwargs)




if __name__ == '__main__':
    st.title('Auto-Research V0.1 - Automated Survey generation from research keywords')
    std_col, survey_col = st.columns(2)
    std_col.header('execution log:')
    survey_col.header('Generated_survey:')
    download_placeholder = survey_col.container()
    download_placeholder = st.empty()
    surveyor_obj = get_surveyor_instance(_print_fn=std_col.write, _survey_print_fn=survey_col.write)
    survey_space(surveyor_obj, survey_col)