update
continuous-integration/drone/push Build was killed Details

This commit is contained in:
yayoimizuha 2023-05-04 20:14:50 +09:00
parent 5f7fd1eb1f
commit 7ca0bfa03c
2 changed files with 5 additions and 3 deletions

View File

@ -22,4 +22,4 @@ steps:
- mkdir -p data - mkdir -p data
- $mount_command - $mount_command
- ls data/ - ls data/
- python resnet_finetune_vggface.py - CI=True python resnet_finetune_vggface.py

View File

@ -1,4 +1,4 @@
from os import makedirs from os import makedirs, environ
from torchvision.models import ResNet50_Weights, resnet50 from torchvision.models import ResNet50_Weights, resnet50
from torch.nn import Linear from torch.nn import Linear
from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, ToTensor, \ from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, ToTensor, \
@ -24,7 +24,9 @@ from datetime import datetime
from pytorch_metric_learning.losses import ArcFaceLoss from pytorch_metric_learning.losses import ArcFaceLoss
from pytorch_metric_learning.distances import CosineSimilarity from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.regularizers import RegularFaceRegularizer from pytorch_metric_learning.regularizers import RegularFaceRegularizer
from distutils.util import strtobool
CI = bool(strtobool(environ['CI']))
device = 'cuda' if is_available() else 'cpu' device = 'cuda' if is_available() else 'cpu'
transform = { transform = {
'train': Compose([ 'train': Compose([
@ -155,7 +157,7 @@ for epoch in range(epochs):
model_gpu.train() model_gpu.train()
makedirs(join(save_dir, 'pallets', str(epoch)), exist_ok=True) makedirs(join(save_dir, 'pallets', str(epoch)), exist_ok=True)
for count, (images, labels) in enumerate(tqdm(dataloader['train'])): for count, (images, labels) in enumerate(tqdm(dataloader['train'], disable=CI)):
if count == 1: if count == 1:
image_pallets = plot_dataset(dataloader=(images, labels), col_len=6, image_pallets = plot_dataset(dataloader=(images, labels), col_len=6,
label_text=image_folder['train'].classes) label_text=image_folder['train'].classes)