Mina Parham commited on
Commit
38366a7
Β·
1 Parent(s): 55aae7a

Initial commit πŸš€

Browse files
Files changed (2) hide show
  1. app.py +136 -63
  2. requirements.txt +54 -1
app.py CHANGED
@@ -1,64 +1,137 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import requests
3
+ import asyncio
4
+ import aiohttp
5
+
6
+ # Models setup
7
+ models = {
8
+ "Mistral-7B-Instruct": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
9
+ "DeepSeek-7B-Instruct": "https://api-inference.huggingface.co/models/deepseek-ai/deepseek-llm-7b-instruct",
10
+ "Qwen-7B-Chat": "https://api-inference.huggingface.co/models/Qwen/Qwen-7B-Chat"
11
+ }
12
+
13
+ # Judge model (Mixtral-8x7B)
14
+ judge_model_url = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
15
+
16
+ # Your Hugging Face API Token
17
+ API_TOKEN = "YOUR_HUGGINGFACE_API_TOKEN"
18
+ HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
19
+
20
+ # Async function to call a model
21
+ async def query_model(session, model_name, question):
22
+ payload = {"inputs": question, "parameters": {"max_new_tokens": 300}}
23
+ try:
24
+ async with session.post(models[model_name], headers=HEADERS, json=payload, timeout=60) as response:
25
+ result = await response.json()
26
+ if isinstance(result, list) and len(result) > 0:
27
+ return model_name, result[0]["generated_text"]
28
+ elif isinstance(result, dict) and "generated_text" in result:
29
+ return model_name, result["generated_text"]
30
+ else:
31
+ return model_name, str(result)
32
+ except Exception as e:
33
+ return model_name, f"Error: {str(e)}"
34
+
35
+ # Async function to call all models
36
+ async def gather_model_answers(question):
37
+ async with aiohttp.ClientSession() as session:
38
+ tasks = [query_model(session, model_name, question) for model_name in models]
39
+ results = await asyncio.gather(*tasks)
40
+ return dict(results)
41
+
42
+ # Function to ask the judge
43
+ def judge_best_answer(question, answers):
44
+ # Format the prompt for the Judge
45
+ judge_prompt = f"""
46
+ You are a wise AI Judge. A user asked the following question:
47
+
48
+ Question:
49
+ {question}
50
+
51
+ Here are the answers provided by different models:
52
+
53
+ Answer 1 (Mistral-7B-Instruct):
54
+ {answers['Mistral-7B-Instruct']}
55
+
56
+ Answer 2 (DeepSeek-7B-Instruct):
57
+ {answers['DeepSeek-7B-Instruct']}
58
+
59
+ Answer 3 (Qwen-7B-Chat):
60
+ {answers['Qwen-7B-Chat']}
61
+
62
+ Please carefully read all three answers. Your job:
63
+ - Pick the best answer (Answer 1, Answer 2, or Answer 3).
64
+ - Explain briefly why you chose that answer.
65
+
66
+ Respond in this JSON format:
67
+ {{"best_answer": "Answer X", "reason": "Your reasoning here"}}
68
+ """.strip()
69
+
70
+ payload = {"inputs": judge_prompt, "parameters": {"max_new_tokens": 300}}
71
+ response = requests.post(judge_model_url, headers=HEADERS, json=payload)
72
+
73
+ if response.status_code == 200:
74
+ result = response.json()
75
+ # Try to extract JSON from response
76
+ import json
77
+ import re
78
+
79
+ # Attempt to extract JSON block
80
+ match = re.search(r"\{.*\}", str(result))
81
+ if match:
82
+ try:
83
+ judge_decision = json.loads(match.group(0))
84
+ return judge_decision
85
+ except json.JSONDecodeError:
86
+ return {"best_answer": "Unknown", "reason": "Failed to parse judge output."}
87
+ else:
88
+ return {"best_answer": "Unknown", "reason": "No JSON found in judge output."}
89
+ else:
90
+ return {"best_answer": "Unknown", "reason": f"Judge API error: {response.status_code}"}
91
+
92
+ # Final app logic
93
+ def multi_model_qa(question):
94
+ answers = asyncio.run(gather_model_answers(question))
95
+ judge_decision = judge_best_answer(question, answers)
96
+
97
+ # Find the selected best answer
98
+ best_answer_key = judge_decision.get("best_answer", "")
99
+ best_answer_text = ""
100
+ if "1" in best_answer_key:
101
+ best_answer_text = answers["Mistral-7B-Instruct"]
102
+ elif "2" in best_answer_key:
103
+ best_answer_text = answers["DeepSeek-7B-Instruct"]
104
+ elif "3" in best_answer_key:
105
+ best_answer_text = answers["Qwen-7B-Chat"]
106
+ else:
107
+ best_answer_text = "Could not determine best answer."
108
+
109
+ return (
110
+ answers["Mistral-7B-Instruct"],
111
+ answers["DeepSeek-7B-Instruct"],
112
+ answers["Qwen-7B-Chat"],
113
+ best_answer_text,
114
+ judge_decision.get("reason", "No reasoning provided.")
115
+ )
116
+
117
+ # Gradio UI
118
+ with gr.Blocks() as demo:
119
+ gr.Markdown("# 🧠 Multi-Model Answer Aggregator")
120
+ gr.Markdown("Ask any question. The system queries multiple models and the AI Judge selects the best answer.")
121
+
122
+ question_input = gr.Textbox(label="Enter your question", placeholder="Ask me anything...", lines=2)
123
+ submit_btn = gr.Button("Get Best Answer")
124
+
125
+ mistral_output = gr.Textbox(label="Mistral-7B-Instruct Answer")
126
+ deepseek_output = gr.Textbox(label="DeepSeek-7B-Instruct Answer")
127
+ qwen_output = gr.Textbox(label="Qwen-7B-Chat Answer")
128
+ best_answer_output = gr.Textbox(label="πŸ† Best Answer Selected")
129
+ judge_reasoning_output = gr.Textbox(label="βš–οΈ Judge's Reasoning")
130
+
131
+ submit_btn.click(
132
+ multi_model_qa,
133
+ inputs=[question_input],
134
+ outputs=[mistral_output, deepseek_output, qwen_output, best_answer_output, judge_reasoning_output]
135
+ )
136
+
137
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,54 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ audioop-lts==0.2.1
5
+ certifi==2025.4.26
6
+ charset-normalizer==3.4.1
7
+ click==8.1.8
8
+ fastapi==0.115.12
9
+ ffmpy==0.5.0
10
+ filelock==3.18.0
11
+ fsspec==2025.3.2
12
+ gradio==5.27.0
13
+ gradio_client==1.9.0
14
+ groovy==0.1.2
15
+ h11==0.16.0
16
+ httpcore==1.0.9
17
+ httpx==0.28.1
18
+ huggingface-hub==0.30.2
19
+ idna==3.10
20
+ Jinja2==3.1.6
21
+ markdown-it-py==3.0.0
22
+ MarkupSafe==3.0.2
23
+ mdurl==0.1.2
24
+ numpy==2.2.5
25
+ orjson==3.10.16
26
+ packaging==25.0
27
+ pandas==2.2.3
28
+ pillow==11.2.1
29
+ pydantic==2.11.3
30
+ pydantic_core==2.33.1
31
+ pydub==0.25.1
32
+ Pygments==2.19.1
33
+ python-dateutil==2.9.0.post0
34
+ python-multipart==0.0.20
35
+ pytz==2025.2
36
+ PyYAML==6.0.2
37
+ requests==2.32.3
38
+ rich==14.0.0
39
+ ruff==0.11.7
40
+ safehttpx==0.1.6
41
+ semantic-version==2.10.0
42
+ shellingham==1.5.4
43
+ six==1.17.0
44
+ sniffio==1.3.1
45
+ starlette==0.46.2
46
+ tomlkit==0.13.2
47
+ tqdm==4.67.1
48
+ typer==0.15.2
49
+ typing-inspection==0.4.0
50
+ typing_extensions==4.13.2
51
+ tzdata==2025.2
52
+ urllib3==2.4.0
53
+ uvicorn==0.34.2
54
+ websockets==15.0.1