Skip to content

Commit 28b2c09

Browse files
committed
perf(train): 更新训练参数
1 parent 4583c12 commit 28b2c09

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

py/lib/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def train_model(data_loader, model, criterion, optimizer, lr_scheduler, num_epoc
115115
model = YOLO_v1(S=S, B=B, C=C)
116116
model = model.to(device)
117117

118-
criterion = MultiPartLoss(S=S, B=B, C=C)
118+
criterion = MultiPartLoss(448, 448, S=S, B=B, C=C)
119119
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
120-
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.9)
120+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.96)
121121

122122
file.make_dir('../models')
123123
train_model(data_loader, model, criterion, optimizer, lr_scheduler, num_epochs=50, device=device)

0 commit comments

Comments
 (0)