AliHassan00 commited on
Commit
9c84880
·
1 Parent(s): 4b492b9

fix: bug #245

Browse files
app/components/chat/Chat.client.tsx CHANGED
@@ -74,8 +74,14 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
74
  const textareaRef = useRef<HTMLTextAreaElement>(null);
75
 
76
  const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
77
- const [model, setModel] = useState(DEFAULT_MODEL);
78
- const [provider, setProvider] = useState(DEFAULT_PROVIDER);
 
 
 
 
 
 
79
 
80
  const { showChat } = useStore(chatStore);
81
 
@@ -216,6 +222,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
216
  }
217
  }, []);
218
 
 
 
 
 
 
 
 
 
 
 
219
  return (
220
  <BaseChat
221
  ref={animationScope}
@@ -228,9 +244,9 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
228
  promptEnhanced={promptEnhanced}
229
  sendMessage={sendMessage}
230
  model={model}
231
- setModel={setModel}
232
  provider={provider}
233
- setProvider={setProvider}
234
  messageRef={messageRef}
235
  scrollRef={scrollRef}
236
  handleInputChange={handleInputChange}
@@ -246,10 +262,16 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
246
  };
247
  })}
248
  enhancePrompt={() => {
249
- enhancePrompt(input, (input) => {
250
- setInput(input);
251
- scrollTextArea();
252
- });
 
 
 
 
 
 
253
  }}
254
  />
255
  );
 
74
  const textareaRef = useRef<HTMLTextAreaElement>(null);
75
 
76
  const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
77
+ const [model, setModel] = useState(() => {
78
+ const savedModel = Cookies.get('selectedModel');
79
+ return savedModel || DEFAULT_MODEL;
80
+ });
81
+ const [provider, setProvider] = useState(() => {
82
+ const savedProvider = Cookies.get('selectedProvider');
83
+ return savedProvider || DEFAULT_PROVIDER;
84
+ });
85
 
86
  const { showChat } = useStore(chatStore);
87
 
 
222
  }
223
  }, []);
224
 
225
+ const handleModelChange = (newModel: string) => {
226
+ setModel(newModel);
227
+ Cookies.set('selectedModel', newModel, { expires: 30 });
228
+ };
229
+
230
+ const handleProviderChange = (newProvider: string) => {
231
+ setProvider(newProvider);
232
+ Cookies.set('selectedProvider', newProvider, { expires: 30 });
233
+ };
234
+
235
  return (
236
  <BaseChat
237
  ref={animationScope}
 
244
  promptEnhanced={promptEnhanced}
245
  sendMessage={sendMessage}
246
  model={model}
247
+ setModel={handleModelChange}
248
  provider={provider}
249
+ setProvider={handleProviderChange}
250
  messageRef={messageRef}
251
  scrollRef={scrollRef}
252
  handleInputChange={handleInputChange}
 
262
  };
263
  })}
264
  enhancePrompt={() => {
265
+ enhancePrompt(
266
+ input,
267
+ (input) => {
268
+ setInput(input);
269
+ scrollTextArea();
270
+ },
271
+ model,
272
+ provider,
273
+ apiKeys
274
+ );
275
  }}
276
  />
277
  );
app/lib/hooks/usePromptEnhancer.ts CHANGED
@@ -12,41 +12,55 @@ export function usePromptEnhancer() {
12
  setPromptEnhanced(false);
13
  };
14
 
