Spaces:
Running
Running
File size: 6,194 Bytes
4ab53a5 d1b07b8 c2e54bd d1b07b8 4ab53a5 c2e54bd 4ab53a5 f11c54d 4ab53a5 949760b d1b07b8 9497543 f11c54d 9497543 4ab53a5 c2e54bd d1b07b8 949760b d1b07b8 4ab53a5 c2e54bd 949760b c2e54bd 4ab53a5 c2e54bd 2a6e291 c2e54bd 2a6e291 c2e54bd 2a6e291 c2e54bd 2a6e291 c2e54bd 75c875f d1b07b8 4ab53a5 75c875f 4ab53a5 9497543 75c875f 9497543 75c875f 4ab53a5 9497543 4ab53a5 75c875f f11c54d 9497543 d1b07b8 9497543 4ab53a5 ef25918 75c875f f11c54d 9497543 f11c54d 9497543 c2e54bd 9497543 c2e54bd d1b07b8 c2e54bd 9497543 c2e54bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
'''For specifying an LLM agent logic flow.'''
from . import ops
import chromadb
import enum
import jinja2
import json
import openai
import numpy as np
import pandas as pd
from .executors import one_by_one
chat_client = openai.OpenAI(base_url="http://localhost:8080/v1")
embedding_client = openai.OpenAI(base_url="http://localhost:7997/")
jinja = jinja2.Environment()
chroma_client = chromadb.Client()
LLM_CACHE = {}
ENV = 'LLM logic'
one_by_one.register(ENV)
op = ops.op_registration(ENV)
def chat(*args, **kwargs):
key = json.dumps({'method': 'chat', 'args': args, 'kwargs': kwargs})
if key not in LLM_CACHE:
completion = chat_client.chat.completions.create(*args, **kwargs)
LLM_CACHE[key] = [c.message.content for c in completion.choices]
return LLM_CACHE[key]
def embedding(*args, **kwargs):
key = json.dumps({'method': 'embedding', 'args': args, 'kwargs': kwargs})
if key not in LLM_CACHE:
res = embedding_client.embeddings.create(*args, **kwargs)
[data] = res.data
LLM_CACHE[key] = data.embedding
return LLM_CACHE[key]
@op("Input CSV")
def input_csv(*, filename: ops.PathStr, key: str):
return pd.read_csv(filename).rename(columns={key: 'text'})
@op("Input document")
def input_document(*, filename: ops.PathStr):
with open(filename) as f:
return {'text': f.read()}
@op("Input chat")
def input_chat(*, chat: str):
return {'text': chat}
@op("Split document")
def split_document(input, *, delimiter: str = '\\n\\n'):
delimiter = delimiter.encode().decode('unicode_escape')
chunks = input['text'].split(delimiter)
return pd.DataFrame(chunks, columns=['text'])
@ops.input_position(input="top")
@op("Build document graph")
def build_document_graph(input):
return [{'source': i, 'target': i+1} for i in range(len(input)-1)]
@ops.input_position(nodes="top", edges="top")
@op("Predict links")
def predict_links(nodes, edges):
'''A placeholder for a real algorithm. For now just adds 2-hop neighbors.'''
edge_map = {} # Source -> [Targets]
for edge in edges:
edge_map.setdefault(edge['source'], [])
edge_map[edge['source']].append(edge['target'])
new_edges = []
for edge in edges:
for t in edge_map.get(edge['target'], []):
new_edges.append({'source': edge['source'], 'target': t})
return edges + new_edges
@ops.input_position(nodes="top", edges="top")
@op("Add neighbors")
def add_neighbors(nodes, edges, item):
nodes = pd.DataFrame(nodes)
edges = pd.DataFrame(edges)
matches = item['rag']
additional_matches = []
for m in matches:
node = nodes[nodes['text'] == m].index[0]
neighbors = edges[edges['source'] == node]['target'].to_list()
additional_matches.extend(nodes.loc[neighbors, 'text'])
return {**item, 'rag': matches + additional_matches}
@op("Create prompt")
def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
t = jinja.from_string(template)
prompt = t.render(**input)
return {**input, save_as: prompt}
@op("Ask LLM")
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
assert model, 'Please specify the model.'
assert 'prompt' in input, 'Please create the prompt first.'
options = {}
if accepted_regex:
options['extra_body'] = {
"guided_regex": accepted_regex,
}
results = chat(
model=model,
max_tokens=max_tokens,
messages=[
{"role": "user", "content": input['prompt']},
],
**options,
)
return [{**input, 'response': r} for r in results]
@op("View", view="table_view")
def view(input, *, _ctx: one_by_one.Context):
v = _ctx.last_result
if v:
columns = v['dataframes']['df']['columns']
v['dataframes']['df']['data'].append([input[c] for c in columns])
else:
columns = [str(c) for c in input.keys() if not str(c).startswith('_')]
v = {
'dataframes': { 'df': {
'columns': columns,
'data': [[input[c] for c in columns]],
}}
}
return v
@ops.input_position(input="right")
@ops.output_position(output="left")
@op("Loop")
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
'''Data can flow back here max_iterations-1 times.'''
key = f'iterations-{_ctx.node.id}'
input[key] = input.get(key, 0) + 1
if input[key] < max_iterations:
return input
@op('Branch', outputs=['true', 'false'])
def branch(input, *, expression: str):
res = eval(expression, input)
return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)
class RagEngine(enum.Enum):
Chroma = 'Chroma'
Custom = 'Custom'
@ops.input_position(db="top")
@op('RAG')
def rag(
input, db, *,
engine: RagEngine = RagEngine.Chroma,
input_field='text', db_field='text', num_matches: int = 10,
_ctx: one_by_one.Context):
if engine == RagEngine.Chroma:
last = _ctx.last_result
if last:
collection = last['_collection']
else:
collection_name = _ctx.node.id.replace(' ', '_')
for c in chroma_client.list_collections():
if c.name == collection_name:
chroma_client.delete_collection(name=collection_name)
collection = chroma_client.create_collection(name=collection_name)
collection.add(
documents=[r[db_field] for r in db],
ids=[str(i) for i in range(len(db))],
)
results = collection.query(
query_texts=[input[input_field]],
n_results=num_matches,
)
results = [db[int(r)] for r in results['ids'][0]]
return {**input, 'rag': results, '_collection': collection}
if engine == RagEngine.Custom:
model = 'google/gemma-2-2b-it'
chat = input[input_field]
embeddings = [embedding(input=[r[db_field]], model=model) for r in db]
q = embedding(input=[chat], model=model)
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
scores.sort(key=lambda x: -x[1])
matches = [db[i][db_field] for i, _ in scores[:num_matches]]
return {**input, 'rag': matches}
@op('Run Python')
def run_python(input, *, template: str):
'''TODO: Implement.'''
return input
|