File size: 4,417 Bytes
370a5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
License:
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
In no event shall the authors or copyright holders be liable
for any claim, damages or other liability, whether in an action of contract,otherwise,
arising from, out of or in connection with the software or the use or 
other dealings in the software.

Copyright (c) 2024 pi19404. All rights reserved.

Authors:
    pi19404 <[email protected]>
"""


"""
Gradio Interface for Shield Gemma LLM Evaluator

This module provides a Gradio interface to interact with the Shield Gemma LLM Evaluator.
It allows users to input JSON data and select various options to evaluate the content
for policy violations.

Functions:
    my_inference_function: The main inference function to process input data and return results.
"""

import gradio as gr
from gradio_client import Client
import torch
import json
import threading
import os

API_TOKEN=os.getenv("API_TOKEN")

lock = threading.Lock()
client = Client("pi19404/ai-worker",hf_token=API_TOKEN)

def my_inference_function(input_data, output_data,mode, max_length, max_new_tokens, model_size):
    """
    The main inference function to process input data and return results.
    
    Args:
        input_data (str or dict): The input data in JSON format.
        mode (str): The mode of operation ("scoring" or "generative").
        max_length (int): The maximum length of the input prompt.
        max_new_tokens (int): The maximum number of new tokens to generate.
        model_size (str): The size of the model to be used.
    
    Returns:
        str: The output data in JSON format.
    """
    with lock:
        try:
        

            
            result = client.predict(
                    input_data=input_data,
                    output_data=output_data,
                    mode=mode,
                    max_length=max_length,
                    max_new_tokens=max_new_tokens,
                    model_size=model_size,
                    api_name="/my_inference_function"
            )
            print(result)
            print("entering return",result)
            return result  # Pretty-print the JSON
        except json.JSONDecodeError:
            return json.dumps({"error": "Invalid JSON input"})
        except KeyError:
            return json.dumps({"error": "Missing 'input' key in JSON"})
        except ValueError as e:
            return json.dumps({"error": str(e)})

with gr.Blocks() as demo:
    gr.Markdown("## LLM Safety Evaluation")
    
    with gr.Tab("ShieldGemma2"):
        input_text = gr.Textbox(label="Input Text")
        output_text = gr.Textbox(
            label="Response Text",
            lines=5,
            max_lines=10,
            show_copy_button=True,
            elem_classes=["wrap-text"]
        )
        mode_input = gr.Dropdown(choices=["scoring", "generative"], label="Prediction Mode")
        max_length_input = gr.Number(label="Max Length", value=150)
        max_new_tokens_input = gr.Number(label="Max New Tokens", value=1024)
        model_size_input = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size")
        response_text = gr.Textbox(
            label="Output Text",
            lines=10,
            max_lines=20,
            show_copy_button=True,
            elem_classes=["wrap-text"]
        )
        text_button = gr.Button("Submit")
        text_button.click(fn=my_inference_function, inputs=[input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text)
    
    # with gr.Tab("API Input"):
    #     api_input = gr.JSON(label="Input JSON")
    #     mode_input_api = gr.Dropdown(choices=["scoring", "generative"], label="Mode")
    #     max_length_input_api = gr.Number(label="Max Length", value=150)
    #     max_new_tokens_input_api = gr.Number(label="Max New Tokens", value=None)
    #     model_size_input_api = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size")
    #     api_output = gr.JSON(label="Output JSON")
    #     api_button = gr.Button("Submit")
    #     api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output)

demo.launch(share=True)