David Chu commited on
Commit
2874a2b
·
unverified ·
1 Parent(s): 7ccab60

Use chat moderation endpoint

Browse files
Files changed (1) hide show
  1. app.py +44 -71
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import asyncio
2
  import json
3
  import os
4
 
@@ -41,75 +40,49 @@ alinia_guardrail = httpx.AsyncClient(
41
  mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
42
 
43
 
44
- async def get_mistral_moderation(user_content, assistant_content):
45
- def sync_moderation(inputs):
46
- return mistral_client.classifiers.moderate_chat(
47
- model="mistral-moderation-latest", inputs=inputs
48
- )
49
-
50
- inputs_assistant = [
51
- {"role": "user", "content": user_content},
52
- {"role": "assistant", "content": assistant_content},
53
- ]
54
-
55
- inputs_user = [{"role": "user", "content": user_content}]
56
-
57
  try:
58
- response_full, response_user_only = await asyncio.gather(
59
- asyncio.to_thread(sync_moderation, inputs_assistant),
60
- asyncio.to_thread(sync_moderation, inputs_user),
61
  )
62
-
63
- return {
64
- "full_interaction": response_full.results,
65
- "user_only": response_user_only.results,
66
- }
 
 
 
67
  except Exception as e:
68
- print(f"Mistral moderation error: {e!s}")
69
- return {"error": str(e)}
 
 
70
 
71
 
72
- async def check_safety(message: str, metadata: dict) -> dict:
73
  try:
74
- user_content = (
75
- metadata["messages"][-2]["content"]
76
- if len(metadata.get("messages", [])) >= 2
77
- else ""
78
- )
79
- # Mistral moderation results
80
- try:
81
- mistral_response = await get_mistral_moderation(user_content, message)
82
- mistral_results = mistral_response
83
- except Exception as e:
84
- print(f"[Mistral moderation error]: {e!s}")
85
- mistral_results = None
86
-
87
  resp = await alinia_guardrail.post(
88
- "/moderations/",
89
  json={
90
- "input": message,
91
  "metadata": {
92
  "app": "slmdr",
93
  "app_environment": "stable",
94
- "chat_model_id": MODEL_ARGS["model"],
95
- "mistral_results": json.loads(
96
- json.dumps(mistral_results, default=str)
97
- ),
98
- }
99
- | metadata,
100
  "detection_config": {"safety": True},
101
  },
102
  )
103
  resp.raise_for_status()
104
- result = resp.json()
105
- selected_results = result["result"]["category_details"]["safety"]
106
- selected_results = {
107
- key.title(): value for key, value in selected_results.items()
108
- }
109
- return selected_results
110
  except Exception as e:
111
- print(f"Safety check error: {e!s}")
112
- return {"Error": str(e)}
 
 
113
 
114
 
115
  def user(message, chat_history):
@@ -118,16 +91,15 @@ def user(message, chat_history):
118
 
119
 
120
  async def assistant(chat_history, system_prompt, model_name):
121
- try:
122
- client = CHAT_CLIENTS[model_name]
123
-
124
- if chat_history[0]["role"] != "system":
125
- chat_history = [{"role": "system", "content": system_prompt}, *chat_history]
126
 
127
- chat_history.append({"role": "assistant", "content": ""})
 
128
 
129
- print(chat_history)
130
 
 
131
  stream = await client.chat.completions.create(
132
  **MODEL_ARGS, messages=chat_history
133
  )
@@ -135,17 +107,18 @@ async def assistant(chat_history, system_prompt, model_name):
135
  async for chunk in stream:
136
  if chunk.choices[0].delta.content is not None:
137
  chat_history[-1]["content"] += chunk.choices[0].delta.content
138
- yield chat_history, ""
139
-
140
- # metadata = {
141
- # "messages": chat_history + [{"role": "assistant", "content": full_response}]
142
- # }
143
- safety_results = await check_safety(chat_history[-1]["content"], {})
144
- yield chat_history, safety_results
145
 
 
 
 
 
 
 
146
  except Exception as e:
147
- chat_history.append({"role": "assistant", "content": f"Error occurred: {e!s}"})
148
- yield chat_history, ""
 
149
 
150
 
151
  with gr.Blocks(title="🦎 Salamandra & Oranguten") as demo:
@@ -187,12 +160,12 @@ with gr.Blocks(title="🦎 Salamandra & Oranguten") as demo:
187
  outputs=[chatbot, response_safety],
188
  )
