Call_model / app.py
disLodge's picture
fx
6acf34c verified
import gradio as gr
import requests
from pdfminer.high_level import extract_text
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace
from langchain_core.runnables import RunnablePassthrough, Runnable
from io import BytesIO
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from huggingface_hub import InferenceClient
import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import logging
import os
lo = "hf_JyAJApaXhIrONPFSIo"
ve = "wbnJbrXViYurrsvP"
half = lo+ve
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", half)
client = InferenceClient(
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
token=HF_TOKEN
)
class HuggingFaceInferenceClientRunnable(Runnable):
def __init__(self, client, max_tokens=512, temperature=0.7, top_p=0.95):
self.client = client
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((requests.exceptions.ConnectionError, requests.exceptions.Timeout))
)
def invoke(self, input, config=None):
# Extract the prompt from the input (ChatPromptTemplate output)
prompt = input.to_messages()[0].content
messages = [{"role": "user", "content": prompt}]
# Call the InferenceClient with streaming
response = ""
for parts in self.client.chat_completion(
messages,
max_tokens=self.max_tokens,
stream=True,
temperature=self.temperature,
top_p=self.top_p
):
# Handle streaming response parts
for part in parts.choices:
token = part.delta.content
if token:
response += token
return response
def update_params(self, max_tokens, temperature, top_p):
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
def extract_pdf_text(url: str) -> str:
response = requests.get(url)
pdf_file = BytesIO(response.content)
text = extract_text(pdf_file)
return text
pdf_url = "https://arxiv.org/pdf/2408.09869"
text = extract_pdf_text(pdf_url)
docs_list = [Document(page_content=text)]
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=7500, chunk_overlap=100)
docs_splits = text_splitter.split_documents(docs_list)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = Chroma.from_documents(
documents=docs_splits,
collection_name="rag-chroma",
embedding=embeddings,
)
retriever = vectorstore.as_retriever()
llm = HuggingFaceInferenceClientRunnable(client)
# After RAG chain
after_rag_template = """You are a {role}. Summarize the following content for yourself and speak in terms of first person.
Only include content relevant to that role like a resume summary.
Context:
{context}
Question: Give a one paragraph summary of the key skills a {role} can have from this document.
"""
after_rag_prompt = ChatPromptTemplate.from_template(after_rag_template)
def format_query(input_dict):
return f"Give a one paragraph summary of the key skills a {input_dict['role']} can have from this document."
after_rag_chain = (
{
"context": format_query | retriever,
"role": lambda x: x["role"],
}
| after_rag_prompt
| llm
| StrOutputParser()
)
def process_query(role, system_message, max_tokens, temperature, top_p):
llm.update_params(max_tokens, temperature, top_p)
# After RAG
after_rag_result = after_rag_chain.invoke({"role": role})
return f"**RAG Summary**\n{after_rag_result}"
with gr.Blocks() as demo:
gr.Markdown("## Zephyr Chatbot Controls")
role_dropdown = gr.Dropdown(choices=["SDE", "BA"], label="Select Role", value="SDE")
system_message = gr.Textbox(value="You are a friendly chatbot.", label="System message")
max_tokens = gr.Slider(1, 2048, value=512, label="Max tokens")
temperature = gr.Slider(0.1, 4.0, value=0.7, label="Temperature", step=0.1)
top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p", step=0.05)
output = gr.Textbox(label="Output", lines=20)
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
submit_btn.click(
fn=process_query,
inputs=[role_dropdown, system_message, max_tokens, temperature, top_p],
outputs=output
)
clear_btn.click(
fn=lambda: ("", gr.Info("Chat cleared!")),
outputs=[output]
)
if __name__ == "__main__":
demo.launch()