From eeec9ba8313597521226c3efb8059f26f3a4c358 Mon Sep 17 00:00:00 2001 From: yayoimizuha Date: Sun, 15 Oct 2023 22:25:29 +0900 Subject: [PATCH] apply PyTorch AMP --- .drone.yml | 2 +- finetune/facenet_transfer_learning.py | 20 +++++-- inference_all.py | 84 +++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 inference_all.py diff --git a/.drone.yml b/.drone.yml index b544777..2ca7c02 100644 --- a/.drone.yml +++ b/.drone.yml @@ -24,4 +24,4 @@ steps: - mkdir -p data - $mount_command - ls data/ - - # CI=True python finetune/facenet_transfer_learning.py \ No newline at end of file + - CI=True python finetune/facenet_transfer_learning.py \ No newline at end of file diff --git a/finetune/facenet_transfer_learning.py b/finetune/facenet_transfer_learning.py index c4003d0..b543b5b 100644 --- a/finetune/facenet_transfer_learning.py +++ b/finetune/facenet_transfer_learning.py @@ -1,9 +1,9 @@ from os import makedirs, environ from torchinfo import summary -from torch.nn import Linear, Sequential, Dropout, CrossEntropyLoss, Identity, ReLU +from torch.nn import Linear, Sequential, Dropout, CrossEntropyLoss, Identity, ReLU from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, ToTensor, \ RandomHorizontalFlip, \ - Resize, RandomAutocontrast, InterpolationMode, RandomErasing, \ + Resize, RandomAutocontrast, InterpolationMode, RandomErasing, \ RandomEqualize, RandomPosterize, RandomPerspective, RandomGrayscale import matplotlib @@ -13,12 +13,13 @@ from numpy import arange, ndarray, ceil, full, uint8 from torch.optim import SGD, Adam, lr_scheduler from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader +from torch.cuda.amp import GradScaler, autocast from tqdm import tqdm from PIL import Image, ImageDraw, ImageFont from settings import datadir from os.path import join from torch.cuda import is_available -from torch import no_grad, save, Tensor, load, device +from torch import no_grad, save, Tensor, load, device, float16 from datetime import datetime from distutils.util import strtobool @@ -125,6 +126,8 @@ optimizer = Adam(params=[ {'params': model_gpu[1].parameters(), 'lr': 1e-3}, ]) +scaler = GradScaler() + # model, optimizer = optimize(model=model, optimizer=optimizer) scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.9) epochs = 100 @@ -155,7 +158,8 @@ for epoch in range(epochs): label_text=image_folder['train'].classes) image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_train.jpg')) optimizer.zero_grad() - images = images.to(device) + with autocast(dtype=float16): + images = images.to(device) labels = labels.to(device) outputs = model(images) @@ -163,8 +167,12 @@ for epoch in range(epochs): loss = criterion(outputs, labels) train_loss += loss.item() - loss.backward() - optimizer.step() + scaler.scale(loss).backward() + scaler.unscale_(optimizer=optimizer) + # loss.backward() + scaler.step(optimizer=optimizer) + scaler.update() + # optimizer.step() predicted = outputs.max(1)[1] train_acc += (predicted == labels).sum() diff --git a/inference_all.py b/inference_all.py new file mode 100644 index 0000000..9d93c1f --- /dev/null +++ b/inference_all.py @@ -0,0 +1,84 @@ +import time +from os import makedirs +from os.path import join, exists, basename +from shutil import rmtree, copyfile +from more_itertools import chunked +from torch import load, no_grad, device, randn, jit, float64, float32, float16 +from torch.cuda import is_available +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder +from torchvision.transforms import Compose, ToTensor, Resize, CenterCrop +from torchinfo import summary +from tqdm import tqdm +from settings import datadir +from concurrent.futures import ThreadPoolExecutor +from pandas import DataFrame +from seaborn import heatmap, color_palette, set_palette +from matplotlib import pyplot +from japanize_matplotlib import japanize +from torch_tensorrt import compile + +device = device('cuda' if is_available() else 'cpu') +# device = 'cpu' +print(f'device: {device}') +model_path: str = join(datadir(), 'artifact', 'facenet-tl_2023-10-15 07:08:44.537055', 'model.pth') +print(f'model path: {model_path}') +input_shape: int = 256 +batch_size = 64 +source_dir = join(datadir(), 'face_cropped') +print(f'judge file: {source_dir}') +dest_dir = join(datadir(), 'infer_all') +image_class = ImageFolder(root=join(datadir(), 'dataset', 'train')).classes +with open(join(datadir(), 'class_text'), mode='w') as f: + f.write(str(image_class)) +rmtree(dest_dir, ignore_errors=True) +makedirs(dest_dir) + +transform = Compose([Resize(size=256), ToTensor()]) +image_folder = ImageFolder(root=source_dir, transform=transform) +dataloader = DataLoader(image_folder, batch_size=batch_size, shuffle=False, num_workers=8) + +model = load(f=model_path) +model = model.to(device) +model.eval() +for layer in model.parameters(): + layer.requires_grad = False + +# summary(model=model, input_size=(batch_size, 3, input_shape, input_shape), device=device) + +if exists(join(datadir(), 'infer_all_torch_trt.ts')): + trt_model = jit.load(join(datadir(), 'infer_all_torch_trt.ts')) +else: + + example_input = randn(size=[batch_size, 3, 256, 256]).float().cuda() + traced_script_module = jit.trace(model, example_inputs=[example_input]) + trt_model = compile(module=traced_script_module, inputs=[example_input], + enabled_precisions={float32, float16}, + truncate_long_and_double=True) + jit.save(trt_model, join(datadir(), 'infer_all_torch_trt.ts')) + +# heatmap_df = DataFrame(index=image_class, columns=image_folder.classes).fillna(0) +begin = time.time() +with ThreadPoolExecutor(max_workers=60) as executor, no_grad(): + for (images, labels), fileinfo in zip(tqdm(dataloader), chunked(image_folder.imgs, n=batch_size)): + # print(labels, fileinfo) + res = trt_model(images.to(device)) + for name, (filename, person) in zip(res.to(device).max(1).indices.tolist(), fileinfo): + if not exists(join(dest_dir, image_class[name])): + makedirs(join(dest_dir, image_class[name]), exist_ok=True) + # print(name, filename, person) + # copyfile(src=filename, + # dst=join(dest_dir, image_folder.classes[name], basename(filename))) + # if image_class[name] != image_folder.classes[person]: + # heatmap_df[image_folder.classes[person]][image_class[name]] += 1 + executor.submit(copyfile, filename, join(dest_dir, image_class[name], basename(filename))) +print(f"{time.time() - begin:5f}sec") +# print(heatmap_df) +# set_palette('Blues') +# pyplot.figure(figsize=(40, 40)) +# heat_img = heatmap(heatmap_df, cmap='Blues', linewidths=1) +# japanize() +# heatmap_df.max() +# pyplot.savefig(join(dest_dir, 'confusion_matrix.png')) +# print(f'acc: {1 - heatmap_df.to_numpy().flatten().sum() / image_folder.__len__()}') +print(image_folder.classes)