ether0 / app.py
jtordable's picture
Update app.py
b893fc1 verified
raw
history blame
1.73 kB
import os
import gradio as gr
import spaces
import torch
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging as hf_logging
logging.basicConfig(
filename="/tmp/app.log",
level=logging.DEBUG,
format="%(asctime)s %(levelname)s: %(message)s"
)
logging.info("Starting app.py logging")
hf_logging.set_verbosity_debug()
hf_logging.set_verbosity_info()
hf_logging.enable_default_handler()
hf_logging.enable_explicit_format()
hf_logging.add_handler(logging.FileHandler("/tmp/transformers.log"))
model_id = "futurehouse/ether0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16
)
@spaces.GPU
def chat_fn(prompt, max_tokens=512):
max_tokens = min(int(max_tokens), 32_000)
messages = [{"role": "user", "content": prompt}]
chat_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
# Generate with proper parameters
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.1,
pad_token_id=tokenizer.eos_token_id
)
# Decode only the new tokens (not the input)
generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return generated_text
gr.Interface(
fn=chat_fn,
inputs=[
gr.Textbox(label="prompt"),
gr.Number(label="max_tokens", value=512, precision=0)
],
outputs="text",
title="Ether0"
).launch(ssr_mode=False)