shisa / app.py
lhl
testing pipelines
812f70a
raw
history blame
2.42 kB
# https://www.gradio.app/guides/using-hugging-face-integrations
from transformers import pipeline
import gradio as gr
pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es")
demo = gr.Interface.from_pipeline(pipe)
demo.launch()
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
model = "mistralai/Mistral-7B-Instruct-v0.1"
model = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
# Gradio
title = "Shisa 7B"
description = "Test out Shisa 7B in either English or Japanese."
placeholder = "Type Here / ここにε…₯εŠ›γ—γ¦γγ γ•γ„"
examples = [
"Hello, how are you?",
"γ“γ‚“γ«γ‘γ―γ€ε…ƒζ°—γ§γ™γ‹οΌŸ",
"γŠγ£γ™γ€ε…ƒζ°—οΌŸ",
"γ“γ‚“γ«γ‘γ―γ€γ„γ‹γŒγŠιŽγ”γ—γ§γ™γ‹οΌŸ",
]
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForCausalLM.from_pretrained(model)
def chat(input, history=[]):
input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
history = model.generate(
input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id
).tolist()
# convert the tokens to text, and then split the responses into lines
response = tokenizer.decode(history[0]).split("<|endoftext|>")
'''
# tokenize the new input sentence
new_user_input_ids = tokenizer.encode(
input + tokenizer.eos_token, return_tensors="pt"
)
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# generate a response
history = model.generate(
bot_input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id
).tolist()
# convert the tokens to text, and then split the responses into lines
response = tokenizer.decode(history[0]).split("<|endoftext|>")
# print('decoded_response-->>'+str(response))
response = [
(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
] # convert to tuples of list
# print('response-->>'+str(response))
'''
return response, history
gr.ChatInterface(
chat,
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
title=title,
description=description,
theme="soft",
examples=examples,
cache_examples=False,
undo_btn="Delete Previous",
clear_btn="Clear",
).queue().launch()
"""