Spaces:
Sleeping
Sleeping
File size: 3,431 Bytes
8749bfc 6bc525e c55b851 b7568df 70a0a43 b7568df 70a0a43 4ee3479 bc82231 8749bfc bc82231 8749bfc bc82231 8749bfc bc82231 8749bfc bc82231 8749bfc bc82231 8749bfc 70a0a43 53e97a5 8ce4fdb cdafbc0 8ce4fdb 8749bfc bc82231 9e78a3e c55b851 000e1d3 c55b851 bc82231 c55b851 bc82231 c55b851 bc82231 000e1d3 bc82231 8749bfc c55b851 8749bfc c55b851 8749bfc c55b851 8749bfc bc82231 8749bfc bc82231 8749bfc 1180a02 8749bfc bc82231 3a1a4bb 97f1c84 bc82231 8749bfc bc82231 8749bfc bc82231 8749bfc bc82231 8749bfc bc82231 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
from PIL import Image
# import pickle
import json
# import numpy as np
# from fastapi import FastAPI,Response
# from sklearn.metrics import accuracy_score, f1_score
# import prometheus_client as prom
# import pandas as pd
# import uvicorn
import os
from transformers import VisionEncoderDecoderModel,pipeline, ViTImageProcessor, AutoTokenizer
import torch
#model
# loaded_model = pickle.load(open(save_file_name, 'rb'))
# app=FastAPI()
# test_data=pd.read_csv("test.csv")
# f1_metric = prom.Gauge('death_f1_score', 'F1 score for test samples')
# Function for updating metrics
# def update_metrics():
# test = test_data.sample(20)
# X = test.iloc[:, :-1].values
# y = test['DEATH_EVENT'].values
# # test_text = test['Text'].values
# test_pred = loaded_model.predict(X)
# #pred_labels = [int(pred['label'].split("_")[1]) for pred in test_pred]
# f1 = f1_score( y , test_pred).round(3)
# #f1 = f1_score(test['labels'], pred_labels).round(3)
# f1_metric.set(f1)
with open("model/config.json") as f:
n=json.load(f)
encoder_name_or_path=n["encoder"]["_name_or_path"]
decoder_name_or_path=n["decoder"]["_name_or_path"]
print(encoder_name_or_path,decoder_name_or_path,)
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_name_or_path,decoder_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(decoder_name_or_path)
tokenizer.pad_token = tokenizer.unk_token
# feature_extractor = ViTImageProcessor.from_pretrained("model")
# cap_model = VisionEncoderDecoderModel.from_pretrained("model")
# tokenizer = AutoTokenizer.from_pretrained("model")
# device = "cuda" if torch.cuda.is_available() else "cpu"
# cap_model.to(device)
# def generate_caption(model, image, tokenizer=None):
# generated_ids = model.generate(pixel_values=inputs.pixel_values)
# print("generated_ids",generated_ids)
# if tokenizer is not None:
# print("tokenizer not null--",tokenizer)
# generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# else:
# print("tokenizer null--",tokenizer)
# generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return generated_caption
def predict_event(image):
img = Image.open(image).convert("RGB")
generated_caption = tokenizer.decode(model.generate(feature_extractor(img, return_tensors="pt").pixel_values.to("cuda"))[0])
# caption_vitgpt = generate_caption(model, image)
#caption_vitgpt = generate_caption(feature_extractor, cap_model, image, tokenizer)
return '\033[96m' +generated_caption[:85]+ '\033[0m'
# @app.get("/metrics")
# async def get_metrics():
# update_metrics()
# return Response(media_type="text/plain", content= prom.generate_latest())
title = "capstone"
description = "final capstone"
iface = gr.Interface(fn=predict_event,
# inputs=gr.inputs.Image(type="pil"),
gr.Image(type="pil"),"image",
outputs = gr.outputs.Textbox(label="Caption generated by ViT+GPT-2")
# app = gr.mount_gradio_app(app, iface, path="/")
iface.launch(server_name = "0.0.0.0", server_port = 8001)
# if __name__ == "__main__":
# Use this for debugging purposes only
# uvicorn.run(app, host="0.0.0.0", port=8001) |