This commit is contained in:
yayoimizuha 2024-02-09 01:18:58 +09:00
parent df540c2d4f
commit a9a106f159
1 changed files with 14 additions and 5 deletions

View File

@ -38,7 +38,7 @@ py_model: Model = get_model(model_name='resnet50_2020-07-20', max_size=512)
print(py_model.predict_jsons(array(image)))
max_size = 512
example_input = randn(size=[10, 3, 256, 256]).float()
example_input = randn(size=[1, 3, 640, 640]).float()
retina_model = RetinaFace(
name="Resnet50",
@ -58,13 +58,22 @@ torch.onnx.export(
input_names=["input"],
output_names=["bbox", "confidence", "landmark"],
dynamic_axes={"input": {
0: "batch_size",
# 0: "batch_size",
2: "height",
3: "width"
},
"bbox": {0: "batch_size", 1: "length"},
"confidence": {0: "batch_size", 1: "length"},
"landmark": {0: "batch_size", 1: "length"},
"bbox": {
# 0: "batch_size",
1: "length"
},
"confidence": {
# 0: "batch_size",
1: "length"
},
"landmark": {
# 0: "batch_size",
1: "length"
},
},
)