Spaces:
Build error
Build error
File size: 4,417 Bytes
d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa d423fba cd0a0fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from share_btn import community_icon_html, loading_icon_html, share_js
import re
import os
hf_token = os.environ.get('HF_TOKEN')
from gradio_client import Client
client = Client("https://fffiloni-test-llama-api-debug.hf.space/", hf_token=hf_token)
clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/")
def get_text_after_colon(input_text):
colon_index = input_text.find(":")
if colon_index != -1:
result_text = input_text[colon_index + 1:].strip()
return result_text
else:
return input_text
def infer(image_input, audience, clip_caption, llama_prompt):
gr.Info('Calling CLIP Interrogator ...')
clipi_result = clipi_client.predict(
image_input,
"best",
4,
api_name="/clipi2"
)
clip_caption.value = clipi_result[0]
llama_q = f"{llama_prompt} " + \
f"Here's the image description: '{clip_caption.value}'"
gr.Info('Calling Llama2 ...')
result = client.predict(
llama_q,
"I2S",
api_name="/predict"
)
result = get_text_after_colon(result)
paragraphs = result.split('\n')
formatted_text = '\n\n'.join(paragraphs)
return formatted_text, gr.Group.update(visible=True)
css = """
#col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
max-width: 13rem;
}
div#share-btn-container > div {
flex-direction: row;
background: black;
align-items: center;
}
#share-btn-container:hover {
background-color: #060606;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor:pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.5rem !important;
padding-bottom: 0.5rem !important;
right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
#share-btn-container.hidden {
display: none!important;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center">Image to Story</h1>
<p style="text-align: center">Upload an image, get a story made by Llama2 !</p>
"""
)
with gr.Row():
with gr.Column():
image_in = gr.Image(label="Image input", type="filepath", elem_id="image-in", height=420)
audience = gr.Radio(label="Target Audience", choices=["Children", "Adult"], value="Children")
clip_caption = gr.Textbox(label="CLIP Generated Caption", default="")
llama_prompt = gr.Textbox(label="Llama2 Prompt", default="I'll give you a simple image caption, please provide a fictional story for a {audience} audience that would fit well with the image. Please be creative, do not worry and only generate a cool fictional story.")
submit_btn = gr.Button('Tell me a story')
with gr.Column():
story = gr.Textbox(label="Generated Story", elem_id="story")
with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
gr.Examples(examples=[["./examples/crabby.png", "Children"],["./examples/hopper.jpeg", "Adult"]],
fn=infer,
inputs=[image_in, audience, clip_caption, llama_prompt],
outputs=[story, share_group],
cache_examples=True
)
submit_btn.click(fn=infer, inputs=[image_in, audience, clip_caption, llama_prompt], outputs
|