update
continuous-integration/drone/push Build is failing
Details
continuous-integration/drone/push Build is failing
Details
This commit is contained in:
parent
fb38a2bb39
commit
c303ecaf21
|
|
@ -15,7 +15,7 @@ steps:
|
||||||
from_secret: mount_command
|
from_secret: mount_command
|
||||||
commands:
|
commands:
|
||||||
- python -m pip install --upgrade pip
|
- python -m pip install --upgrade pip
|
||||||
- pip install torchsummary matplotlib
|
- pip install torchsummary matplotlib pytorch-metric-learning
|
||||||
- ls ./
|
- ls ./
|
||||||
- mkdir -p data
|
- mkdir -p data
|
||||||
- $mount_command
|
- $mount_command
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, T
|
||||||
RandomHorizontalFlip, \
|
RandomHorizontalFlip, \
|
||||||
Resize, CenterCrop, RandomAffine, GaussianBlur, RandomAutocontrast
|
Resize, CenterCrop, RandomAffine, GaussianBlur, RandomAutocontrast
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from numpy import arange, ndarray, ceil, full, uint8
|
from numpy import arange, ndarray, ceil, full, uint8
|
||||||
|
|
@ -20,11 +21,14 @@ from os.path import join
|
||||||
from torch.cuda import is_available
|
from torch.cuda import is_available
|
||||||
from torch import no_grad, save, Tensor
|
from torch import no_grad, save, Tensor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pytorch_metric_learning.losses import ArcFaceLoss
|
||||||
|
from pytorch_metric_learning.distances import CosineSimilarity
|
||||||
|
from pytorch_metric_learning.regularizers import RegularFaceRegularizer
|
||||||
|
|
||||||
device = 'cuda' if is_available() else 'cpu'
|
device = 'cuda' if is_available() else 'cpu'
|
||||||
transform = {
|
transform = {
|
||||||
'train': Compose([
|
'train': Compose([
|
||||||
Resize(448),
|
Resize(350),
|
||||||
CenterCrop(224),
|
CenterCrop(224),
|
||||||
RandomHorizontalFlip(p=0.1),
|
RandomHorizontalFlip(p=0.1),
|
||||||
GaussianBlur(kernel_size=3),
|
GaussianBlur(kernel_size=3),
|
||||||
|
|
@ -35,7 +39,7 @@ transform = {
|
||||||
RandomResizedCrop(size=224, scale=(0.7, 1.0), ratio=(1.0, 1.0), antialias=True)
|
RandomResizedCrop(size=224, scale=(0.7, 1.0), ratio=(1.0, 1.0), antialias=True)
|
||||||
]),
|
]),
|
||||||
'val': Compose([
|
'val': Compose([
|
||||||
Resize(448),
|
Resize(350),
|
||||||
CenterCrop(224),
|
CenterCrop(224),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
RandomAffine(scale=(0.8, 0.8), degrees=(0, 0)),
|
RandomAffine(scale=(0.8, 0.8), degrees=(0, 0)),
|
||||||
|
|
@ -48,8 +52,8 @@ image_folder = {
|
||||||
}
|
}
|
||||||
|
|
||||||
dataloader = {
|
dataloader = {
|
||||||
'train': DataLoader(image_folder['train'], batch_size=64, shuffle=True, num_workers=3),
|
'train': DataLoader(image_folder['train'], batch_size=32, shuffle=True, num_workers=3),
|
||||||
'val': DataLoader(image_folder['val'], batch_size=64, shuffle=False, num_workers=3)
|
'val': DataLoader(image_folder['val'], batch_size=32, shuffle=False, num_workers=3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -153,7 +157,8 @@ for epoch in range(epochs):
|
||||||
|
|
||||||
for count, (images, labels) in enumerate(tqdm(dataloader['train'])):
|
for count, (images, labels) in enumerate(tqdm(dataloader['train'])):
|
||||||
if count == 1:
|
if count == 1:
|
||||||
image_pallets = plot_dataset(dataloader=(images, labels), col_len=8, label_text=image_folder['train'].classes)
|
image_pallets = plot_dataset(dataloader=(images, labels), col_len=6,
|
||||||
|
label_text=image_folder['train'].classes)
|
||||||
image_pallets.save(join(save_dir, 'pallets', str(epoch), 'pallet.jpg'))
|
image_pallets.save(join(save_dir, 'pallets', str(epoch), 'pallet.jpg'))
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue