apply PyTorch AMP
continuous-integration/drone/push Build was killed
Details
continuous-integration/drone/push Build was killed
Details
This commit is contained in:
parent
1d943d029b
commit
af154fed8d
|
|
@ -1,6 +1,6 @@
|
|||
from os import makedirs, environ
|
||||
from torchinfo import summary
|
||||
from torch.nn import Linear, Sequential, Dropout, CrossEntropyLoss, Identity, ReLU
|
||||
from torch.nn import Linear, Sequential, Dropout, CrossEntropyLoss, Identity, ReLU, utils
|
||||
from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, ToTensor, \
|
||||
RandomHorizontalFlip, \
|
||||
Resize, RandomAutocontrast, InterpolationMode, RandomErasing, \
|
||||
|
|
@ -167,6 +167,8 @@ for epoch in range(epochs):
|
|||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer=optimizer)
|
||||
|
||||
utils.clip_grad_norm_(model.parameters(), max_norm=.5)
|
||||
# loss.backward()
|
||||
scaler.step(optimizer=optimizer)
|
||||
scaler.update()
|
||||
|
|
|
|||
Loading…
Reference in New Issue