LLM_Ariphes / api.py
Euryeth's picture
Update api.py
85b0ca0 verified
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
from huggingface_hub import login
from flask import Flask, request, jsonify
# Login with Hugging Face token
login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))
API_TOKEN = os.getenv("HF_API_TOKEN")
# Model and loading config
model_name = "cerebras/btlm-3b-8k-chat"
revision = "main" # Pin to specific model revision
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
os.environ['HF_HOME'] = '/tmp/cache'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map="auto",
trust_remote_code=True,
revision=revision
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
torch_dtype=torch_dtype,
pad_token_id=tokenizer.eos_token_id,
trust_remote_code=True
)
app = Flask(__name__)
@app.route("/")
def home():
return "API is running"
@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
return jsonify({"error": "Unauthorized"}), 401
data = request.get_json()
messages = data.get("messages", [])
max_tokens = data.get("max_tokens", 256)
temperature = data.get("temperature", 0.7)
prompt = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role and content:
prompt += f"{role.capitalize()}: {content}\n"
prompt += "Assistant:"
output = generator(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True
)
generated_text = output[0]['generated_text']
assistant_reply = generated_text.replace(prompt, "").strip()
return jsonify({
"choices": [{
"message": {
"role": "assistant",
"content": assistant_reply
},
"finish_reason": "stop",
"index": 0
}]
})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8080)