File size: 6,194 Bytes
4ab53a5
 
d1b07b8
c2e54bd
d1b07b8
4ab53a5
 
c2e54bd
4ab53a5
f11c54d
4ab53a5
949760b
 
d1b07b8
 
 
9497543
f11c54d
9497543
4ab53a5
 
c2e54bd
d1b07b8
949760b
d1b07b8
 
4ab53a5
c2e54bd
 
 
949760b
c2e54bd
 
 
 
 
 
4ab53a5
 
c2e54bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6e291
c2e54bd
 
 
 
 
 
 
 
 
 
2a6e291
 
 
 
c2e54bd
 
 
 
2a6e291
 
c2e54bd
 
 
 
2a6e291
c2e54bd
 
 
75c875f
d1b07b8
 
 
 
 
4ab53a5
75c875f
 
4ab53a5
9497543
 
75c875f
9497543
75c875f
4ab53a5
9497543
 
 
 
 
 
 
 
 
4ab53a5
75c875f
f11c54d
9497543
 
 
 
 
 
 
 
 
d1b07b8
9497543
 
4ab53a5
ef25918
 
 
75c875f
f11c54d
9497543
 
 
 
 
 
 
 
 
f11c54d
9497543
c2e54bd
 
 
 
9497543
 
c2e54bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1b07b8
c2e54bd
 
 
 
 
 
 
 
 
 
 
 
 
9497543
 
 
c2e54bd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
'''For specifying an LLM agent logic flow.'''
from . import ops
import chromadb
import enum
import jinja2
import json
import openai
import numpy as np
import pandas as pd
from .executors import one_by_one

chat_client = openai.OpenAI(base_url="http://localhost:8080/v1")
embedding_client = openai.OpenAI(base_url="http://localhost:7997/")
jinja = jinja2.Environment()
chroma_client = chromadb.Client()
LLM_CACHE = {}
ENV = 'LLM logic'
one_by_one.register(ENV)
op = ops.op_registration(ENV)

def chat(*args, **kwargs):
  key = json.dumps({'method': 'chat', 'args': args, 'kwargs': kwargs})
  if key not in LLM_CACHE:
    completion = chat_client.chat.completions.create(*args, **kwargs)
    LLM_CACHE[key] = [c.message.content for c in completion.choices]
  return LLM_CACHE[key]

def embedding(*args, **kwargs):
  key = json.dumps({'method': 'embedding', 'args': args, 'kwargs': kwargs})
  if key not in LLM_CACHE:
    res = embedding_client.embeddings.create(*args, **kwargs)
    [data] = res.data
    LLM_CACHE[key] = data.embedding
  return LLM_CACHE[key]

@op("Input CSV")
def input_csv(*, filename: ops.PathStr, key: str):
  return pd.read_csv(filename).rename(columns={key: 'text'})

@op("Input document")
def input_document(*, filename: ops.PathStr):
  with open(filename) as f:
    return {'text': f.read()}

@op("Input chat")
def input_chat(*, chat: str):
  return {'text': chat}

@op("Split document")
def split_document(input, *, delimiter: str = '\\n\\n'):
  delimiter = delimiter.encode().decode('unicode_escape')
  chunks = input['text'].split(delimiter)
  return pd.DataFrame(chunks, columns=['text'])

@ops.input_position(input="top")
@op("Build document graph")
def build_document_graph(input):
  return [{'source': i, 'target': i+1} for i in range(len(input)-1)]

@ops.input_position(nodes="top", edges="top")
@op("Predict links")
def predict_links(nodes, edges):
  '''A placeholder for a real algorithm. For now just adds 2-hop neighbors.'''
  edge_map = {} # Source -> [Targets]
  for edge in edges:
    edge_map.setdefault(edge['source'], [])
    edge_map[edge['source']].append(edge['target'])
  new_edges = []
  for edge in edges:
    for t in edge_map.get(edge['target'], []):
      new_edges.append({'source': edge['source'], 'target': t})
  return edges + new_edges

