helloproject-ai/generate_embeddings_edgefac...

97 lines
3.6 KiB
Python

import concurrent.futures
import os.path
from concurrent.futures.process import ProcessPoolExecutor
from itertools import chain
from PIL import Image
from more_itertools import chunked
from os import listdir
from torchvision import transforms
from torchvision.io import decode_jpeg
from tqdm import tqdm
from edgeface.backbones import get_model
import torch
import numpy
from edgeface.face_alignment import align
CROPPED_DIR = r"D:\helloproject-ai-data\face_cropped"
MODEL_NAME = "edgeface_s_gamma_05"
CHUNK_SIZE = 64
DEVICE = torch.device("cuda")
INPUT_SIZE = 112
TYPE = "edgeface"
transform = transforms.Compose([
transforms.ToTensor(),
# lambda x: x.to(torch.float32) / 255.,
transforms.Resize(size=INPUT_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
# transforms.Resize(size=int(INPUT_SIZE * 1.2), interpolation=transforms.InterpolationMode.BILINEAR),
# transforms.CenterCrop(size=INPUT_SIZE)
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def align_edgeface(p: str):
sub_dir_name = p.split("=")[0]
aligned = align.get_aligned_face(os.path.join(CROPPED_DIR, sub_dir_name, p))
if aligned is None:
return None
return transform(aligned).to(DEVICE)
if __name__ == '__main__':
model: torch.nn.Module = get_model(name=MODEL_NAME)
model.load_state_dict(
torch.load(os.path.join(os.path.dirname(__file__), "edgeface", "checkpoints", f"{MODEL_NAME}.pt"),
weights_only=False))
model = model.eval().cuda(device=DEVICE)
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):
# print(model.eval())
# summary(model, input_size=[CHUNK_SIZE, 3, INPUT_SIZE, INPUT_SIZE])
# trt_model = torch.compile(model)
embeddings: numpy.ndarray | None = None
labels = []
if os.path.exists(f"embeddings_{TYPE}_label.npy"):
embeddings: numpy.ndarray = numpy.load(f"embeddings_{TYPE}.npy")
labels: list[str] = numpy.load(f"embeddings_{TYPE}_label.npy").tolist()
all_cropped_list = list(
chain.from_iterable([listdir(os.path.join(CROPPED_DIR, name)) for name in listdir(CROPPED_DIR)]))
# all_cropped_list = all_cropped_list[:1000]
labels_set = set(labels)
pbar = tqdm(total=all_cropped_list.__len__())
for chk in chunked(all_cropped_list, n=CHUNK_SIZE):
decoded_images = []
pool_res_list = []
for file_name in chk:
pbar.update(1)
if pbar.desc != file_name.split("=")[0]:
pbar.set_description(file_name.split("=")[0])
if file_name in labels_set:
continue
pool_res = align_edgeface(file_name)
pool_res_list.append(pool_res)
for result, name in zip(pool_res_list, chk):
# result = result.result()
if result is not None:
decoded_images.append(result)
labels.append(name)
if not decoded_images:
continue
stacked = torch.stack(decoded_images)
# print(stacked.shape)
res = model(stacked)
# print(res.shape)
if embeddings is not None:
embeddings = numpy.concatenate([embeddings, res.cpu().numpy()], axis=0)
# print(embeddings.shape)
else:
embeddings = res.cpu().numpy()
# print(embeddings.shape, labels.__len__())
numpy.save(f"embeddings_{TYPE}.npy", embeddings)
numpy.save(f"embeddings_{TYPE}_label.npy", numpy.array(labels))