stillerman HF Staff commited on
Commit
ff1e468
·
1 Parent(s): 8def8f4

streamable inference

Browse files
src/components/game-component.tsx CHANGED
@@ -1,14 +1,19 @@
1
  "use client";
2
 
3
- import { useState, useEffect, useCallback, useMemo } from "react";
4
  import { Card } from "@/components/ui/card";
5
  import { Button } from "@/components/ui/button";
6
- import { Flag, Clock, Hash, BarChart, ArrowRight, Bot } from "lucide-react";
7
- import inference from "@/lib/inference";
8
- import ReasoningTrace, { Run, Step } from "./reasoning-trace";
9
- import ForceDirectedGraph from "./force-directed-graph";
10
  import { API_BASE } from "@/lib/constants";
11
 
 
 
 
 
 
 
12
  const buildPrompt = (
13
  current: string,
14
  target: string,
@@ -69,24 +74,12 @@ export default function GameComponent({
69
  "playing"
70
  );
71
 
72
- const [reasoningTrace, setReasoningTrace] = useState<Run | null>({
73
- start_article: startPage,
74
- destination_article: targetPage,
75
- steps: [
76
- {
77
- type: "start",
78
- article: startPage,
79
- metadata: {
80
- message: "Starting Node",
81
- },
82
- },
83
- ],
84
- });
85
- const [isModelThinking, setIsModelThinking] = useState<boolean>(false);
86
 
87
- const runs = useMemo(() => {
88
- return reasoningTrace ? [reasoningTrace] : [];
89
- }, [reasoningTrace]);
 
90
 
91
  const fetchCurrentPageLinks = useCallback(async () => {
92
  setLinksLoading(true);
@@ -121,43 +114,15 @@ export default function GameComponent({
121
  }
122
  }, [currentPage, targetPage, hops, maxHops]);
123
 
124
- const addStepToReasoningTrace = (step: Step) => {
125
- setReasoningTrace((prev) => {
126
- if (!prev)
127
- return {
128
- steps: [step],
129
- start_article: startPage,
130
- destination_article: targetPage,
131
- };
132
- return {
133
- steps: [...prev.steps, step],
134
- start_article: startPage,
135
- destination_article: targetPage,
136
- };
137
- });
138
- };
139
-
140
- const handleLinkClick = (link: string, userClicked: boolean = true) => {
141
  if (gameStatus !== "playing") return;
142
 
143
  setCurrentPage(link);
144
  setHops((prev) => prev + 1);
145
  setVisitedNodes((prev) => [...prev, link]);
146
- if (userClicked) {
147
- addStepToReasoningTrace({
148
- type: "step",
149
- article: link,
150
- links: currentPageLinks,
151
- metadata: {
152
- message: "User clicked link",
153
- },
154
- });
155
- }
156
  };
157
 
158
  const makeModelMove = async () => {
159
- setIsModelThinking(true);
160
-
161
  const prompt = buildPrompt(
162
  currentPage,
163
  targetPage,
@@ -165,16 +130,25 @@ export default function GameComponent({
165
  currentPageLinks
166
  );
167
 
 
 
 
 
 
168
  const modelResponse = await inference({
169
- apiKey:
170
- window.localStorage.getItem("huggingface_access_token") || undefined,
171
  model: model,
172
  prompt,
173
  maxTokens: maxTokens,
174
  });
175
- console.log("Model response", modelResponse.content);
176
 
177
- const answer = modelResponse.content?.match(/<answer>(.*?)<\/answer>/)?.[1];
 
 
 
 
 
 
 
178
  if (!answer) {
179
  console.error("No answer found in model response");
180
  return;
@@ -207,27 +181,7 @@ export default function GameComponent({
207
  currentPageLinks
208
  );
209
 
210
- addStepToReasoningTrace({
211
- type: "step",
212
- article: selectedLink,
213
- links: currentPageLinks,
214
- metadata: {
215
- message: "Model picked link",
216
- conversation: [
217
- {
218
- role: "user",
219
- content: prompt,
220
- },
221
- {
222
- role: "assistant",
223
- content: modelResponse.content || "",
224
- },
225
- ],
226
- },
227
- });
228
-
229
- handleLinkClick(selectedLink, false);
230
- setIsModelThinking(false);
231
  };
232
 
233
  const handleGiveUp = () => {
@@ -240,6 +194,10 @@ export default function GameComponent({
240
  return `${mins}:${secs < 10 ? "0" : ""}${secs}`;
241
  };
242
 
 
 
 
 
243
  return (
244
  <div className="grid grid-cols-1 md:grid-cols-2 gap-4">
245
  <Card className="p-4 flex col-span-2">
@@ -280,7 +238,7 @@ export default function GameComponent({
280
  <div className="flex items-center gap-2">
281
  <Bot className="h-4 w-4 text-blue-500" />
282
  <span className="font-medium text-blue-700">
283
- {model} {isModelThinking ? "is playing..." : "is playing"}
284
  </span>
285
  </div>
286
  </div>
@@ -310,7 +268,7 @@ export default function GameComponent({
310
  size="sm"
311
  className="justify-start overflow-hidden text-ellipsis whitespace-nowrap"
312
  onClick={() => handleLinkClick(link)}
313
- disabled={player === "model" || isModelThinking}
314
  >
315
  {link}
316
  </Button>
@@ -320,7 +278,7 @@ export default function GameComponent({
320
  {player === "model" && (
321
  <Button
322
  onClick={makeModelMove}
323
- disabled={isModelThinking || linksLoading}
324
  >
325
  Make Move
326
  </Button>
@@ -328,7 +286,7 @@ export default function GameComponent({
328
  </>
329
  )}
330
 
331
- {player === "model" && isModelThinking && gameStatus === "playing" && (
332
  <div className="flex items-center gap-2 text-sm animate-pulse mb-4">
333
  <Bot className="h-4 w-4" />
334
  <span>{model} is thinking...</span>
@@ -367,12 +325,28 @@ export default function GameComponent({
367
  </div>
368
  )}
369
  </Card>
370
- {/*
371
- <Card className="p-4 flex flex-col max-h-[500px] overflow-y-auto">
372
- <ReasoningTrace run={reasoningTrace} />
373
- </Card> */}
374
 
375
  <Card className="p-4 flex flex-col max-h-[500px] overflow-y-auto">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  <iframe
377
  src={`https://simple.wikipedia.org/wiki/${currentPage.replace(
378
  /\s+/g,
@@ -380,19 +354,6 @@ export default function GameComponent({
380
  )}`}
381
  className="w-full h-full"
382
  />
383
- </Card>
384
-
385
- {/* Right pane - Game stats and graph */}
386
- {/* <Card className="p-4 flex flex-col overflow-y-auto">
387
- <div className="flex-1 bg-muted/30 rounded-md p-4">
388
- <div className="flex items-center gap-2 text-sm font-medium text-muted-foreground mb-2">
389
- <BarChart className="h-4 w-4" /> Path Visualization
390
- </div>
391
-
392
- <div className="h-[500px]">
393
- <ForceDirectedGraph runs={runs} runId={0} />
394
- </div>
395
- </div>
396
  </Card> */}
397
  </div>
398
  );
 
1
  "use client";
2
 
3
+ import { useState, useEffect, useCallback } from "react";
4
  import { Card } from "@/components/ui/card";
5
  import { Button } from "@/components/ui/button";
6
+ import { Flag, Clock, Hash, ArrowRight, Bot } from "lucide-react";
7
+ import { useInference } from "@/lib/inference";
8
+
 
9
  import { API_BASE } from "@/lib/constants";
10
 
11
+
12
+ type Message = {
13
+ role: "user" | "assistant";
14
+ content: string;
15
+ };
16
+
17
  const buildPrompt = (
18
  current: string,
19
  target: string,
 
74
  "playing"
75
  );
76
 
77
+ const [convo, setConvo] = useState([]);
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ const { status: modelStatus, partialText, inferenceResult, inference } = useInference({
80
+ apiKey:
81
+ window.localStorage.getItem("huggingface_access_token") || undefined,
82
+ });
83
 
84
  const fetchCurrentPageLinks = useCallback(async () => {
85
  setLinksLoading(true);
 
114
  }
115
  }, [currentPage, targetPage, hops, maxHops]);
116
 
117
+ const handleLinkClick = (link: string) => {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if (gameStatus !== "playing") return;
119
 
120
  setCurrentPage(link);
121
  setHops((prev) => prev + 1);
122
  setVisitedNodes((prev) => [...prev, link]);
 
 
 
 
 
 
 
 
 
 
123
  };
124
 
125
  const makeModelMove = async () => {
 
 
126
  const prompt = buildPrompt(
127
  currentPage,
128
  targetPage,
 
130
  currentPageLinks
131
  );
132
 
133
+ pushConvo({
134
+ role: "user",
135
+ content: prompt,
136
+ });
137
+
138
  const modelResponse = await inference({
 
 
139
  model: model,
140
  prompt,
141
  maxTokens: maxTokens,
142
  });
 
143
 
144
+ pushConvo({
145
+ role: "assistant",
146
+ content: modelResponse,
147
+ });
148
+
149
+ console.log("Model response", modelResponse);
150
+
151
+ const answer = modelResponse.match(/<answer>(.*?)<\/answer>/)?.[1];
152
  if (!answer) {
153
  console.error("No answer found in model response");
154
  return;
 
181
  currentPageLinks
182
  );
183
 
184
+ handleLinkClick(selectedLink);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  };
186
 
187
  const handleGiveUp = () => {
 
194
  return `${mins}:${secs < 10 ? "0" : ""}${secs}`;
195
  };
196
 
197
+ const pushConvo = (message: Message) => {
198
+ setConvo((prev) => [...prev, message]);
199
+ };
200
+
201
  return (
202
  <div className="grid grid-cols-1 md:grid-cols-2 gap-4">
203
  <Card className="p-4 flex col-span-2">
 
238
  <div className="flex items-center gap-2">
239
  <Bot className="h-4 w-4 text-blue-500" />
240
  <span className="font-medium text-blue-700">
241
+ {model} {modelStatus === "thinking" ? "is thinking..." : "is playing"}
242
  </span>
243
  </div>
244
  </div>
 
268
  size="sm"
269
  className="justify-start overflow-hidden text-ellipsis whitespace-nowrap"
270
  onClick={() => handleLinkClick(link)}
271
+ disabled={player === "model" || modelStatus === "thinking"}
272
  >
273
  {link}
274
  </Button>
 
278
  {player === "model" && (
279
  <Button
280
  onClick={makeModelMove}
281
+ disabled={modelStatus === "thinking" || linksLoading}
282
  >
283
  Make Move
284
  </Button>
 
286
  </>
287
  )}
288
 
289
+ {player === "model" && modelStatus === "thinking" && gameStatus === "playing" && (
290
  <div className="flex items-center gap-2 text-sm animate-pulse mb-4">
291
  <Bot className="h-4 w-4" />
292
  <span>{model} is thinking...</span>
 
325
  </div>
326
  )}
327
  </Card>
 
 
 
 
328
 
329
  <Card className="p-4 flex flex-col max-h-[500px] overflow-y-auto">
330
+ <h2 className="text-xl font-bold">LLM Reasoning</h2>
331
+ {
332
+ convo.map((message, index) => (
333
+ <div key={index}>
334
+ <p>{message.role}</p>
335
+ <p>{message.content}</p>
336
+ <hr />
337
+ </div>
338
+ ))
339
+ }
340
+
341
+ { modelStatus === "thinking" && (
342
+ <div className="flex items-center gap-2 text-sm animate-pulse mb-4">
343
+ <Bot className="h-4 w-4" />
344
+ <p>{partialText}</p>
345
+ </div>
346
+ )}
347
+ </Card>
348
+
349
+ {/* <Card className="p-4 flex flex-col max-h-[500px] overflow-y-auto">
350
  <iframe
351
  src={`https://simple.wikipedia.org/wiki/${currentPage.replace(
352
  /\s+/g,
 
354
  )}`}
355
  className="w-full h-full"
356
  />
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  </Card> */}
358
  </div>
359
  );
src/lib/inference.tsx CHANGED
@@ -1,4 +1,5 @@
1
  import { InferenceClient } from "@huggingface/inference";
 
2
 
3
  export default async function inference({
4
  prompt,
@@ -33,7 +34,120 @@ export default async function inference({
33
  ],
34
  max_tokens: maxTokens,
35
  });
 
36
 
37
  console.log("Inference response", chatCompletion.choices[0].message);
38
  return chatCompletion.choices[0].message;
39
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import { InferenceClient } from "@huggingface/inference";
2
+ import { useState } from "react";
3
 
4
  export default async function inference({
5
  prompt,
 
34
  ],
35
  max_tokens: maxTokens,
36
  });
37
+
38
 
39
  console.log("Inference response", chatCompletion.choices[0].message);
40
  return chatCompletion.choices[0].message;
41
+ }
42
+
43
+ export function useInferenceOld({ apiKey }) {
44
+ const [isLoading, setIsLoading] = useState(false);
45
+ const [partialText, setPartialText] = useState("");
46
+ const [inferenceResult, setInferenceResult] = useState("");
47
+ const [error, setError] = useState<string | null>(null);
48
+ const inferenceInternal = async ({
49
+ prompt,
50
+ model,
51
+ maxTokens,
52
+ }: {
53
+ prompt: string;
54
+ model: string;
55
+ maxTokens: number;
56
+ }) => {
57
+ setIsLoading(true);
58
+ setPartialText("boop boop partial text");
59
+
60
+ try {
61
+ const result = await inference({
62
+ prompt,
63
+ model,
64
+ apiKey,
65
+ maxTokens,
66
+ });
67
+
68
+ setInferenceResult(result.content);
69
+ setIsLoading(false);
70
+
71
+ return result.content;
72
+ } catch (error) {
73
+ console.error("Error in inference", error);
74
+ setError(error.message);
75
+ setIsLoading(false);
76
+ return null;
77
+ }
78
+ };
79
+
80
+ const status = isLoading ? "thinking" : error ? "error" : "done";
81
+
82
+ return {
83
+ status,
84
+ partialText,
85
+ inferenceResult,
86
+ error,
87
+ inference: inferenceInternal,
88
+ };
89
+ }
90
+
91
+
92
+ export function useInference({ apiKey }) {
93
+ const [isLoading, setIsLoading] = useState(false);
94
+ const [partialText, setPartialText] = useState("");
95
+ const [inferenceResult, setInferenceResult] = useState("");
96
+ const [error, setError] = useState<string | null>(null);
97
+ const inferenceInternal = async ({
98
+ prompt,
99
+ model,
100
+ maxTokens,
101
+ }: {
102
+ prompt: string;
103
+ model: string;
104
+ maxTokens: number;
105
+ }) => {
106
+ setIsLoading(true);
107
+ setPartialText("");
108
+
109
+ const client = new InferenceClient(apiKey);
110
+
111
+ try {
112
+ const stream = client.chatCompletionStream({
113
+ provider: "hyperbolic",
114
+ model,
115
+ maxTokens,
116
+ messages: [
117
+ {
118
+ role: "user",
119
+ content: prompt,
120
+ },
121
+ ],
122
+ });
123
+
124
+ let result = "";
125
+
126
+ for await (const chunk of stream) {
127
+ result += chunk.choices[0].delta.content;
128
+ setPartialText(result);
129
+ }
130
+
131
+ setIsLoading(false);
132
+
133
+ setInferenceResult(result);
134
+
135
+ return result;
136
+ } catch (error) {
137
+ console.error("Error in inference", error);
138
+ setError(error.message);
139
+ setIsLoading(false);
140
+ return null;
141
+ }
142
+ };
143
+
144
+ const status = isLoading ? "thinking" : error ? "error" : "done";
145
+
146
+ return {
147
+ status,
148
+ partialText,
149
+ inferenceResult,
150
+ error,
151
+ inference: inferenceInternal,
152
+ };
153
+ }