alxd commited on
Commit
c59b529
Β·
1 Parent(s): 320f15a

added openai and max_tokens

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -1
  2. scoutLLM.py +106 -40
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  gradio==3.40.0
2
  langchain-community==0.0.19
3
  langchain_core==0.1.22
4
- langchain-openai==0.0.5
5
  faiss-cpu==1.7.3
6
  huggingface-hub==0.24.7
7
  google-generativeai==0.3.2
@@ -56,3 +56,5 @@ google-auth-oauthlib
56
  google-auth-httplib2
57
 
58
  pyperclip
 
 
 
1
  gradio==3.40.0
2
  langchain-community==0.0.19
3
  langchain_core==0.1.22
4
+ #langchain-openai==0.0.5
5
  faiss-cpu==1.7.3
6
  huggingface-hub==0.24.7
7
  google-generativeai==0.3.2
 
56
  google-auth-httplib2
57
 
58
  pyperclip
59
+
60
+ openai==0.28
scoutLLM.py CHANGED
@@ -20,7 +20,8 @@ from googleapiclient.discovery import build
20
  import base64
21
  from google.oauth2.credentials import Credentials
22
  from google.auth.transport.requests import Request
23
-
 
24
 
25
  # ------------------------------
26
  # Helper functions and globals
@@ -28,6 +29,7 @@ from google.auth.transport.requests import Request
28
  sheet_data = None
29
  file_name = None
30
  sheet = None
 
31
 
32
  def debug_print(message: str):
33
  print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
@@ -49,40 +51,95 @@ def count_tokens(text: str) -> int:
49
  return len(text.split())
50
  return len(text.split())
51
 
52
- def generate_response(prompt: str, model_name: str, sheet_data: str) -> str:
53
- full_prompt = f"{prompt}\n\nSheet Data:\n{sheet_data}" # Append sheet data to prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- if "Mistral" in model_name:
56
- mistral_api_key = os.getenv("MISTRAL_API_KEY")
57
- if not mistral_api_key:
58
- raise ValueError("MISTRAL_API_KEY environment variable not set.")
59
- mistral_client = Mistral(api_key=mistral_api_key)
60
- response = mistral_client.chat.complete(
61
- model="mistral-small-latest",
62
- messages=[{"role": "user", "content": full_prompt}],
63
- temperature=0.7,
64
- top_p=0.95
65
- )
66
- return response.choices[0].message.content
67
-
68
- elif "Meta-Llama" in model_name:
69
- hf_api_token = os.getenv("HF_API_TOKEN")
70
- if not hf_api_token:
71
- raise ValueError("HF_API_TOKEN environment variable not set.")
72
- client = InferenceClient(token=hf_api_token)
73
- response = client.text_generation(
74
- full_prompt,
75
- model="meta-llama/Meta-Llama-3-8B-Instruct",
76
- temperature=0.7,
77
- top_p=0.95,
78
- max_new_tokens=512
79
- )
80
- return response
81
 
82
- else:
83
- raise ValueError("Invalid model selection. Please choose either 'Mistral-API' or 'Meta-Llama-3'.")
84
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
 
86
  def process_query(prompt: str, model_name: str):
87
  global sheet_data
88
 
@@ -103,9 +160,6 @@ def process_query(prompt: str, model_name: str):
103
  # Return the response along with token counts
104
  return response, f"Input tokens: {input_tokens}", f"Output tokens: {output_tokens}"
105
 
106
- def ui_process_query(prompt, model_name):
107
- return process_query(prompt, model_name)
108
-
109
  # ------------------------------
110
  # Global variables for background jobs
111
  # ------------------------------
@@ -182,10 +236,12 @@ def process_in_background(job_id, func, args):
182
  debug_print(f"Job {job_id} finished processing in background.")
183
 
184
 
185
- def submit_query_async(query, model_choice=None):
186
  """Asynchronous version of submit_query_updated to prevent timeouts."""
187
  global last_job_id
188
  global sheet_data
 
 
189
 
190
  if not query:
