Skip to content

Commit 75d2760

Browse files
committed
perf(file): 更新目录检查和创建功能
1 parent b86fba7 commit 75d2760

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

py/batch_detect.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
dst_img_dir = os.path.join(dst_root_dir, 'imgs')
3232
tmp_json_dir = os.path.join(dst_root_dir, '.tmp_files')
3333

34-
file.check_dir(dst_root_dir)
35-
file.check_dir(dst_target_dir)
36-
file.check_dir(dst_pred_dir)
37-
file.check_dir(dst_img_dir)
38-
file.check_dir(tmp_json_dir)
34+
file.make_dir(dst_root_dir, is_rm=True)
35+
file.make_dir(dst_target_dir, is_rm=True)
36+
file.make_dir(dst_pred_dir, is_rm=True)
37+
file.make_dir(dst_img_dir, is_rm=True)
38+
file.make_dir(tmp_json_dir, is_rm=True)
3939

4040

4141
def get_transform():

py/lib/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def train_model(data_loader, model, criterion, optimizer, lr_scheduler, num_epoc
9292
best_model_weights = copy.deepcopy(model.cpu().state_dict())
9393
model = model.to(device)
9494

95-
file.check_dir('../models')
9695
file.save_model(best_model_weights, '../models/checkpoint_yolo_v1_%d.pth' % (epoch))
9796
print('save model')
9897

@@ -118,6 +117,7 @@ def train_model(data_loader, model, criterion, optimizer, lr_scheduler, num_epoc
118117

119118
criterion = MultiPartLoss(S=S, B=B, C=C)
120119
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
121-
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.96)
120+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.9)
122121

122+
file.make_dir('../models')
123123
train_model(data_loader, model, criterion, optimizer, lr_scheduler, num_epochs=50, device=device)

py/lib/utils/file.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
import glob
1717

1818

19-
def check_dir(data_dir):
20-
if os.path.exists(data_dir):
21-
shutil.rmtree(data_dir)
22-
os.mkdir(data_dir)
19+
def make_dir(data_dir, is_rm=False):
20+
if is_rm:
21+
if os.path.exists(data_dir):
22+
shutil.rmtree(data_dir)
23+
if not os.path.exists(data_dir):
24+
os.mkdir(data_dir)
2325

2426

2527
def parse_location_xml(xml_path):

0 commit comments

Comments
 (0)