File size: 5,106 Bytes
6f5ab24
 
 
 
 
b884aa9
6f5ab24
b884aa9
6f5ab24
 
 
 
 
b884aa9
 
6f5ab24
b884aa9
52c1cb9
b927386
b884aa9
 
6f5ab24
 
b884aa9
6f5ab24
b884aa9
 
 
 
 
 
6f5ab24
 
 
 
 
 
 
 
 
 
 
 
 
b884aa9
6f5ab24
 
b884aa9
 
 
 
 
dc294fb
b884aa9
6f5ab24
b884aa9
dc294fb
6f5ab24
 
 
 
 
 
 
 
 
b884aa9
 
 
6f5ab24
 
 
b884aa9
6f5ab24
dc294fb
b884aa9
 
 
 
 
6f5ab24
b884aa9
6f5ab24
dc294fb
6f5ab24
b884aa9
 
 
 
 
 
6f5ab24
b884aa9
6f5ab24
b884aa9
b063a25
b884aa9
6f5ab24
 
 
 
 
 
 
 
b884aa9
dc294fb
6f5ab24
d6de8db
 
 
 
 
 
 
 
 
 
 
 
 
 
b884aa9
6f5ab24
b884aa9
 
62841b2
81fddf5
b884aa9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio  # Interface handling
import spaces  # For GPU
import transformers  # LLM Loading
import langchain_community.vectorstores  # Vectorstore for publications
import langchain_huggingface  # Embeddings

# Greeting message
GREETING = (
    "Howdy! I'm an AI agent that uses "
    "[retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) "
    "to answer questions about additive manufacturing research. "
    "I'm still improving, so bear with me if I make any mistakes. "
    "What can I help you with today?"
)

# Constants
EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2"
LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
PUBLICATIONS_TO_RETRIEVE = 10


def embedding(device: str = "cuda", normalize_embeddings: bool = False) -> langchain_huggingface.HuggingFaceEmbeddings:
    """Loads embedding model with specified device and normalization."""
    return langchain_huggingface.HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_NAME,
        model_kwargs={"device": device},
        encode_kwargs={"normalize_embeddings": normalize_embeddings},
    )


def load_publication_vectorstore() -> langchain_community.vectorstores.FAISS:
    """Load the publication vectorstore safely."""
    try:
        return langchain_community.vectorstores.FAISS.load_local(
            folder_path="publication_vectorstore",
            embeddings=embedding(),
            allow_dangerous_deserialization=True,
        )
    except Exception as e:
        print(f"Error loading vectorstore: {e}")
        return None


# Load vectorstore and models
publication_vectorstore = load_publication_vectorstore()
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True
)


def preprocess(query: str, k: int) -> str:
    """
    Generates a prompt based on the top k documents matching the query.
    """
    documents = publication_vectorstore.search(query, k=k, search_type="similarity")
    research_excerpts = [f'"... {doc.page_content}..."' for doc in documents]

    # Prompt template
    prompt_template = (
        "You are an AI assistant who enjoys helping users learn about research. "
        "Answer the following question on additive manufacturing research using the RESEARCH_EXCERPTS. "
        "Provide a concise ANSWER based on these excerpts. Avoid listing references.\n\n"
        "===== RESEARCH_EXCERPTS =====:\n{research_excerpts}\n\n"
        "===== USER_QUERY =====:\n{query}\n\n"
        "===== ANSWER =====:\n"
    )

    prompt = prompt_template.format(
        research_excerpts="\n\n".join(research_excerpts), query=query
    )

    print(prompt)  # Useful for debugging prompt content
    return prompt


@spaces.GPU
def reply(message: str, history: list[str]) -> str:
    """
    Generates a response to the user’s message.
    """
    # Preprocess message
    message = preprocess(message, PUBLICATIONS_TO_RETRIEVE)
    history_formatted = [
        {"role": role, "content": message_pair[idx]}
        for message_pair in history
        for idx, role in enumerate(["user", "assistant"])
        if message_pair[idx] is not None
    ] + [{"role": "user", "content": message}]

    # Tokenize and prepare model input
    text = tokenizer.apply_chat_template(
        history_formatted, tokenize=False, add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to("cuda")

    # Generate response directly
    output_tokens = chatmodel.generate(
        **model_inputs, max_new_tokens=512
    )
    
    # Decode the output tokens
    response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    return response


# Example Queries for Interface
EXAMPLE_QUERIES = [
    "What is multi-material 3D printing?",
    "How is additive manufacturing being applied in aerospace?",
    "Tell me about innovations in metal 3D printing techniques.",
    "What are some sustainable materials for 3D printing?",
    "What are the biggest challenges with support structures in additive manufacturing?",
    "How is 3D printing impacting the medical field?",
    "What are some common applications of additive manufacturing in industry?",
    "What are the benefits and limitations of using polymers in 3D printing?",
    "Tell me about the environmental impacts of additive manufacturing.",
    "What are the primary limitations of current 3D printing technologies?",
    "How are researchers improving the speed of 3D printing processes?",
    "What are the best practices for managing post-processing in additive manufacturing?",
]

# Run the Gradio Interface
gradio.ChatInterface(
    reply,
    examples=EXAMPLE_QUERIES,
    cache_examples=False,
    chatbot=gradio.Chatbot(
        show_label=False,
        show_share_button=False,
        show_copy_button=False,
        bubble_full_width=False,
    ),
).launch(debug=True)