Spaces:
Sleeping
Sleeping
# import gradio as gr | |
import os | |
import torch | |
from transformers import AutoProcessor, MllamaForConditionalGeneration, TextIteratorStreamer | |
from PIL import Image | |
import spaces | |
import tempfile | |
import requests | |
from PyPDF2 import PdfReader | |
from threading import Thread | |
from flask import Flask, request, jsonify | |
import io | |
import fitz | |
# Check if we're running in a Hugging Face Space and if SPACES_ZERO_GPU is enabled | |
# IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1" | |
# IS_SPACE = os.environ.get("SPACE_ID", None) is not None | |
# Determine the device (GPU if available, else CPU) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1" | |
print(f"Using device: {device}") | |
print(f"Low memory mode: {LOW_MEMORY}") | |
app = Flask(__name__) | |
# Get Hugging Face token from environment variables | |
HF_TOKEN = os.environ.get('HF_TOKEN') | |
# Load the model and processor | |
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" | |
model = MllamaForConditionalGeneration.from_pretrained( | |
model_name, | |
use_auth_token=HF_TOKEN, | |
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, | |
device_map="auto" if device == "cuda" else None, # Use device mapping if CUDA is available | |
) | |
# Move the model to the appropriate device (GPU if available) | |
# model.to(device) | |
processor = AutoProcessor.from_pretrained(model_name, use_auth_token=HF_TOKEN) | |
def extract_image_from_pdf(pdf_url, dpi=75): | |
""" | |
Extract first page of PDF as image in memory | |
Args: | |
pdf_url (str): URL of PDF | |
dpi (int): Image resolution | |
Returns: | |
PIL.Image: First page as image or None | |
""" | |
try: | |
# Download PDF | |
response = requests.get(pdf_url, timeout=30) | |
response.raise_for_status() | |
# Open PDF from bytes | |
pdf_document = fitz.open(stream=response.content, filetype="pdf") | |
# Get first page | |
first_page = pdf_document[0] | |
# Render page to pixmap | |
pix = first_page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72)) | |
# Convert to PIL Image | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
pdf_document.close() | |
return img | |
except Exception as e: | |
print(f"Error extracting first page: {e}") | |
return None | |
def predict_image(image_url, text, file_pref): | |
try: | |
# Download the image from the URL | |
# response = requests.get(image_url) | |
# response.raise_for_status() # Raise an error for invalid responses | |
# image = Image.open(io.BytesIO(response.content)).convert("RGB") | |
if file_pref == 'img': | |
response = requests.get(image_url) | |
response.raise_for_status() # Raise an error for invalid responses | |
image = Image.open(io.BytesIO(response.content)).convert("RGB") | |
else: | |
image = extract_image_from_pdf(image_url) | |
messages = [ | |
{"role": "user", "content": [ | |
{"type": "image"}, # Specify that an image is provided | |
{"type": "text", "text": text} # Add the user-provided text input | |
]} | |
] | |
# Create the input text using the processor's chat template | |
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
# Process the inputs and move to the appropriate device | |
inputs = processor(image, input_text, return_tensors="pt").to(device) | |
# outputs = model.generate(**inputs, max_new_tokens=100) | |
# # Decode the output to return the final response | |
# response = processor.decode(outputs[0], skip_special_tokens=True) | |
# return response | |
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=4096) | |
generated_text = "" | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
# generated_text_without_prompt = buffer | |
# # time.sleep(0.01) | |
# yield buffer | |
return buffer | |
except Exception as e: | |
raise ValueError(f"Error during prediction: {str(e)}") | |
def extract_text_from_pdf(pdf_url): | |
try: | |
response = requests.get(pdf_url) | |
response.raise_for_status() | |
with tempfile.NamedTemporaryFile(delete=False) as temp_pdf: | |
temp_pdf.write(response.content) | |
temp_pdf_path = temp_pdf.name | |
reader = PdfReader(temp_pdf_path) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
os.remove(temp_pdf_path) | |
return text | |
except Exception as e: | |
raise ValueError(f"Error extracting text from PDF: {str(e)}") | |
# raise HTTPException(status_code=400, detail=f"Error extracting text from PDF: {str(e)}") | |
def predict_text(text): | |
# pdf_text = extract_text_from_pdf('https://arinsight.co/2024_FA_AEC_1200_GR1_GR2.pdf') | |
text_combined = text # + "\n\nExtracted Text from PDF:\n" + pdf_text | |
# Prepare the input messages | |
messages = [{"role": "user", "content": [{"type": "text", "text": text_combined}]}] | |
# Create the input text using the processor's chat template | |
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
# Process the inputs and move to the appropriate device | |
# inputs = processor(image, input_text, return_tensors="pt").to(device) | |
inputs = processor(text=input_text, return_tensors="pt").to("cuda") | |
# Generate a response from the model | |
# outputs = model.generate(**inputs, max_new_tokens=1024) | |
# # Decode the output to return the final response | |
# response = processor.decode(outputs[0], skip_special_tokens=True, skip_prompt=True) | |
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048) | |
generated_text = "" | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
# generated_text_without_prompt = buffer | |
# # time.sleep(0.01) | |
# yield buffer | |
return buffer | |
PROMPT = ( | |
"Extract the following information as per this format:\n" | |
"'Course Code:'\n" | |
"'Course Name:'\n" | |
"'Course Description:'\n" | |
"'Course Credits:'\n" | |
"'Course Learning Outcomes:'\n" | |
"'Delivery Method:'\n" | |
"'Prerequisite(s):'\n" | |
"'Co-requisite(s):'\n" | |
"'Materials:'\n" | |
"'Topical Outline:'\n" | |
"Do not add anything else except the required information from this text." | |
) | |
PROMPT_SKILLS = ( | |
"Provide skills based on the Lightcast Open Skills Taxonomy in categories as:\n" | |
"'Primary Skills' (the degree program or certification),\n" | |
"'Secondary Skills', and\n" | |
"'Tertiary Skills'." | |
) | |
# PROMPT_IMAGE = ( | |
# "You are a highly intelligent assistant designed to analyze images and extract structured information from them. " | |
# "Your task is to analyze the given image of a student's academic record and generate a response in the exact JSON format provided below. " | |
# "If any specific information is missing or unavailable in the image, replace the corresponding field with null. " | |
# "Ensure the format is consistent, strictly adhering to the structure shown below.\n\n" | |
# "Required JSON Format:\n\n" | |
# "{\n" | |
# ' "student": {\n' | |
# ' "name": "string",\n' | |
# ' "id": "string",\n' | |
# ' "dob": "string",\n' | |
# ' "original_start_date": "string",\n' | |
# ' "cumulative_gpa": "string",\n' | |
# ' "program": "string",\n' | |
# ' "status": "string"\n' | |
# ' },\n' | |
# ' "courses": [\n' | |
# ' {\n' | |
# ' "transfer_institution": "string",\n' | |
# ' "course_code": "string",\n' | |
# ' "course_name": "string",\n' | |
# ' "credits_attempted": number,\n' | |
# ' "credits_earned": number,\n' | |
# ' "grade": "string",\n' | |
# ' "quality_points": number,\n' | |
# ' "semester_code": "string",\n' | |
# ' "semester_dates": "string"\n' | |
# ' }\n' | |
# " // Additional courses can be added here\n" | |
# " ]\n" | |
# "}\n\n" | |
# "Instructions:\n\n" | |
# "1. Extract the student information and course details as displayed in the image.\n" | |
# "2. Use null for any missing or unavailable information.\n" | |
# "3. Format the extracted data exactly as shown above. Do not deviate from this structure.\n" | |
# "4. Use accurate field names and ensure proper nesting of data (e.g., 'student' and 'courses' sections).\n" | |
# "5. The values for numeric fields like credits_attempted, credits_earned, and quality_points should be numbers (not strings).\n" | |
# ) | |
PROMPT_IMAGE_STUDENT = ( | |
"You are a highly intelligent assistant designed to analyze images and extract structured information from them. " | |
"Your task is to analyze the given image of a student's academic record and generate a response in the exact JSON format provided below. " | |
"If any specific information is missing or unavailable in the image, replace the corresponding field with null. " | |
"Ensure the format is consistent, strictly adhering to the structure shown below.\n\n" | |
"Required JSON Format:\n\n" | |
"{\n" | |
' "student": {\n' | |
' "name": "string",\n' | |
' "id": "string",\n' | |
' "dob": "string",\n' | |
' "original_start_date": "string",\n' | |
' "cumulative_gpa": "string",\n' | |
' "program": "string",\n' | |
' "status": "string"\n' | |
' }\n' | |
"}\n\n" | |
"Instructions:\n\n" | |
"1. Extract the student's general information as displayed in the image.\n" | |
"2. Use null for any missing or unavailable information.\n" | |
"3. Format the extracted data exactly as shown above. Do not deviate from this structure.\n" | |
"4. Ensure accurate field names and proper nesting.\n" | |
"5. Return only the 'student' section as JSON.\n" | |
) | |
PROMPT_IMAGE_COURSES = ( | |
"You are a highly intelligent assistant designed to analyze images and extract structured information from them. " | |
"Your task is to analyze the given image of a student's academic record and generate a response in the exact JSON format provided below. " | |
"If any specific information is missing or unavailable in the image, replace the corresponding field with null. " | |
"Ensure the format is consistent, strictly adhering to the structure shown below.\n\n" | |
"Required JSON Format:\n\n" | |
"{\n" | |
' "courses": [\n' | |
' {\n' | |
' "transfer_institution": "string",\n' | |
' "course_code": "string",\n' | |
' "course_name": "string",\n' | |
' "credits_attempted": number,\n' | |
' "credits_earned": number,\n' | |
' "grade": "string",\n' | |
' "quality_points": number,\n' | |
' "semester_code": "string",\n' | |
' "semester_dates": "string"\n' | |
' }\n' | |
" // Additional courses can be added here\n" | |
" ]\n" | |
"}\n\n" | |
"Instructions:\n\n" | |
"1. Extract the course details as displayed in the image.\n" | |
"2. Use null for any missing or unavailable information.\n" | |
"3. Format the extracted data exactly as shown above. Do not deviate from this structure.\n" | |
"4. Ensure accurate field names and proper nesting.\n" | |
"5. Return only the 'courses' section as JSON.\n" | |
) | |
def home(): | |
return jsonify({"message": "Welcome to the PDF Extraction API. Use the /extract endpoint to extract information."}) | |
def favicon(): | |
return "", 204 | |
def extract_info(): | |
data = request.json | |
if not data: | |
return jsonify({"error": "Please provide a PDF URL in the request body."}), 400 | |
try: | |
if data["url"] is not None: | |
pdf_url = data["url"] | |
pdf_text = extract_text_from_pdf(pdf_url) | |
prompt = f"{PROMPT}\n\n{pdf_text}" | |
response = predict_text(prompt) | |
else: | |
response = '' | |
if data["skills"] == True: | |
if response: | |
prompt_skills = f"{PROMPT_SKILLS} using this information only -- {response}" | |
response_skills = predict_text(prompt_skills) | |
else: | |
response_skills = '' | |
else: | |
response_skills = '' | |
if data["img_url"] is not None: | |
prompt_student = f"{PROMPT_IMAGE_STUDENT}\n" | |
prompt_courses = f"{PROMPT_IMAGE_COURSES}\n" | |
img_url = data["img_url"] | |
file_pref = data["file_pref"] | |
response_student = predict_image(img_url, prompt_student, file_pref) | |
response_courses = predict_image(img_url, prompt_courses, file_pref) | |
response_image = response_student + response_courses | |
# response_image = {"student": response_student.get("student", {}), "courses": response_courses.get("courses", [])} | |
else: | |
response_image = '' | |
return jsonify({"extracted_info": response + "\n" + response_skills + "\n" + response_image}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |
# # Define the Gradio interface | |
# interface = gr.Interface( | |
# fn=predict_text, | |
# inputs=[ | |
# # gr.Image(type="pil", label="Image Input"), # Image input with label | |
# gr.Textbox(label="Text Input") # Textbox input with label | |
# ], | |
# outputs=gr.Textbox(label="Generated Response"), # Output with a more descriptive label | |
# title="Llama 3.2 11B Vision Instruct Demo", # Title of the interface | |
# description="This demo uses Meta's Llama 3.2 11B Vision model to generate responses based on an image and text input.", # Short description | |
# theme="compact" # Using a compact theme for a cleaner look | |
# ) | |
# # Launch the interface | |
# interface.launch(debug=True) |