mgbam commited on
Commit
228cd9c
·
verified ·
1 Parent(s): 8a19371

Update medgemma.py

Browse files
Files changed (1) hide show
  1. medgemma.py +39 -29
medgemma.py CHANGED
@@ -12,61 +12,71 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- # MedGemma endpoint
16
  import requests
 
17
  from auth import create_credentials, get_access_token_refresh_if_needed
18
- import os
19
  from cache import cache
20
 
21
- _endpoint_url = os.environ.get('GCP_MEDGEMMA_ENDPOINT')
22
 
23
  # Create credentials
24
- secret_key_json = os.environ.get('GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY')
25
  medgemma_credentials = create_credentials(secret_key_json)
26
 
27
  # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations.endpoints.chat/completions
28
  @cache.memoize()
29
  def medgemma_get_text_response(
30
- messages: list,
31
  temperature: float = 0.1,
32
  max_tokens: int = 4096,
33
  stream: bool = False,
34
- top_p: float | None = None,
35
- seed: int | None = None,
36
- stop: list[str] | str | None = None,
37
- frequency_penalty: float | None = None,
38
- presence_penalty: float | None = None,
39
- model: str="tgi"
40
- ):
41
  """
42
- Makes a chat completion request to the configured LLM API (OpenAI-compatible).
43
  """
44
  headers = {
45
  "Authorization": f"Bearer {get_access_token_refresh_if_needed(medgemma_credentials)}",
46
  "Content-Type": "application/json",
47
  }
48
 
49
- # Based on the openai format
50
  payload = {
51
- "messages": messages,
52
- "max_tokens": max_tokens
53
- }
54
-
55
-
56
- if temperature is not None: payload["temperature"] = temperature
57
- if top_p is not None: payload["top_p"] = top_p
58
- if seed is not None: payload["seed"] = seed
59
- if stop is not None: payload["stop"] = stop
60
- if frequency_penalty is not None: payload["frequency_penalty"] = frequency_penalty
61
- if presence_penalty is not None: payload["presence_penalty"] = presence_penalty
62
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- response = requests.post(_endpoint_url, headers=headers, json=payload, stream=stream, timeout=60)
 
 
 
 
 
 
65
  try:
66
  response.raise_for_status()
67
  return response.json()["choices"][0]["message"]["content"]
68
  except requests.exceptions.JSONDecodeError:
69
- # Log the problematic response for easier debugging in the future.
70
- print(f"Error: Failed to decode JSON from MedGemma. Status: {response.status_code}, Response: {response.text}")
71
- # Re-raise the exception so the caller knows something went wrong.
 
72
  raise
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import os
16
  import requests
17
+ from typing import Optional, Union, List
18
  from auth import create_credentials, get_access_token_refresh_if_needed
 
19
  from cache import cache
20
 
21
+ _endpoint_url = os.environ.get("GCP_MEDGEMMA_ENDPOINT")
22
 
23
  # Create credentials
24
+ secret_key_json = os.environ.get("GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY")
25
  medgemma_credentials = create_credentials(secret_key_json)
26
 
27
  # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations.endpoints.chat/completions
28
  @cache.memoize()
29
  def medgemma_get_text_response(
30
+ messages: List[dict],
31
  temperature: float = 0.1,
32
  max_tokens: int = 4096,
33
  stream: bool = False,
34
+ top_p: Optional[float] = None,
35
+ seed: Optional[int] = None,
36
+ stop: Optional[Union[List[str], str]] = None,
37
+ frequency_penalty: Optional[float] = None,
38
+ presence_penalty: Optional[float] = None,
39
+ model: str = "tgi"
40
+ ) -> str:
41
  """
42
+ Makes a chat completion request to the configured LLM API (Vertex AI/OpenAI-compatible).
43
  """
44
  headers = {
45
  "Authorization": f"Bearer {get_access_token_refresh_if_needed(medgemma_credentials)}",
46
  "Content-Type": "application/json",
47
  }
48
 
 
49
  payload = {
50
+ "messages": messages,
51
+ "max_tokens": max_tokens
52
+ }
 
 
 
 
 
 
 
 
53
 
54
+ if temperature is not None:
55
+ payload["temperature"] = temperature
56
+ if top_p is not None:
57
+ payload["top_p"] = top_p
58
+ if seed is not None:
59
+ payload["seed"] = seed
60
+ if stop is not None:
61
+ payload["stop"] = stop
62
+ if frequency_penalty is not None:
63
+ payload["frequency_penalty"] = frequency_penalty
64
+ if presence_penalty is not None:
65
+ payload["presence_penalty"] = presence_penalty
66
 
67
+ response = requests.post(
68
+ _endpoint_url,
69
+ headers=headers,
70
+ json=payload,
71
+ stream=stream,
72
+ timeout=60
73
+ )
74
  try:
75
  response.raise_for_status()
76
  return response.json()["choices"][0]["message"]["content"]
77
  except requests.exceptions.JSONDecodeError:
78
+ print(
79
+ f"Error: Failed to decode JSON from MedGemma. "
80
+ f"Status: {response.status_code}, Response: {response.text}"
81
+ )
82
  raise