apply PyTorch AMP
continuous-integration/drone/push Build was killed
Details
continuous-integration/drone/push Build was killed
Details
This commit is contained in:
parent
1d943d029b
commit
af154fed8d
|
|
@ -1,6 +1,6 @@
|
||||||
from os import makedirs, environ
|
from os import makedirs, environ
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
from torch.nn import Linear, Sequential, Dropout, CrossEntropyLoss, Identity, ReLU
|
from torch.nn import Linear, Sequential, Dropout, CrossEntropyLoss, Identity, ReLU, utils
|
||||||
from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, ToTensor, \
|
from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, ToTensor, \
|
||||||
RandomHorizontalFlip, \
|
RandomHorizontalFlip, \
|
||||||
Resize, RandomAutocontrast, InterpolationMode, RandomErasing, \
|
Resize, RandomAutocontrast, InterpolationMode, RandomErasing, \
|
||||||
|
|
@ -167,6 +167,8 @@ for epoch in range(epochs):
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scaler.unscale_(optimizer=optimizer)
|
scaler.unscale_(optimizer=optimizer)
|
||||||
|
|
||||||
|
utils.clip_grad_norm_(model.parameters(), max_norm=.5)
|
||||||
# loss.backward()
|
# loss.backward()
|
||||||
scaler.step(optimizer=optimizer)
|
scaler.step(optimizer=optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue