nnsohamnn commited on
Commit
591ec58
·
verified ·
1 Parent(s): 1377962

Upload 4 files

Browse files
Files changed (4) hide show
  1. Conv_GPT.pth +3 -0
  2. app.py +56 -0
  3. model.py +93 -0
  4. requirements.txt +4 -0
Conv_GPT.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d07006505c691bae29120861fbc9dfe9ad3b75d4964e38b8445020991d4d6b17
3
+ size 358490096
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import GPT2Tokenizer
4
+ from model import TransformerModel
5
+
6
+ # Load tokenizer
7
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
8
+ tokenizer.pad_token = tokenizer.eos_token
9
+
10
+ # Load model
11
+ model = TransformerModel(
12
+ vocab_size=tokenizer.vocab_size,
13
+ hidden_size=512,
14
+ num_layers=12,
15
+ num_heads=16,
16
+ dropout=0.1
17
+ )
18
+ model.load_state_dict(torch.load("Conv_GPT.pth", map_location=torch.device('cpu')))
19
+ model.eval()
20
+
21
+ # Define generation function
22
+ def generate_text(prompt, max_new_tokens=50):
23
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
24
+ # Ensure input sequence length does not exceed 512 (model's max_seq_len)
25
+ if input_ids.size(1) > 512:
26
+ input_ids = input_ids[:, :512]
27
+ generated_ids = input_ids
28
+ with torch.no_grad():
29
+ for _ in range(max_new_tokens):
30
+ logits = model(generated_ids)
31
+ next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
32
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
33
+ # Truncate if exceeding 512 tokens
34
+ if generated_ids.size(1) > 512:
35
+ generated_ids = generated_ids[:, -512:]
36
+ if tokenizer.decode(next_token.item()) == '\n':
37
+ break
38
+ return tokenizer.decode(generated_ids[0, len(input_ids[0]):]).strip()
39
+
40
+ # Chat function for Gradio
41
+ def chat(message, history):
42
+ prompt = f"User: {message}\nAssistant:"
43
+ response = generate_text(prompt)
44
+ return response
45
+
46
+ # Create Gradio interface
47
+ interface = gr.ChatInterface(
48
+ fn=chat,
49
+ title="Conv_GPT Chatbot",
50
+ description="Chat with Conv_GPT, a custom transformer trained on DailyDialog! Enter your message below.",
51
+ theme="default",
52
+ examples=["Hi, how are you?", "What's your favorite food?", "Tell me about your day."]
53
+ )
54
+
55
+ # Launch the app
56
+ interface.launch(server_name="0.0.0.0", server_port=7860)
model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+
6
+ class MultiHeadSelfAttention(nn.Module):
7
+ def __init__(self, hidden_size, num_heads, dropout=0.1):
8
+ super().__init__()
9
+ assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
10
+ self.hidden_size = hidden_size
11
+ self.num_heads = num_heads
12
+ self.head_dim = hidden_size // num_heads
13
+
14
+ self.query = nn.Linear(hidden_size, hidden_size)
15
+ self.key = nn.Linear(hidden_size, hidden_size)
16
+ self.value = nn.Linear(hidden_size, hidden_size)
17
+ self.out = nn.Linear(hidden_size, hidden_size)
18
+
19
+ self.dropout = nn.Dropout(dropout)
20
+ self.scale = math.sqrt(self.head_dim)
21
+
22
+ def forward(self, x, mask=None, padding_mask=None):
23
+ batch_size, seq_len, _ = x.size()
24
+
25
+ q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
26
+ k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
27
+ v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
28
+
29
+ scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
30
+
31
+ if mask is not None:
32
+ scores = scores.masked_fill(mask == 1, -1e4) # Adjusted for FP16 compatibility
33
+ if padding_mask is not None:
34
+ padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
35
+ scores = scores.masked_fill(padding_mask, -1e4) # Adjusted for FP16 compatibility
36
+
37
+ attn = torch.softmax(scores, dim=-1)
38
+ attn = self.dropout(attn)
39
+
40
+ out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
41
+ out = self.out(out)
42
+ return out
43
+
44
+ class TransformerLayer(nn.Module):
45
+ def __init__(self, hidden_size, num_heads, dropout=0.1):
46
+ super().__init__()
47
+ self.attn = MultiHeadSelfAttention(hidden_size, num_heads, dropout)
48
+ self.ffn = nn.Sequential(
49
+ nn.Linear(hidden_size, 4 * hidden_size),
50
+ nn.ReLU(),
51
+ nn.Linear(4 * hidden_size, hidden_size),
52
+ nn.Dropout(dropout)
53
+ )
54
+ self.ln1 = nn.LayerNorm(hidden_size)
55
+ self.ln2 = nn.LayerNorm(hidden_size)
56
+ self.dropout = nn.Dropout(dropout)
57
+
58
+ def forward(self, x, mask=None, padding_mask=None):
59
+ x = self.ln1(x)
60
+ attn_out = self.attn(x, mask, padding_mask)
61
+ x = x + self.dropout(attn_out)
62
+
63
+ x = self.ln2(x)
64
+ ffn_out = self.ffn(x)
65
+ x = x + self.dropout(ffn_out)
66
+ return x
67
+
68
+ class TransformerModel(nn.Module):
69
+ def __init__(self, vocab_size, hidden_size=512, num_layers=6, num_heads=8, dropout=0.1):
70
+ super().__init__()
71
+ self.token_embedding = nn.Embedding(vocab_size, hidden_size)
72
+ self.pos_embedding = nn.Embedding(512, hidden_size) # Fixed max_seq_len=512
73
+ self.layers = nn.ModuleList([
74
+ TransformerLayer(hidden_size, num_heads, dropout) for _ in range(num_layers)
75
+ ])
76
+ self.final_ln = nn.LayerNorm(hidden_size)
77
+ self.head = nn.Linear(hidden_size, vocab_size)
78
+ self.dropout = nn.Dropout(dropout)
79
+
80
+ def forward(self, input_ids, padding_mask=None):
81
+ batch_size, seq_len = input_ids.size()
82
+ positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0).expand_as(input_ids)
83
+ x = self.token_embedding(input_ids) + self.pos_embedding(positions)
84
+ x = self.dropout(x)
85
+
86
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
87
+
88
+ for layer in self.layers:
89
+ x = layer(x, causal_mask, padding_mask)
90
+
91
+ x = self.final_ln(x)
92
+ logits = self.head(x)
93
+ return logits
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ huggingface_hub