File size: 4,417 Bytes
d423fba
 
 
 
cd0a0fa
d423fba
 
 
 
 
 
 
 
 
 
 
 
 
cd0a0fa
d423fba
 
cd0a0fa
 
 
d423fba
 
cd0a0fa
d423fba
cd0a0fa
 
 
d423fba
 
cd0a0fa
d423fba
 
 
 
 
 
 
 
cd0a0fa
d423fba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd0a0fa
 
d423fba
 
cd0a0fa
d423fba
 
 
 
 
 
 
cd0a0fa
d423fba
 
 
 
cd0a0fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
from share_btn import community_icon_html, loading_icon_html, share_js
import re
import os 

hf_token = os.environ.get('HF_TOKEN')
from gradio_client import Client
client = Client("https://fffiloni-test-llama-api-debug.hf.space/", hf_token=hf_token)
clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/")

def get_text_after_colon(input_text):
    colon_index = input_text.find(":")
    if colon_index != -1:
        result_text = input_text[colon_index + 1:].strip()
        return result_text
    else:
        return input_text

def infer(image_input, audience, clip_caption, llama_prompt):
    gr.Info('Calling CLIP Interrogator ...')
    clipi_result = clipi_client.predict(
    				image_input,
    				"best",
    				4,
    				api_name="/clipi2"
    )
    clip_caption.value = clipi_result[0]

    llama_q = f"{llama_prompt} " + \
              f"Here's the image description: '{clip_caption.value}'"
              
    gr.Info('Calling Llama2 ...')
    result = client.predict(
    				llama_q,
                    "I2S",
    				api_name="/predict"
    )
    result = get_text_after_colon(result)
    paragraphs = result.split('\n')
    formatted_text = '\n\n'.join(paragraphs)
    return formatted_text, gr.Group.update(visible=True)

css = """
#col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
  animation: spin 1s linear infinite;
}
@keyframes spin {
  from {
      transform: rotate(0deg);
  }
  to {
      transform: rotate(360deg);
  }
}
#share-btn-container {
  display: flex; 
  padding-left: 0.5rem !important; 
  padding-right: 0.5rem !important; 
  background-color: #000000; 
  justify-content: center; 
  align-items: center; 
  border-radius: 9999px !important; 
  max-width: 13rem;
}
div#share-btn-container > div {
    flex-direction: row;
    background: black;
    align-items: center;
}
#share-btn-container:hover {
  background-color: #060606;
}
#share-btn {
  all: initial; 
  color: #ffffff;
  font-weight: 600; 
  cursor:pointer; 
  font-family: 'IBM Plex Sans', sans-serif; 
  margin-left: 0.5rem !important; 
  padding-top: 0.5rem !important; 
  padding-bottom: 0.5rem !important;
  right:0;
}
#share-btn * {
  all: unset;
}
#share-btn-container div:nth-child(-n+2){
  width: auto !important;
  min-height: 0px !important;
}
#share-btn-container .wrap {
  display: none !important;
}
#share-btn-container.hidden {
  display: none!important;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            """
            <h1 style="text-align: center">Image to Story</h1>
            <p style="text-align: center">Upload an image, get a story made by Llama2 !</p>
            """
        )
        with gr.Row():
            with gr.Column():
                image_in = gr.Image(label="Image input", type="filepath", elem_id="image-in", height=420)
                audience = gr.Radio(label="Target Audience", choices=["Children", "Adult"], value="Children")
                clip_caption = gr.Textbox(label="CLIP Generated Caption", default="")
                llama_prompt = gr.Textbox(label="Llama2 Prompt", default="I'll give you a simple image caption, please provide a fictional story for a {audience} audience that would fit well with the image. Please be creative, do not worry and only generate a cool fictional story.")
                submit_btn = gr.Button('Tell me a story')
            with gr.Column():
                story = gr.Textbox(label="Generated Story", elem_id="story")
                with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
                    community_icon = gr.HTML(community_icon_html)
                    loading_icon = gr.HTML(loading_icon_html)
                    share_button = gr.Button("Share to community", elem_id="share-btn")
        
        gr.Examples(examples=[["./examples/crabby.png", "Children"],["./examples/hopper.jpeg", "Adult"]],
                    fn=infer,
                    inputs=[image_in, audience, clip_caption, llama_prompt],
                    outputs=[story, share_group],
                    cache_examples=True
                   )
        
    submit_btn.click(fn=infer, inputs=[image_in, audience, clip_caption, llama_prompt], outputs