apply PyTorch AMP
continuous-integration/drone/push Build was killed Details

This commit is contained in:
yayoimizuha 2023-10-15 22:58:30 +09:00
parent 1d943d029b
commit af154fed8d
1 changed files with 3 additions and 1 deletions

View File

@ -1,6 +1,6 @@
from os import makedirs, environ from os import makedirs, environ
from torchinfo import summary from torchinfo import summary
from torch.nn import Linear, Sequential, Dropout, CrossEntropyLoss, Identity, ReLU from torch.nn import Linear, Sequential, Dropout, CrossEntropyLoss, Identity, ReLU, utils
from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, ToTensor, \ from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, ToTensor, \
RandomHorizontalFlip, \ RandomHorizontalFlip, \
Resize, RandomAutocontrast, InterpolationMode, RandomErasing, \ Resize, RandomAutocontrast, InterpolationMode, RandomErasing, \
@ -167,6 +167,8 @@ for epoch in range(epochs):
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer=optimizer) scaler.unscale_(optimizer=optimizer)
utils.clip_grad_norm_(model.parameters(), max_norm=.5)
# loss.backward() # loss.backward()
scaler.step(optimizer=optimizer) scaler.step(optimizer=optimizer)
scaler.update() scaler.update()