apply PyTorch AMP
continuous-integration/drone/push Build was killed Details

This commit is contained in:
yayoimizuha 2023-10-15 22:43:07 +09:00
parent 143510a83f
commit 4d9595fe5e
1 changed files with 2 additions and 2 deletions

View File

@ -126,7 +126,7 @@ optimizer = Adam(params=[
{'params': model_gpu[1].parameters(), 'lr': 1e-3},
])
scaler = GradScaler()
scaler = GradScaler(init_scale=4096)
# model, optimizer = optimize(model=model, optimizer=optimizer)
scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.9)
@ -158,7 +158,7 @@ for epoch in range(epochs):
label_text=image_folder['train'].classes)
image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_train.jpg'))
optimizer.zero_grad()
with autocast(dtype=float16,enabled=True):
with autocast(dtype=float16, enabled=True):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)