Spaces:
Running
Running
Commit
·
046eafc
1
Parent(s):
56980e9
fix secrets handling
Browse files
app.py
CHANGED
|
@@ -50,6 +50,8 @@ if "verbose" not in st.session_state:
|
|
| 50 |
st.session_state.verbose = verbose
|
| 51 |
if "max_tokens" not in st.session_state:
|
| 52 |
st.session_state.max_tokens = max_tokens
|
|
|
|
|
|
|
| 53 |
if "temperature" not in st.session_state:
|
| 54 |
st.session_state.temperature = temperature
|
| 55 |
if "next_prompts" not in st.session_state:
|
|
@@ -272,6 +274,8 @@ try:
|
|
| 272 |
num_turns=st.session_state.num_turns,
|
| 273 |
temperature=st.session_state.temperature,
|
| 274 |
max_tokens=st.session_state.max_tokens,
|
|
|
|
|
|
|
| 275 |
verbose=st.session_state.verbose,
|
| 276 |
)
|
| 277 |
chunk = next(st.session_state.generator)
|
|
|
|
| 50 |
st.session_state.verbose = verbose
|
| 51 |
if "max_tokens" not in st.session_state:
|
| 52 |
st.session_state.max_tokens = max_tokens
|
| 53 |
+
if "seed" not in st.session_state:
|
| 54 |
+
st.session_state.seed = 0
|
| 55 |
if "temperature" not in st.session_state:
|
| 56 |
st.session_state.temperature = temperature
|
| 57 |
if "next_prompts" not in st.session_state:
|
|
|
|
| 274 |
num_turns=st.session_state.num_turns,
|
| 275 |
temperature=st.session_state.temperature,
|
| 276 |
max_tokens=st.session_state.max_tokens,
|
| 277 |
+
seed=st.session_state.seed,
|
| 278 |
+
secrets=st.session_state.secrets,
|
| 279 |
verbose=st.session_state.verbose,
|
| 280 |
)
|
| 281 |
chunk = next(st.session_state.generator)
|
cli.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
import time
|
| 3 |
|
| 4 |
from src.open_strawberry import get_defaults, manage_conversation
|
|
@@ -54,6 +55,7 @@ def go_cli():
|
|
| 54 |
temperature=args.temperature,
|
| 55 |
max_tokens=args.max_tokens,
|
| 56 |
seed=args.seed,
|
|
|
|
| 57 |
cli_mode=True)
|
| 58 |
response = ''
|
| 59 |
conversation_history = []
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import os
|
| 3 |
import time
|
| 4 |
|
| 5 |
from src.open_strawberry import get_defaults, manage_conversation
|
|
|
|
| 55 |
temperature=args.temperature,
|
| 56 |
max_tokens=args.max_tokens,
|
| 57 |
seed=args.seed,
|
| 58 |
+
secrets=dict(os.environ),
|
| 59 |
cli_mode=True)
|
| 60 |
response = ''
|
| 61 |
conversation_history = []
|
models.py
CHANGED
|
@@ -25,6 +25,7 @@ def get_anthropic(model: str,
|
|
| 25 |
max_tokens: int = 4096,
|
| 26 |
system: str = '',
|
| 27 |
chat_history: List[Dict] = None,
|
|
|
|
| 28 |
verbose=False) -> \
|
| 29 |
Generator[dict, None, None]:
|
| 30 |
model = model.replace('anthropic:', '')
|
|
@@ -32,7 +33,7 @@ def get_anthropic(model: str,
|
|
| 32 |
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
| 33 |
import anthropic
|
| 34 |
|
| 35 |
-
clawd_key =
|
| 36 |
clawd_client = anthropic.Anthropic(api_key=clawd_key) if clawd_key else None
|
| 37 |
|
| 38 |
if chat_history is None:
|
|
@@ -118,16 +119,16 @@ def get_openai(model: str,
|
|
| 118 |
max_tokens: int = 4096,
|
| 119 |
system: str = '',
|
| 120 |
chat_history: List[Dict] = None,
|
|
|
|
| 121 |
verbose=False) -> Generator[dict, None, None]:
|
| 122 |
-
|
| 123 |
-
if model in ollama:
|
| 124 |
model = model.replace('ollama:', '')
|
| 125 |
-
openai_key =
|
| 126 |
-
openai_base_url =
|
| 127 |
else:
|
| 128 |
model = model.replace('openai:', '')
|
| 129 |
-
openai_key =
|
| 130 |
-
openai_base_url =
|
| 131 |
|
| 132 |
from openai import OpenAI
|
| 133 |
|
|
@@ -206,12 +207,13 @@ def get_google(model: str,
|
|
| 206 |
max_tokens: int = 4096,
|
| 207 |
system: str = '',
|
| 208 |
chat_history: List[Dict] = None,
|
|
|
|
| 209 |
verbose=False) -> Generator[dict, None, None]:
|
| 210 |
model = model.replace('google:', '').replace('gemini:', '')
|
| 211 |
|
| 212 |
import google.generativeai as genai
|
| 213 |
|
| 214 |
-
gemini_key =
|
| 215 |
genai.configure(api_key=gemini_key)
|
| 216 |
# Create the model
|
| 217 |
generation_config = {
|
|
@@ -308,12 +310,13 @@ def get_groq(model: str,
|
|
| 308 |
max_tokens: int = 4096,
|
| 309 |
system: str = '',
|
| 310 |
chat_history: List[Dict] = None,
|
|
|
|
| 311 |
verbose=False) -> Generator[dict, None, None]:
|
| 312 |
model = model.replace('groq:', '')
|
| 313 |
|
| 314 |
from groq import Groq
|
| 315 |
|
| 316 |
-
groq_key =
|
| 317 |
client = Groq(api_key=groq_key)
|
| 318 |
|
| 319 |
if chat_history is None:
|
|
@@ -352,15 +355,16 @@ def get_openai_azure(model: str,
|
|
| 352 |
max_tokens: int = 4096,
|
| 353 |
system: str = '',
|
| 354 |
chat_history: List[Dict] = None,
|
|
|
|
| 355 |
verbose=False) -> Generator[dict, None, None]:
|
| 356 |
model = model.replace('azure:', '').replace('openai_azure:', '')
|
| 357 |
|
| 358 |
from openai import AzureOpenAI
|
| 359 |
|
| 360 |
-
azure_endpoint =
|
| 361 |
-
azure_key =
|
| 362 |
-
azure_deployment =
|
| 363 |
-
azure_api_version =
|
| 364 |
assert azure_endpoint is not None, "Azure OpenAI endpoint not set"
|
| 365 |
assert azure_key is not None, "Azure OpenAI API key not set"
|
| 366 |
assert azure_deployment is not None, "Azure OpenAI deployment not set"
|
|
@@ -420,15 +424,15 @@ def get_model_names(secrets, on_hf_spaces=False):
|
|
| 420 |
else:
|
| 421 |
anthropic_models = []
|
| 422 |
if secrets.get('OPENAI_API_KEY'):
|
| 423 |
-
if
|
| 424 |
-
openai_models = to_list(
|
| 425 |
else:
|
| 426 |
openai_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
|
| 427 |
else:
|
| 428 |
openai_models = []
|
| 429 |
if secrets.get('AZURE_OPENAI_API_KEY'):
|
| 430 |
-
if
|
| 431 |
-
azure_models = to_list(
|
| 432 |
else:
|
| 433 |
azure_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
|
| 434 |
else:
|
|
|
|
| 25 |
max_tokens: int = 4096,
|
| 26 |
system: str = '',
|
| 27 |
chat_history: List[Dict] = None,
|
| 28 |
+
secrets: Dict = {},
|
| 29 |
verbose=False) -> \
|
| 30 |
Generator[dict, None, None]:
|
| 31 |
model = model.replace('anthropic:', '')
|
|
|
|
| 33 |
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
| 34 |
import anthropic
|
| 35 |
|
| 36 |
+
clawd_key = secrets.get('ANTHROPIC_API_KEY')
|
| 37 |
clawd_client = anthropic.Anthropic(api_key=clawd_key) if clawd_key else None
|
| 38 |
|
| 39 |
if chat_history is None:
|
|
|
|
| 119 |
max_tokens: int = 4096,
|
| 120 |
system: str = '',
|
| 121 |
chat_history: List[Dict] = None,
|
| 122 |
+
secrets: Dict = {},
|
| 123 |
verbose=False) -> Generator[dict, None, None]:
|
| 124 |
+
if model.startswith('ollama:'):
|
|
|
|
| 125 |
model = model.replace('ollama:', '')
|
| 126 |
+
openai_key = secrets.get('OLLAMA_OPENAI_API_KEY')
|
| 127 |
+
openai_base_url = secrets.get('OLLAMA_OPENAI_BASE_URL', 'http://localhost:11434/v1/')
|
| 128 |
else:
|
| 129 |
model = model.replace('openai:', '')
|
| 130 |
+
openai_key = secrets.get('OPENAI_API_KEY')
|
| 131 |
+
openai_base_url = secrets.get('OPENAI_BASE_URL', 'https://api.openai.com/v1')
|
| 132 |
|
| 133 |
from openai import OpenAI
|
| 134 |
|
|
|
|
| 207 |
max_tokens: int = 4096,
|
| 208 |
system: str = '',
|
| 209 |
chat_history: List[Dict] = None,
|
| 210 |
+
secrets: Dict = {},
|
| 211 |
verbose=False) -> Generator[dict, None, None]:
|
| 212 |
model = model.replace('google:', '').replace('gemini:', '')
|
| 213 |
|
| 214 |
import google.generativeai as genai
|
| 215 |
|
| 216 |
+
gemini_key = secrets.get("GEMINI_API_KEY")
|
| 217 |
genai.configure(api_key=gemini_key)
|
| 218 |
# Create the model
|
| 219 |
generation_config = {
|
|
|
|
| 310 |
max_tokens: int = 4096,
|
| 311 |
system: str = '',
|
| 312 |
chat_history: List[Dict] = None,
|
| 313 |
+
secrets: Dict = {},
|
| 314 |
verbose=False) -> Generator[dict, None, None]:
|
| 315 |
model = model.replace('groq:', '')
|
| 316 |
|
| 317 |
from groq import Groq
|
| 318 |
|
| 319 |
+
groq_key = secrets.get("GROQ_API_KEY")
|
| 320 |
client = Groq(api_key=groq_key)
|
| 321 |
|
| 322 |
if chat_history is None:
|
|
|
|
| 355 |
max_tokens: int = 4096,
|
| 356 |
system: str = '',
|
| 357 |
chat_history: List[Dict] = None,
|
| 358 |
+
secrets: Dict = {},
|
| 359 |
verbose=False) -> Generator[dict, None, None]:
|
| 360 |
model = model.replace('azure:', '').replace('openai_azure:', '')
|
| 361 |
|
| 362 |
from openai import AzureOpenAI
|
| 363 |
|
| 364 |
+
azure_endpoint = secrets.get("AZURE_OPENAI_ENDPOINT") # e.g. https://project.openai.azure.com
|
| 365 |
+
azure_key = secrets.get("AZURE_OPENAI_API_KEY")
|
| 366 |
+
azure_deployment = secrets.get("AZURE_OPENAI_DEPLOYMENT") # i.e. deployment name with some models deployed
|
| 367 |
+
azure_api_version = secrets.get('AZURE_OPENAI_API_VERSION', '2024-07-01-preview')
|
| 368 |
assert azure_endpoint is not None, "Azure OpenAI endpoint not set"
|
| 369 |
assert azure_key is not None, "Azure OpenAI API key not set"
|
| 370 |
assert azure_deployment is not None, "Azure OpenAI deployment not set"
|
|
|
|
| 424 |
else:
|
| 425 |
anthropic_models = []
|
| 426 |
if secrets.get('OPENAI_API_KEY'):
|
| 427 |
+
if secrets.get('OPENAI_MODEL_NAME'):
|
| 428 |
+
openai_models = to_list(secrets.get('OPENAI_MODEL_NAME'))
|
| 429 |
else:
|
| 430 |
openai_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
|
| 431 |
else:
|
| 432 |
openai_models = []
|
| 433 |
if secrets.get('AZURE_OPENAI_API_KEY'):
|
| 434 |
+
if secrets.get('AZURE_OPENAI_MODEL_NAME'):
|
| 435 |
+
azure_models = to_list(secrets.get('AZURE_OPENAI_MODEL_NAME'))
|
| 436 |
else:
|
| 437 |
azure_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
|
| 438 |
else:
|
open_strawberry.py
CHANGED
|
@@ -290,7 +290,8 @@ def manage_conversation(model: str,
|
|
| 290 |
temperature: float = 0.3,
|
| 291 |
max_tokens: int = 4096,
|
| 292 |
seed: int = 1234,
|
| 293 |
-
|
|
|
|
| 294 |
) -> Generator[Dict, None, list]:
|
| 295 |
if seed == 0:
|
| 296 |
seed = random.randint(0, 1000000)
|
|
@@ -344,7 +345,9 @@ def manage_conversation(model: str,
|
|
| 344 |
thinking_time = time.time()
|
| 345 |
response_text = ''
|
| 346 |
for chunk in get_model_func(model, prompt, system=system, chat_history=chat_history,
|
| 347 |
-
temperature=temperature, max_tokens=max_tokens,
|
|
|
|
|
|
|
| 348 |
if 'text' in chunk and chunk['text']:
|
| 349 |
response_text += chunk['text']
|
| 350 |
yield {"role": "assistant", "content": chunk['text'], "streaming": True, "chat_history": chat_history,
|
|
|
|
| 290 |
temperature: float = 0.3,
|
| 291 |
max_tokens: int = 4096,
|
| 292 |
seed: int = 1234,
|
| 293 |
+
secrets: Dict = {},
|
| 294 |
+
verbose: bool = False,
|
| 295 |
) -> Generator[Dict, None, list]:
|
| 296 |
if seed == 0:
|
| 297 |
seed = random.randint(0, 1000000)
|
|
|
|
| 345 |
thinking_time = time.time()
|
| 346 |
response_text = ''
|
| 347 |
for chunk in get_model_func(model, prompt, system=system, chat_history=chat_history,
|
| 348 |
+
temperature=temperature, max_tokens=max_tokens,
|
| 349 |
+
secrets=secrets,
|
| 350 |
+
verbose=verbose):
|
| 351 |
if 'text' in chunk and chunk['text']:
|
| 352 |
response_text += chunk['text']
|
| 353 |
yield {"role": "assistant", "content": chunk['text'], "streaming": True, "chat_history": chat_history,
|