191
  return ("Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list())
@@ -197,6 +253,7 @@ def submit_query_async(query, model_choice=None):
197
  if sheet_data is None:
198
  sheet_data = get_sheet_data()
199
 
 
200
  query = f"{query}\n\nSheet Data:\n{sheet_data}" # Append sheet data to prompt
201
 
202
  # Start background thread to process the query
@@ -510,11 +567,21 @@ with gr.Blocks() as app:
510
  with gr.Column(scale=1):
511
  gr.Markdown("### πŸš€ Submit Query")
512
  gr.Markdown("Enter your prompt below and choose a model. Your query will be processed in the background.")
 
 
513
  model_dropdown = gr.Dropdown(
514
- choices=["πŸ‡ΊπŸ‡Έ Remote Meta-Llama-3", "πŸ‡ͺπŸ‡Ί Mistral-API"],
515
- value="πŸ‡ͺπŸ‡Ί Mistral-API", # Default model set to Mistral
 
 
 
 
 
 
516
  label="Select Model"
517
  )
 
 
518
  prompt_input = gr.Textbox(label="Enter your prompt", value=default_prompt, lines=6)
519
  with gr.Row():
520
  auto_refresh_checkbox = gr.Checkbox(
@@ -562,7 +629,6 @@ with gr.Blocks() as app:
562
  def load_file(file, sheet_name):
563
  global sheet_data
564
  global file_name
565
- global sheet
566
  file_name = file
567
  sheet = sheet_name
568
 
@@ -585,7 +651,7 @@ with gr.Blocks() as app:
585
  # When submitting a query asynchronously
586
  submit_button.click(
587
  fn=submit_query_async,
588
- inputs=[prompt_input, model_dropdown],
589
  outputs=[
590
  response_output, token_info,
591
  input_tokens_display, output_tokens_display,
 
20
  import base64
21
  from google.oauth2.credentials import Credentials
22
  from google.auth.transport.requests import Request
23
+ import openai # Correct OpenAI import
24
+ from openai.error import RateLimitError # Import rate limit error handling
25
 
26
  # ------------------------------
27
  # Helper functions and globals
 
29
  sheet_data = None
30
  file_name = None
31
  sheet = None
32
+ slider_max_tokens = None
33
 
34
  def debug_print(message: str):
35
  print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
 
51
  return len(text.split())
52
  return len(text.split())
53
 
54
+ def get_model_max_tokens(model_name: str) -> int:
55
+ """Return the max context length for the selected model."""
56
+ model_token_limits = {
57
+ "GPT-3.5": 16385,
58
+ "GPT-4o": 128000,
59
+ "GPT-4o mini": 128000,
60
+ "Meta-Llama-3": 4096, # Adjust based on actual limits
61
+ "Mistral-API": 128000 # Adjust based on actual limits
62
+ }
63
+ for key in model_token_limits:
64
+ if key in model_name:
65
+ return model_token_limits[key]
66
+ return 4096 # Default safety limit
67
+
68
+ def get_model_max_tokens(model_name: str) -> int:
69
+ """Return the max context length for the selected model."""
70
+ model_token_limits = {
71
+ "GPT-3.5": 16385,
72
+ "GPT-4o": 128000,
73
+ "GPT-4o mini": 128000,
74
+ "Meta-Llama-3": 4096,
75
+ "Mistral-API": 4096
76
+ }
77
+ for key in model_token_limits:
78
+ if key in model_name:
79
+ return model_token_limits[key]
80
+ return 4096 # Default safety limit
81
+
82
+
83
+ def generate_response(prompt: str, model_name: str, sheet_data: str = "") -> str:
84
+ global slider_max_tokens
85
 
86
+ full_prompt = f"{prompt}\n\nSheet Data:\n{sheet_data}" if sheet_data else prompt
87
+ max_context_tokens = get_model_max_tokens(model_name)
88
+ max_tokens = min(slider_max_tokens, max_context_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ try:
91
+ if "Mistral" in model_name:
92
+ mistral_api_key = os.getenv("MISTRAL_API_KEY")
93
+ if not mistral_api_key:
94
+ raise ValueError("MISTRAL_API_KEY environment variable not set.")
95
+ mistral_client = Mistral(api_key=mistral_api_key)
96
+ response = mistral_client.chat.complete(
97
+ model="mistral-small-latest",
98
+ messages=[{"role": "user", "content": full_prompt[:max_tokens]}],
99
+ temperature=0.7,
100
+ top_p=0.95
101
+ )
102
+ return f"[Model: {model_name}]" + response.choices[0].message.content
103
+
104
+ elif "Meta-Llama" in model_name:
105
+ hf_api_token = os.getenv("HF_API_TOKEN")
106
+ if not hf_api_token:
107
+ raise ValueError("HF_API_TOKEN environment variable not set.")
108
+ client = InferenceClient(token=hf_api_token)
109
+ response = client.text_generation(
110
+ full_prompt[:max_tokens],
111
+ model="meta-llama/Meta-Llama-3-8B-Instruct",
112
+ temperature=0.7,
113
+ top_p=0.95,
114
+ max_new_tokens=max_tokens
115
+ )
116
+ return f"[Model: {model_name}]" + response
117
+
118
+ elif any(model in model_name for model in ["GPT-3.5", "GPT-4o", "GPT-4o mini"]):
119
+ model_map = {
120
+ "GPT-3.5": "gpt-3.5-turbo",
121
+ "GPT-4o": "gpt-4o",
122
+ "GPT-4o mini": "gpt-4o-mini"
123
+ }
124
+ model = next((model_map[key] for key in model_map if key in model_name), None)
125
+
126
+ if not model:
127
+ raise ValueError(f"Unsupported OpenAI model: {model_name}")
128
+
129
+ response = openai.ChatCompletion.create(
130
+ model=model,
131
+ messages=[{"role": "user", "content": full_prompt[:max_tokens]}],
132
+ temperature=0.7,
133
+ max_tokens=max_tokens
134
+ )
135
+ return f"[Model: {model_name}]" + response["choices"][0]["message"]["content"]
136
+
137
+ except Exception as e:
138
+ debug_print(f"❌ Error generating response: {str(e)}")
139
+ return f"[Model: {model_name}][Error] {str(e)}"
140
 
141
+
142
+
143
  def process_query(prompt: str, model_name: str):
144
  global sheet_data
145
 
 
160
  # Return the response along with token counts
161
  return response, f"Input tokens: {input_tokens}", f"Output tokens: {output_tokens}"
162
 
 
 
 
163
  # ------------------------------
164
  # Global variables for background jobs
165
  # ------------------------------
 
236
  debug_print(f"Job {job_id} finished processing in background.")
237
 
238
 
239
+ def submit_query_async(query, model_choice, max_tokens_slider):
240
  """Asynchronous version of submit_query_updated to prevent timeouts."""
241
  global last_job_id
242
  global sheet_data
243
+ global slider_max_tokens
244
+ slider_max_tokens = max_tokens_slider
245
 
246
  if not query:
247
  return ("Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list())
 
253
  if sheet_data is None:
254
  sheet_data = get_sheet_data()
255
 
256
+
257
  query = f"{query}\n\nSheet Data:\n{sheet_data}" # Append sheet data to prompt
258
 
259
  # Start background thread to process the query
 
567
  with gr.Column(scale=1):
568
  gr.Markdown("### πŸš€ Submit Query")
569
  gr.Markdown("Enter your prompt below and choose a model. Your query will be processed in the background.")
570
+ # Update the model dropdown in the Gradio UI
571
+ # Update the model dropdown in the Gradio UI
572
  model_dropdown = gr.Dropdown(
573
+ choices=[
574
+ "πŸ‡ΊπŸ‡Έ GPT-3.5",
575
+ "πŸ‡ΊπŸ‡Έ GPT-4o",
576
+ "πŸ‡ΊπŸ‡Έ GPT-4o mini",
577
+ "πŸ‡ΊπŸ‡Έ Remote Meta-Llama-3",
578
+ "πŸ‡ͺπŸ‡Ί Mistral-API",
579
+ ],
580
+ value="πŸ‡ΊπŸ‡Έ GPT-4o mini", # Default model set to Mistral
581
  label="Select Model"
582
  )
583
+ max_tokens_slider = gr.Slider(minimum=50, maximum=4096, value=512, label="πŸ”’ Max Tokens", step=50)
584
+
585
  prompt_input = gr.Textbox(label="Enter your prompt", value=default_prompt, lines=6)
586
  with gr.Row():
587
  auto_refresh_checkbox = gr.Checkbox(
 
629
  def load_file(file, sheet_name):
630
  global sheet_data
631
  global file_name
 
632
  file_name = file
633
  sheet = sheet_name
634
 
 
651
  # When submitting a query asynchronously
652
  submit_button.click(
653
  fn=submit_query_async,
654
+ inputs=[prompt_input, model_dropdown, max_tokens_slider],
655
  outputs=[
656
  response_output, token_info,
657
  input_tokens_display, output_tokens_display,