diff --git a/test_script/retinaface_pure_impl.py b/test_script/retinaface_pure_impl.py index 083e33f..dd6f1ca 100644 --- a/test_script/retinaface_pure_impl.py +++ b/test_script/retinaface_pure_impl.py @@ -38,7 +38,7 @@ py_model: Model = get_model(model_name='resnet50_2020-07-20', max_size=512) print(py_model.predict_jsons(array(image))) max_size = 512 -example_input = randn(size=[10, 3, 256, 256]).float() +example_input = randn(size=[1, 3, 640, 640]).float() retina_model = RetinaFace( name="Resnet50", @@ -58,13 +58,22 @@ torch.onnx.export( input_names=["input"], output_names=["bbox", "confidence", "landmark"], dynamic_axes={"input": { - 0: "batch_size", + # 0: "batch_size", 2: "height", 3: "width" }, - "bbox": {0: "batch_size", 1: "length"}, - "confidence": {0: "batch_size", 1: "length"}, - "landmark": {0: "batch_size", 1: "length"}, + "bbox": { + # 0: "batch_size", + 1: "length" + }, + "confidence": { + # 0: "batch_size", + 1: "length" + }, + "landmark": { + # 0: "batch_size", + 1: "length" + }, }, )