import gradio as gr import os os.system("pip install huggingface_hub") from huggingface_hub import space_info from predict import * from transformers import T5ForConditionalGeneration from transformers import T5TokenizerFast as T5Tokenizer import pandas as pd model = "svjack/comet-atomic-zh" device = "cpu" #device = "cuda:0" tokenizer = T5Tokenizer.from_pretrained(model) model = T5ForConditionalGeneration.from_pretrained(model).to(device).eval() NEED_PREFIX = '以下事件有哪些必要的先决条件:' EFFECT_PREFIX = '下面的事件发生后可能会发生什么:' INTENT_PREFIX = '以下事件的动机是什么:' REACT_PREFIX = '以下事件发生后,你有什么感觉:' obj = Obj(model, tokenizer, device) text0 = "X吃到了一顿大餐。" text1 = "X和Y一起搭了个积木。" example_sample = [ [text0, False], [text1, False], ] def demo_func(event, do_sample): #event = "X吃到了一顿大餐。" times = 1 df = pd.DataFrame( pd.Series( [NEED_PREFIX, EFFECT_PREFIX, INTENT_PREFIX, REACT_PREFIX] ).map( lambda x: (x, [obj.predict( "{}{}".format(x, event), do_sample = do_sample )[0] for _ in range(times)][0]) ).values.tolist() ) df.columns = ["PREFIX", "PRED"] l = df.apply(lambda x: x.to_dict(), axis = 1).values.tolist() return { "Output": l } markdown_exp_size = "##" lora_repo = "svjack/chatglm3-few-shot" lora_repo_link = "svjack/chatglm3-few-shot/?input_list_index=5" emoji_info = space_info(lora_repo).__dict__["cardData"]["emoji"] space_cnt = 1 task_name = "[---Chinese Comet Atomic---]" description = f"{markdown_exp_size} {task_name} few shot prompt in ChatGLM3 Few Shot space repo (click submit to activate) : [{lora_repo_link}](https://huggingface.co/spaces/{lora_repo_link}) {emoji_info}" demo = gr.Interface( fn=demo_func, inputs=[gr.Text(label = "Event"), gr.Checkbox(label="do sample"), ], outputs="json", title=f"Chinese Comet Atomic 🐰 demonstration", description = 'This _example_ was **drive** from

[https://github.com/svjack/COMET-ATOMIC-En-Zh](https://github.com/svjack/COMET-ATOMIC-En-Zh)

\n', #description = description, examples=example_sample if example_sample else None, cache_examples = False ) with demo: gr.HTML( '''
''' ) demo.launch(server_name=None, server_port=None)