@@ -37,6 +37,132 @@ def forward(self, preds, targets):
37
37
:param targets: (N, S*S, (B*5+C))
38
38
:return:
39
39
"""
40
+ # print('loss 1', self._process1(preds, targets))
41
+ # print('loss 2', self._process2(preds, targets))
42
+ return self ._process3 (preds , targets )
43
+
44
+ def _process1 (self , preds , targets ):
45
+ N = preds .shape [0 ]
46
+ total_loss = 0.0
47
+ for pred , target in zip (preds , targets ):
48
+ """
49
+ 逐个图像计算
50
+ pred: [S*S, (B*5+C)]
51
+ target: [S*S, (B*5+C)]
52
+ """
53
+ # 分类概率
54
+ # [S*S, C]
55
+ pred_probs = pred [:, :self .C ]
56
+ target_probs = target [:, :self .C ]
57
+ # 置信度
58
+ # [S*S, B]
59
+ pred_confidences = pred [:, self .C :(self .C + self .B )]
60
+ target_confidences = target [:, self .C :(self .C + self .B )]
61
+ # 边界框坐标
62
+ pred_bboxs = pred [:, (self .C + self .B ):]
63
+ target_bboxs = target [:, (self .C + self .B ):]
64
+
65
+ for i in range (self .S * self .S ):
66
+ """
67
+ 逐个网格计算
68
+ """
69
+ pred_single_probs = pred_probs [i ]
70
+ target_single_probs = target_probs [i ]
71
+
72
+ pred_single_confidences = pred_confidences [i ]
73
+ target_single_confidences = target_confidences [i ]
74
+
75
+ pred_single_bboxs = pred_bboxs [i ]
76
+ target_single_bboxs = target_bboxs [i ]
77
+
78
+ # 是否存在置信度(如果存在,则target的置信度必然大于0)
79
+ is_obj = target_single_confidences [0 ] > 0
80
+ # 计算置信度损失 假定该网格不存在对象
81
+ total_loss += self .noobj * self .sum_squared_error (pred_single_confidences , target_single_confidences )
82
+ if is_obj :
83
+ # 如果存在
84
+ # 计算分类损失
85
+ total_loss += self .sum_squared_error (pred_single_probs , target_single_probs )
86
+
87
+ # 计算所有预测边界框和标注边界框的IoU
88
+ pred_single_bboxs = pred_single_bboxs .reshape (- 1 , 4 )
89
+ target_single_bboxs = target_single_bboxs .reshape (- 1 , 4 )
90
+
91
+ scores = self .iou (pred_single_bboxs , target_single_bboxs )
92
+ # 提取IoU最大的下标
93
+ bbox_idx = torch .argmax (scores )
94
+ # 计算置信度损失
95
+ total_loss += (1 - self .noobj ) * \
96
+ self .sum_squared_error (pred_single_confidences [bbox_idx ],
97
+ target_single_confidences [bbox_idx ])
98
+ # 计算边界框损失
99
+ total_loss += self .coord * self .bbox_loss (pred_single_bboxs [bbox_idx ].reshape (- 1 , 4 ),
100
+ target_single_bboxs [bbox_idx ].reshape (- 1 , 4 ))
101
+
102
+ return total_loss / N
103
+
104
+ def _process2 (self , preds , targets ):
105
+ N = preds .shape [0 ]
106
+ total_loss = 0.0
107
+ for pred , target in zip (preds , targets ):
108
+ """
109
+ 逐个图像计算
110
+ pred: [S*S, (B*5+C)]
111
+ target: [S*S, (B*5+C)]
112
+ """
113
+ # 分类概率
114
+ # [S*S, C]
115
+ pred_probs = pred [:, :self .C ]
116
+ target_probs = target [:, :self .C ]
117
+ # 置信度
118
+ # [S*S, B]
119
+ pred_confidences = pred [:, self .C :(self .C + self .B )]
120
+ target_confidences = target [:, self .C :(self .C + self .B )]
121
+ # 边界框坐标
122
+ # [S*S, B*4] -> [S*S, B, 4]
123
+ pred_bboxs = pred [:, (self .C + self .B ):].reshape (self .S * self .S , self .B , 4 )
124
+ target_bboxs = target [:, (self .C + self .B ):].reshape (self .S * self .S , self .B , 4 )
125
+
126
+ # 统一计算置信度损失
127
+ total_loss += self .noobj * self .sum_squared_error (pred_confidences , target_confidences )
128
+ # 计算每个网格预测边界框的IoU
129
+ # Input: [S*S, B, 4] -> [S*S*B, 4]
130
+ # Output: [S*S*B] -> [S*S, B]
131
+ iou_scores = self .iou (pred_bboxs .reshape (- 1 , 4 ), target_bboxs .reshape (- 1 , 4 )).reshape (self .S * self .S ,
132
+ self .B )
133
+ # 计算其中最大IoU所属下标
134
+ # [S*S]
135
+ top_idxs = torch .argmax (iou_scores , dim = 1 )
136
+ top_len = len (top_idxs )
137
+ # 提取对应的边界框以及置信度
138
+ # [S*S, 4]
139
+ top_pred_bboxs = pred_bboxs [range (top_len ), top_idxs ]
140
+ top_pred_confidences = pred_confidences [range (top_len ), top_idxs ]
141
+ top_target_bboxs = target_bboxs [range (top_len ), top_idxs ]
142
+ top_target_confidences = target_confidences [range (top_len ), top_idxs ]
143
+
144
+ # 计算网格中是否存在目标
145
+ # [S*S, C] -> [S*S]
146
+ obj_idxs = torch .sum (target_probs , dim = 1 ) > 0
147
+ # 提取对应的目标分类概率、置信度以及边界框坐标
148
+ # [S*S, C]
149
+ obj_pred_probs = pred_probs [obj_idxs ]
150
+ obj_pred_confidences = top_pred_confidences [obj_idxs ]
151
+ obj_pred_bboxs = top_pred_bboxs [obj_idxs ]
152
+
153
+ obj_target_probs = target_probs [obj_idxs ]
154
+ obj_target_confidences = top_target_confidences [obj_idxs ]
155
+ obj_target_bboxs = top_target_bboxs [obj_idxs ]
156
+
157
+ # 计算置信度损失
158
+ total_loss += (1 - self .noobj ) * self .sum_squared_error (obj_pred_confidences , obj_target_confidences )
159
+ # 分类概率损失
160
+ total_loss += self .sum_squared_error (obj_pred_probs , obj_target_probs )
161
+ # 坐标损失
162
+ total_loss += self .coord * self .bbox_loss (obj_pred_bboxs , obj_target_bboxs )
163
+ return total_loss / N
164
+
165
+ def _process3 (self , preds , targets ):
40
166
N = preds .shape [0 ]
41
167
## 预测
42
168
# 提取每个网格的分类概率
@@ -82,7 +208,7 @@ def forward(self, preds, targets):
82
208
# print(top_pred_bboxs.shape)
83
209
84
210
# 选取存在目标的网格
85
- obj_idxs = torch .sum (target_probs , dim = 1 ) == 1
211
+ obj_idxs = torch .sum (target_probs , dim = 1 ) > 0
86
212
# print(obj_idxs)
87
213
88
214
obj_pred_confidences = top_pred_confidences [obj_idxs ]
@@ -98,77 +224,12 @@ def forward(self, preds, targets):
98
224
## 计算分类概率损失
99
225
loss += self .sum_squared_error (obj_pred_probs , obj_target_probs )
100
226
## 计算边界框坐标损失
101
- loss += self .sum_squared_error (obj_pred_bboxs [:, :2 ], obj_target_bboxs [:, :2 ])
102
- loss += self .sum_squared_error (torch .sqrt (obj_pred_bboxs [:, 2 :]), torch .sqrt (obj_target_bboxs [:, 2 :]))
227
+ loss += self .coord * self .sum_squared_error (obj_pred_bboxs [:, :2 ], obj_target_bboxs [:, :2 ])
228
+ loss += self .coord * self .sum_squared_error (torch .sqrt (obj_pred_bboxs [:, 2 :]),
229
+ torch .sqrt (obj_target_bboxs [:, 2 :]))
103
230
104
231
return loss / N
105
232
106
- # N = preds.shape[0]
107
- # total_loss = 0.0
108
- # print(preds.shape)
109
- # print(targets.shape)
110
- # for pred, target in zip(preds, targets):
111
- # """
112
- # 逐个图像计算
113
- # pred: [S*S, (B*5+C)]
114
- # target: [S*S, (B*5+C)]
115
- # """
116
- # # 分类概率
117
- # pred_probs = pred[:, :self.C]
118
- # target_probs = target[:, :self.C]
119
- # # 置信度
120
- # pred_confidences = pred[:, self.C:(self.C + self.B)]
121
- # target_confidences = target[:, self.C:(self.C + self.B)]
122
- # # 边界框坐标
123
- # pred_bboxs = pred[:, (self.C + self.B):]
124
- # target_bboxs = target[:, (self.C + self.B):]
125
- #
126
- # for i in range(self.S * self.S):
127
- # """
128
- # 逐个网格计算
129
- # """
130
- # pred_single_probs = pred_probs[i]
131
- # target_single_probs = target_probs[i]
132
- #
133
- # pred_single_confidences = pred_confidences[i]
134
- # target_single_confidences = target_confidences[i]
135
- #
136
- # pred_single_bboxs = pred_bboxs[i]
137
- # target_single_bboxs = target_bboxs[i]
138
- #
139
- # # 是否存在置信度(如果存在,则target的置信度必然大于0)
140
- # is_obj = target_single_confidences[0] > 0
141
- # # 计算置信度损失 假定该网格不存在对象
142
- # total_loss += self.noobj * self.sum_squared_error(pred_single_confidences, target_single_confidences)
143
- # print(total_loss)
144
- # if is_obj:
145
- # print('i = %d' % (i))
146
- # # 如果存在
147
- # # 计算分类损失
148
- # total_loss += self.sum_squared_error(pred_single_probs, target_single_probs)
149
- # print(total_loss)
150
- #
151
- # # 计算所有预测边界框和标注边界框的IoU
152
- # pred_single_bboxs = pred_single_bboxs.reshape(-1, 4)
153
- # target_single_bboxs = target_single_bboxs.reshape(-1, 4)
154
- #
155
- # scores = self.iou(pred_single_bboxs, target_single_bboxs)
156
- # # 提取IoU最大的下标
157
- # bbox_idx = torch.argmax(scores)
158
- # # 计算置信度损失
159
- # total_loss += (1 - self.noobj) * \
160
- # self.sum_squared_error(pred_single_confidences[bbox_idx],
161
- # target_single_confidences[bbox_idx])
162
- # print(total_loss)
163
- # # 计算边界框损失
164
- # total_loss += self.coord * self.bbox_loss(pred_single_bboxs[bbox_idx].reshape(-1, 4),
165
- # target_single_bboxs[bbox_idx].reshape(-1, 4))
166
- # print(total_loss)
167
- #
168
- # print('done')
169
- #
170
- # return total_loss / N
171
-
172
233
def sum_squared_error (self , preds , targets ):
173
234
return torch .sum ((preds - targets ) ** 2 )
174
235
@@ -241,9 +302,6 @@ def load_data(data_root_dir, cate_list, S=7, B=2, C=20):
241
302
for inputs , labels in data_loader :
242
303
inputs = inputs
243
304
labels = labels
244
- print (inputs .shape )
245
- print (labels .shape )
246
-
247
305
with torch .set_grad_enabled (False ):
248
306
outputs = model (inputs )
249
307
loss = criterion (outputs , labels )
0 commit comments