From 435adff46bfac813e60d8eb18d45663d1f30a783 Mon Sep 17 00:00:00 2001 From: yayoimizuha Date: Sun, 12 May 2024 18:05:41 +0900 Subject: [PATCH] update --- test_script/nvjpeg_infer.py | 31 +++++++++++++++++++++++++++++++ test_script/onnx_cacher.py | 4 ++-- 2 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 test_script/nvjpeg_infer.py diff --git a/test_script/nvjpeg_infer.py b/test_script/nvjpeg_infer.py new file mode 100644 index 0000000..608f6a7 --- /dev/null +++ b/test_script/nvjpeg_infer.py @@ -0,0 +1,31 @@ +import time +import tkinter + +# import cv2 +from nvjpeg_decoder import decode +from os import listdir, getcwd +from os.path import join +from numpy import array +# from matplotlib import pyplot, figure +# from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg +# import matplotlib_fontja +from onnxruntime import InferenceSession + +onnx_path = r"C:\Users\tomokazu\RustroverProjects\ameba_blog_downloader\src\retinaface\resnet_retinaface.onnx" +datadir = r"D:\helloproject-ai-data\blog_images" + +session = InferenceSession( + path_or_bytes=onnx_path, + # providers=[ + # # 'TensorrtExecutionProvider', + # # 'CUDAExecutionProvider', + # 'CPUExecutionProvider' + # ] +) +for member in listdir(datadir): + for file in listdir(join(datadir, member)): + with open(join(datadir, member, file), mode="rb") as f: + (data, (width, height)) = decode((f.read()), "imagenet") + print(width, height) + image_arr = array(data).reshape((1, 3, height, width)) # .transpose([1, 2, 0]) + session.run(input_feed={'input': image_arr}, output_names=['bbox', 'confidence', 'landmark']) diff --git a/test_script/onnx_cacher.py b/test_script/onnx_cacher.py index 5097e56..bb21d04 100644 --- a/test_script/onnx_cacher.py +++ b/test_script/onnx_cacher.py @@ -9,7 +9,7 @@ print(get_available_providers()) onnx_session = InferenceSession( path_or_bytes=r"retinaface.onnx", - providers=[ + providers=( 'CUDAExecutionProvider', ('TensorrtExecutionProvider', { 'trt_engine_cache_enable': True, @@ -18,7 +18,7 @@ onnx_session = InferenceSession( }), 'DmlExecutionProvider', 'CPUExecutionProvider' - ] + ) ) print(__version__) image_arr = numpy.expand_dims(numpy.array(