-
Notifications
You must be signed in to change notification settings - Fork 6
/
functions.py
68 lines (41 loc) · 1.49 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from torch.autograd import Function
import torch.nn as nn
import torch
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, p):
ctx.p = p
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.p
return output, None
class MSE(nn.Module):
def __init__(self):
super(MSE, self).__init__()
def forward(self, pred, real):
diffs = torch.add(real, -pred)
n = torch.numel(diffs.data)
mse = torch.sum(diffs.pow(2)) / n
return mse
class SIMSE(nn.Module):
def __init__(self):
super(SIMSE, self).__init__()
def forward(self, pred, real):
diffs = torch.add(real, - pred)
n = torch.numel(diffs.data)
simse = torch.sum(diffs).pow(2) / (n ** 2)
return simse
class DiffLoss(nn.Module):
def __init__(self):
super(DiffLoss, self).__init__()
def forward(self, input1, input2):
batch_size = input1.size(0)
input1 = input1.view(batch_size, -1)
input2 = input2.view(batch_size, -1)
input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True).detach()
input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6)
input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True).detach()
input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6)
diff_loss = torch.mean((input1_l2.t().mm(input2_l2)).pow(2))
return diff_loss