import io import base64 from PIL import Image from PIL import ImageDraw from PIL import ImageFont import gradio as gr from datasets import load_dataset from datasets import DownloadMode, VerificationMode STYLES = """ #container { margin: auto; width: 50%; } #gallery { height: 500px !important; } .center { text-align: center; } .small-big { font-size: 12pt !important; } """ titles = [] stories = [] def add_title(image, title): dr = ImageDraw.Draw(image) myFont = ImageFont.truetype('arial_bold.ttf', 30) _, _, w, h = dr.textbbox((0, 0), title, font=myFont) dr.rectangle([(0, image.height-80), (image.width, (image.height-80)+h)], fill="white", outline="white") dr.text(((image.width-w)/2, image.height-80), title, font=myFont, fill=(0, 0, 0)) return image def gallery_select(gallery, evt: gr.SelectData): print(evt.value) print(evt.index) print(evt.target) return [ gr.update(value=f"## {titles[evt.index]}", visible=True), gr.update(value=stories[evt.index], visible=True), ] def get_gallery(): global titles, stories images = [] titles = [] stories = [] dataset = load_dataset( "chansung/llama2-stories", download_mode=DownloadMode.FORCE_REDOWNLOAD, verification_mode=VerificationMode.NO_CHECKS ) for row in dataset['train']: try: base64_image = row['image'] base64_decoded = base64.b64decode(base64_image) image = Image.open(io.BytesIO(base64_decoded)) except: image = Image.open('placeholder.png') titles.append(row['title']) stories.append(row['story']) images.append(add_title(image, row['title'])) return images with gr.Blocks(css=STYLES) as demo: with gr.Column(elem_id="container"): gr.Markdown("## LLaMA2 Story Showcase", elem_classes=['center']) gr.Markdown("This space is where community shares generated stories by [chansung/co-write-with-llama2](https://huggingface.co/spaces/chansung/co-write-with-llama2) space. " "Generated stories are archived in [chansung/llama2-stories](https://huggingface.co/datasets/chansung/llama2-stories) dataset repository. The gallery will be " "regularly updated in a daily basis.", elem_classes=['small-big', 'center']) gallery = gr.Gallery(get_gallery, every=3000, columns=5, container=False, elem_id="gallery") with gr.Column(): title = gr.Markdown("title", visible=False, elem_classes=['center']) story = gr.Markdown("stories goes here...", visible=False, elem_classes=['small-big']) gallery.select( fn=gallery_select, inputs=[gallery], outputs=[title, story] ) demo.launch()