Skip to content

Commit a59ac19

Browse files
committed
Add args.validation_interval to config. Use args.validation_interval to skip validation and stats logging at every epoch.
1 parent 9fd9789 commit a59ac19

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

rfdetr/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class TrainConfig(BaseModel):
130130
ema_tau: int = 100
131131
lr_drop: int = 100
132132
checkpoint_interval: int = 10
133+
validation_interval: int = 1
133134
warmup_epochs: float = 0.0
134135
lr_vit_layer_decay: float = 0.8
135136
lr_component_decay: float = 0.7

rfdetr/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,11 @@ def lr_lambda(current_step: int):
368368

369369
utils.save_on_master(weights, checkpoint_path)
370370

371+
# Run validation and log stats every validation_interval
372+
if (epoch + 1) % args.validation_interval != 0:
373+
print(f"Skipping epoch {epoch} for evaluation. validation_interval = {args.validation_interval}")
374+
continue
375+
371376
with torch.inference_mode():
372377
test_stats, coco_evaluator = evaluate(
373378
model, criterion, postprocess, data_loader_val, base_ds, device, args=args
@@ -946,6 +951,7 @@ def populate_args(
946951
output_dir='output',
947952
dont_save_weights=False,
948953
checkpoint_interval=10,
954+
validation_interval=1,
949955
seed=42,
950956
resume='',
951957
start_epoch=0,
@@ -1053,6 +1059,7 @@ def populate_args(
10531059
output_dir=output_dir,
10541060
dont_save_weights=dont_save_weights,
10551061
checkpoint_interval=checkpoint_interval,
1062+
validation_interval=validation_interval,
10561063
seed=seed,
10571064
resume=resume,
10581065
start_epoch=start_epoch,

0 commit comments

Comments
 (0)