Spaces:
Sleeping
Sleeping
import os, asyncio, logging | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from .prompt import build_prompt | |
# --------------------------------------------------------------------- | |
# model / client initialisation | |
# --------------------------------------------------------------------- | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Meta-Llama-3-8B-Instruct") | |
MAX_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512")) | |
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2")) | |
if not HF_TOKEN: | |
raise RuntimeError( | |
"HF_TOKEN env-var missing. " | |
) | |
client = InferenceClient(model=MODEL_ID, token=HF_TOKEN) | |
# --------------------------------------------------------------------- | |
# Core generation function for both Gradio UI and MCP | |
# --------------------------------------------------------------------- | |
async def _call_llm(prompt: str) -> str: | |
""" | |
Try text_generation first (for models/providers that still support it); | |
fall back to chat_completion when the provider is chat-only (Novita, etc.). | |
""" | |
try: | |
# hf-inference | |
return await asyncio.to_thread( | |
client.text_generation, | |
prompt, | |
max_new_tokens=MAX_TOKENS, | |
temperature=TEMPERATURE, | |
) | |
except ValueError as e: | |
if "Supported task: conversational" not in str(e): | |
raise # genuine error → bubble up | |
# fallback for Novita | |
messages = [{"role": "user", "content": prompt}] | |
completion = await asyncio.to_thread( | |
client.chat_completion, | |
messages=messages, | |
model=MODEL_ID, | |
max_tokens=MAX_TOKENS, | |
temperature=TEMPERATURE, | |
) | |
return completion.choices[0].message.content.strip() | |
async def rag_generate(query: str, context: str) -> str: | |
""" | |
Generate an answer to a query using provided context through RAG. | |
This function takes a user query and relevant context, then uses a language model | |
to generate a comprehensive answer based on the provided information. | |
Args: | |
query (str): The user's question or query | |
context (str): The relevant context/documents to use for answering | |
Returns: | |
str: The generated answer based on the query and context | |
""" | |
if not query.strip(): | |
return "Error: Query cannot be empty" | |
if not context.strip(): | |
return "Error: Context cannot be empty" | |
prompt = build_prompt(query, context) | |
try: | |
answer = await _call_llm(prompt) | |
return answer | |
except Exception as e: | |
logging.exception("Generation failed") | |
return f"Error: {str(e)}" | |
# --------------------------------------------------------------------- | |
# Gradio Interface with MCP support | |
# --------------------------------------------------------------------- | |
ui = gr.Interface( | |
fn=rag_generate, | |
inputs=[ | |
gr.Textbox( | |
label="Query", | |
lines=2, | |
placeholder="What would you like to know?", | |
info="Enter your question here" | |
), | |
gr.Textbox( | |
label="Context", | |
lines=8, | |
placeholder="Paste relevant documents or context here...", | |
info="Provide the context/documents to use for answering" | |
), | |
], | |
outputs=gr.Textbox( | |
label="Generated Answer", | |
lines=6, | |
show_copy_button=True | |
), | |
title="RAG Generation Service", | |
description="Ask questions and get answers based on your provided context. This service is also available as an MCP server for integration with AI applications.", | |
examples=[ | |
[ | |
"What is the main benefit mentioned?", | |
"Machine learning has revolutionized many industries. The main benefit is increased efficiency and accuracy in data processing." | |
], | |
[ | |
"Who is the CEO?", | |
"Company ABC was founded in 2020. The current CEO is Jane Smith, who has led the company to significant growth." | |
] | |
] | |
) | |
# Launch with MCP server enabled | |
if __name__ == "__main__": | |
ui.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
mcp_server=True, | |
show_error=True | |
) | |