Skip to content

Commit 08b1dcb

Browse files
committed
perf(loss): 批量计算
1 parent 75d2760 commit 08b1dcb

File tree

1 file changed

+130
-72
lines changed

1 file changed

+130
-72
lines changed

py/lib/models/multi_part_loss.py

+130-72
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,132 @@ def forward(self, preds, targets):
3737
:param targets: (N, S*S, (B*5+C))
3838
:return:
3939
"""
40+
# print('loss 1', self._process1(preds, targets))
41+
# print('loss 2', self._process2(preds, targets))
42+
return self._process3(preds, targets)
43+
44+
def _process1(self, preds, targets):
45+
N = preds.shape[0]
46+
total_loss = 0.0
47+
for pred, target in zip(preds, targets):
48+
"""
49+
逐个图像计算
50+
pred: [S*S, (B*5+C)]
51+
target: [S*S, (B*5+C)]
52+
"""
53+
# 分类概率
54+
# [S*S, C]
55+
pred_probs = pred[:, :self.C]
56+
target_probs = target[:, :self.C]
57+
# 置信度
58+
# [S*S, B]
59+
pred_confidences = pred[:, self.C:(self.C + self.B)]
60+
target_confidences = target[:, self.C:(self.C + self.B)]
61+
# 边界框坐标
62+
pred_bboxs = pred[:, (self.C + self.B):]
63+
target_bboxs = target[:, (self.C + self.B):]
64+
65+
for i in range(self.S * self.S):
66+
"""
67+
逐个网格计算
68+
"""
69+
pred_single_probs = pred_probs[i]
70+
target_single_probs = target_probs[i]
71+
72+
pred_single_confidences = pred_confidences[i]
73+
target_single_confidences = target_confidences[i]
74+
75+
pred_single_bboxs = pred_bboxs[i]
76+
target_single_bboxs = target_bboxs[i]
77+
78+
# 是否存在置信度(如果存在,则target的置信度必然大于0)
79+
is_obj = target_single_confidences[0] > 0
80+
# 计算置信度损失 假定该网格不存在对象
81+
total_loss += self.noobj * self.sum_squared_error(pred_single_confidences, target_single_confidences)
82+
if is_obj:
83+
# 如果存在
84+
# 计算分类损失
85+
total_loss += self.sum_squared_error(pred_single_probs, target_single_probs)
86+
87+
# 计算所有预测边界框和标注边界框的IoU
88+
pred_single_bboxs = pred_single_bboxs.reshape(-1, 4)
89+
target_single_bboxs = target_single_bboxs.reshape(-1, 4)
90+
91+
scores = self.iou(pred_single_bboxs, target_single_bboxs)
92+
# 提取IoU最大的下标
93+
bbox_idx = torch.argmax(scores)
94+
# 计算置信度损失
95+
total_loss += (1 - self.noobj) * \
96+
self.sum_squared_error(pred_single_confidences[bbox_idx],
97+
target_single_confidences[bbox_idx])
98+
# 计算边界框损失
99+
total_loss += self.coord * self.bbox_loss(pred_single_bboxs[bbox_idx].reshape(-1, 4),
100+
target_single_bboxs[bbox_idx].reshape(-1, 4))
101+
102+
return total_loss / N
103+
104+
def _process2(self, preds, targets):
105+
N = preds.shape[0]
106+
total_loss = 0.0
107+
for pred, target in zip(preds, targets):
108+
"""
109+
逐个图像计算
110+
pred: [S*S, (B*5+C)]
111+
target: [S*S, (B*5+C)]
112+
"""
113+
# 分类概率
114+
# [S*S, C]
115+
pred_probs = pred[:, :self.C]
116+
target_probs = target[:, :self.C]
117+
# 置信度
118+
# [S*S, B]
119+
pred_confidences = pred[:, self.C:(self.C + self.B)]
120+
target_confidences = target[:, self.C:(self.C + self.B)]
121+
# 边界框坐标
122+
# [S*S, B*4] -> [S*S, B, 4]
123+
pred_bboxs = pred[:, (self.C + self.B):].reshape(self.S * self.S, self.B, 4)
124+
target_bboxs = target[:, (self.C + self.B):].reshape(self.S * self.S, self.B, 4)
125+
126+
# 统一计算置信度损失
127+
total_loss += self.noobj * self.sum_squared_error(pred_confidences, target_confidences)
128+
# 计算每个网格预测边界框的IoU
129+
# Input: [S*S, B, 4] -> [S*S*B, 4]
130+
# Output: [S*S*B] -> [S*S, B]
131+
iou_scores = self.iou(pred_bboxs.reshape(-1, 4), target_bboxs.reshape(-1, 4)).reshape(self.S * self.S,
132+
self.B)
133+
# 计算其中最大IoU所属下标
134+
# [S*S]
135+
top_idxs = torch.argmax(iou_scores, dim=1)
136+
top_len = len(top_idxs)
137+
# 提取对应的边界框以及置信度
138+
# [S*S, 4]
139+
top_pred_bboxs = pred_bboxs[range(top_len), top_idxs]
140+
top_pred_confidences = pred_confidences[range(top_len), top_idxs]
141+
top_target_bboxs = target_bboxs[range(top_len), top_idxs]
142+
top_target_confidences = target_confidences[range(top_len), top_idxs]
143+
144+
# 计算网格中是否存在目标
145+
# [S*S, C] -> [S*S]
146+
obj_idxs = torch.sum(target_probs, dim=1) > 0
147+
# 提取对应的目标分类概率、置信度以及边界框坐标
148+
# [S*S, C]
149+
obj_pred_probs = pred_probs[obj_idxs]
150+
obj_pred_confidences = top_pred_confidences[obj_idxs]
151+
obj_pred_bboxs = top_pred_bboxs[obj_idxs]
152+
153+
obj_target_probs = target_probs[obj_idxs]
154+
obj_target_confidences = top_target_confidences[obj_idxs]
155+
obj_target_bboxs = top_target_bboxs[obj_idxs]
156+
157+
# 计算置信度损失
158+
total_loss += (1 - self.noobj) * self.sum_squared_error(obj_pred_confidences, obj_target_confidences)
159+
# 分类概率损失
160+
total_loss += self.sum_squared_error(obj_pred_probs, obj_target_probs)
161+
# 坐标损失
162+
total_loss += self.coord * self.bbox_loss(obj_pred_bboxs, obj_target_bboxs)
163+
return total_loss / N
164+
165+
def _process3(self, preds, targets):
40166
N = preds.shape[0]
41167
## 预测
42168
# 提取每个网格的分类概率
@@ -82,7 +208,7 @@ def forward(self, preds, targets):
82208
# print(top_pred_bboxs.shape)
83209

84210
# 选取存在目标的网格
85-
obj_idxs = torch.sum(target_probs, dim=1) == 1
211+
obj_idxs = torch.sum(target_probs, dim=1) > 0
86212
# print(obj_idxs)
87213

88214
obj_pred_confidences = top_pred_confidences[obj_idxs]
@@ -98,77 +224,12 @@ def forward(self, preds, targets):
98224
## 计算分类概率损失
99225
loss += self.sum_squared_error(obj_pred_probs, obj_target_probs)
100226
## 计算边界框坐标损失
101-
loss += self.sum_squared_error(obj_pred_bboxs[:, :2], obj_target_bboxs[:, :2])
102-
loss += self.sum_squared_error(torch.sqrt(obj_pred_bboxs[:, 2:]), torch.sqrt(obj_target_bboxs[:, 2:]))
227+
loss += self.coord * self.sum_squared_error(obj_pred_bboxs[:, :2], obj_target_bboxs[:, :2])
228+
loss += self.coord * self.sum_squared_error(torch.sqrt(obj_pred_bboxs[:, 2:]),
229+
torch.sqrt(obj_target_bboxs[:, 2:]))
103230

104231
return loss / N
105232

106-
# N = preds.shape[0]
107-
# total_loss = 0.0
108-
# print(preds.shape)
109-
# print(targets.shape)
110-
# for pred, target in zip(preds, targets):
111-
# """
112-
# 逐个图像计算
113-
# pred: [S*S, (B*5+C)]
114-
# target: [S*S, (B*5+C)]
115-
# """
116-
# # 分类概率
117-
# pred_probs = pred[:, :self.C]
118-
# target_probs = target[:, :self.C]
119-
# # 置信度
120-
# pred_confidences = pred[:, self.C:(self.C + self.B)]
121-
# target_confidences = target[:, self.C:(self.C + self.B)]
122-
# # 边界框坐标
123-
# pred_bboxs = pred[:, (self.C + self.B):]
124-
# target_bboxs = target[:, (self.C + self.B):]
125-
#
126-
# for i in range(self.S * self.S):
127-
# """
128-
# 逐个网格计算
129-
# """
130-
# pred_single_probs = pred_probs[i]
131-
# target_single_probs = target_probs[i]
132-
#
133-
# pred_single_confidences = pred_confidences[i]
134-
# target_single_confidences = target_confidences[i]
135-
#
136-
# pred_single_bboxs = pred_bboxs[i]
137-
# target_single_bboxs = target_bboxs[i]
138-
#
139-
# # 是否存在置信度(如果存在,则target的置信度必然大于0)
140-
# is_obj = target_single_confidences[0] > 0
141-
# # 计算置信度损失 假定该网格不存在对象
142-
# total_loss += self.noobj * self.sum_squared_error(pred_single_confidences, target_single_confidences)
143-
# print(total_loss)
144-
# if is_obj:
145-
# print('i = %d' % (i))
146-
# # 如果存在
147-
# # 计算分类损失
148-
# total_loss += self.sum_squared_error(pred_single_probs, target_single_probs)
149-
# print(total_loss)
150-
#
151-
# # 计算所有预测边界框和标注边界框的IoU
152-
# pred_single_bboxs = pred_single_bboxs.reshape(-1, 4)
153-
# target_single_bboxs = target_single_bboxs.reshape(-1, 4)
154-
#
155-
# scores = self.iou(pred_single_bboxs, target_single_bboxs)
156-
# # 提取IoU最大的下标
157-
# bbox_idx = torch.argmax(scores)
158-
# # 计算置信度损失
159-
# total_loss += (1 - self.noobj) * \
160-
# self.sum_squared_error(pred_single_confidences[bbox_idx],
161-
# target_single_confidences[bbox_idx])
162-
# print(total_loss)
163-
# # 计算边界框损失
164-
# total_loss += self.coord * self.bbox_loss(pred_single_bboxs[bbox_idx].reshape(-1, 4),
165-
# target_single_bboxs[bbox_idx].reshape(-1, 4))
166-
# print(total_loss)
167-
#
168-
# print('done')
169-
#
170-
# return total_loss / N
171-
172233
def sum_squared_error(self, preds, targets):
173234
return torch.sum((preds - targets) ** 2)
174235

@@ -241,9 +302,6 @@ def load_data(data_root_dir, cate_list, S=7, B=2, C=20):
241302
for inputs, labels in data_loader:
242303
inputs = inputs
243304
labels = labels
244-
print(inputs.shape)
245-
print(labels.shape)
246-
247305
with torch.set_grad_enabled(False):
248306
outputs = model(inputs)
249307
loss = criterion(outputs, labels)

0 commit comments

Comments
 (0)