Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
model_name = "microsoft/phi-2" | |
# Reload model in FP16 and merge it with LoRA weights | |
base_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
low_cpu_mem_usage=True, | |
return_dict=True, | |
torch_dtype=torch.float32, | |
trust_remote_code=True | |
# device_map=device_map, | |
) | |
from peft import PeftModel | |
new_model = "piyushgrover/phi-2-qlora-adapter-custom" | |
model = PeftModel.from_pretrained(base_model, new_model) | |
model = model.merge_and_unload() | |
model.eval() | |
# Reload tokenizer to save it | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
from transformers import pipeline | |
gen = pipeline('text-generation', model=model, tokenizer=tokenizer, max_length=300) | |
def fn_query_on_load(): | |
input = "Explain nuclear physics to a five year old kid." | |
return input | |
def generate_response(input): | |
prompt = f"### Human: {input}\n\n### Assistant: " | |
result = gen(prompt) | |
resp = result[0]['generated_text'] | |
print(resp) | |
resp_arr = resp.replace(prompt, '').split('###') | |
print(resp_arr) | |
final_resp = resp_arr[0] | |
''' | |
start_ids = encode(start) | |
x = (torch.tensor(start_ids, dtype=torch.long, device=device_type)[None, ...]) | |
out_text = '' | |
with torch.no_grad(): | |
with ctx: | |
for k in range(num_samples): | |
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) | |
out_text += decode(y[0].tolist()) | |
out_text += '\n-o-o-o-o-o-o-o-\n\n' | |
''' | |
return { | |
output: final_resp | |
} | |
with gr.Blocks() as app: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
# PhiGPT - Ask Me Anything (AI Assistant) | |
### Phi2 Model (2Billion parameters) Fine-tuned on OpenAssistant/oasst1 Dataset :) | |
#### [Please be patient as it's running on CPU & not GPU] | |
""") | |
with gr.Row(visible=True): | |
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt') | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant='primary') | |
clear_btn = gr.ClearButton() | |
with gr.Row(): | |
with gr.Row(): | |
output = gr.Textbox(lines=15, interactive=False, label='Response ') | |
def clear_data(): | |
return { | |
output: None, | |
search_text: None | |
} | |
clear_btn.click(clear_data, None, [output, search_text]) | |
submit_btn.click( | |
generate_response, | |
search_text, | |
output | |
) | |
''' | |
Launch the app | |
''' | |
app.queue().launch() |