From c303ecaf21aa8ae0d7d854ede7aacf05856624f3 Mon Sep 17 00:00:00 2001 From: yayoimizuha Date: Thu, 4 May 2023 19:45:59 +0900 Subject: [PATCH] update --- .drone.yml | 2 +- resnet_finetune_vggface.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.drone.yml b/.drone.yml index b1e1c2f..b2624eb 100644 --- a/.drone.yml +++ b/.drone.yml @@ -15,7 +15,7 @@ steps: from_secret: mount_command commands: - python -m pip install --upgrade pip - - pip install torchsummary matplotlib + - pip install torchsummary matplotlib pytorch-metric-learning - ls ./ - mkdir -p data - $mount_command diff --git a/resnet_finetune_vggface.py b/resnet_finetune_vggface.py index 4f4e0e9..20891a0 100644 --- a/resnet_finetune_vggface.py +++ b/resnet_finetune_vggface.py @@ -5,6 +5,7 @@ from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, T RandomHorizontalFlip, \ Resize, CenterCrop, RandomAffine, GaussianBlur, RandomAutocontrast import matplotlib + matplotlib.use('Agg') import matplotlib.pyplot as plt from numpy import arange, ndarray, ceil, full, uint8 @@ -20,11 +21,14 @@ from os.path import join from torch.cuda import is_available from torch import no_grad, save, Tensor from datetime import datetime +from pytorch_metric_learning.losses import ArcFaceLoss +from pytorch_metric_learning.distances import CosineSimilarity +from pytorch_metric_learning.regularizers import RegularFaceRegularizer device = 'cuda' if is_available() else 'cpu' transform = { 'train': Compose([ - Resize(448), + Resize(350), CenterCrop(224), RandomHorizontalFlip(p=0.1), GaussianBlur(kernel_size=3), @@ -35,7 +39,7 @@ transform = { RandomResizedCrop(size=224, scale=(0.7, 1.0), ratio=(1.0, 1.0), antialias=True) ]), 'val': Compose([ - Resize(448), + Resize(350), CenterCrop(224), ToTensor(), RandomAffine(scale=(0.8, 0.8), degrees=(0, 0)), @@ -48,8 +52,8 @@ image_folder = { } dataloader = { - 'train': DataLoader(image_folder['train'], batch_size=64, shuffle=True, num_workers=3), - 'val': DataLoader(image_folder['val'], batch_size=64, shuffle=False, num_workers=3) + 'train': DataLoader(image_folder['train'], batch_size=32, shuffle=True, num_workers=3), + 'val': DataLoader(image_folder['val'], batch_size=32, shuffle=False, num_workers=3) } @@ -153,7 +157,8 @@ for epoch in range(epochs): for count, (images, labels) in enumerate(tqdm(dataloader['train'])): if count == 1: - image_pallets = plot_dataset(dataloader=(images, labels), col_len=8, label_text=image_folder['train'].classes) + image_pallets = plot_dataset(dataloader=(images, labels), col_len=6, + label_text=image_folder['train'].classes) image_pallets.save(join(save_dir, 'pallets', str(epoch), 'pallet.jpg')) optimizer.zero_grad() images = images.to(device)