helloproject-ai/face_crop.py

165 lines
6.5 KiB
Python
Executable File

import torch.backends.cudnn
from torch import no_grad, backends
import settings
from os import makedirs, listdir, stat, utime, devnull
from os.path import join, exists
from tqdm import tqdm
from PIL import Image
from numpy import array, arctan2, pi
from aiofiles import open as a_open
from asyncio import gather, run
from multiprocessing import Queue, Process, get_start_method, set_start_method
from time import sleep
from io import BytesIO
from math import sqrt
# from insightface.app import FaceAnalysis
from retinaface.pre_trained_models import get_model
from retinaface.predict_single import Model
from contextlib import redirect_stdout, redirect_stderr
from collections import OrderedDict
face_dir = join(settings.datadir(), 'face_cropped')
blog_images = join(settings.datadir(), 'blog_images')
if not exists(face_dir):
makedirs(face_dir)
def truncate(landmark: list[tuple[float]]) -> tuple[tuple[int, int], float]:
left_eye, right_eye, nose, left_mouth, right_mouth = landmark
center_x = sum((left_eye[0], right_eye[0], left_mouth[0], right_mouth[0])) / 4
center_y = sum((left_eye[1], right_eye[1], left_mouth[1], right_mouth[1])) / 4
eye_center = (right_eye[0] + left_eye[0]) / 2, (right_eye[1] + left_eye[1]) / 2
mouth_center = (right_mouth[0] + left_mouth[0]) / 2, (right_mouth[1] + left_mouth[1]) / 2
return (int(center_x), int(center_y)), arctan2(eye_center[0] - mouth_center[0], mouth_center[1] - eye_center[1])
def load_image(basedir: str, queue: Queue, progress: tuple[Queue]) -> None:
def list_up():
for name in listdir(basedir): # ["江端妃咲"]:
for image_file in listdir(join(basedir, name)):
yield name, image_file
async def single_read(path: tuple[str, str]):
async with a_open(join(basedir, *path), mode='rb') as f:
return await f.read(), path
async def parallel_read(paths: list[tuple[str, str]]):
return await gather(*[single_read(path) for path in paths])
file_list = [i for i in list_up()]
bar = tqdm(total=file_list.__len__())
for i in range(0, file_list.__len__(), 20):
while queue.qsize() > 50:
sleep(1e-3)
chunk = file_list[i:i + 20]
img_bins = run(parallel_read(chunk))
for img_bin, p in img_bins:
queue.put((Image.open(BytesIO(img_bin)), p))
bar.update(1)
bar.set_postfix(OrderedDict(name=p[0], qsize=[q.qsize() for q in progress]))
return
def pre_process(q1: Queue, q2: Queue):
while True:
while q2.qsize() > 50:
# print("occur wait")
pass
# sleep(1e-4)
image, path = q1.get()
# print(type(image))
img_arr = array(image) # [:, :, ::-1]
q2.put((img_arr, path))
def predict(q1: Queue, q2: Queue, gpu: int):
sleep(gpu * 4)
torch.backends.cudnn.benchmark = True
with redirect_stderr(open(devnull, mode='w')):
# face_analysis = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection'])
# face_analysis.prepare(ctx_id=gpu)
model: Model = get_model(model_name='resnet50_2020-07-20', max_size=512, device='cuda')
model.eval()
with no_grad():
while True:
image, path = q1.get()
res = model.predict_jsons(image=image, confidence_threshold=0.4, nms_threshold=0.4)
if res.__len__() == 1:
if res[0]['score'] == -1:
continue
faces = []
for face in res:
faces.append((face['landmarks'], face['score'], face['bbox']))
q2.put((faces, path))
def post_process(queue: Queue):
while True:
res, path = queue.get()
image = Image.open(join(blog_images, *path))
name, file = path
# width, height = image.size
# if width * height > 400_0000:
# image = image.resize(size=(width // 2, height // 2))
for order, face in enumerate(res):
landmarks, score, bbox = face
face_width = bbox[2] - bbox[0]
face_height = bbox[3] - bbox[1]
trans = truncate(landmarks)
rotated = image.rotate(angle=trans[1] * 360 / (2 * pi), center=trans[0])
image_size = max(face_width, face_height) * sqrt(2) // 2
if image_size < 100:
continue
cropped = rotated.crop((trans[0][0] - image_size, trans[0][1] - image_size, trans[0][0] + image_size,
trans[0][1] + image_size))
if not exists(join(face_dir, name)):
makedirs(join(face_dir, name), exist_ok=True)
saved_path = join(face_dir, name, file.replace('.jpg', '-' + str(order + 1) + '.jpg'))
cropped.save(saved_path)
utime(path=saved_path, times=(stat(join(blog_images, *path)).st_atime,
stat(join(blog_images, *path)).st_mtime))
if __name__ == '__main__':
if get_start_method() == 'fork':
set_start_method('spawn', force=True)
try:
Load_Q, PreProcess_Q, Predict_Q, PostProcess_Q = (Queue() for i in range(4))
Load_Processes = [
Process(target=load_image, args=(blog_images, Load_Q, (Load_Q, PreProcess_Q, Predict_Q, PostProcess_Q)))
for _ in range(settings.FaceCropProcesses.load)]
PreProcesses = [Process(target=pre_process, args=(Load_Q, PreProcess_Q))
for _ in range(settings.FaceCropProcesses.pre_process)]
Predict_Process = [Process(target=predict, args=(PreProcess_Q, Predict_Q, gpu_id))
for gpu_id in range(settings.FaceCropProcesses.predict)]
PostProcesses = [Process(target=post_process, args=(Predict_Q,))
for _ in range(settings.FaceCropProcesses.post_process)]
[p.start() for p in Load_Processes]
[p.start() for p in PreProcesses]
[p.start() for p in Predict_Process]
[p.start() for p in PostProcesses]
sleep(20)
while True:
sleep(5)
# print(Load_Q.qsize(), PreProcess_Q.qsize(), Predict_Q.qsize(), PostProcess_Q.qsize())
if sum((Load_Q.qsize(), PreProcess_Q.qsize(), Predict_Q.qsize(), PostProcess_Q.qsize())) == 0:
raise KeyboardInterrupt
except KeyboardInterrupt as e:
print(e)
[p.terminate() for p in Load_Processes]
[p.terminate() for p in PreProcesses]
[p.terminate() for p in Predict_Process]
[p.terminate() for p in PostProcesses]