QalamV0.2 / app.py
gagan3012's picture
Update app.py
97cea74
raw
history blame
7.28 kB
import streamlit as st
from streamlit_cropper import st_cropper
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor, NougatProcessor
import torch
import re
import pytesseract
from io import BytesIO
import openai
def predict_arabic(img, model_name="UBC-NLP/Qalam"):
# if img is None:
# _,generated_text=main(image)
# return generated_text
# else:
# model_name = "UBC-NLP/Qalam"
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
images = img.convert("RGB")
pixel_values = processor(images, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values, max_length=256)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True)[0]
return generated_text
def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2"):
processor = DonutProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
image = img.convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence).strip()
return sequence
def predict_nougat(img, model_name="facebook/nougat-small"):
device="cpu"
processor = NougatProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
image = img.convert("RGB")
pixel_values = processor(image, return_tensors="pt", data_format="channels_first").pixel_values
# generate transcription (here we only generate 30 tokens)
outputs = model.generate(
pixel_values.to(device),
min_length=1,
max_new_tokens=1500,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
)
page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
# page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
return page_sequence
def predict_tesseract(img):
text = pytesseract.image_to_string(Image.open(img))
return text
st.set_option('deprecation.showfileUploaderEncoding', False)
st.set_page_config(
page_title="Ex-stream-ly Cool App",
page_icon="🖊️",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': "# This is a header. This is an *extremely* cool app!"
}
)
# Upload an image and set some options for demo purposes
st.header("Qalam: A Multilingual OCR System")
st.sidebar.header("Configuration and Image Upload")
st.sidebar.subheader("Adjust Image Enhancement Options")
img_file = st.sidebar.file_uploader(label='Upload a file', type=['png', 'jpg'])
realtime_update = st.sidebar.checkbox(label="Update in Real Time", value=True)
# box_color = st.sidebar.color_picker(label="Box Color", value='#0000FF')
aspect_choice = st.sidebar.radio(label="Aspect Ratio", options=[
"Free"])
aspect_dict = {
"Free": None
}
aspect_ratio = aspect_dict[aspect_choice]
st.sidebar.subheader("Select OCR Language and Model")
Lng = st.sidebar.selectbox(label="Language", options=[
"English", "Arabic", "French", "Korean", "Chinese"])
Models = {
"Arabic": "Qalam",
"English": "Nougat",
"French": "Tesseract",
"Korean": "Donut",
"Chinese": "Donut"
}
st.sidebar.markdown(f"### Selected Model: {Models[Lng]}")
if img_file:
img = Image.open(img_file)
if not realtime_update:
st.write("Double click to save crop")
col1, col2, col3 = st.columns(3)
with col1:
st.subheader("Input: Upload and Crop Your Image")
# Get a cropped image from the frontend
cropped_img = st_cropper(
img,
realtime_update=realtime_update,
box_color="#FF0000",
aspect_ratio=aspect_ratio,
should_resize_image=True,
)
with col2:
# Manipulate cropped image at will
st.subheader("Output: Preview and Analyze")
# _ = cropped_img.thumbnail((150, 150))
st.image(cropped_img)
button = st.button("Run OCR")
with col3:
if button:
with st.spinner('Running OCR...'):
if Lng == "Arabic":
ocr_text = predict_arabic(cropped_img)
elif Lng == "English":
ocr_text = predict_nougat(cropped_img)
elif Lng == "French":
ocr_text = predict_tesseract(cropped_img)
elif Lng == "Korean":
ocr_text = predict_english(cropped_img)
elif Lng == "Chinese":
ocr_text = predict_english(cropped_img)
st.subheader(f"OCR Results for {Lng}")
st.write(ocr_text)
text_file = BytesIO(ocr_text.encode())
st.download_button('Download Text', text_file, file_name='ocr_text.txt')
openai.api_key = ""
if "openai_model" not in st.session_state:
st.session_state["openai_model"] = "gpt-3.5-turbo"
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("How can I help?"):
st.session_state.messages.append({"role": "user", "content": ocr_text+ prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
for response in openai.ChatCompletion.create(
model=st.session_state["openai_model"],
messages=[
{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages
],
stream=True,
):
full_response += response.choices[0].delta.get("content", "")
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})