img2text / app.py
Nerva1228's picture
Update app.py
d4e2f52 verified
raw
history blame
8.15 kB
# import gradio as gr
# import spaces
# from PIL import Image
# import torch
# from transformers import AutoModelForCausalLM, AutoProcessor
# import requests
# import json
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
# processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
# SERVER_URL = 'http://43.156.72.113:8188'
# FETCH_TASKS_URL = SERVER_URL + '/fetch/'
# UPDATE_TASK_STATUS_URL = SERVER_URL + '/update/'
# def fetch_task(category, fetch_all=False):
# params = {'fetch_all': 'true' if fetch_all else 'false'}
# response = requests.post(FETCH_TASKS_URL + category, params=params)
# if response.status_code == 200:
# return response.json()
# else:
# print(f"Failed to fetch tasks: {response.status_code} - {response.text}")
# return None
# def update_task_status(category, task_id, status, result=None):
# data = {'status': status}
# if result:
# data['result'] = result
# response = requests.post(UPDATE_TASK_STATUS_URL + category + f'/{task_id}', json=data)
# if response.status_code == 200:
# print(f"Task {task_id} updated successfully: {json.dumps(response.json(), indent=4)}")
# else:
# print(f"Failed to update task {task_id}: {response.status_code} - {response.text}")
# @spaces.GPU(duration=200)
# def infer(prompt, image, request: gr.Request):
# if request:
# print("请求头字典:", request.headers)
# print("IP 地址:", request.client.host)
# print("查询参数:", dict(request.query_params))
# print("会话哈希:", request.session_hash)
# max_size = 256
# width, height = image.size
# if width > height:
# new_width = max_size
# new_height = int((new_width / width) * height)
# else:
# new_height = max_size
# new_width = int((new_height / height) * width)
# image = image.resize((new_width, new_height), Image.LANCZOS)
# inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
# generated_ids = model.generate(
# input_ids=inputs["input_ids"],
# pixel_values=inputs["pixel_values"],
# max_new_tokens=1024,
# do_sample=False,
# num_beams=3
# )
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
# return parsed_answer
# css = """
# #col-container {
# margin: 0 auto;
# max-width: 800px;
# }
# """
# with gr.Blocks(css=css) as app:
# with gr.Column(elem_id="col-container"):
# gr.Markdown(f"""# Tag The Image
# Get tag based on images using the Florence-2-base-PromptGen-v1.5 model.
# """)
# with gr.Row():
# prompt = gr.Text(
# label="Prompt",
# show_label=False,
# max_lines=1,
# placeholder="Enter your prompt or blank here.",
# container=False,
# )
# image_input = gr.Image(
# label="Image",
# type="pil",
# show_label=False,
# container=False,
# )
# run_button = gr.Button("Run", scale=0)
# result = gr.Textbox(label="Generated Text", show_label=False)
# gr.on(
# triggers=[run_button.click, prompt.submit],
# fn=infer,
# inputs=[prompt, image_input],
# outputs=[result]
# )
# app.queue()
# app.launch(show_error=True)
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import requests
import json
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
SERVER_URL = 'http://43.156.72.113:8188'
FETCH_TASKS_URL = SERVER_URL + '/fetch/'
UPDATE_TASK_STATUS_URL = SERVER_URL + '/update/'
def fetch_task(category, fetch_all=False):
params = {'fetch_all': 'true' if fetch_all else 'false'}
response = requests.post(FETCH_TASKS_URL + category, params=params)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to fetch tasks: {response.status_code} - {response.text}")
return None
def update_task_status(category, task_id, status, result=None):
data = {'status': status}
if result:
data['result'] = result
response = requests.post(UPDATE_TASK_STATUS_URL + category + f'/{task_id}', json=data)
if response.status_code == 200:
print(f"Task {task_id} updated successfully: {json.dumps(response.json(), indent=4)}")
else:
print(f"Failed to update task {task_id}: {response.status_code} - {response.text}")
@spaces.GPU(duration=150)
def infer(request: gr.Request):
if request:
print("请求头字典:", request.headers)
print("IP 地址:", request.client.host)
print("查询参数:", dict(request.query_params))
print("会话哈希:", request.session_hash)
# Fetch tasks
img2text_tasks = fetch_task('img2text', fetch_all=True)
if not img2text_tasks:
return "No tasks found or failed to fetch tasks."
for task in img2text_tasks:
try:
image_url = task['content']['url']
prompt = task['content']['prompt']
print(image_url)
print(prompt)
# Fetch the image from the URL
image_response = requests.get(image_url)
image = Image.open(BytesIO(image_response.content))
# Resize image
max_size = 256
width, height = image.size
if width > height:
new_width = max_size
new_height = int((new_width / width) * height)
else:
new_height = max_size
new_width = int((new_height / height) * width)
image = image.resize((new_width, new_height), Image.LANCZOS)
# Process the image and prompt
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
print(task['id'])
print(parsed_answer)
# Update the task status to Successed with result
update_task_status('img2text', task['id'], 'Successed', {"text": parsed_answer})
except Exception as e:
print(f"Error processing task {task['id']}: {e}")
# If error occurs, update the task status to Failed
update_task_status('img2text', task['id'], 'Failed')
return f"Error processing task {task['id']}: {e}"
return "No pending tasks found."
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
"""
with gr.Blocks(css=css) as app:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# Tag The Image
Get tag based on images using the Florence-2-base-PromptGen-v1.5 model.
""")
run_button = gr.Button("Run", scale=0)
result = gr.Textbox(label="Generated Text", show_label=False)
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[],
outputs=[result]
)
app.queue()
app.launch(show_error=True)