Commit
·
aa968b3
1
Parent(s):
cbb363e
Update cross_visual.py
Browse files- cross_visual.py +1 -1
cross_visual.py
CHANGED
|
@@ -794,4 +794,4 @@ class CrossVisionModel(nn.Module):
|
|
| 794 |
|
| 795 |
def forward(self, images):
|
| 796 |
enc = self.vit(images)
|
| 797 |
-
return enc + self.pos_embed.unsqueeze(0)
|
|
|
|
| 794 |
|
| 795 |
def forward(self, images):
|
| 796 |
enc = self.vit(images)
|
| 797 |
+
return enc + self.pos_embed.to(enc.device).unsqueeze(0)
|