# import gradio as gr | |
# import spaces | |
# from PIL import Image | |
# import torch | |
# from transformers import AutoModelForCausalLM, AutoProcessor | |
# import requests | |
# import json | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device) | |
# processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True) | |
# SERVER_URL = 'http://43.156.72.113:8188' | |
# FETCH_TASKS_URL = SERVER_URL + '/fetch/' | |
# UPDATE_TASK_STATUS_URL = SERVER_URL + '/update/' | |
# def fetch_task(category, fetch_all=False): | |
# params = {'fetch_all': 'true' if fetch_all else 'false'} | |
# response = requests.post(FETCH_TASKS_URL + category, params=params) | |
# if response.status_code == 200: | |
# return response.json() | |
# else: | |
# print(f"Failed to fetch tasks: {response.status_code} - {response.text}") | |
# return None | |
# def update_task_status(category, task_id, status, result=None): | |
# data = {'status': status} | |
# if result: | |
# data['result'] = result | |
# response = requests.post(UPDATE_TASK_STATUS_URL + category + f'/{task_id}', json=data) | |
# if response.status_code == 200: | |
# print(f"Task {task_id} updated successfully: {json.dumps(response.json(), indent=4)}") | |
# else: | |
# print(f"Failed to update task {task_id}: {response.status_code} - {response.text}") | |
# @spaces.GPU(duration=200) | |
# def infer(prompt, image, request: gr.Request): | |
# if request: | |
# print("请求头字典:", request.headers) | |
# print("IP 地址:", request.client.host) | |
# print("查询参数:", dict(request.query_params)) | |
# print("会话哈希:", request.session_hash) | |
# max_size = 256 | |
# width, height = image.size | |
# if width > height: | |
# new_width = max_size | |
# new_height = int((new_width / width) * height) | |
# else: | |
# new_height = max_size | |
# new_width = int((new_height / height) * width) | |
# image = image.resize((new_width, new_height), Image.LANCZOS) | |
# inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
# generated_ids = model.generate( | |
# input_ids=inputs["input_ids"], | |
# pixel_values=inputs["pixel_values"], | |
# max_new_tokens=1024, | |
# do_sample=False, | |
# num_beams=3 | |
# ) | |
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
# parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) | |
# return parsed_answer | |
# css = """ | |
# #col-container { | |
# margin: 0 auto; | |
# max-width: 800px; | |
# } | |
# """ | |
# with gr.Blocks(css=css) as app: | |
# with gr.Column(elem_id="col-container"): | |
# gr.Markdown(f"""# Tag The Image | |
# Get tag based on images using the Florence-2-base-PromptGen-v1.5 model. | |
# """) | |
# with gr.Row(): | |
# prompt = gr.Text( | |
# label="Prompt", | |
# show_label=False, | |
# max_lines=1, | |
# placeholder="Enter your prompt or blank here.", | |
# container=False, | |
# ) | |
# image_input = gr.Image( | |
# label="Image", | |
# type="pil", | |
# show_label=False, | |
# container=False, | |
# ) | |
# run_button = gr.Button("Run", scale=0) | |
# result = gr.Textbox(label="Generated Text", show_label=False) | |
# gr.on( | |
# triggers=[run_button.click, prompt.submit], | |
# fn=infer, | |
# inputs=[prompt, image_input], | |
# outputs=[result] | |
# ) | |
# app.queue() | |
# app.launch(show_error=True) | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
import torch | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
import requests | |
import json | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True) | |
SERVER_URL = 'http://43.156.72.113:8188' | |
FETCH_TASKS_URL = SERVER_URL + '/fetch/' | |
UPDATE_TASK_STATUS_URL = SERVER_URL + '/update/' | |
def fetch_task(category, fetch_all=False): | |
params = {'fetch_all': 'true' if fetch_all else 'false'} | |
response = requests.post(FETCH_TASKS_URL + category, params=params) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
print(f"Failed to fetch tasks: {response.status_code} - {response.text}") | |
return None | |
def update_task_status(category, task_id, status, result=None): | |
data = {'status': status} | |
if result: | |
data['result'] = result | |
response = requests.post(UPDATE_TASK_STATUS_URL + category + f'/{task_id}', json=data) | |
if response.status_code == 200: | |
print(f"Task {task_id} updated successfully: {json.dumps(response.json(), indent=4)}") | |
else: | |
print(f"Failed to update task {task_id}: {response.status_code} - {response.text}") | |
def infer(request: gr.Request): | |
if request: | |
print("请求头字典:", request.headers) | |
print("IP 地址:", request.client.host) | |
print("查询参数:", dict(request.query_params)) | |
print("会话哈希:", request.session_hash) | |
# Fetch tasks | |
img2text_tasks = fetch_task('img2text', fetch_all=True) | |
if not img2text_tasks: | |
return "No tasks found or failed to fetch tasks." | |
for task in img2text_tasks: | |
try: | |
image_url = task['content']['url'] | |
prompt = task['content']['prompt'] | |
print(image_url) | |
print(prompt) | |
# Fetch the image from the URL | |
image_response = requests.get(image_url) | |
image = Image.open(BytesIO(image_response.content)) | |
# Resize image | |
max_size = 256 | |
width, height = image.size | |
if width > height: | |
new_width = max_size | |
new_height = int((new_width / width) * height) | |
else: | |
new_height = max_size | |
new_width = int((new_height / height) * width) | |
image = image.resize((new_width, new_height), Image.LANCZOS) | |
# Process the image and prompt | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
do_sample=False, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) | |
print(task['id']) | |
print(parsed_answer) | |
# Update the task status to Successed with result | |
update_task_status('img2text', task['id'], 'Successed', {"text": parsed_answer}) | |
except Exception as e: | |
print(f"Error processing task {task['id']}: {e}") | |
# If error occurs, update the task status to Failed | |
update_task_status('img2text', task['id'], 'Failed') | |
return f"Error processing task {task['id']}: {e}" | |
return "No pending tasks found." | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 800px; | |
} | |
""" | |
with gr.Blocks(css=css) as app: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f"""# Tag The Image | |
Get tag based on images using the Florence-2-base-PromptGen-v1.5 model. | |
""") | |
run_button = gr.Button("Run", scale=0) | |
result = gr.Textbox(label="Generated Text", show_label=False) | |
gr.on( | |
triggers=[run_button.click], | |
fn=infer, | |
inputs=[], | |
outputs=[result] | |
) | |
app.queue() | |
app.launch(show_error=True) | |