Spaces:
Build error
Build error
File size: 1,773 Bytes
3f7c440 4db110c f86827c 4db110c 3f7c440 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import gradio as gr
# make function using import pip to install torch
import pip
pip.main(['install', 'torch'])
pip.main(['install', 'transformers'])
import torch
import transformers
# saved_model
def load_model(model_path):
saved_data = torch.load(
model_path,
map_location="cpu"
)
bart_best = saved_data["model"]
train_config = saved_data["config"]
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
## Load weights.
model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
model.load_state_dict(bart_best)
return model, tokenizer
# main
def inference(prompt):
model_path = "./kobart-model-essay.pth"
model, tokenizer = load_model(
model_path=model_path
)
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.unsqueeze(0)
output = model.generate(input_ids)
output = tokenizer.decode(output[0], skip_special_tokens=True)
return output
demo = gr.Interface(
fn=inference,
inputs="text",
outputs="text", #return κ°
examples=[
"κΏ μμμ λλ λ§λ²μ μ²μΌλ‘ λ λκ² λμλ€. λ§λ²μ μ²μμ λλ λΉμ루λ₯Ό νκ³ λ μλ€λ
λ€. μ²μ λ μλ€λλ λμ€, λλ μ λΉλ‘μ΄ μ±μ λ°κ²¬νκ² λμλ€. κ·Έ μ± μμλ 무μμ΄ μμκΉ? λλ κ·Έ μ± μμΌλ‘ λ€μ΄κ°λ€. μ± μμλ λ§λ²μ¬κ° μ΄κ³ μμλλ°, λμκ² λ§λ²μ κ°λ₯΄μ³ μ£Όμλ€. κ·Έ λ§λ²μΌλ‘ λλ λ΄κ° μ’μνλ μμμ λ§λ€μ΄μ λ§μκ» λ¨Ήμλ€!"
]
).launch() # launch(share=True)λ₯Ό μ€μ νλ©΄ μΈλΆμμ μ μ κ°λ₯ν λ§ν¬κ° μμ±λ¨
demo.launch() |