PhiGPT / app.py
piyushgrover's picture
Merge branch 'main' of https://huggingface.co/spaces/piyushgrover/PhiGPT into main
abaf27c
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "microsoft/phi-2"
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float32,
trust_remote_code=True
# device_map=device_map,
)
from peft import PeftModel
new_model = "piyushgrover/phi-2-qlora-adapter-custom"
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()
model.eval()
# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
from transformers import pipeline
gen = pipeline('text-generation', model=model, tokenizer=tokenizer, max_length=300)
def fn_query_on_load():
input = "Explain nuclear physics to a five year old kid."
return input
def generate_response(input):
prompt = f"### Human: {input}\n\n### Assistant: "
result = gen(prompt)
resp = result[0]['generated_text']
print(resp)
resp_arr = resp.replace(prompt, '').split('###')
print(resp_arr)
final_resp = resp_arr[0]
'''
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device_type)[None, ...])
out_text = ''
with torch.no_grad():
with ctx:
for k in range(num_samples):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
out_text += decode(y[0].tolist())
out_text += '\n-o-o-o-o-o-o-o-\n\n'
'''
return {
output: final_resp
}
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""
# PhiGPT - Ask Me Anything (AI Assistant)
### Phi2 Model (2Billion parameters) Fine-tuned on OpenAssistant/oasst1 Dataset :)
#### [Please be patient as it's running on CPU & not GPU]
""")
with gr.Row(visible=True):
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt')
with gr.Row():
submit_btn = gr.Button("Submit", variant='primary')
clear_btn = gr.ClearButton()
with gr.Row():
with gr.Row():
output = gr.Textbox(lines=15, interactive=False, label='Response ')
def clear_data():
return {
output: None,
search_text: None
}
clear_btn.click(clear_data, None, [output, search_text])
submit_btn.click(
generate_response,
search_text,
output
)
'''
Launch the app
'''
app.queue().launch()