apply PyTorch AMP
continuous-integration/drone/push Build was killed Details

This commit is contained in:
yayoimizuha 2023-10-15 22:31:41 +09:00
parent eeec9ba831
commit 143510a83f
2 changed files with 9 additions and 10 deletions

View File

@ -20,8 +20,8 @@ steps:
- pip install -q torchsummary matplotlib pytorch-metric-learning torchinfo torch torchvision tqdm Pillow facenet-pytorch - pip install -q torchsummary matplotlib pytorch-metric-learning torchinfo torch torchvision tqdm Pillow facenet-pytorch
- apt update -qq - apt update -qq
- apt install fonts-noto-cjk-extra -y -qq - apt install fonts-noto-cjk-extra -y -qq
- ls ./ # - ls ./
- mkdir -p data - mkdir -p data
- $mount_command - $mount_command
- ls data/ # - ls data/
- CI=True python finetune/facenet_transfer_learning.py - CI=True python finetune/facenet_transfer_learning.py

View File

@ -158,13 +158,11 @@ for epoch in range(epochs):
label_text=image_folder['train'].classes) label_text=image_folder['train'].classes)
image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_train.jpg')) image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_train.jpg'))
optimizer.zero_grad() optimizer.zero_grad()
with autocast(dtype=float16): with autocast(dtype=float16,enabled=True):
images = images.to(device) images = images.to(device)
labels = labels.to(device) labels = labels.to(device)
outputs = model(images)
outputs = model(images) loss = criterion(outputs, labels)
loss = criterion(outputs, labels)
train_loss += loss.item() train_loss += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
@ -187,8 +185,9 @@ for epoch in range(epochs):
image_pallets = plot_dataset(dataloader=(images, labels), col_len=6, image_pallets = plot_dataset(dataloader=(images, labels), col_len=6,
label_text=image_folder['train'].classes) label_text=image_folder['train'].classes)
image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_val.jpg')) image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_val.jpg'))
images = images.to(device) with autocast(dtype=float16, enabled=True):
labels = labels.to(device) images = images.to(device)
labels = labels.to(device)
outputs = model_gpu(images) outputs = model_gpu(images)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
val_loss += loss.item() val_loss += loss.item()