| | import argparse |
| | import os |
| |
|
| | import numpy as np |
| | import onnx |
| | import onnxruntime |
| | import torch |
| | from monai.networks.nets import FlexibleUNet |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | def load_model_and_export( |
| | modelname, outname, out_channels, height, width, multigpu=False, in_channels=3, backbone="efficientnet-b0" |
| | ): |
| | """ |
| | Loading a model by name. |
| | |
| | Args: |
| | modelname: a whole path name of the model that need to be loaded. |
| | outname: a name for output onnx model. |
| | out_channels: output channels, which usually equals to 1 + class_number. |
| | height: input images' height. |
| | width: input images' width. |
| | multigpu: if the pre-trained model trained on a multigpu environment. |
| | in_channels: input images' channel number. |
| | backbone: a name of backbone used by the flexible unet. |
| | """ |
| | isopen = os.path.exists(modelname) |
| | if not isopen: |
| | raise Exception("The specified model to load does not exist!") |
| |
|
| | model = FlexibleUNet( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | backbone=backbone, |
| | is_pad=False, |
| | pretrained=False, |
| | dropout=None, |
| | ) |
| |
|
| | if multigpu: |
| | model = torch.nn.DataParallel(model) |
| | model = model.cuda() |
| | model.load_state_dict(torch.load(modelname, map_location=device)) |
| | model = model.eval() |
| |
|
| | np.random.seed(0) |
| | x = np.random.random((1, 3, width, height)) |
| | x = torch.tensor(x, dtype=torch.float32) |
| | x = x.cuda() |
| | torch_out = model(x) |
| | input_names = ["INPUT__0"] |
| | output_names = ["OUTPUT__0"] |
| | |
| | if multigpu: |
| | model_trans = model.module |
| | else: |
| | model_trans = model |
| | torch.onnx.export( |
| | model_trans, |
| | x, |
| | outname, |
| | export_params=True, |
| | verbose=True, |
| | do_constant_folding=True, |
| | input_names=input_names, |
| | output_names=output_names, |
| | opset_version=15, |
| | dynamic_axes={"INPUT__0": {0: "batch_size"}, "OUTPUT__0": {0: "batch_size"}}, |
| | ) |
| | onnx_model = onnx.load(outname) |
| | onnx.checker.check_model(onnx_model, full_check=True) |
| | ort_session = onnxruntime.InferenceSession(outname) |
| |
|
| | def to_numpy(tensor): |
| | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
| |
|
| | |
| | ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} |
| | ort_outs = ort_session.run(["OUTPUT__0"], ort_inputs) |
| | numpy_torch_out = to_numpy(torch_out) |
| | |
| | np.testing.assert_allclose(numpy_torch_out, ort_outs[0], rtol=1e-03, atol=1e-05) |
| | print("Exported model has been tested with ONNXRuntime, and the result looks good!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument( |
| | "--model", type=str, default=r"/workspace/models/model.pt", help="Input an existing model weight" |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--outpath", type=str, default=r"/workspace/models/model.onnx", help="A path to save the onnx model." |
| | ) |
| |
|
| | parser.add_argument("--width", type=int, default=736, help="Width for exporting onnx model.") |
| |
|
| | parser.add_argument("--height", type=int, default=480, help="Height for exporting onnx model.") |
| |
|
| | parser.add_argument( |
| | "--out_channels", type=int, default=2, help="Number of expected out_channels in model for exporting to onnx." |
| | ) |
| |
|
| | parser.add_argument("--multigpu", type=bool, default=False, help="If loading model trained with multi gpu.") |
| |
|
| | args = parser.parse_args() |
| | modelname = args.model |
| | outname = args.outpath |
| | out_channels = args.out_channels |
| | height = args.height |
| | width = args.width |
| | multigpu = args.multigpu |
| |
|
| | if os.path.exists(outname): |
| | raise Exception( |
| | "The specified outpath already exists! Change the outpath to avoid overwriting your saved model. " |
| | ) |
| | model = load_model_and_export(modelname, outname, out_channels, height, width, multigpu) |
| |
|