File size: 4,535 Bytes
0097859
 
 
8918695
 
0097859
45721f6
 
8918695
0097859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45721f6
 
 
 
 
 
 
 
 
 
 
 
 
0097859
 
 
 
 
 
 
 
 
 
45721f6
0097859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45721f6
 
 
 
0097859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45721f6
0097859
 
 
6f83f05
57d9ac1
6f83f05
2295ad4
0097859
 
 
 
 
 
 
 
 
f3475ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import spaces
import torch
from collections.abc import Iterator
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from pyvis.network import Network
import networkx as nx
import os

DESCRIPTION = """
# GWQ PREV
"""

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "prithivMLmods/GWQ2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()

def create_knowledge_graph(text):
    # Simple example: Create a graph from the text
    G = nx.Graph()
    words = text.split()
    for i in range(len(words) - 1):
        G.add_edge(words[i], words[i + 1])
    return G

def visualize_knowledge_graph(graph):
    net = Network(notebook=True, cdn_resources='in_line')
    net.from_nx(graph)
    net.show("knowledge_graph.html")
    return "knowledge_graph.html"

@spaces.GPU(duration=120)
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
    visualize_graph: bool = False,
) -> Iterator[str]:
    conversation = chat_history.copy()
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

    if visualize_graph:
        graph = create_knowledge_graph("".join(outputs))
        graph_file = visualize_knowledge_graph(graph)
        yield f"Knowledge graph saved to {graph_file}"

demo = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
        gr.Checkbox(label="Visualize Knowledge Graph", value=False),
    ],
    stop_btn=None,
    examples=[
        ["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
        ["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
        ["Difference between List comprehension and Lambda in Python lst  =  [x ** 2  for x in range (1, 11)   if  x % 2 == 1] print(lst)"],
        ["How Many R's in the Word 'STRAWBERRY' ?"],
    ],
    cache_examples=False,
    type="messages",
    description=DESCRIPTION,
    css_paths="style.css",
    fill_height=True,
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()