File size: 2,710 Bytes
9982036
 
 
90dff4c
 
 
ba71a21
 
e2ba4d1
ba71a21
8767911
ed78ab3
 
 
 
 
9ff3b6e
 
 
 
8767911
 
 
b19f5b3
 
 
 
8767911
 
7f99115
 
ba71a21
423fb82
5c3a233
423fb82
5c3a233
 
 
 
423fb82
5c3a233
423fb82
fff19e6
5c3a233
 
 
fff19e6
8767911
970166d
8767911
 
 
5c3a233
e1a48ca
 
25a28bf
e1a48ca
 
3d6f217
 
 
 
 
5c3a233
9982036
1b39588
 
 
 
 
 
7f99115
 
 
e0f7776
5c3a233
 
 
0cff7d6
 
 
 
 
 
 
b19f5b3
0cff7d6
fff19e6
0cff7d6
 
 
8767911
0cff7d6
 
 
 
 
ba71a21
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import io
import base64
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont

import gradio as gr
from datasets import load_dataset
from datasets import DownloadMode, VerificationMode

STYLES = """
#container {
  margin: auto;
  width: 50%;
}

#gallery {
  height: 500px !important;
}

.center {
  text-align: center;
}

.small-big {
  font-size: 12pt !important;
}
"""

titles = []
stories = []

def add_title(image, title):
  dr = ImageDraw.Draw(image)
    
  myFont = ImageFont.truetype('arial_bold.ttf', 30)
  _, _, w, h = dr.textbbox((0, 0), title, font=myFont)
  dr.rectangle([(0, image.height-80), (image.width, (image.height-80)+h)], fill="white", outline="white")
  dr.text(((image.width-w)/2, image.height-80), title, font=myFont, fill=(0, 0, 0))

  return image

def gallery_select(gallery, evt: gr.SelectData):
  print(evt.value)
  print(evt.index)
  print(evt.target)

  return [
      gr.update(value=f"## {titles[evt.index]}", visible=True),
      gr.update(value=stories[evt.index], visible=True),
  ]

def get_gallery():
  global titles, stories
    
  images = []
  titles = []
  stories = []
  dataset = load_dataset(
      "chansung/llama2-stories", 
      download_mode=DownloadMode.FORCE_REDOWNLOAD,
      verification_mode=VerificationMode.NO_CHECKS
  )
    
  for row in dataset['train']:
    try:
        base64_image = row['image']
        base64_decoded = base64.b64decode(base64_image)
        image = Image.open(io.BytesIO(base64_decoded))
    except:
        image = Image.open('placeholder.png')

    titles.append(row['title'])
    stories.append(row['story'])
    images.append(add_title(image, row['title']))

  return images

with gr.Blocks(css=STYLES) as demo:
  with gr.Column(elem_id="container"):
    gr.Markdown("## LLaMA2 Story Showcase", elem_classes=['center'])
    gr.Markdown("This space is where community shares generated stories by [chansung/co-write-with-llama2](https://huggingface.co/spaces/chansung/co-write-with-llama2) space. "
                "Generated stories are archived in [chansung/llama2-stories](https://huggingface.co/datasets/chansung/llama2-stories) dataset repository. The gallery will be "
                "regularly updated in a daily basis.",
                elem_classes=['small-big', 'center'])
    
    gallery = gr.Gallery(get_gallery, every=3000, columns=5, container=False, elem_id="gallery")

    with gr.Column():
        title = gr.Markdown("title", visible=False, elem_classes=['center'])
        story = gr.Markdown("stories goes here...", visible=False, elem_classes=['small-big'])

    gallery.select(
        fn=gallery_select,
        inputs=[gallery],
        outputs=[title, story]
    )

demo.launch()