aditide commited on
Commit
96320a1
·
verified ·
1 Parent(s): aa0ff34

Upload 2 files

Browse files
Files changed (2) hide show
  1. zia_model.pt +3 -0
  2. zia_model.py +44 -0
zia_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90d29ff8c870a548ab1868f6a17b5c13d1d65df590e5096472e0b03981e7be69
3
+ size 4826444
zia_model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
4
+
5
+ class ZIAModel(nn.Module):
6
+ def __init__(self, n_intents=10, d_model=128, nhead=8, num_layers=6, dim_feedforward=512):
7
+ super(ZIAModel, self).__init__()
8
+ self.d_model = d_model
9
+
10
+ # Modality-specific encoders
11
+ self.gaze_encoder = nn.Linear(2, d_model)
12
+ self.hr_encoder = nn.Linear(1, d_model)
13
+ self.eeg_encoder = nn.Linear(4, d_model)
14
+ self.context_encoder = nn.Linear(32 + 3 + 20, d_model) # Time (32) + Location (3) + Usage (20)
15
+
16
+ # Transformer
17
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1, batch_first=True)
18
+ self.transformer = TransformerEncoder(encoder_layer, num_layers)
19
+
20
+ # Output layer
21
+ self.fc = nn.Linear(d_model, n_intents)
22
+
23
+ def forward(self, gaze, hr, eeg, context):
24
+ # Encode modalities
25
+ gaze_emb = self.gaze_encoder(gaze) # [batch, seq, d_model]
26
+ hr_emb = self.hr_encoder(hr.unsqueeze(-1))
27
+ eeg_emb = self.eeg_encoder(eeg)
28
+ context_emb = self.context_encoder(context)
29
+
30
+ # Fuse modalities
31
+ fused = (gaze_emb + hr_emb + eeg_emb + context_emb) / 4 # Simple averaging
32
+
33
+ # Transformer
34
+ output = self.transformer(fused)
35
+ output = output.mean(dim=1) # Pool over sequence
36
+
37
+ # Predict intent
38
+ logits = self.fc(output)
39
+ return logits
40
+
41
+ # Example usage
42
+ if __name__ == "__main__":
43
+ model = ZIAModel()
44
+ print(model)