143 lines
4.5 KiB
Python
143 lines
4.5 KiB
Python
import tensorrt as trt
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from ..exceptions import (
|
|
InvalidInputException,
|
|
ModelLoadingException,
|
|
)
|
|
from contextlib import ExitStack
|
|
|
|
LOGGER = trt.Logger(trt.Logger.Severity.ERROR)
|
|
|
|
DEFAULT_DETECTION_MAX_BATCH_SIZE = 1
|
|
DEFAULT_QUALITY_MAX_BATCH_SIZE = 4
|
|
DEFAULT_LANDMARKS_MAX_BATCH_SIZE = 4
|
|
DEFAULT_RECOGNITION_MAX_BATCH_SIZE = 4
|
|
DEFAULT_ATTRIBUTES_MAX_BATCH_SIZE = 4
|
|
DEFAULT_MASK_MAX_BATCH_SIZE = 4
|
|
NUM_CHANNELS_RGB = 3
|
|
MAX_WORKSPACE_SIZE = 1 << 28
|
|
|
|
trt.init_libnvinfer_plugins(LOGGER, "")
|
|
|
|
PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list
|
|
|
|
|
|
def _get_max_batch_size(name, settings):
|
|
if name == "detection":
|
|
# batching is not enabled for detection yet
|
|
return DEFAULT_DETECTION_MAX_BATCH_SIZE
|
|
|
|
if name == "landmarks":
|
|
size = settings.get(
|
|
"landmarks_max_batch_size", DEFAULT_LANDMARKS_MAX_BATCH_SIZE
|
|
)
|
|
elif name == "recognition":
|
|
size = settings.get(
|
|
"recognition_max_batch_size", DEFAULT_RECOGNITION_MAX_BATCH_SIZE
|
|
)
|
|
elif name == "attributes":
|
|
size = settings.get(
|
|
"attributes_max_batch_size", DEFAULT_ATTRIBUTES_MAX_BATCH_SIZE
|
|
)
|
|
elif name == "mask":
|
|
size = settings.get("mask_max_batch_size", DEFAULT_MASK_MAX_BATCH_SIZE)
|
|
elif name == "quality":
|
|
size = settings.get("quality_max_batch_size", DEFAULT_QUALITY_MAX_BATCH_SIZE)
|
|
else:
|
|
raise InvalidInputException("Batch size is not specified")
|
|
|
|
return size
|
|
|
|
|
|
def build_engine(name, models_dir, models_type, engine_path, settings, shape):
|
|
if name == "mask":
|
|
model_file = os.path.join(models_dir, models_type, f"{name}.onnx")
|
|
else:
|
|
model_file = os.path.join(models_dir, name, models_type, f"{name}.onnx")
|
|
|
|
batch_size = _get_max_batch_size(name, settings)
|
|
|
|
trt_version = int(trt.__version__.split(".")[0])
|
|
if trt_version >= 8:
|
|
# -1 indicates dynamic batching. Does not work for detection model currently
|
|
input_shape = [
|
|
batch_size if name == "detection" else -1,
|
|
NUM_CHANNELS_RGB,
|
|
] + list(shape)
|
|
net_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
else:
|
|
raise ModelLoadingException(
|
|
"TensorRT version 8 or higher required to build engine"
|
|
)
|
|
|
|
if not os.path.isfile(model_file):
|
|
return None
|
|
|
|
with ExitStack() as stack:
|
|
builder = stack.enter_context(trt.Builder(LOGGER))
|
|
config = stack.enter_context(builder.create_builder_config())
|
|
network = stack.enter_context(builder.create_network(net_flags))
|
|
parser = stack.enter_context(trt.OnnxParser(network, LOGGER))
|
|
|
|
success = parser.parse_from_file(model_file)
|
|
if not success:
|
|
raise ModelLoadingException(f"Cannot parse {name} model.")
|
|
|
|
builder.max_batch_size = batch_size
|
|
config.max_workspace_size = MAX_WORKSPACE_SIZE
|
|
|
|
profile = _create_opt_profile(builder, network, batch_size)
|
|
config.add_optimization_profile(profile)
|
|
|
|
network.get_input(0).shape = input_shape
|
|
serialized_engine = builder.build_serialized_network(network, config)
|
|
if serialized_engine is None:
|
|
raise ModelLoadingException(f"Cannot serialize {name} engine.")
|
|
|
|
engine_dir = Path(engine_path).parent
|
|
engine_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(engine_path, "wb") as f:
|
|
f.write(serialized_engine)
|
|
|
|
return serialized_engine
|
|
|
|
|
|
def _create_opt_profile(builder, network, max_batch_size):
|
|
profile = builder.create_optimization_profile()
|
|
|
|
if network.num_inputs <= 0:
|
|
return profile
|
|
|
|
input_ = network.get_input(0)
|
|
|
|
min_shape = trt.Dims(input_.shape)
|
|
min_shape[0] = 1
|
|
|
|
opt_shape = trt.Dims(input_.shape)
|
|
opt_shape[0] = max_batch_size
|
|
|
|
max_shape = trt.Dims(input_.shape)
|
|
max_shape[0] = max_batch_size
|
|
|
|
profile.set_shape(input_.name, min_shape, opt_shape, max_shape)
|
|
|
|
return profile
|
|
|
|
|
|
def load_engine(name, engine_path, models_dir, models_type, settings, input_shape):
|
|
if not os.path.isfile(engine_path):
|
|
serialized_engine = build_engine(
|
|
name, models_dir, models_type, engine_path, settings, input_shape
|
|
)
|
|
else:
|
|
with open(engine_path, "rb") as f:
|
|
serialized_engine = f.read()
|
|
|
|
if not serialized_engine:
|
|
raise ModelLoadingException(f"Cannot build {name} engine.")
|
|
|
|
runtime = trt.Runtime(LOGGER)
|
|
return runtime.deserialize_cuda_engine(serialized_engine)
|