@ops.input_position(nodes="top", edges="top")
@op("Add neighbors")
def add_neighbors(nodes, edges, item):
  nodes = pd.DataFrame(nodes)
  edges = pd.DataFrame(edges)
  matches = item['rag']
  additional_matches = []
  for m in matches:
    node = nodes[nodes['text'] == m].index[0]
    neighbors = edges[edges['source'] == node]['target'].to_list()
    additional_matches.extend(nodes.loc[neighbors, 'text'])
  return {**item, 'rag': matches + additional_matches}

@op("Create prompt")
def create_prompt(input, *, save_as='prompt', template: ops.LongStr):
  assert template, 'Please specify the template. Refer to columns using the Jinja2 syntax.'
  t = jinja.from_string(template)
  prompt = t.render(**input)
  return {**input, save_as: prompt}

@op("Ask LLM")
def ask_llm(input, *, model: str, accepted_regex: str = None, max_tokens: int = 100):
  assert model, 'Please specify the model.'
  assert 'prompt' in input, 'Please create the prompt first.'
  options = {}
  if accepted_regex:
    options['extra_body'] = {
      "guided_regex": accepted_regex,
    }
  results = chat(
    model=model,
    max_tokens=max_tokens,
    messages=[
      {"role": "user", "content": input['prompt']},
    ],
    **options,
  )
  return [{**input, 'response': r} for r in results]

@op("View", view="table_view")
def view(input, *, _ctx: one_by_one.Context):
  v = _ctx.last_result
  if v:
    columns = v['dataframes']['df']['columns']
    v['dataframes']['df']['data'].append([input[c] for c in columns])
  else:
    columns = [str(c) for c in input.keys() if not str(c).startswith('_')]
    v = {
      'dataframes': { 'df': {
        'columns': columns,
        'data': [[input[c] for c in columns]],
      }}
    }
  return v

@ops.input_position(input="right")
@ops.output_position(output="left")
@op("Loop")
def loop(input, *, max_iterations: int = 3, _ctx: one_by_one.Context):
  '''Data can flow back here max_iterations-1 times.'''
  key = f'iterations-{_ctx.node.id}'
  input[key] = input.get(key, 0) + 1
  if input[key] < max_iterations:
    return input

@op('Branch', outputs=['true', 'false'])
def branch(input, *, expression: str):
  res = eval(expression, input)
  return one_by_one.Output(output_handle=str(bool(res)).lower(), value=input)

class RagEngine(enum.Enum):
  Chroma = 'Chroma'
  Custom = 'Custom'

@ops.input_position(db="top")
@op('RAG')
def rag(
  input, db, *,
  engine: RagEngine = RagEngine.Chroma,
  input_field='text', db_field='text', num_matches: int = 10,
  _ctx: one_by_one.Context):
  if engine == RagEngine.Chroma:
    last = _ctx.last_result
    if last:
      collection = last['_collection']
    else:
      collection_name = _ctx.node.id.replace(' ', '_')
      for c in chroma_client.list_collections():
        if c.name == collection_name:
          chroma_client.delete_collection(name=collection_name)
      collection = chroma_client.create_collection(name=collection_name)
      collection.add(
        documents=[r[db_field] for r in db],
        ids=[str(i) for i in range(len(db))],
      )
    results = collection.query(
      query_texts=[input[input_field]],
      n_results=num_matches,
    )
    results = [db[int(r)] for r in results['ids'][0]]
    return {**input, 'rag': results, '_collection': collection}
  if engine == RagEngine.Custom:
    model = 'google/gemma-2-2b-it'
    chat = input[input_field]
    embeddings = [embedding(input=[r[db_field]], model=model) for r in db]
    q = embedding(input=[chat], model=model)
    def cosine_similarity(a, b):
      return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    scores = [(i, cosine_similarity(q, e)) for i, e in enumerate(embeddings)]
    scores.sort(key=lambda x: -x[1])
    matches = [db[i][db_field] for i, _ in scores[:num_matches]]
    return {**input, 'rag': matches}

@op('Run Python')
def run_python(input, *, template: str):
  '''TODO: Implement.'''
  return input