ashwml commited on
Commit
bc82231
·
1 Parent(s): fbccd2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -56
app.py CHANGED
@@ -5,89 +5,94 @@ from fastapi import FastAPI,Response
5
  from sklearn.metrics import accuracy_score, f1_score
6
  import prometheus_client as prom
7
  import pandas as pd
8
- # from transformers import pipeline
9
-
 
10
 
11
 
12
  #model
13
- save_file_name="xgboost-model.pkl"
14
- loaded_model = pickle.load(open(save_file_name, 'rb'))
15
 
16
- app=FastAPI()
 
 
 
17
 
18
- # username="ashwml"
19
- # repo_name="prometheus_model"
20
- # model=username+'/'+repo_name
21
- test_data=pd.read_csv("test.csv")
22
 
23
 
24
- f1_metric = prom.Gauge('death_f1_score', 'F1 score for test samples')
25
 
26
  # Function for updating metrics
27
- def update_metrics():
28
- test = test_data.sample(20)
29
- X = test.iloc[:, :-1].values
30
- y = test['DEATH_EVENT'].values
31
 
32
- # test_text = test['Text'].values
33
- test_pred = loaded_model.predict(X)
34
- #pred_labels = [int(pred['label'].split("_")[1]) for pred in test_pred]
 
 
 
 
 
 
35
 
36
- f1 = f1_score( y , test_pred).round(3)
 
 
37
 
38
- #f1 = f1_score(test['labels'], pred_labels).round(3)
39
 
40
- f1_metric.set(f1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
 
43
 
44
- def predict_death_event(age, anaemia, creatinine_phosphokinase ,diabetes ,ejection_fraction, high_blood_pressure ,platelets ,serum_creatinine, serum_sodium, sex ,smoking ,time):
45
- input=[[age, anaemia, creatinine_phosphokinase ,diabetes ,ejection_fraction, high_blood_pressure ,platelets ,serum_creatinine, serum_sodium, sex ,smoking ,time]]
46
- result=loaded_model.predict(input)
47
 
48
- if result[0]==1:
49
- return 'Positive'
50
- else:
51
- return 'Negative'
52
- return result
53
 
54
 
55
- @app.get("/metrics")
56
- async def get_metrics():
57
- update_metrics()
58
- return Response(media_type="text/plain", content= prom.generate_latest())
59
 
 
 
 
 
60
 
61
 
62
- title = "Patient Survival Prediction"
63
- description = "Predict survival of patient with heart failure, given their clinical record"
64
 
65
- out_response = gr.components.Textbox(type="text", label='Death_event')
 
66
 
67
- iface = gr.Interface(fn=predict_death_event,
68
- inputs=[
69
- gr.Slider(18, 100, value=20, label="Age"),
70
- gr.Slider(0, 1, value=1, label="anaemia"),
71
- gr.Slider(100, 2000, value=20, label="creatinine_phosphokinase"),
72
- gr.Slider(0, 1, value=1, label="diabetes"),
73
- gr.Slider(18, 100, value=20, label="ejection_fraction"),
74
- gr.Slider(0, 1, value=1, label="high_blood_pressure"),
75
- gr.Slider(18, 400000, value=20, label="platelets"),
76
- gr.Slider(1, 10, value=20, label="serum_creatinine"),
77
- gr.Slider(100, 200, value=20, label="serum_sodium"),
78
- gr.Slider(0, 1, value=1, label="sex"),
79
- gr.Slider(0, 1, value=1, label="smoking"),
80
- gr.Slider(1, 10, value=20, label="time"),
81
- ],
82
- outputs = [out_response])
83
 
 
 
 
 
 
84
 
85
- app = gr.mount_gradio_app(app, iface, path="/")
86
 
87
- # iface.launch(server_name = "0.0.0.0", server_port = 8001)
88
 
 
89
 
90
- if __name__ == "__main__":
91
  # Use this for debugging purposes only
92
- import uvicorn
93
- uvicorn.run(app, host="0.0.0.0", port=8001)
 
5
  from sklearn.metrics import accuracy_score, f1_score
6
  import prometheus_client as prom
7
  import pandas as pd
8
+ import uvicorn
9
+ from transformers import VisionEncoderDecoderModel,pipeline
10
+ import torch
11
 
12
 
13
  #model
 
 
14
 
15
+ # loaded_model = pickle.load(open(save_file_name, 'rb'))
16
+
17
+ # app=FastAPI()
18
+
19
 
20
+ # test_data=pd.read_csv("test.csv")
 
 
 
21
 
22
 
23
+ # f1_metric = prom.Gauge('death_f1_score', 'F1 score for test samples')
24
 
25
  # Function for updating metrics
26
+ # def update_metrics():
27
+ # test = test_data.sample(20)
28
+ # X = test.iloc[:, :-1].values
29
+ # y = test['DEATH_EVENT'].values
30
 
31
+ # # test_text = test['Text'].values
32
+ # test_pred = loaded_model.predict(X)
33
+ # #pred_labels = [int(pred['label'].split("_")[1]) for pred in test_pred]
34
+
35
+ # f1 = f1_score( y , test_pred).round(3)
36
+
37
+ # #f1 = f1_score(test['labels'], pred_labels).round(3)
38
+
39
+ # f1_metric.set(f1)
40
 
41
+ vitgpt_processor = AutoImageProcessor.from_pretrained("model")
42
+ vitgpt_model = VisionEncoderDecoderModel.from_pretrained("model")
43
+ vitgpt_tokenizer = AutoTokenizer.from_pretrained("model")
44
 
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
+ vitgpt_model.to(device)
48
+
49
+ def generate_caption(processor, model, image, tokenizer=None):
50
+ inputs = processor(images=image, return_tensors="pt").to(device)
51
+
52
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
53
+
54
+ if tokenizer is not None:
55
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
56
+ else:
57
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
58
+
59
+ return generated_caption
60
+
61
+ def predict_event(input):
62
+
63
 
64
 
65
+ caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
66
 
67
+ return caption_vitgpt
 
 
68
 
 
 
 
 
 
69
 
70
 
 
 
 
 
71
 
72
+ # @app.get("/metrics")
73
+ # async def get_metrics():
74
+ # update_metrics()
75
+ # return Response(media_type="text/plain", content= prom.generate_latest())
76
 
77
 
 
 
78
 
79
+ title = "capstone"
80
+ description = "final capstone"
81
 
82
+ out_response = gr.outputs.Textbox(label="Caption generated by ViT+GPT-2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ iface = gr.Interface(fn=predict_event,
85
+ inputs=gr.inputs.Image(type="pil"),
86
+ outputs=out_response,
87
+ enable_queue=True)
88
+
89
 
 
90
 
91
+ # app = gr.mount_gradio_app(app, iface, path="/")
92
 
93
+ iface.launch(server_name = "0.0.0.0", server_port = 8001)
94
 
95
+ # if __name__ == "__main__":
96
  # Use this for debugging purposes only
97
+
98
+ # uvicorn.run(app, host="0.0.0.0", port=8001)