-
Notifications
You must be signed in to change notification settings - Fork 43
/
losses.py
34 lines (24 loc) · 1016 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# encoding:utf-8
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss
__call__ = ['CrossEntropy', 'BCEWithLogLoss']
class CrossEntropy(object):
def __init__(self, ignore_index=-1):
self.loss_f = CrossEntropyLoss(ignore_index=ignore_index)
def __call__(self, output, target):
loss = self.loss_f(input=output, target=target)
return loss
class BCEWithLogLoss(object):
def __init__(self):
self.loss_fn = BCEWithLogitsLoss()
def __call__(self, output, target):
loss = self.loss_fn(input=output, target=target)
return loss
class SpanLoss(object):
def __init__(self, ignore_index=-100):
self.loss_fn = CrossEntropyLoss(ignore_index=ignore_index)
def __call__(self, output, target, mask):
active_loss = mask.view(-1) == 1
active_logits = output[active_loss]
active_labels = target[active_loss]
return self.loss_fn(active_logits, active_labels)