Spaces:
Running
Running
'''For specifying an LLM agent logic flow.''' | |
from . import ops | |
import dataclasses | |
import inspect | |
import json | |
import openai | |
import pandas as pd | |
import traceback | |
from . import workspace | |
client = openai.OpenAI(base_url="http://localhost:11434/v1") | |
CACHE = {} | |
ENV = 'LLM logic' | |
op = ops.op_registration(ENV) | |
class Context: | |
'''Passed to operation functions as "_ctx" if they have such a parameter.''' | |
node: workspace.WorkspaceNode | |
last_result = None | |
class Output: | |
'''Return this to send values to specific outputs of a node.''' | |
output_handle: str | |
value: dict | |
def chat(*args, **kwargs): | |
key = json.dumps({'args': args, 'kwargs': kwargs}) | |
if key not in CACHE: | |
completion = client.chat.completions.create(*args, **kwargs) | |
CACHE[key] = [c.message.content for c in completion.choices] | |
return CACHE[key] | |
def input(*, filename: ops.PathStr, key: str): | |
return pd.read_csv(filename).rename(columns={key: 'text'}) | |
def create_prompt(input, *, template: ops.LongStr): | |
assert template, 'Please specify the template. Refer to columns using their names in uppercase.' | |
p = template | |
for k, v in input.items(): | |
p = p.replace(k.upper(), str(v)) | |
return p | |
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100): | |
assert model, 'Please specify the model.' | |
assert 'prompt' in input, 'Please create the prompt first.' | |
options = {} | |
if accepted_regex: | |
options['extra_body'] = { | |
"guided_regex": accepted_regex, | |
} | |
results = chat( | |
model=model, | |
max_tokens=max_tokens, | |
messages=[ | |
{"role": "user", "content": input['prompt']}, | |
], | |
**options, | |
) | |
return [{**input, 'response': r} for r in results] | |
def view(input, *, _ctx: Context): | |
v = _ctx.last_result | |
if v: | |
columns = v['dataframes']['df']['columns'] | |
v['dataframes']['df']['data'].append([input[c] for c in columns]) | |
else: | |
columns = [str(c) for c in input.keys() if not str(c).startswith('_')] | |
v = { | |
'dataframes': { 'df': { | |
'columns': columns, | |
'data': [input[c] for c in columns], | |
}} | |
} | |
return v | |
def loop(input, *, max_iterations: int = 3, _ctx: Context): | |
'''Data can flow back here max_iterations-1 times.''' | |
key = f'iterations-{_ctx.node.id}' | |
input[key] = input.get(key, 0) + 1 | |
if input[key] < max_iterations: | |
return input | |
def branch(input, *, expression: str): | |
res = eval(expression, input) | |
return Output(str(bool(res)).lower(), input) | |
def rag(input, db, *, closest_n: int=10): | |
return input | |
def run_python(input, *, template: str): | |
assert template, 'Please specify the template. Refer to columns using their names in uppercase.' | |
p = template | |
for k, v in input.items(): | |
p = p.replace(k.upper(), str(v)) | |
return p | |
def execute(ws): | |
catalog = ops.CATALOGS[ENV] | |
nodes = {n.id: n for n in ws.nodes} | |
contexts = {n.id: Context(n) for n in ws.nodes} | |
edges = {n.id: [] for n in ws.nodes} | |
for e in ws.edges: | |
edges[e.source].append(e.target) | |
tasks = {} | |
NO_INPUT = object() # Marker for initial tasks. | |
for node in ws.nodes: | |
node.data.error = None | |
op = catalog[node.data.title] | |
# Start tasks for nodes that have no inputs. | |
if not op.inputs: | |
tasks[node.id] = [NO_INPUT] | |
# Run the rest until we run out of tasks. | |
while tasks: | |
n, ts = tasks.popitem() | |
node = nodes[n] | |
data = node.data | |
op = catalog[data.title] | |
params = {**data.params} | |
if has_ctx(op): | |
params['_ctx'] = contexts[node.id] | |
results = [] | |
for task in ts: | |
try: | |
if task is NO_INPUT: | |
result = op(**params) | |
else: | |
# TODO: Tasks with multiple inputs? | |
result = op(task, **params) | |
except Exception as e: | |
traceback.print_exc() | |
data.error = str(e) | |
break | |
contexts[node.id].last_result = result | |
# Returned lists and DataFrames are considered multiple tasks. | |
if isinstance(result, pd.DataFrame): | |
result = df_to_list(result) | |
elif not isinstance(result, list): | |
result = [result] | |
results.extend(result) | |
else: # Finished all tasks without errors. | |
if op.type == 'visualization' or op.type == 'table_view': | |
data.display = results | |
for target in edges[node.id]: | |
tasks.setdefault(target, []).extend(results) | |
def df_to_list(df): | |
return [dict(zip(df.columns, row)) for row in df.values] | |
def has_ctx(op): | |
sig = inspect.signature(op.func) | |
return '_ctx' in sig.parameters | |