This commit is contained in:
yayoimizuha 2024-12-07 23:23:36 +09:00
parent f3bc7e5529
commit 707f6fdf91
2 changed files with 67 additions and 28 deletions

View File

@ -46,15 +46,14 @@ pyplot.figure(figsize=(20, 20), dpi=150)
pyplot.imshow(yuv_plane[2, :, :]) pyplot.imshow(yuv_plane[2, :, :])
pyplot.show() pyplot.show()
pyplot.close("all") pyplot.close("all")
ycbcr_mat = yuv_plane.transpose((1, 2, 0)).reshape((-1, 3)) - [0, 128, 128] ycbcr_mat = yuv_plane.transpose((1, 2, 0)) - [0, 128, 128]
# print(ycbcr_mat) # print(ycbcr_mat)
transform_matrix = numpy.array([ transform_matrix = numpy.array([
[1.0, 0.0, 1.5748], [1, 0, 1.402],
[1.0, -0.1873, -0.4681], [1, -0.344136, -0.714136],
[1.0, 1.8556, 0.0] [1, 1.772, 0]
]) ])
rgb_plane = (numpy.clip(numpy.dot(ycbcr_mat, transform_matrix.T), 0, 255) rgb_plane = (numpy.clip(numpy.dot(ycbcr_mat, transform_matrix.T), 0, 255).astype(numpy.uint8))
.reshape(pitch_h, pitch_w, 3).astype(numpy.uint8))
pyplot.figure(figsize=(20, 20), dpi=150) pyplot.figure(figsize=(20, 20), dpi=150)
pyplot.imshow(rgb_plane) pyplot.imshow(rgb_plane)
pyplot.show() pyplot.show()

View File

