Mina Parham
commited on
Commit
Β·
38366a7
1
Parent(s):
55aae7a
Initial commit π
Browse files- app.py +136 -63
- requirements.txt +54 -1
app.py
CHANGED
@@ -1,64 +1,137 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
"""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import requests
|
3 |
+
import asyncio
|
4 |
+
import aiohttp
|
5 |
+
|
6 |
+
# Models setup
|
7 |
+
models = {
|
8 |
+
"Mistral-7B-Instruct": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
|
9 |
+
"DeepSeek-7B-Instruct": "https://api-inference.huggingface.co/models/deepseek-ai/deepseek-llm-7b-instruct",
|
10 |
+
"Qwen-7B-Chat": "https://api-inference.huggingface.co/models/Qwen/Qwen-7B-Chat"
|
11 |
+
}
|
12 |
+
|
13 |
+
# Judge model (Mixtral-8x7B)
|
14 |
+
judge_model_url = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
15 |
+
|
16 |
+
# Your Hugging Face API Token
|
17 |
+
API_TOKEN = "YOUR_HUGGINGFACE_API_TOKEN"
|
18 |
+
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
|
19 |
+
|
20 |
+
# Async function to call a model
|
21 |
+
async def query_model(session, model_name, question):
|
22 |
+
payload = {"inputs": question, "parameters": {"max_new_tokens": 300}}
|
23 |
+
try:
|
24 |
+
async with session.post(models[model_name], headers=HEADERS, json=payload, timeout=60) as response:
|
25 |
+
result = await response.json()
|
26 |
+
if isinstance(result, list) and len(result) > 0:
|
27 |
+
return model_name, result[0]["generated_text"]
|
28 |
+
elif isinstance(result, dict) and "generated_text" in result:
|
29 |
+
return model_name, result["generated_text"]
|
30 |
+
else:
|
31 |
+
return model_name, str(result)
|
32 |
+
except Exception as e:
|
33 |
+
return model_name, f"Error: {str(e)}"
|
34 |
+
|
35 |
+
# Async function to call all models
|
36 |
+
async def gather_model_answers(question):
|
37 |
+
async with aiohttp.ClientSession() as session:
|
38 |
+
tasks = [query_model(session, model_name, question) for model_name in models]
|
39 |
+
results = await asyncio.gather(*tasks)
|
40 |
+
return dict(results)
|
41 |
+
|
42 |
+
# Function to ask the judge
|
43 |
+
def judge_best_answer(question, answers):
|
44 |
+
# Format the prompt for the Judge
|
45 |
+
judge_prompt = f"""
|
46 |
+
You are a wise AI Judge. A user asked the following question:
|
47 |
+
|
48 |
+
Question:
|
49 |
+
{question}
|
50 |
+
|
51 |
+
Here are the answers provided by different models:
|
52 |
+
|
53 |
+
Answer 1 (Mistral-7B-Instruct):
|
54 |
+
{answers['Mistral-7B-Instruct']}
|
55 |
+
|
56 |
+
Answer 2 (DeepSeek-7B-Instruct):
|
57 |
+
{answers['DeepSeek-7B-Instruct']}
|
58 |
+
|
59 |
+
Answer 3 (Qwen-7B-Chat):
|
60 |
+
{answers['Qwen-7B-Chat']}
|
61 |
+
|
62 |
+
Please carefully read all three answers. Your job:
|
63 |
+
- Pick the best answer (Answer 1, Answer 2, or Answer 3).
|
64 |
+
- Explain briefly why you chose that answer.
|
65 |
+
|
66 |
+
Respond in this JSON format:
|
67 |
+
{{"best_answer": "Answer X", "reason": "Your reasoning here"}}
|
68 |
+
""".strip()
|
69 |
+
|
70 |
+
payload = {"inputs": judge_prompt, "parameters": {"max_new_tokens": 300}}
|
71 |
+
response = requests.post(judge_model_url, headers=HEADERS, json=payload)
|
72 |
+
|
73 |
+
if response.status_code == 200:
|
74 |
+
result = response.json()
|
75 |
+
# Try to extract JSON from response
|
76 |
+
import json
|
77 |
+
import re
|
78 |
+
|
79 |
+
# Attempt to extract JSON block
|
80 |
+
match = re.search(r"\{.*\}", str(result))
|
81 |
+
if match:
|
82 |
+
try:
|
83 |
+
judge_decision = json.loads(match.group(0))
|
84 |
+
return judge_decision
|
85 |
+
except json.JSONDecodeError:
|
86 |
+
return {"best_answer": "Unknown", "reason": "Failed to parse judge output."}
|
87 |
+
else:
|
88 |
+
return {"best_answer": "Unknown", "reason": "No JSON found in judge output."}
|
89 |
+
else:
|
90 |
+
return {"best_answer": "Unknown", "reason": f"Judge API error: {response.status_code}"}
|
91 |
+
|
92 |
+
# Final app logic
|
93 |
+
def multi_model_qa(question):
|
94 |
+
answers = asyncio.run(gather_model_answers(question))
|
95 |
+
judge_decision = judge_best_answer(question, answers)
|
96 |
+
|
97 |
+
# Find the selected best answer
|
98 |
+
best_answer_key = judge_decision.get("best_answer", "")
|
99 |
+
best_answer_text = ""
|
100 |
+
if "1" in best_answer_key:
|
101 |
+
best_answer_text = answers["Mistral-7B-Instruct"]
|
102 |
+
elif "2" in best_answer_key:
|
103 |
+
best_answer_text = answers["DeepSeek-7B-Instruct"]
|
104 |
+
elif "3" in best_answer_key:
|
105 |
+
best_answer_text = answers["Qwen-7B-Chat"]
|
106 |
+
else:
|
107 |
+
best_answer_text = "Could not determine best answer."
|
108 |
+
|
109 |
+
return (
|
110 |
+
answers["Mistral-7B-Instruct"],
|
111 |
+
answers["DeepSeek-7B-Instruct"],
|
112 |
+
answers["Qwen-7B-Chat"],
|
113 |
+
best_answer_text,
|
114 |
+
judge_decision.get("reason", "No reasoning provided.")
|
115 |
+
)
|
116 |
+
|
117 |
+
# Gradio UI
|
118 |
+
with gr.Blocks() as demo:
|
119 |
+
gr.Markdown("# π§ Multi-Model Answer Aggregator")
|
120 |
+
gr.Markdown("Ask any question. The system queries multiple models and the AI Judge selects the best answer.")
|
121 |
+
|
122 |
+
question_input = gr.Textbox(label="Enter your question", placeholder="Ask me anything...", lines=2)
|
123 |
+
submit_btn = gr.Button("Get Best Answer")
|
124 |
+
|
125 |
+
mistral_output = gr.Textbox(label="Mistral-7B-Instruct Answer")
|
126 |
+
deepseek_output = gr.Textbox(label="DeepSeek-7B-Instruct Answer")
|
127 |
+
qwen_output = gr.Textbox(label="Qwen-7B-Chat Answer")
|
128 |
+
best_answer_output = gr.Textbox(label="π Best Answer Selected")
|
129 |
+
judge_reasoning_output = gr.Textbox(label="βοΈ Judge's Reasoning")
|
130 |
+
|
131 |
+
submit_btn.click(
|
132 |
+
multi_model_qa,
|
133 |
+
inputs=[question_input],
|
134 |
+
outputs=[mistral_output, deepseek_output, qwen_output, best_answer_output, judge_reasoning_output]
|
135 |
+
)
|
136 |
+
|
137 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -1 +1,54 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==24.1.0
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.9.0
|
4 |
+
audioop-lts==0.2.1
|
5 |
+
certifi==2025.4.26
|
6 |
+
charset-normalizer==3.4.1
|
7 |
+
click==8.1.8
|
8 |
+
fastapi==0.115.12
|
9 |
+
ffmpy==0.5.0
|
10 |
+
filelock==3.18.0
|
11 |
+
fsspec==2025.3.2
|
12 |
+
gradio==5.27.0
|
13 |
+
gradio_client==1.9.0
|
14 |
+
groovy==0.1.2
|
15 |
+
h11==0.16.0
|
16 |
+
httpcore==1.0.9
|
17 |
+
httpx==0.28.1
|
18 |
+
huggingface-hub==0.30.2
|
19 |
+
idna==3.10
|
20 |
+
Jinja2==3.1.6
|
21 |
+
markdown-it-py==3.0.0
|
22 |
+
MarkupSafe==3.0.2
|
23 |
+
mdurl==0.1.2
|
24 |
+
numpy==2.2.5
|
25 |
+
orjson==3.10.16
|
26 |
+
packaging==25.0
|
27 |
+
pandas==2.2.3
|
28 |
+
pillow==11.2.1
|
29 |
+
pydantic==2.11.3
|
30 |
+
pydantic_core==2.33.1
|
31 |
+
pydub==0.25.1
|
32 |
+
Pygments==2.19.1
|
33 |
+
python-dateutil==2.9.0.post0
|
34 |
+
python-multipart==0.0.20
|
35 |
+
pytz==2025.2
|
36 |
+
PyYAML==6.0.2
|
37 |
+
requests==2.32.3
|
38 |
+
rich==14.0.0
|
39 |
+
ruff==0.11.7
|
40 |
+
safehttpx==0.1.6
|
41 |
+
semantic-version==2.10.0
|
42 |
+
shellingham==1.5.4
|
43 |
+
six==1.17.0
|
44 |
+
sniffio==1.3.1
|
45 |
+
starlette==0.46.2
|
46 |
+
tomlkit==0.13.2
|
47 |
+
tqdm==4.67.1
|
48 |
+
typer==0.15.2
|
49 |
+
typing-inspection==0.4.0
|
50 |
+
typing_extensions==4.13.2
|
51 |
+
tzdata==2025.2
|
52 |
+
urllib3==2.4.0
|
53 |
+
uvicorn==0.34.2
|
54 |
+
websockets==15.0.1
|