Spaces:
Runtime error
Runtime error
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
from huggingface_hub import login | |
from flask import Flask, request, jsonify | |
# Login with Hugging Face token | |
login(os.getenv("HUGGINGFACEHUB_API_TOKEN")) | |
API_TOKEN = os.getenv("HF_API_TOKEN") | |
# Model and loading config | |
model_name = "cerebras/btlm-3b-8k-chat" | |
revision = "main" # Pin to specific model revision | |
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
os.environ['HF_HOME'] = '/tmp/cache' | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision=revision) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch_dtype, | |
device_map="auto", | |
trust_remote_code=True, | |
revision=revision | |
) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto", | |
torch_dtype=torch_dtype, | |
pad_token_id=tokenizer.eos_token_id, | |
trust_remote_code=True | |
) | |
app = Flask(__name__) | |
def home(): | |
return "API is running" | |
def chat_completions(): | |
auth_header = request.headers.get("Authorization", "") | |
if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN: | |
return jsonify({"error": "Unauthorized"}), 401 | |
data = request.get_json() | |
messages = data.get("messages", []) | |
max_tokens = data.get("max_tokens", 256) | |
temperature = data.get("temperature", 0.7) | |
prompt = "" | |
for msg in messages: | |
role = msg.get("role", "") | |
content = msg.get("content", "") | |
if role and content: | |
prompt += f"{role.capitalize()}: {content}\n" | |
prompt += "Assistant:" | |
output = generator( | |
prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
do_sample=True | |
) | |
generated_text = output[0]['generated_text'] | |
assistant_reply = generated_text.replace(prompt, "").strip() | |
return jsonify({ | |
"choices": [{ | |
"message": { | |
"role": "assistant", | |
"content": assistant_reply | |
}, | |
"finish_reason": "stop", | |
"index": 0 | |
}] | |
}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=8080) | |