|
import os |
|
import gradio as gr |
|
import spaces |
|
import torch |
|
import logging |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers.utils import logging as hf_logging |
|
|
|
logging.basicConfig( |
|
filename="/tmp/app.log", |
|
level=logging.DEBUG, |
|
format="%(asctime)s %(levelname)s: %(message)s" |
|
) |
|
|
|
logging.info("Starting app.py logging") |
|
hf_logging.set_verbosity_debug() |
|
hf_logging.set_verbosity_info() |
|
hf_logging.enable_default_handler() |
|
hf_logging.enable_explicit_format() |
|
hf_logging.add_handler(logging.FileHandler("/tmp/transformers.log")) |
|
|
|
|
|
model_id = "futurehouse/ether0" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
@spaces.GPU |
|
def chat_fn(prompt, max_tokens=512): |
|
max_tokens = min(int(max_tokens), 32_000) |
|
messages = [{"role": "user", "content": prompt}] |
|
chat_prompt = tokenizer.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
do_sample=True, |
|
temperature=0.1, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
|
return generated_text |
|
|
|
gr.Interface( |
|
fn=chat_fn, |
|
inputs=[ |
|
gr.Textbox(label="prompt"), |
|
gr.Number(label="max_tokens", value=512, precision=0) |
|
], |
|
outputs="text", |
|
title="Ether0" |
|
).launch(ssr_mode=False) |