Wootang01's picture
Update app.py
365113e
import streamlit as st
from multiprocessing import Process
import json
import requests
import time
import os
def start_server():
os.system("uvicorn InferenceServer:app --port 8080 --host 0.0.0.0 --workers 2")
def load_models():
if not is_port_in_use(8080):
with st.spinner(text="The model is loading."):
proc = Process(target=start_server, args=(), daemon=True)
proc.start()
while not is_port_in_use(8080):
time.sleep(1)
st.success("Model server started.")
else:
st.success("The model has loaded.")
st.session_state['models_loaded'] = True
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('0.0.0.0', port)) == 0
if 'models_loaded' not in st.session_state:
st.session_state['models_loaded'] = False
def get_correction(input_text):
st.markdown(f'##### Corrected text:')
st.write('')
correct_request = "http://0.0.0.0:8080/restore?input_sentence="+input_text
with st.spinner('Wait for it...'):
correct_response = requests.get(correct_request)
correct_json = json.loads(correct_response.text)
corrected_sentence = correct_json["corrected_sentence"]
result = diff_strings(corrected_sentence,input_text)
st.markdown(result, unsafe_allow_html=True)
def diff_strings(output_text, input_text):
c_text = ""
for x in output_text.split(" "):
if x in input_text.split(" "):
c_text = c_text + x + " "
else:
c_text = c_text + '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + x + '</span>' + " "
return c_text
if __name__ == "__main__":
st.title('Punctuation and Capitalization Corrector -- BERT')
st.markdown("Choose an example or input your own text. The machine will attempt to correct the text's punctuation and capitalization.")
examples = [
"my name is clara and i live in berkeley california",
"in 2018 cornell researchers built a high-powered detector",
"lorem ipsum has been the industrys standard dummy text ever since the 1500s when an unknown printer took a galley of type and scrambled it to make a type specimen book"
]
if not st.session_state['models_loaded']:
load_models()
input_text = st.selectbox(
label="Examples",
options=examples
)
input_text = st.text_input(
label="Write or paste text",
value=input_text
)
if input_text.strip():
get_correction(input_text)