File size: 3,110 Bytes
8749bfc
6bc525e
c55b851
b7568df
70a0a43
8515673
b7568df
 
 
8515673
b7568df
70a0a43
4ee3479
bc82231
8749bfc
 
 
 
bc82231
 
 
 
8749bfc
bc82231
8749bfc
 
bc82231
8749bfc
 
bc82231
 
 
 
8749bfc
bc82231
 
 
 
 
 
 
 
 
8749bfc
70a0a43
53e97a5
 
 
8ce4fdb
 
cdafbc0
 
8ce4fdb
 
8749bfc
bc82231
9e78a3e
c55b851
000e1d3
 
 
d296284
c55b851
 
 
 
 
 
 
 
 
8dad981
c55b851
 
 
 
 
bc82231
c55b851
 
 
 
 
 
 
 
 
bc82231
c55b851
bc82231
5a4b53d
 
 
 
000e1d3
bc82231
8dad981
8749bfc
b205d7e
8749bfc
 
 
 
bc82231
 
 
 
8749bfc
 
 
bc82231
 
8749bfc
1180a02
dd83231
8749bfc
f70fd1c
b2ab744
356967f
f70fd1c
97f1c84
8749bfc
5a4b53d
83ed194
 
bc82231
8749bfc
5ef6085
5a4b53d
bc82231
8749bfc
bc82231
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr

from PIL import Image
# import pickle
import json
import numpy as np
# from fastapi import FastAPI,Response
# from sklearn.metrics import accuracy_score, f1_score
# import prometheus_client as prom
import pandas as pd
# import uvicorn
import os
from transformers import VisionEncoderDecoderModel,pipeline, ViTImageProcessor, AutoTokenizer
import torch


#model

# loaded_model = pickle.load(open(save_file_name, 'rb'))

# app=FastAPI()


# test_data=pd.read_csv("test.csv")


# f1_metric = prom.Gauge('death_f1_score', 'F1 score for test samples')

# Function for updating metrics
# def update_metrics():
#     test = test_data.sample(20)
#     X = test.iloc[:, :-1].values
#     y = test['DEATH_EVENT'].values
    
#     # test_text = test['Text'].values
#     test_pred = loaded_model.predict(X)
#     #pred_labels = [int(pred['label'].split("_")[1]) for pred in test_pred]

#     f1 = f1_score( y , test_pred).round(3)

#     #f1 = f1_score(test['labels'], pred_labels).round(3)

#     f1_metric.set(f1)



with open("model/config.json") as f:
    n=json.load(f)
    encoder_name_or_path=n["encoder"]["_name_or_path"]
    decoder_name_or_path=n["decoder"]["_name_or_path"]


print(encoder_name_or_path,decoder_name_or_path,)
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_name_or_path,decoder_name_or_path)


tokenizer = AutoTokenizer.from_pretrained(decoder_name_or_path)
tokenizer.pad_token = tokenizer.unk_token



feature_extractor = ViTImageProcessor.from_pretrained(encoder_name_or_path)









device = "cuda" if torch.cuda.is_available() else "cpu"

# cap_model.to(device)

# def generate_caption(model, image, tokenizer=None):

    
#     generated_ids = model.generate(pixel_values=inputs.pixel_values)
#     print("generated_ids",generated_ids)

#     if tokenizer is not None:
#         print("tokenizer not null--",tokenizer)
#         generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
#     else:
#         print("tokenizer null--",tokenizer)
#         generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
   
#     return generated_caption





def predict_event(image):
    
    generated_caption = tokenizer.decode(model.generate(feature_extractor(image, return_tensors="pt").pixel_values.to(device))[0])

    return '\033[96m' +generated_caption+ '\033[0m'




# @app.get("/metrics")
# async def get_metrics():
#     update_metrics()
#     return Response(media_type="text/plain", content= prom.generate_latest())



title = "capstone"
description = "final capstone"


# inputs=gr.inputs.Image(type="pil")

iface = gr.Interface(predict_event,
                         inputs=["image"],
                         # gr.Image(type="pil"),
                         outputs=["text"] )
                        

iface.launch()


# app = gr.mount_gradio_app(app, iface, path="/")

# iface.launch(server_name = "0.0.0.0", server_port = 8001,share=True)

# if __name__ == "__main__":
    # Use this for debugging purposes only
 
    # uvicorn.run(app, host="0.0.0.0", port=8001)