File size: 931 Bytes
01af800
 
 
2375d69
 
 
01af800
2375d69
 
 
01af800
2375d69
 
01af800
 
2375d69
 
01af800
2375d69
 
 
 
01af800
2375d69
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from transformers import pipeline

def get_tinyllama():
    tinyllama = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
    return tinyllama

def response_tinyllama(
        model=None,
        messages=None
        ):
    
    messages_dict = [
        {
            "role": "system",
            "content": "You are a friendly and helpful chatbot",
        }
    ]
    for step in messages:
        messages_dict.append({'role': 'user', 'content': step[0]})
        if len(step) >= 2:
            messages_dict.append({'role': 'assistant', 'content': step[1]})

    prompt = model.tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True)
    outputs = model(prompt, max_new_tokens=32, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)

    return outputs[0]['generated_text'].split('<|assistant|>')[1].strip()