File size: 401 Bytes
54619a0
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from transformers import pipeline

def load_llm():
    return pipeline(
        "text-generation",
        model="mistralai/Mistral-7B-Instruct-v0.1",
        device_map="auto",
        trust_remote_code=True
    )

def get_response(pipe, prompt, max_new_tokens=256):
    out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True)
    return out[0]["generated_text"].split("User:")[-1].strip()