Spaces:
Running
Running
File size: 3,640 Bytes
4ab53a5 d1b07b8 4ab53a5 f11c54d 4ab53a5 d1b07b8 9497543 f11c54d 9497543 4ab53a5 d1b07b8 4ab53a5 d1b07b8 4ab53a5 75c875f 35abaee 4ab53a5 75c875f d1b07b8 4ab53a5 75c875f 4ab53a5 9497543 75c875f 9497543 75c875f 4ab53a5 9497543 4ab53a5 75c875f f11c54d 9497543 d1b07b8 9497543 4ab53a5 ef25918 75c875f f11c54d 9497543 f11c54d 9497543 f11c54d d1b07b8 9497543 d1b07b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
'''For specifying an LLM agent logic flow.'''
from . import ops
import chromadb
import jinja2
import json
import openai
import pandas as pd
from .executors import one_by_one
client = openai.OpenAI(base_url="http://localhost:11434/v1")
jinja = jinja2.Environment()
chroma_client = chromadb.Client()
LLM_CACHE = {}
ENV = 'LLM logic'
one_by_one.register(ENV)
op = ops.op_registration(ENV)
def chat(*args, **kwargs):
key = json.dumps({'args': args, 'kwargs': kwargs})
if key not in LLM_CACHE:
completion = client.chat.completions.create(*args, **kwargs)
LLM_CACHE[key] = [c.message.content for c in completion.choices]
return LLM_CACHE[key]
@op("Input")
def input(*, filename: ops.PathStr, key: str):
return pd.read_csv(filename).rename(columns={key: 'text'})
@op("Create prompt")
def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
t = jinja.from_string(template)
prompt = t.render(**input)
return {**input, save_as: prompt}
@op("Ask LLM")
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
assert model, 'Please specify the model.'
assert 'prompt' in input, 'Please create the prompt first.'
options = {}
if accepted_regex:
options['extra_body'] = {
"guided_regex": accepted_regex,
}
results = chat(
model=model,
max_tokens=max_tokens,
messages=[
{"role": "user", "content": input['prompt']},
],
**options,
)
return [{**input, 'response': r} for r in results]
@op("View", view="table_view")
def view(input, *, _ctx: one_by_one.Context):
v = _ctx.last_result
if v:
columns = v['dataframes']['df']['columns']
v['dataframes']['df']['data'].append([input[c] for c in columns])
else:
columns = [str(c) for c in input.keys() if not str(c).startswith('_')]
v = {
'dataframes': { 'df': {
'columns': columns,
'data': [[input[c] for c in columns]],
}}
}
return v
@ops.input_position(input="right")
@ops.output_position(output="left")
@op("Loop")
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
'''Data can flow back here max_iterations-1 times.'''
key = f'iterations-{_ctx.node.id}'
input[key] = input.get(key, 0) + 1
if input[key] < max_iterations:
return input
@op('Branch', outputs=['true', 'false'])
def branch(input, *, expression: str):
res = eval(expression, input)
return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
@ops.input_position(db="top")
@op('RAG')
def rag(input, db, *, input_field='text', db_field='text', num_matches: int=10, _ctx: one_by_one.Context):
last = _ctx.last_result
if last:
collection = last['_collection']
else:
collection_name = _ctx.node.id.replace(' ', '_')
for c in chroma_client.list_collections():
if c.name == collection_name:
chroma_client.delete_collection(name=collection_name)
collection = chroma_client.create_collection(name=collection_name)
collection.add(
documents=[r[db_field] for r in db],
ids=[str(i) for i in range(len(db))],
)
results = collection.query(
query_texts=[input[input_field]],
n_results=num_matches,
)
results = [db[int(r)] for r in results['ids'][0]]
return {**input, 'rag': results, '_collection': collection}
@op('Run Python')
def run_python(input, *, template: str):
assert template, 'Please specify the template. Refer to columns using their names in uppercase.'
p = template
for k, v in input.items():
p = p.replace(k.upper(), str(v))
return p
|