Spaces:
Sleeping
Sleeping
| import json | |
| import requests | |
| PREREQUISITE_PROMPT = """\ | |
| あなたは採点者です。 | |
| 問題, 採点基準, 回答 が与えられます。 | |
| 回答を1,2,3,4,5の5段階で採点し、数字のみを出力してください。 | |
| # 採点基準 | |
| 基本的な採点基準 | |
| - 1点: 誤っている、 指示に従えていない | |
| - 2点: 誤っているが、方向性は合っている | |
| - 3点: 部分的に誤っている、 部分的に合っている | |
| - 4点: 合っている | |
| - 5点: 役に立つ | |
| 基本的な減点項目 | |
| - 不自然な日本語: -1点 | |
| - 部分的に事実と異なる内容を述べている: -1点 | |
| """ | |
| def evaluation_prompt( | |
| input: str, output: str, eval_aspect: str | None, target: str | None | |
| ) -> str: | |
| return f"""\ | |
| 回答を1,2,3,4,5の5段階で採点し、数字のみを出力してください。 | |
| # 問題: {input} | |
| {f"# 正解例: {target}" if target is not None else ""} | |
| {f"# 採点基準: {eval_aspect}" if eval_aspect is not None else ""} | |
| # 回答: {output} | |
| """ | |
| # GradioからのGemini SDKを用いた通信がいつまでも終わらないため、REST APIを利用する | |
| def evaluate(results: list[dict], api_key: str, batch_size: int = 10) -> list[dict]: | |
| url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}" | |
| headers = {"Content-Type": "application/json"} | |
| evaluations = [] | |
| for i in range(0, len(results), batch_size): | |
| batch_results = results[i : i + batch_size] | |
| prompts = [ | |
| evaluation_prompt( | |
| result["input"], | |
| result["output"], | |
| result.get("eval_aspect"), | |
| result.get("target"), | |
| ) | |
| for result in batch_results | |
| ] | |
| data = { | |
| "contents": [{"parts": [{"text": "\n".join(prompts)}]}], | |
| "generationConfig": { | |
| "response_mime_type": "application/json", | |
| "response_schema": {"type": "ARRAY", "items": {"type": "NUMBER"}}, | |
| }, | |
| } | |
| response = requests.post(url, headers=headers, data=json.dumps(data)) | |
| if response.status_code == 200: | |
| response_data = response.json() | |
| # Parse the response_data to extract the scores | |
| scores = json.loads( | |
| response_data["candidates"][0]["content"]["parts"][0]["text"] | |
| ) | |
| else: | |
| raise Exception( | |
| f"API request failed with status code {response.status_code}: {response.text}" | |
| ) | |
| for result, score in zip(batch_results, scores): | |
| evaluations.append( | |
| { | |
| "input": result["input"], | |
| "output": result["output"], | |
| "eval_aspect": result.get("eval_aspect"), | |
| "target": result.get("target"), | |
| "score": score, | |
| } | |
| ) | |
| return evaluations | |
| def report(tasks: list[dict]) -> str: | |
| return ( | |
| """\ | |
| <!DOCTYPE html> | |
| <html lang="ja"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>レポート</title> | |
| <style> | |
| body { | |
| background-color: #f8f9fa; | |
| } | |
| .container { | |
| width: 80%; /* 可変幅 */ | |
| margin: 20px auto; | |
| background-color: #ffffff; | |
| border-radius: 8px; | |
| } | |
| .divider { | |
| position: relative; | |
| padding: 16px 0; | |
| align-items: center; | |
| justify-content: center; | |
| } | |
| .divider .line { | |
| height: 1px; | |
| background-color: #ddd; | |
| } | |
| .divider .taskName { | |
| position: absolute; | |
| margin: -8px; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| padding: 0 10px; | |
| font-size: 14px; | |
| font-weight: 900; | |
| text-align: center; | |
| border: 1px solid #ddd; | |
| border-radius: 9999px; | |
| background-color: #ffffff; | |
| white-space: nowrap; | |
| } | |
| .message { | |
| padding: 8px; | |
| } | |
| .content { | |
| font-size: 14px; | |
| font-weight: 400; | |
| } | |
| .from { | |
| font-size: 14px; | |
| font-weight: 900; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container" id="container"></div> | |
| <script> | |
| const messages = """ | |
| + json.dumps(tasks) | |
| + """; | |
| // taskName: str | |
| const createDivider = (taskName) => { | |
| const divider = document.createElement('div'); | |
| divider.classList.add('divider'); | |
| const line = document.createElement('div'); | |
| line.classList.add('line'); | |
| const taskNameLabel = document.createElement('div'); | |
| taskNameLabel.classList.add('taskName'); | |
| taskNameLabel.textContent = taskName; | |
| divider.appendChild(line); | |
| divider.appendChild(taskNameLabel); | |
| return divider; | |
| }; | |
| // task: HTMLDivElement, from: 'input' | 'output' | str, text: string | |
| // return: HTMLDivElement | |
| const createMessage = (text, name) => { | |
| const message = document.createElement('div'); | |
| message.classList.add('message'); | |
| const from = document.createElement('div'); | |
| from.classList.add('from'); | |
| from.textContent = name; | |
| const content = document.createElement('div'); | |
| content.classList.add('content'); | |
| content.textContent = text; | |
| message.appendChild(from); | |
| message.appendChild(content); | |
| return message; | |
| }; | |
| const container = document.getElementById('container'); | |
| messages.forEach((message, i) => { | |
| const task = document.createElement('div'); | |
| task.classList.add('task'); | |
| task.appendChild(createDivider(message.task_id ? "Task ID: " + message.task_id : "Task Index" + i)); | |
| task.appendChild(createMessage(message.input, 'input')); | |
| task.appendChild(createMessage(message.output, 'output')); | |
| container.appendChild(task); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| ) | |