85 lines
3.6 KiB
Python
85 lines
3.6 KiB
Python
import time
|
|
from os import makedirs
|
|
from os.path import join, exists, basename
|
|
from shutil import rmtree, copyfile
|
|
from more_itertools import chunked
|
|
from torch import load, no_grad, device, randn, jit, float64, float32, float16
|
|
from torch.cuda import is_available
|
|
from torch.utils.data import DataLoader
|
|
from torchvision.datasets import ImageFolder
|
|
from torchvision.transforms import Compose, ToTensor, Resize, CenterCrop
|
|
from torchinfo import summary
|
|
from tqdm import tqdm
|
|
from settings import datadir
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pandas import DataFrame
|
|
from seaborn import heatmap, color_palette, set_palette
|
|
from matplotlib import pyplot
|
|
from japanize_matplotlib import japanize
|
|
from torch_tensorrt import compile
|
|
|
|
device = device('cuda' if is_available() else 'cpu')
|
|
# device = 'cpu'
|
|
print(f'device: {device}')
|
|
model_path: str = join(datadir(), 'artifact', 'facenet-tl_2023-10-15 07:08:44.537055', 'model.pth')
|
|
print(f'model path: {model_path}')
|
|
input_shape: int = 256
|
|
batch_size = 64
|
|
source_dir = join(datadir(), 'face_cropped')
|
|
print(f'judge file: {source_dir}')
|
|
dest_dir = join(datadir(), 'infer_all')
|
|
image_class = ImageFolder(root=join(datadir(), 'dataset', 'train')).classes
|
|
with open(join(datadir(), 'class_text'), mode='w') as f:
|
|
f.write(str(image_class))
|
|
rmtree(dest_dir, ignore_errors=True)
|
|
makedirs(dest_dir)
|
|
|
|
transform = Compose([Resize(size=256), ToTensor()])
|
|
image_folder = ImageFolder(root=source_dir, transform=transform)
|
|
dataloader = DataLoader(image_folder, batch_size=batch_size, shuffle=False, num_workers=8)
|
|
|
|
model = load(f=model_path)
|
|
model = model.to(device)
|
|
model.eval()
|
|
for layer in model.parameters():
|
|
layer.requires_grad = False
|
|
|
|
# summary(model=model, input_size=(batch_size, 3, input_shape, input_shape), device=device)
|
|
|
|
if exists(join(datadir(), 'infer_all_torch_trt.ts')):
|
|
trt_model = jit.load(join(datadir(), 'infer_all_torch_trt.ts'))
|
|
else:
|
|
|
|
example_input = randn(size=[batch_size, 3, 256, 256]).float().cuda()
|
|
traced_script_module = jit.trace(model, example_inputs=[example_input])
|
|
trt_model = compile(module=traced_script_module, inputs=[example_input],
|
|
enabled_precisions={float32, float16},
|
|
truncate_long_and_double=True)
|
|
jit.save(trt_model, join(datadir(), 'infer_all_torch_trt.ts'))
|
|
|
|
# heatmap_df = DataFrame(index=image_class, columns=image_folder.classes).fillna(0)
|
|
begin = time.time()
|
|
with ThreadPoolExecutor(max_workers=60) as executor, no_grad():
|
|
for (images, labels), fileinfo in zip(tqdm(dataloader), chunked(image_folder.imgs, n=batch_size)):
|
|
# print(labels, fileinfo)
|
|
res = trt_model(images.to(device))
|
|
for name, (filename, person) in zip(res.to(device).max(1).indices.tolist(), fileinfo):
|
|
if not exists(join(dest_dir, image_class[name])):
|
|
makedirs(join(dest_dir, image_class[name]), exist_ok=True)
|
|
# print(name, filename, person)
|
|
# copyfile(src=filename,
|
|
# dst=join(dest_dir, image_folder.classes[name], basename(filename)))
|
|
# if image_class[name] != image_folder.classes[person]:
|
|
# heatmap_df[image_folder.classes[person]][image_class[name]] += 1
|
|
executor.submit(copyfile, filename, join(dest_dir, image_class[name], basename(filename)))
|
|
print(f"{time.time() - begin:5f}sec")
|
|
# print(heatmap_df)
|
|
# set_palette('Blues')
|
|
# pyplot.figure(figsize=(40, 40))
|
|
# heat_img = heatmap(heatmap_df, cmap='Blues', linewidths=1)
|
|
# japanize()
|
|
# heatmap_df.max()
|
|
# pyplot.savefig(join(dest_dir, 'confusion_matrix.png'))
|
|
# print(f'acc: {1 - heatmap_df.to_numpy().flatten().sum() / image_folder.__len__()}')
|
|
print(image_folder.classes)
|