From d78e67cbb4b779a84ddd43fac436c891896fe83e Mon Sep 17 00:00:00 2001 From: yayoimizuha Date: Sun, 15 Oct 2023 23:18:55 +0900 Subject: [PATCH] apply PyTorch AMP --- finetune/facenet_transfer_learning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/finetune/facenet_transfer_learning.py b/finetune/facenet_transfer_learning.py index 7a25861..76054c7 100644 --- a/finetune/facenet_transfer_learning.py +++ b/finetune/facenet_transfer_learning.py @@ -19,7 +19,7 @@ from PIL import Image, ImageDraw, ImageFont from settings import datadir from os.path import join from torch.cuda import is_available -from torch import no_grad, save, Tensor, load, device, float16 +from torch import no_grad, save, Tensor, load, device, float16, float32 from datetime import datetime from distutils.util import strtobool @@ -126,7 +126,7 @@ optimizer = Adam(params=[ {'params': model_gpu[1].parameters(), 'lr': 1e-3}, ]) -scaler = GradScaler(init_scale=4096) +scaler = GradScaler() # model, optimizer = optimize(model=model, optimizer=optimizer) scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.9) @@ -158,7 +158,7 @@ for epoch in range(epochs): label_text=image_folder['train'].classes) image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_train.jpg')) optimizer.zero_grad() - with autocast(dtype=float16, enabled=True): + with autocast(dtype=float32, enabled=True): images = images.to(device) labels = labels.to(device) outputs = model(images)