189
 
 
190
  system_prompt_selector.change(
191
  lambda example_name: EXAMPLE_PROMPTS[example_name],
192
  inputs=system_prompt_selector,
193
  outputs=system_prompt,
194
  )
195
-
196
  system_prompt.change(lambda: [], outputs=chatbot)
197
 
198
  new_chat.click(
 
 
1
  import json
2
  import os
3
 
 
40
  mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
41
 
42
 
43
+ async def mistral_moderate(chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
+ response_full = await mistral_client.classifiers.moderate_chat_async(
46
+ model="mistral-moderation-latest", inputs=chat_history[-2:]
 
47
  )
48
+ response_user_only = await mistral_client.classifiers.moderate_chat_async(
49
+ model="mistral-moderation-latest", inputs=chat_history[-1:]
50
+ )
51
+ result = {}
52
+ if response_full.results:
53
+ result["full_interaction"] = response_full.results[0].category_scores
54
+ if response_user_only.results:
55
+ result["user_only"] = response_user_only.results[0].category_scores
56
  except Exception as e:
57
+ message = f"Mistral moderate failed: {e!s}"
58
+ print(message)
59
+ result = message
60
+ return result
61
 
62
 
63
+ async def alinia_moderate(chat_history, chat_model_id, mistral_moderation) -> dict:
64
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  resp = await alinia_guardrail.post(
66
+ "/chat/moderations",
67
  json={
68
+ "messages": chat_history[-2:],
69
  "metadata": {
70
  "app": "slmdr",
71
  "app_environment": "stable",
72
+ "chat_model_id": chat_model_id,
73
+ "mistral_results": json.dumps(mistral_moderation, default=str),
74
+ },
 
 
 
75
  "detection_config": {"safety": True},
76
  },
77
  )
78
  resp.raise_for_status()
79
+ moderation = resp.json()["result"]["category_details"]["safety"]
80
+ result = {key.title(): value for key, value in moderation.items()}
 
 
 
 
81
  except Exception as e:
82
+ message = f"Alinia moderate failed: {e!s}"
83
+ print(message)
84
+ result = {"Error": 1}
85
+ return result
86
 
87
 
88
  def user(message, chat_history):
 
91
 
92
 
93
  async def assistant(chat_history, system_prompt, model_name):
94
+ client = CHAT_CLIENTS[model_name]
95
+ alinia_moderation = {}
 
 
 
96
 
97
+ if chat_history[0]["role"] != "system":
98
+ chat_history = [{"role": "system", "content": system_prompt}, *chat_history]
99
 
100
+ chat_history.append({"role": "assistant", "content": ""})
101
 
102
+ try:
103
  stream = await client.chat.completions.create(
104
  **MODEL_ARGS, messages=chat_history
105
  )
 
107
  async for chunk in stream:
108
  if chunk.choices[0].delta.content is not None:
109
  chat_history[-1]["content"] += chunk.choices[0].delta.content
110
+ yield chat_history, alinia_moderation
 
 
 
 
 
 
111
 
112
+ mistral_moderation = await mistral_moderate(chat_history)
113
+ alinia_moderation = await alinia_moderate(
114
+ chat_history,
115
+ chat_model_id=model_name,
116
+ mistral_moderation=mistral_moderation,
117
+ )
118
  except Exception as e:
119
+ chat_history[-1]["content"] = f"Error occurred: {e!s}"
120
+
121
+ yield chat_history, alinia_moderation
122
 
123
 
124
  with gr.Blocks(title="🦎 Salamandra & Oranguten") as demo:
 
160
  outputs=[chatbot, response_safety],
161
  )
162
 
163
+ model_selector.change(lambda: [], outputs=chatbot)
164
  system_prompt_selector.change(
165
  lambda example_name: EXAMPLE_PROMPTS[example_name],
166
  inputs=system_prompt_selector,
167
  outputs=system_prompt,
168
  )
 
169
  system_prompt.change(lambda: [], outputs=chatbot)
170
 
171
  new_chat.click(