apply PyTorch AMP
continuous-integration/drone/push Build is passing Details

This commit is contained in:
yayoimizuha 2023-10-15 23:45:40 +09:00
parent d78e67cbb4
commit 1e7cf01e0d
1 changed files with 2 additions and 2 deletions

View File

@ -166,9 +166,9 @@ for epoch in range(epochs):
train_loss += loss.item() train_loss += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer=optimizer)
utils.clip_grad_norm_(model.parameters(), max_norm=.1) # scaler.unscale_(optimizer=optimizer)
# utils.clip_grad_norm_(model.parameters(), max_norm=.1)
# loss.backward() # loss.backward()
scaler.step(optimizer=optimizer) scaler.step(optimizer=optimizer)
scaler.update() scaler.update()