File size: 4,300 Bytes
9a00c34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os, asyncio, logging
import gradio as gr
from huggingface_hub import InferenceClient
from .prompt import build_prompt

# ---------------------------------------------------------------------
# model / client initialisation
# ---------------------------------------------------------------------
HF_TOKEN  = os.getenv("HF_TOKEN")
MODEL_ID  = os.getenv("MODEL_ID", "meta-llama/Meta-Llama-3-8B-Instruct")
MAX_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))

if not HF_TOKEN:
    raise RuntimeError(
        "HF_TOKEN env-var missing. "
    )

client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)

# ---------------------------------------------------------------------
# Core generation function for both Gradio UI and MCP
# ---------------------------------------------------------------------
async def _call_llm(prompt: str) -> str:
    """
    Try text_generation first (for models/providers that still support it);
    fall back to chat_completion when the provider is chat-only (Novita, etc.).
    """
    try:
        # hf-inference
        return await asyncio.to_thread(
            client.text_generation,
            prompt,
            max_new_tokens=MAX_TOKENS,
            temperature=TEMPERATURE,
        )
    except ValueError as e:
        if "Supported task: conversational" not in str(e):
            raise                              # genuine error → bubble up

        # fallback for Novita
        messages = [{"role": "user", "content": prompt}]
        completion = await asyncio.to_thread(
            client.chat_completion,
            messages=messages,
            model=MODEL_ID,
            max_tokens=MAX_TOKENS,
            temperature=TEMPERATURE,
        )
        return completion.choices[0].message.content.strip()

async def rag_generate(query: str, context: str) -> str:
    """
    Generate an answer to a query using provided context through RAG.
    
    This function takes a user query and relevant context, then uses a language model
    to generate a comprehensive answer based on the provided information.
    
    Args:
        query (str): The user's question or query
        context (str): The relevant context/documents to use for answering
        
    Returns:
        str: The generated answer based on the query and context
    """
    if not query.strip():
        return "Error: Query cannot be empty"
    
    if not context.strip():
        return "Error: Context cannot be empty"
    
    prompt = build_prompt(query, context)
    try:
        answer = await _call_llm(prompt)
        return answer
    except Exception as e:
        logging.exception("Generation failed")
        return f"Error: {str(e)}"

# ---------------------------------------------------------------------
# Gradio Interface with MCP support
# ---------------------------------------------------------------------
ui = gr.Interface(
    fn=rag_generate,
    inputs=[
        gr.Textbox(
            label="Query", 
            lines=2, 
            placeholder="What would you like to know?",
            info="Enter your question here"
        ),
        gr.Textbox(
            label="Context", 
            lines=8, 
            placeholder="Paste relevant documents or context here...",
            info="Provide the context/documents to use for answering"
        ),
    ],
    outputs=gr.Textbox(
        label="Generated Answer",
        lines=6,
        show_copy_button=True
    ),
    title="RAG Generation Service",
    description="Ask questions and get answers based on your provided context. This service is also available as an MCP server for integration with AI applications.",
    examples=[
        [
            "What is the main benefit mentioned?",
            "Machine learning has revolutionized many industries. The main benefit is increased efficiency and accuracy in data processing."
        ],
        [
            "Who is the CEO?",
            "Company ABC was founded in 2020. The current CEO is Jane Smith, who has led the company to significant growth."
        ]
    ]
)

# Launch with MCP server enabled
if __name__ == "__main__":
    ui.launch(
        server_name="0.0.0.0",
        server_port=7860,
        mcp_server=True,
        show_error=True
    )