apply PyTorch AMP
continuous-integration/drone/push Build was killed Details

This commit is contained in:
yayoimizuha 2023-10-15 22:25:29 +09:00
parent b455a5bf3e
commit eeec9ba831
3 changed files with 99 additions and 7 deletions

View File

@ -24,4 +24,4 @@ steps:
- mkdir -p data - mkdir -p data
- $mount_command - $mount_command
- ls data/ - ls data/
- # CI=True python finetune/facenet_transfer_learning.py - CI=True python finetune/facenet_transfer_learning.py

View File

@ -13,12 +13,13 @@ from numpy import arange, ndarray, ceil, full, uint8
from torch.optim import SGD, Adam, lr_scheduler from torch.optim import SGD, Adam, lr_scheduler
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from settings import datadir from settings import datadir
from os.path import join from os.path import join
from torch.cuda import is_available from torch.cuda import is_available
from torch import no_grad, save, Tensor, load, device from torch import no_grad, save, Tensor, load, device, float16
from datetime import datetime from datetime import datetime
from distutils.util import strtobool from distutils.util import strtobool
@ -125,6 +126,8 @@ optimizer = Adam(params=[
{'params': model_gpu[1].parameters(), 'lr': 1e-3}, {'params': model_gpu[1].parameters(), 'lr': 1e-3},
]) ])
scaler = GradScaler()
# model, optimizer = optimize(model=model, optimizer=optimizer) # model, optimizer = optimize(model=model, optimizer=optimizer)
scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.9) scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.9)
epochs = 100 epochs = 100
@ -155,6 +158,7 @@ for epoch in range(epochs):
label_text=image_folder['train'].classes) label_text=image_folder['train'].classes)
image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_train.jpg')) image_pallets.save(join(save_dir, 'pallets', str(epoch) + '_train.jpg'))
optimizer.zero_grad() optimizer.zero_grad()
with autocast(dtype=float16):
images = images.to(device) images = images.to(device)
labels = labels.to(device) labels = labels.to(device)
@ -163,8 +167,12 @@ for epoch in range(epochs):
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
train_loss += loss.item() train_loss += loss.item()
loss.backward() scaler.scale(loss).backward()
optimizer.step() scaler.unscale_(optimizer=optimizer)
# loss.backward()
scaler.step(optimizer=optimizer)
scaler.update()
# optimizer.step()
predicted = outputs.max(1)[1] predicted = outputs.max(1)[1]
train_acc += (predicted == labels).sum() train_acc += (predicted == labels).sum()

84
inference_all.py Normal file
View File

@ -0,0 +1,84 @@
import time
from os import makedirs
from os.path import join, exists, basename
from shutil import rmtree, copyfile
from more_itertools import chunked
from torch import load, no_grad, device, randn, jit, float64, float32, float16
from torch.cuda import is_available
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Resize, CenterCrop
from torchinfo import summary
from tqdm import tqdm
from settings import datadir
from concurrent.futures import ThreadPoolExecutor
from pandas import DataFrame
from seaborn import heatmap, color_palette, set_palette
from matplotlib import pyplot
from japanize_matplotlib import japanize
from torch_tensorrt import compile
device = device('cuda' if is_available() else 'cpu')
# device = 'cpu'
print(f'device: {device}')
model_path: str = join(datadir(), 'artifact', 'facenet-tl_2023-10-15 07:08:44.537055', 'model.pth')
print(f'model path: {model_path}')
input_shape: int = 256
batch_size = 64
source_dir = join(datadir(), 'face_cropped')
print(f'judge file: {source_dir}')
dest_dir = join(datadir(), 'infer_all')
image_class = ImageFolder(root=join(datadir(), 'dataset', 'train')).classes
with open(join(datadir(), 'class_text'), mode='w') as f:
f.write(str(image_class))
rmtree(dest_dir, ignore_errors=True)
makedirs(dest_dir)
transform = Compose([Resize(size=256), ToTensor()])
image_folder = ImageFolder(root=source_dir, transform=transform)
dataloader = DataLoader(image_folder, batch_size=batch_size, shuffle=False, num_workers=8)
model = load(f=model_path)
model = model.to(device)
model.eval()
for layer in model.parameters():
layer.requires_grad = False
# summary(model=model, input_size=(batch_size, 3, input_shape, input_shape), device=device)
if exists(join(datadir(), 'infer_all_torch_trt.ts')):
trt_model = jit.load(join(datadir(), 'infer_all_torch_trt.ts'))
else:
example_input = randn(size=[batch_size, 3, 256, 256]).float().cuda()
traced_script_module = jit.trace(model, example_inputs=[example_input])
trt_model = compile(module=traced_script_module, inputs=[example_input],
enabled_precisions={float32, float16},
truncate_long_and_double=True)
jit.save(trt_model, join(datadir(), 'infer_all_torch_trt.ts'))
# heatmap_df = DataFrame(index=image_class, columns=image_folder.classes).fillna(0)
begin = time.time()
with ThreadPoolExecutor(max_workers=60) as executor, no_grad():
for (images, labels), fileinfo in zip(tqdm(dataloader), chunked(image_folder.imgs, n=batch_size)):
# print(labels, fileinfo)
res = trt_model(images.to(device))
for name, (filename, person) in zip(res.to(device).max(1).indices.tolist(), fileinfo):
if not exists(join(dest_dir, image_class[name])):
makedirs(join(dest_dir, image_class[name]), exist_ok=True)
# print(name, filename, person)
# copyfile(src=filename,
# dst=join(dest_dir, image_folder.classes[name], basename(filename)))
# if image_class[name] != image_folder.classes[person]:
# heatmap_df[image_folder.classes[person]][image_class[name]] += 1
executor.submit(copyfile, filename, join(dest_dir, image_class[name], basename(filename)))
print(f"{time.time() - begin:5f}sec")
# print(heatmap_df)
# set_palette('Blues')
# pyplot.figure(figsize=(40, 40))
# heat_img = heatmap(heatmap_df, cmap='Blues', linewidths=1)
# japanize()
# heatmap_df.max()
# pyplot.savefig(join(dest_dir, 'confusion_matrix.png'))
# print(f'acc: {1 - heatmap_df.to_numpy().flatten().sum() / image_folder.__len__()}')
print(image_folder.classes)