|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
model_path = "gupta1912/phi-2-custom-oasst1" |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
def generate_text(prompt, response_length): |
|
|
|
prompt = str(prompt) |
|
max_len = int(response_length) |
|
|
|
gen = pipeline('text-generation', model=model, tokenizer=tokenizer, max_length=max_len) |
|
result = gen(f"<s>[INST] {prompt} [/INST]") |
|
output_msg = result[0]['generated_text'].split("[/INST] ")[1] |
|
return output_msg |
|
|
|
def gradio_fn(prompt, response_length): |
|
output_txt_msg = generate_text(prompt, response_length) |
|
return output_txt_msg |
|
|
|
markdown_description = """ |
|
- This is a Gradio app that answers the query you ask it |
|
- Uses **microsoft/phi-2** model finetuned on **OpenAssistant/oasst1** dataset |
|
""" |
|
demo = gr.Interface(fn=gradio_fn, |
|
inputs=[gr.Textbox(info="How may I help you ? please enter your prompt here..."), |
|
gr.Slider(value=50, minimum=50, maximum=300, \ |
|
info="Choose a response length min chars=50, max=300")], |
|
outputs=gr.Textbox(), |
|
title="custom trained phi2 - Dialog Partner", |
|
description=markdown_description) |
|
|
|
demo.queue().launch(share=True, debug=True) |
|
|