File size: 1,807 Bytes
d331351
c0fe323
efdd63d
fbbc9c1
f2bc959
 
d331351
fbbc9c1
c325ffc
d331351
beb9a26
6cf41e9
c325ffc
fce7f32
7b7ead5
 
 
 
 
d331351
7b7ead5
 
c325ffc
18e3582
c0fe323
7b7ead5
 
d331351
c325ffc
 
c0fe323
 
6db605f
 
f2bc959
c325ffc
6db605f
 
 
 
 
 
 
beb9a26
 
e2f4417
c325ffc
e2f4417
 
d331351
beb9a26
6db605f
 
c325ffc
6db605f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
from huggingface_hub import login

# Authenticate with Hugging Face token
login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))

# Setup environment and dtype for CPU/GPU compatibility
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
os.environ['HF_HOME'] = '/tmp/cache'

# Load model and tokenizer (using cerebras BTLM-3B-8K)
model_name = "cerebras/btlm-3b-8k-chat"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch_dtype,
    device_map="auto"
)

# Create text generation pipeline with required pad_token_id for this model
generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    torch_dtype=torch_dtype,
    pad_token_id=tokenizer.eos_token_id  # Important for BTLM model
)

def generate_chat_completion(message: str, history: list = None):
    """
    If history is provided as list of {'role': str, 'content': str} dicts,
    reconstructs the full prompt and returns updated history.
    """
    history = history or []
    prompt = ""
    for msg in history:
        prompt += f"{msg['role'].capitalize()}: {msg['content']}\n"
    prompt += f"User: {message}\nAssistant:"

    output = generator(
        prompt,
        max_new_tokens=256,
        temperature=0.7,    # Slightly lower temp for more coherent replies
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True
    )
    reply = output[0]['generated_text'].replace(prompt, "").strip()

    # Append new interaction to history
    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": reply})
    return history