15
- const enhancePrompt = async (input: string, setInput: (value: string) => void) => {
 
 
 
 
 
 
16
  setEnhancingPrompt(true);
17
  setPromptEnhanced(false);
18
-
 
 
 
 
 
 
 
 
 
 
19
  const response = await fetch('/api/enhancer', {
20
  method: 'POST',
21
- body: JSON.stringify({
22
- message: input,
23
- }),
24
  });
25
-
26
  const reader = response.body?.getReader();
27
-
28
  const originalInput = input;
29
-
30
  if (reader) {
31
  const decoder = new TextDecoder();
32
-
33
  let _input = '';
34
  let _error;
35
-
36
  try {
37
  setInput('');
38
-
39
  while (true) {
40
  const { value, done } = await reader.read();
41
-
42
  if (done) {
43
  break;
44
  }
45
-
46
  _input += decoder.decode(value);
47
-
48
  logger.trace('Set input', _input);
49
-
50
  setInput(_input);
51
  }
52
  } catch (error) {
@@ -56,10 +70,10 @@ export function usePromptEnhancer() {
56
  if (_error) {
57
  logger.error(_error);
58
  }
59
-
60
  setEnhancingPrompt(false);
61
  setPromptEnhanced(true);
62
-
63
  setTimeout(() => {
64
  setInput(_input);
65
  });
 
12
  setPromptEnhanced(false);
13
  };
14
 
15
+ const enhancePrompt = async (
16
+ input: string,
17
+ setInput: (value: string) => void,
18
+ model: string,
19
+ provider: string,
20
+ apiKeys?: Record<string, string>
21
+ ) => {
22
  setEnhancingPrompt(true);
23
  setPromptEnhanced(false);
24
+
25
+ const requestBody: any = {
26
+ message: input,
27
+ model,
28
+ provider,
29
+ };
30
+
31
+ if (apiKeys) {
32
+ requestBody.apiKeys = apiKeys;
33
+ }
34
+
35
  const response = await fetch('/api/enhancer', {
36
  method: 'POST',
37
+ body: JSON.stringify(requestBody),
 
 
38
  });
39
+
40
  const reader = response.body?.getReader();
41
+
42
  const originalInput = input;
43
+
44
  if (reader) {
45
  const decoder = new TextDecoder();
46
+
47
  let _input = '';
48
  let _error;
49
+
50
  try {
51
  setInput('');
52
+
53
  while (true) {
54
  const { value, done } = await reader.read();
55
+
56
  if (done) {
57
  break;
58
  }
59
+
60
  _input += decoder.decode(value);
61
+
62
  logger.trace('Set input', _input);
63
+
64
  setInput(_input);
65
  }
66
  } catch (error) {
 
70
  if (_error) {
71
  logger.error(_error);
72
  }
73
+
74
  setEnhancingPrompt(false);
75
  setPromptEnhanced(true);
76
+
77
  setTimeout(() => {
78
  setInput(_input);
79
  });
app/routes/api.enhancer.ts CHANGED
@@ -2,6 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
2
  import { StreamingTextResponse, parseStreamPart } from 'ai';
3
  import { streamText } from '~/lib/.server/llm/stream-text';
4
  import { stripIndents } from '~/utils/stripIndent';
 
5
 
6
  const encoder = new TextEncoder();
7
  const decoder = new TextDecoder();
@@ -11,14 +12,34 @@ export async function action(args: ActionFunctionArgs) {
11
  }
12
 
13
  async function enhancerAction({ context, request }: ActionFunctionArgs) {
14
- const { message } = await request.json<{ message: string }>();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  try {
17
  const result = await streamText(
18
  [
19
  {
20
  role: 'user',
21
- content: stripIndents`
22
  I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
23
 
24
  IMPORTANT: Only respond with the improved prompt and nothing else!
@@ -30,28 +51,42 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
30
  },
31
  ],
32
  context.cloudflare.env,
 
 
33
  );
34
 
35
  const transformStream = new TransformStream({
36
  transform(chunk, controller) {
37
- const processedChunk = decoder
38
- .decode(chunk)
39
- .split('\n')
40
- .filter((line) => line !== '')
41
- .map(parseStreamPart)
42
- .map((part) => part.value)
43
- .join('');
44
-
45
- controller.enqueue(encoder.encode(processedChunk));
 
 
 
 
 
46
  },
47
  });
48
 
49
  const transformedStream = result.toAIStream().pipeThrough(transformStream);
50
 
51
  return new StreamingTextResponse(transformedStream);
52
- } catch (error) {
53
  console.log(error);
54
 
 
 
 
 
 
 
 
55
  throw new Response(null, {
56
  status: 500,
57
  statusText: 'Internal Server Error',
 
2
  import { StreamingTextResponse, parseStreamPart } from 'ai';
3
  import { streamText } from '~/lib/.server/llm/stream-text';
4
  import { stripIndents } from '~/utils/stripIndent';
5
+ import type { StreamingOptions } from '~/lib/.server/llm/stream-text';
6
 
7
  const encoder = new TextEncoder();
8
  const decoder = new TextDecoder();
 
12
  }
13
 
14
  async function enhancerAction({ context, request }: ActionFunctionArgs) {
15
+ const { message, model, provider, apiKeys } = await request.json<{
16
+ message: string;
17
+ model: string;
18
+ provider: string;
19
+ apiKeys?: Record<string, string>;
20
+ }>();
21
+
22
+ // Validate 'model' and 'provider' fields
23
+ if (!model || typeof model !== 'string') {
24
+ throw new Response('Invalid or missing model', {
25
+ status: 400,
26
+ statusText: 'Bad Request'
27
+ });
28
+ }
29
+
30
+ if (!provider || typeof provider !== 'string') {
31
+ throw new Response('Invalid or missing provider', {
32
+ status: 400,
33
+ statusText: 'Bad Request'
34
+ });
35
+ }
36
 
37
  try {
38
  const result = await streamText(
39
  [
40
  {
41
  role: 'user',
42
+ content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n` + stripIndents`
43
  I want you to improve the user prompt that is wrapped in \`<original_prompt>\` tags.
44
 
45
  IMPORTANT: Only respond with the improved prompt and nothing else!
 
51
  },
52
  ],
53
  context.cloudflare.env,
54
+ undefined,
55
+ apiKeys
56
  );
57
 
58
  const transformStream = new TransformStream({
59
  transform(chunk, controller) {
60
+ const text = decoder.decode(chunk);
61
+ const lines = text.split('\n').filter(line => line.trim() !== '');
62
+
63
+ for (const line of lines) {
64
+ try {
65
+ const parsed = parseStreamPart(line);
66
+ if (parsed.type === 'text') {
67
+ controller.enqueue(encoder.encode(parsed.value));
68
+ }
69
+ } catch (e) {
70
+ // Skip invalid JSON lines
71
+ console.warn('Failed to parse stream part:', line);
72
+ }
73
+ }
74
  },
75
  });
76
 
77
  const transformedStream = result.toAIStream().pipeThrough(transformStream);
78
 
79
  return new StreamingTextResponse(transformedStream);
80
+ } catch (error: unknown) {
81
  console.log(error);
82
 
83
+ if (error instanceof Error && error.message?.includes('API key')) {
84
+ throw new Response('Invalid or missing API key', {
85
+ status: 401,
86
+ statusText: 'Unauthorized'
87
+ });
88
+ }
89
+
90
  throw new Response(null, {
91
  status: 500,
92
  statusText: 'Internal Server Error',