Commit
·
9c84880
1
Parent(s):
4b492b9
fix: bug #245
Browse files- app/components/chat/Chat.client.tsx +30 -8
- app/lib/hooks/usePromptEnhancer.ts +31 -17
- app/routes/api.enhancer.ts +47 -12
app/components/chat/Chat.client.tsx
CHANGED
@@ -74,8 +74,14 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
|
|
74 |
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
75 |
|
76 |
const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
|
77 |
-
const [model, setModel] = useState(
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
const { showChat } = useStore(chatStore);
|
81 |
|
@@ -216,6 +222,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
|
|
216 |
}
|
217 |
}, []);
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
return (
|
220 |
<BaseChat
|
221 |
ref={animationScope}
|
@@ -228,9 +244,9 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
|
|
228 |
promptEnhanced={promptEnhanced}
|
229 |
sendMessage={sendMessage}
|
230 |
model={model}
|
231 |
-
setModel={
|
232 |
provider={provider}
|
233 |
-
setProvider={
|
234 |
messageRef={messageRef}
|
235 |
scrollRef={scrollRef}
|
236 |
handleInputChange={handleInputChange}
|
@@ -246,10 +262,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
|
|
246 |
};
|
247 |
})}
|
248 |
enhancePrompt={() => {
|
249 |
-
enhancePrompt(
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
}}
|
254 |
/>
|
255 |
);
|
|
|
74 |
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
75 |
|
76 |
const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
|
77 |
+
const [model, setModel] = useState(() => {
|
78 |
+
const savedModel = Cookies.get('selectedModel');
|
79 |
+
return savedModel || DEFAULT_MODEL;
|
80 |
+
});
|
81 |
+
const [provider, setProvider] = useState(() => {
|
82 |
+
const savedProvider = Cookies.get('selectedProvider');
|
83 |
+
return savedProvider || DEFAULT_PROVIDER;
|
84 |
+
});
|
85 |
|
86 |
const { showChat } = useStore(chatStore);
|
87 |
|
|
|
222 |
}
|
223 |
}, []);
|
224 |
|
225 |
+
const handleModelChange = (newModel: string) => {
|
226 |
+
setModel(newModel);
|
227 |
+
Cookies.set('selectedModel', newModel, { expires: 30 });
|
228 |
+
};
|
229 |
+
|
230 |
+
const handleProviderChange = (newProvider: string) => {
|
231 |
+
setProvider(newProvider);
|
232 |
+
Cookies.set('selectedProvider', newProvider, { expires: 30 });
|
233 |
+
};
|
234 |
+
|
235 |
return (
|
236 |
<BaseChat
|
237 |
ref={animationScope}
|
|
|
244 |
promptEnhanced={promptEnhanced}
|
245 |
sendMessage={sendMessage}
|
246 |
model={model}
|
247 |
+
setModel={handleModelChange}
|
248 |
provider={provider}
|
249 |
+
setProvider={handleProviderChange}
|
250 |
messageRef={messageRef}
|
251 |
scrollRef={scrollRef}
|
252 |
handleInputChange={handleInputChange}
|
|
|
262 |
};
|
263 |
})}
|
264 |
enhancePrompt={() => {
|
265 |
+
enhancePrompt(
|
266 |
+
input,
|
267 |
+
(input) => {
|
268 |
+
setInput(input);
|
269 |
+
scrollTextArea();
|
270 |
+
},
|
271 |
+
model,
|
272 |
+
provider,
|
273 |
+
apiKeys
|
274 |
+
);
|
275 |
}}
|
276 |
/>
|
277 |
);
|
app/lib/hooks/usePromptEnhancer.ts
CHANGED
@@ -12,41 +12,55 @@ export function usePromptEnhancer() {
|
|
12 |
setPromptEnhanced(false);
|
13 |
};
|
14 |
|
15 |
-
const enhancePrompt = async (
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
setEnhancingPrompt(true);
|
17 |
setPromptEnhanced(false);
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
const response = await fetch('/api/enhancer', {
|
20 |
method: 'POST',
|
21 |
-
body: JSON.stringify(
|
22 |
-
message: input,
|
23 |
-
}),
|
24 |
});
|
25 |
-
|
26 |
const reader = response.body?.getReader();
|
27 |
-
|
28 |
const originalInput = input;
|
29 |
-
|
30 |
if (reader) {
|
31 |
const decoder = new TextDecoder();
|
32 |
-
|
33 |
let _input = '';
|
34 |
let _error;
|
35 |
-
|
36 |
try {
|
37 |
setInput('');
|
38 |
-
|
39 |
while (true) {
|
40 |
const { value, done } = await reader.read();
|
41 |
-
|
42 |
if (done) {
|
43 |
break;
|
44 |
}
|
45 |
-
|
46 |
_input += decoder.decode(value);
|
47 |
-
|
48 |
logger.trace('Set input', _input);
|
49 |
-
|
50 |
setInput(_input);
|
51 |
}
|
52 |
} catch (error) {
|
@@ -56,10 +70,10 @@ export function usePromptEnhancer() {
|
|
56 |
if (_error) {
|
57 |
logger.error(_error);
|
58 |
}
|
59 |
-
|
60 |
setEnhancingPrompt(false);
|
61 |
setPromptEnhanced(true);
|
62 |
-
|
63 |
setTimeout(() => {
|
64 |
setInput(_input);
|
65 |
});
|
|
|
12 |
setPromptEnhanced(false);
|
13 |
};
|
14 |
|
15 |
+
const enhancePrompt = async (
|
16 |
+
input: string,
|
17 |
+
setInput: (value: string) => void,
|
18 |
+
model: string,
|
19 |
+
provider: string,
|
20 |
+
apiKeys?: Record<string, string>
|
21 |
+
) => {
|
22 |
setEnhancingPrompt(true);
|
23 |
setPromptEnhanced(false);
|
24 |
+
|
25 |
+
const requestBody: any = {
|
26 |
+
message: input,
|
27 |
+
model,
|
28 |
+
provider,
|
29 |
+
};
|
30 |
+
|
31 |
+
if (apiKeys) {
|
32 |
+
requestBody.apiKeys = apiKeys;
|
33 |
+
}
|
34 |
+
|
35 |
const response = await fetch('/api/enhancer', {
|
36 |
method: 'POST',
|
37 |
+
body: JSON.stringify(requestBody),
|
|
|
|
|
38 |
});
|
39 |
+
|
40 |
const reader = response.body?.getReader();
|
41 |
+
|
42 |
const originalInput = input;
|
43 |
+
|
44 |
if (reader) {
|
45 |
const decoder = new TextDecoder();
|
46 |
+
|
47 |
let _input = '';
|
48 |
let _error;
|
49 |
+
|
50 |
try {
|
51 |
setInput('');
|
52 |
+
|
53 |
while (true) {
|
54 |
const { value, done } = await reader.read();
|
55 |
+
|
56 |
if (done) {
|
57 |
break;
|
58 |
}
|
59 |
+
|
60 |
_input += decoder.decode(value);
|
61 |
+
|
62 |
logger.trace('Set input', _input);
|
63 |
+
|
64 |
setInput(_input);
|
65 |
}
|
66 |
} catch (error) {
|
|
|
70 |
if (_error) {
|
71 |
logger.error(_error);
|
72 |
}
|
73 |
+
|
74 |
setEnhancingPrompt(false);
|
75 |
setPromptEnhanced(true);
|
76 |
+
|
77 |
setTimeout(() => {
|
78 |
setInput(_input);
|
79 |
});
|
app/routes/api.enhancer.ts
CHANGED
@@ -2,6 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
|
|
2 |
import { StreamingTextResponse, parseStreamPart } from 'ai';
|
3 |
import { streamText } from '~/lib/.server/llm/stream-text';
|
4 |
import { stripIndents } from '~/utils/stripIndent';
|
|
|
5 |
|
6 |
const encoder = new TextEncoder();
|
7 |
const decoder = new TextDecoder();
|
@@ -11,14 +12,34 @@ export async function action(args: ActionFunctionArgs) {
|
|
11 |
}
|
12 |
|
13 |
async function enhancerAction({ context, request }: ActionFunctionArgs) {
|
14 |
-
const { message } = await request.json<{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
try {
|
17 |
const result = await streamText(
|
18 |
[
|
19 |
{
|
20 |
role: 'user',
|
21 |
-
content: stripIndents`
|
22 |
I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
|
23 |
|
24 |
IMPORTANT: Only respond with the improved prompt and nothing else!
|
@@ -30,28 +51,42 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
|
|
30 |
},
|
31 |
],
|
32 |
context.cloudflare.env,
|
|
|
|
|
33 |
);
|
34 |
|
35 |
const transformStream = new TransformStream({
|
36 |
transform(chunk, controller) {
|
37 |
-
const
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
46 |
},
|
47 |
});
|
48 |
|
49 |
const transformedStream = result.toAIStream().pipeThrough(transformStream);
|
50 |
|
51 |
return new StreamingTextResponse(transformedStream);
|
52 |
-
} catch (error) {
|
53 |
console.log(error);
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
throw new Response(null, {
|
56 |
status: 500,
|
57 |
statusText: 'Internal Server Error',
|
|
|
2 |
import { StreamingTextResponse, parseStreamPart } from 'ai';
|
3 |
import { streamText } from '~/lib/.server/llm/stream-text';
|
4 |
import { stripIndents } from '~/utils/stripIndent';
|
5 |
+
import type { StreamingOptions } from '~/lib/.server/llm/stream-text';
|
6 |
|
7 |
const encoder = new TextEncoder();
|
8 |
const decoder = new TextDecoder();
|
|
|
12 |
}
|
13 |
|
14 |
async function enhancerAction({ context, request }: ActionFunctionArgs) {
|
15 |
+
const { message, model, provider, apiKeys } = await request.json<{
|
16 |
+
message: string;
|
17 |
+
model: string;
|
18 |
+
provider: string;
|
19 |
+
apiKeys?: Record<string, string>;
|
20 |
+
}>();
|
21 |
+
|
22 |
+
// Validate 'model' and 'provider' fields
|
23 |
+
if (!model || typeof model !== 'string') {
|
24 |
+
throw new Response('Invalid or missing model', {
|
25 |
+
status: 400,
|
26 |
+
statusText: 'Bad Request'
|
27 |
+
});
|
28 |
+
}
|
29 |
+
|
30 |
+
if (!provider || typeof provider !== 'string') {
|
31 |
+
throw new Response('Invalid or missing provider', {
|
32 |
+
status: 400,
|
33 |
+
statusText: 'Bad Request'
|
34 |
+
});
|
35 |
+
}
|
36 |
|
37 |
try {
|
38 |
const result = await streamText(
|
39 |
[
|
40 |
{
|
41 |
role: 'user',
|
42 |
+
content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents`
|
43 |
I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
|
44 |
|
45 |
IMPORTANT: Only respond with the improved prompt and nothing else!
|
|
|
51 |
},
|
52 |
],
|
53 |
context.cloudflare.env,
|
54 |
+
undefined,
|
55 |
+
apiKeys
|
56 |
);
|
57 |
|
58 |
const transformStream = new TransformStream({
|
59 |
transform(chunk, controller) {
|
60 |
+
const text = decoder.decode(chunk);
|
61 |
+
const lines = text.split('\n').filter(line => line.trim() !== '');
|
62 |
+
|
63 |
+
for (const line of lines) {
|
64 |
+
try {
|
65 |
+
const parsed = parseStreamPart(line);
|
66 |
+
if (parsed.type === 'text') {
|
67 |
+
controller.enqueue(encoder.encode(parsed.value));
|
68 |
+
}
|
69 |
+
} catch (e) {
|
70 |
+
// Skip invalid JSON lines
|
71 |
+
console.warn('Failed to parse stream part:', line);
|
72 |
+
}
|
73 |
+
}
|
74 |
},
|
75 |
});
|
76 |
|
77 |
const transformedStream = result.toAIStream().pipeThrough(transformStream);
|
78 |
|
79 |
return new StreamingTextResponse(transformedStream);
|
80 |
+
} catch (error: unknown) {
|
81 |
console.log(error);
|
82 |
|
83 |
+
if (error instanceof Error && error.message?.includes('API key')) {
|
84 |
+
throw new Response('Invalid or missing API key', {
|
85 |
+
status: 401,
|
86 |
+
statusText: 'Unauthorized'
|
87 |
+
});
|
88 |
+
}
|
89 |
+
|
90 |
throw new Response(null, {
|
91 |
status: 500,
|
92 |
statusText: 'Internal Server Error',
|