97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
import concurrent.futures
|
|
import os.path
|
|
from concurrent.futures.process import ProcessPoolExecutor
|
|
from itertools import chain
|
|
from PIL import Image
|
|
from more_itertools import chunked
|
|
from os import listdir
|
|
from torchvision import transforms
|
|
from torchvision.io import decode_jpeg
|
|
from tqdm import tqdm
|
|
from edgeface.backbones import get_model
|
|
import torch
|
|
import numpy
|
|
from edgeface.face_alignment import align
|
|
|
|
CROPPED_DIR = r"D:\helloproject-ai-data\face_cropped"
|
|
MODEL_NAME = "edgeface_s_gamma_05"
|
|
CHUNK_SIZE = 64
|
|
DEVICE = torch.device("cuda")
|
|
INPUT_SIZE = 112
|
|
TYPE = "edgeface"
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
# lambda x: x.to(torch.float32) / 255.,
|
|
transforms.Resize(size=INPUT_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
|
|
# transforms.Resize(size=int(INPUT_SIZE * 1.2), interpolation=transforms.InterpolationMode.BILINEAR),
|
|
# transforms.CenterCrop(size=INPUT_SIZE)
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
])
|
|
|
|
|
|
def align_edgeface(p: str):
|
|
sub_dir_name = p.split("=")[0]
|
|
aligned = align.get_aligned_face(os.path.join(CROPPED_DIR, sub_dir_name, p))
|
|
if aligned is None:
|
|
return None
|
|
return transform(aligned).to(DEVICE)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model: torch.nn.Module = get_model(name=MODEL_NAME)
|
|
model.load_state_dict(
|
|
torch.load(os.path.join(os.path.dirname(__file__), "edgeface", "checkpoints", f"{MODEL_NAME}.pt"),
|
|
weights_only=False))
|
|
model = model.eval().cuda(device=DEVICE)
|
|
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
|
# print(model.eval())
|
|
# summary(model, input_size=[CHUNK_SIZE, 3, INPUT_SIZE, INPUT_SIZE])
|
|
# trt_model = torch.compile(model)
|
|
|
|
embeddings: numpy.ndarray | None = None
|
|
labels = []
|
|
if os.path.exists(f"embeddings_{TYPE}_label.npy"):
|
|
embeddings: numpy.ndarray = numpy.load(f"embeddings_{TYPE}.npy")
|
|
labels: list[str] = numpy.load(f"embeddings_{TYPE}_label.npy").tolist()
|
|
|
|
all_cropped_list = list(
|
|
chain.from_iterable([listdir(os.path.join(CROPPED_DIR, name)) for name in listdir(CROPPED_DIR)]))
|
|
# all_cropped_list = all_cropped_list[:1000]
|
|
|
|
labels_set = set(labels)
|
|
pbar = tqdm(total=all_cropped_list.__len__())
|
|
|
|
for chk in chunked(all_cropped_list, n=CHUNK_SIZE):
|
|
decoded_images = []
|
|
pool_res_list = []
|
|
for file_name in chk:
|
|
pbar.update(1)
|
|
if pbar.desc != file_name.split("=")[0]:
|
|
pbar.set_description(file_name.split("=")[0])
|
|
if file_name in labels_set:
|
|
continue
|
|
pool_res = align_edgeface(file_name)
|
|
pool_res_list.append(pool_res)
|
|
for result, name in zip(pool_res_list, chk):
|
|
# result = result.result()
|
|
if result is not None:
|
|
decoded_images.append(result)
|
|
labels.append(name)
|
|
|
|
if not decoded_images:
|
|
continue
|
|
stacked = torch.stack(decoded_images)
|
|
# print(stacked.shape)
|
|
res = model(stacked)
|
|
# print(res.shape)
|
|
if embeddings is not None:
|
|
embeddings = numpy.concatenate([embeddings, res.cpu().numpy()], axis=0)
|
|
# print(embeddings.shape)
|
|
else:
|
|
embeddings = res.cpu().numpy()
|
|
|
|
# print(embeddings.shape, labels.__len__())
|
|
numpy.save(f"embeddings_{TYPE}.npy", embeddings)
|
|
numpy.save(f"embeddings_{TYPE}_label.npy", numpy.array(labels))
|