@ -1,7 +1,12 @@
import ctypes
import inspect
import json import json
import math
import os import os
import warnings import warnings
from numpy.f2py.auxfuncs import throw_error
warnings.filterwarnings("ignore", lineno=6, category=UserWarning) warnings.filterwarnings("ignore", lineno=6, category=UserWarning)
from concurrent.futures.process import ProcessPoolExecutor from concurrent.futures.process import ProcessPoolExecutor
from itertools import chain from itertools import chain
@ -16,7 +21,7 @@ import tqdm
from PIL import Image from PIL import Image
from uuid import uuid4 from uuid import uuid4
from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel
from torch import tensor from torch import tensor, as_strided
import aiofiles import aiofiles
import numpy import numpy
import torch import torch
@ -24,6 +29,7 @@ from torchvision.io import decode_jpeg
from asyncio import run, gather, Semaphore from asyncio import run, gather, Semaphore
from site import getsitepackages from site import getsitepackages
from rust_retinaface_post_processor import resnet_post_process from rust_retinaface_post_processor import resnet_post_process
from test_ext import decode as qsv_decode
USE_OPENVINO = True USE_OPENVINO = True
if USE_OPENVINO: if USE_OPENVINO:
@ -39,7 +45,7 @@ files = []
files_data: dict[str, numpy.ndarray | None] = {} files_data: dict[str, numpy.ndarray | None] = {}
chunk_size = 16 chunk_size = 16
image_size = 640 image_size = 640
device = torch.device("cpu") if torch.xpu.is_available() else exit(-1) device = torch.device("xpu") if torch.xpu.is_available() else exit(-1)
async def async_read(path: str, semaphore: Semaphore): async def async_read(path: str, semaphore: Semaphore):
@ -78,6 +84,7 @@ def post_processor_shm(shm_name, sizes, batch_size, image_size):
def dec_jpg(f, fn): def dec_jpg(f, fn):
# print("USE PILLOW")
_decoded_image = tensor(numpy.array(Image.open(BytesIO(f.tobytes()))).transpose([2, 0, 1])) _decoded_image = tensor(numpy.array(Image.open(BytesIO(f.tobytes()))).transpose([2, 0, 1]))
_decoded_image = _decoded_image.to(device, torch.float16) / 255 _decoded_image = _decoded_image.to(device, torch.float16) / 255
_decoded_image = fn[2](_decoded_image) _decoded_image = fn[2](_decoded_image)
@ -85,6 +92,31 @@ def dec_jpg(f, fn):
return fn[1](_decoded_image_resized) return fn[1](_decoded_image_resized)
def dec_jpg_qsv(f, fn):
ptr, height, width, pitch = qsv_decode(f.tobytes())
pitch_h = math.ceil(height / 2) * 2
pitch_w = math.ceil(width / 2) * 2
y_arr = torch.frombuffer((ctypes.c_uint8 * (pitch_h * pitch)).from_address(ptr), dtype=torch.uint8,
count=pitch_h * pitch).to(device)
uv_arr = torch.frombuffer((ctypes.c_uint8 * (int(pitch_h * 1.5) * pitch)).from_address(ptr),
dtype=torch.uint8, count=int(pitch_h / 2) * pitch, offset=pitch_h * pitch).to(device)
y_plane = as_strided(y_arr, (pitch_h, pitch_w), (pitch, 1))
uv_plane = as_strided(uv_arr, (int(pitch_h / 2), int(pitch_w / 2), 2), (pitch, 2, 1))
yuv_plane = torch.stack([y_plane,
uv_plane[:, :, 0].repeat_interleave(2, dim=0).repeat_interleave(2, dim=1),
uv_plane[:, :, 1].repeat_interleave(2, dim=0).repeat_interleave(2, dim=1)])
ycbcr_mat = yuv_plane.permute((1, 2, 0)) - torch.Tensor([0, 128, 128]).to(device)
transform_matrix = torch.Tensor([
[1, 0, 1.402],
[1, -0.344136, -0.714136],
[1, 1.772, 0]
]).to(device)
rgb_plane = torch.clip(torch.matmul(ycbcr_mat, transform_matrix.T), 0, 255).to(device, torch.uint8) / 255
_decoded_image = fn[2](rgb_plane.permute((2, 0, 1)))
_decoded_image_resized = fn[0](_decoded_image)
return fn[1](_decoded_image_resized)
if __name__ == '__main__': if __name__ == '__main__':
from kornia.augmentation import LongestMaxSize, PadTo, Normalize from kornia.augmentation import LongestMaxSize, PadTo, Normalize
from kornia.constants import Resample from kornia.constants import Resample
@ -93,10 +125,14 @@ if __name__ == '__main__':
pad_to = PadTo(size=(640, 640), pad_value=1.) pad_to = PadTo(size=(640, 640), pad_value=1.)
normalize = Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) normalize = Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
if USE_OPENVINO: if USE_OPENVINO:
for ov_device in ov_core.get_available_devices():
device_name = ov_core.get_property(ov_device, "FULL_DEVICE_NAME")
print(f"{ov_device}: {device_name}")
onnx_model = ov_core.read_model(model_path) onnx_model = ov_core.read_model(model_path)
onnx_model.reshape([chunk_size, 3, image_size, image_size]) onnx_model.reshape([chunk_size, 3, image_size, image_size])
onnx_model = ov_core.compile_model(onnx_model, device_name='GPU') onnx_model = ov_core.compile_model(onnx_model, device_name='GPU')
# print(onnx_model)
else: else:
session_options = SessionOptions() session_options = SessionOptions()
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
@ -104,26 +140,17 @@ if __name__ == '__main__':
session = InferenceSession( session = InferenceSession(
path_or_bytes=model_path, path_or_bytes=model_path,
providers=[ providers=[
('TensorrtExecutionProvider', {
'trt_engine_cache_enable': True,
'trt_engine_cache_path': 'trt_cache',
'trt_fp16_enable': True,
'trt_profile_min_shapes': f'input:1x3x{image_size}x{image_size}',
'trt_profile_max_shapes': f'input:{chunk_size}x3x{image_size}x{image_size}',
'trt_profile_opt_shapes': f'input:{chunk_size}x3x{image_size}x{image_size}',
}),
('OpenVINOExecutionProvider', { ('OpenVINOExecutionProvider', {
'device_type': 'GPU.0', 'device_type': 'GPU.0',
'precision': 'FP16', 'precision': 'FP16',
'cache_dir': 'openvino_cache' 'cache_dir': 'openvino_cache'
}), }),
'CUDAExecutionProvider',
'CPUExecutionProvider' 'CPUExecutionProvider'
], ],
sess_options=session_options sess_options=session_options
) )
if os.path.exists("faces.jsonl"): if os.path.exists("faces_qsv.jsonl"):
with open(file="faces.jsonl", mode="r", encoding="utf-8") as fp: with open(file="faces_qsv.jsonl", mode="r", encoding="utf-8") as fp:
already = {list(msgspec.json.decode(line).keys())[0] for line in fp.read().removesuffix("\n").split("\n")} already = {list(msgspec.json.decode(line).keys())[0] for line in fp.read().removesuffix("\n").split("\n")}
else: else:
already = set() already = set()
@ -134,7 +161,7 @@ if __name__ == '__main__':
# exit(0) # exit(0)
for name in listdir(root_dir): for name in listdir(root_dir):
with (ProcessPoolExecutor(max_workers=16) as executor): with (ProcessPoolExecutor(max_workers=12) as executor):
pbar.set_description_str(desc=name, refresh=True) pbar.set_description_str(desc=name, refresh=True)
if name != "ブログ": if name != "ブログ":
# continue # continue
@ -157,12 +184,26 @@ if __name__ == '__main__':
if USE_OPENVINO: if USE_OPENVINO:
fn_pack = [longest_max_size, pad_to, normalize] fn_pack = [longest_max_size, pad_to, normalize]
submits = [] submits = []
# for file, dat in cnk:
# submits.append(executor.submit(dec_jpg_qsv, dat, fn_pack))
# names.append(file)
# for submit in submits:
# try:
# stack.append(submit.result().to(device).squeeze())
# except Exception as e:
# print(e)
# stack.append(dec_jpg(dat, fn_pack).squeeze())
for file, dat in cnk: for file, dat in cnk:
submits.append(executor.submit(dec_jpg, dat, fn_pack)) try:
raise Exception
stack.append(dec_jpg_qsv(dat, fn_pack).squeeze())
except Exception as e:
# print(e)
stack.append(dec_jpg(dat, fn_pack).squeeze())
names.append(file) names.append(file)
for submit in submits:
stack.append(submit.result().squeeze())
else: else:
print("fallback", inspect.currentframe().f_lineno)
for file, dat in cnk: for file, dat in cnk:
try: try:
decoded_image = decode_jpeg(tensor(dat), device=device) decoded_image = decode_jpeg(tensor(dat), device=device)
@ -181,10 +222,9 @@ if __name__ == '__main__':
stacked = torch.stack(stack).contiguous() stacked = torch.stack(stack).contiguous()
# print(stacked.shape) # print(stacked.shape)
if USE_OPENVINO: if USE_OPENVINO:
_outputs = onnx_model([stacked]) _outputs = onnx_model([stacked.cpu()])
# print(_outputs[onnx_model.output(0)]) # print(_outputs[onnx_model.output(0)])
outputs = [_outputs[onnx_model.output(i)] for i in range(2, -1, -1)] outputs = [_outputs[onnx_model.output(i)] for i in range(2, -1, -1)]
# print(outputs)
else: else:
io_binding = session.io_binding() io_binding = session.io_binding()
io_binding.bind_input( io_binding.bind_input(
@ -200,7 +240,7 @@ if __name__ == '__main__':
io_binding.bind_output("bbox") io_binding.bind_output("bbox")
session.run_with_iobinding(iobinding=io_binding) session.run_with_iobinding(iobinding=io_binding)
outputs: list[numpy.ndarray] = io_binding.copy_outputs_to_cpu() outputs: list[numpy.ndarray] = io_binding.copy_outputs_to_cpu()
print("fallback", inspect.currentframe().f_lineno)
# [numpy.memmap(filename=path.join("memmap", tmp_file_name + str(order)), dtype=numpy.float16, # [numpy.memmap(filename=path.join("memmap", tmp_file_name + str(order)), dtype=numpy.float16,
# mode="w+", shape=output.shape) for order, output in enumerate(outputs)] # mode="w+", shape=output.shape) for order, output in enumerate(outputs)]
uuid = uuid4().__str__() uuid = uuid4().__str__()
@ -218,7 +258,7 @@ if __name__ == '__main__':
# exit(0) # exit(0)
pbar.update(n=cnk.__len__()) pbar.update(n=cnk.__len__())
# result_dict = dict() # result_dict = dict()
with open("faces.jsonl", mode="a", encoding="utf-8") as fp: with open("faces_qsv.jsonl", mode="a", encoding="utf-8") as fp:
futures_results = [future.result() for future in futures] futures_results = [future.result() for future in futures]
# pprint(futures_results) # pprint(futures_results)
for names, futures_result in zip(namess, futures_results): for names, futures_result in zip(namess, futures_results):