Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import fitz | |
| import streamlit as st | |
| import torch | |
| from PIL import Image, ImageGrab | |
| from transformers import pipeline | |
| # --- Configuration and Setup --- | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| st.set_page_config( | |
| page_title="Invoice AI | by Arif Dogan", | |
| page_icon="🧾", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| # --- Styling --- | |
| st.markdown( | |
| """ | |
| <style> | |
| .stApp {max-width: 1200px; margin: 0 auto} | |
| .stButton>button {background-color: #4CAF50; color: white; border-radius: 5px;} | |
| .stProgress>div>div {background-color: #4CAF50} | |
| footer {visibility: hidden} | |
| .high {color: #4CAF50; font-weight: bold} | |
| .medium {color: #FFA726; font-weight: bold} | |
| .low {color: #EF5350; font-weight: bold} | |
| div[data-testid="stToolbar"] {visibility: hidden; height: 0} | |
| [data-testid="stExpanderContent"] {background-color: rgba(67, 76, 94, 0.5);} | |
| .stTextInput>div>div {background-color: rgba(67, 76, 94, 0.5)} | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # --- Functions --- | |
| def load_model(): | |
| return pipeline( | |
| "document-question-answering", | |
| model="faisalraza/layoutlm-invoices", | |
| device=DEVICE, | |
| ) | |
| def process_pdf(pdf_file): | |
| pdf_content = pdf_file.read() | |
| pdf_stream = io.BytesIO(pdf_content) | |
| try: | |
| with fitz.open(stream=pdf_stream, filetype="pdf") as pdf_document: | |
| if pdf_document.page_count > 0: | |
| page = pdf_document[0] | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) | |
| img_data = pix.tobytes("png") | |
| return Image.open(io.BytesIO(img_data)), pdf_document.page_count | |
| else: | |
| raise ValueError("PDF has no pages") | |
| except Exception as e: | |
| raise e | |
| finally: | |
| pdf_stream.close() | |
| def process_image(uploaded_file): | |
| uploaded_file.seek(0) | |
| if uploaded_file.type == "application/pdf": | |
| return process_pdf(uploaded_file) | |
| return Image.open(uploaded_file), 1 | |
| def get_clipboard_image(): | |
| try: | |
| img = ImageGrab.grabclipboard() | |
| return (img, 1) if isinstance(img, Image.Image) else (None, 0) | |
| except Exception: | |
| return None, 0 | |
| def prepare_export_data(extracted_info, format_type): | |
| if format_type == "JSON": | |
| return json.dumps( | |
| {field: data["value"] for field, data in extracted_info.items()}, indent=2 | |
| ) | |
| elif format_type == "CSV": | |
| header = ",".join(extracted_info.keys()) | |
| values = ",".join(f'"{data["value"]}"' for data in extracted_info.values()) | |
| return f"{header}\n{values}" | |
| else: # TXT | |
| return "\n".join( | |
| f"{field}: {data['value']}" for field, data in extracted_info.items() | |
| ) | |
| def extract_information(model, image, questions, progress_bar, status_text): | |
| extracted_info = {} | |
| for idx, question in enumerate(questions): | |
| try: | |
| # Update progress bar and status text | |
| progress_bar.progress((idx + 1) / len(questions)) | |
| status_text.text(f"Processing: {question} ({idx + 1}/{len(questions)})") | |
| response = model(image=image, question=question) | |
| if ( | |
| response and response[0].get("answer", "").strip() | |
| ): # Check for non-empty answer | |
| answer = response[0]["answer"] | |
| confidence = response[0]["score"] | |
| if confidence > 0.1: | |
| field = ( | |
| question.replace("What is the ", "").replace("?", "").title() | |
| ) | |
| extracted_info[field] = {"value": answer, "confidence": confidence} | |
| except Exception: | |
| continue # Handle potential errors during model processing | |
| return extracted_info | |
| # --- Initialization --- | |
| if "processed_image" not in st.session_state: | |
| st.session_state.processed_image = None | |
| if "extracted_info" not in st.session_state: | |
| st.session_state.extracted_info = {} | |
| # --- UI Layout --- | |
| st.markdown( | |
| """ | |
| <div style='text-align: center; padding: 1rem;'> | |
| <h1>🧾 Invoice AI Extractor</h1> | |
| <p style='font-size: 1.2em; color: #999;'>Powered by LayoutLM</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| model = load_model() | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| uploaded_file = st.file_uploader( | |
| "Drop invoice (PDF, JPG, PNG)", type=["pdf", "jpg", "jpeg", "png"] | |
| ) | |
| with col2: | |
| st.write("Or paste from clipboard (Ctrl/Cmd + V)") | |
| check_clipboard = st.button("📎 Check Clipboard") | |
| # --- Image Processing Logic --- | |
| if uploaded_file: | |
| try: | |
| image, _ = process_image(uploaded_file) | |
| st.session_state.processed_image = image | |
| st.session_state.extracted_info = {} # Reset on new upload | |
| except Exception as e: | |
| st.error(f"Error processing file: {e}") | |
| elif check_clipboard: | |
| clipboard_image, _ = get_clipboard_image() | |
| if clipboard_image: | |
| st.session_state.processed_image = clipboard_image | |
| st.session_state.extracted_info = {} | |
| st.success("Image loaded from clipboard") | |
| else: | |
| st.warning("No image found in clipboard") | |
| # --- Display and Information Extraction --- | |
| if st.session_state.processed_image: | |
| try: | |
| image = st.session_state.processed_image.convert("RGB") | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.image(image, caption="Document", use_container_width=True) | |
| with col2: | |
| st.markdown("### 📊 Extracted Information") | |
| if not st.session_state.extracted_info: | |
| questions = [ | |
| "What is the invoice number?", | |
| "What is the invoice date?", | |
| "What is the total amount?", | |
| "What is the company name?", | |
| "What is the due date?", | |
| "What is the tax amount?", | |
| ] | |
| # Create progress bar and status text elements | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| st.session_state.extracted_info = extract_information( | |
| model, image, questions, progress_bar, status_text | |
| ) | |
| # Clear status text after completion | |
| status_text.empty() | |
| if st.session_state.extracted_info: | |
| for field, data in st.session_state.extracted_info.items(): | |
| conf_col, val_col = st.columns([1, 4]) | |
| with val_col: | |
| st.text_input( | |
| field, data["value"], disabled=True, key=f"input_{field}" | |
| ) # added key | |
| with conf_col: | |
| confidence = data["confidence"] | |
| css_class = ( | |
| "high" | |
| if confidence > 0.7 | |
| else "medium" | |
| if confidence > 0.4 | |
| else "low" | |
| ) | |
| st.markdown( | |
| f"<p class='{css_class}'>{confidence:.1%}</p>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("### 📥 Export") | |
| export_format = st.selectbox("Format", ["JSON", "CSV", "TXT"]) | |
| export_data = prepare_export_data( | |
| st.session_state.extracted_info, export_format | |
| ) | |
| file_extension = export_format.lower() | |
| st.download_button( | |
| "Download", | |
| export_data, | |
| file_name=f"invoice_data.{file_extension}", | |
| mime=f"text/{file_extension}", | |
| ) | |
| else: | |
| st.warning( | |
| "Could not extract information. Please ensure the document is clear." | |
| ) | |
| except Exception as e: | |
| st.error(f"Error during processing: {e}") | |
| # --- Footer --- | |
| st.markdown("---") | |
| st.markdown( | |
| """ | |
| <div style='text-align: center'> | |
| <p>Created by <a href='https://github.com/doganarif' target='_blank'>Arif Dogan</a> | | |
| <a href='https://huggingface.co/arifdogan' target='_blank'>🤗 Hugging Face</a></p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |