import tensorrt as trt import os from pathlib import Path from ..exceptions import ( InvalidInputException, ModelLoadingException, ) from contextlib import ExitStack LOGGER = trt.Logger(trt.Logger.Severity.ERROR) DEFAULT_DETECTION_MAX_BATCH_SIZE = 1 DEFAULT_QUALITY_MAX_BATCH_SIZE = 4 DEFAULT_LANDMARKS_MAX_BATCH_SIZE = 4 DEFAULT_RECOGNITION_MAX_BATCH_SIZE = 4 DEFAULT_ATTRIBUTES_MAX_BATCH_SIZE = 4 DEFAULT_MASK_MAX_BATCH_SIZE = 4 NUM_CHANNELS_RGB = 3 MAX_WORKSPACE_SIZE = 1 << 28 trt.init_libnvinfer_plugins(LOGGER, "") PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list def _get_max_batch_size(name, settings): if name == "detection": # batching is not enabled for detection yet return DEFAULT_DETECTION_MAX_BATCH_SIZE if name == "landmarks": size = settings.get( "landmarks_max_batch_size", DEFAULT_LANDMARKS_MAX_BATCH_SIZE ) elif name == "recognition": size = settings.get( "recognition_max_batch_size", DEFAULT_RECOGNITION_MAX_BATCH_SIZE ) elif name == "attributes": size = settings.get( "attributes_max_batch_size", DEFAULT_ATTRIBUTES_MAX_BATCH_SIZE ) elif name == "mask": size = settings.get("mask_max_batch_size", DEFAULT_MASK_MAX_BATCH_SIZE) elif name == "quality": size = settings.get("quality_max_batch_size", DEFAULT_QUALITY_MAX_BATCH_SIZE) else: raise InvalidInputException("Batch size is not specified") return size def build_engine(name, models_dir, models_type, engine_path, settings, shape): if name == "mask": model_file = os.path.join(models_dir, models_type, f"{name}.onnx") else: model_file = os.path.join(models_dir, name, models_type, f"{name}.onnx") batch_size = _get_max_batch_size(name, settings) trt_version = int(trt.__version__.split(".")[0]) if trt_version >= 8: # -1 indicates dynamic batching. Does not work for detection model currently input_shape = [ batch_size if name == "detection" else -1, NUM_CHANNELS_RGB, ] + list(shape) net_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) else: raise ModelLoadingException( "TensorRT version 8 or higher required to build engine" ) if not os.path.isfile(model_file): return None with ExitStack() as stack: builder = stack.enter_context(trt.Builder(LOGGER)) config = stack.enter_context(builder.create_builder_config()) network = stack.enter_context(builder.create_network(net_flags)) parser = stack.enter_context(trt.OnnxParser(network, LOGGER)) success = parser.parse_from_file(model_file) if not success: raise ModelLoadingException(f"Cannot parse {name} model.") builder.max_batch_size = batch_size config.max_workspace_size = MAX_WORKSPACE_SIZE profile = _create_opt_profile(builder, network, batch_size) config.add_optimization_profile(profile) network.get_input(0).shape = input_shape serialized_engine = builder.build_serialized_network(network, config) if serialized_engine is None: raise ModelLoadingException(f"Cannot serialize {name} engine.") engine_dir = Path(engine_path).parent engine_dir.mkdir(parents=True, exist_ok=True) with open(engine_path, "wb") as f: f.write(serialized_engine) return serialized_engine def _create_opt_profile(builder, network, max_batch_size): profile = builder.create_optimization_profile() if network.num_inputs <= 0: return profile input_ = network.get_input(0) min_shape = trt.Dims(input_.shape) min_shape[0] = 1 opt_shape = trt.Dims(input_.shape) opt_shape[0] = max_batch_size max_shape = trt.Dims(input_.shape) max_shape[0] = max_batch_size profile.set_shape(input_.name, min_shape, opt_shape, max_shape) return profile def load_engine(name, engine_path, models_dir, models_type, settings, input_shape): if not os.path.isfile(engine_path): serialized_engine = build_engine( name, models_dir, models_type, engine_path, settings, input_shape ) else: with open(engine_path, "rb") as f: serialized_engine = f.read() if not serialized_engine: raise ModelLoadingException(f"Cannot build {name} engine.") runtime = trt.Runtime(LOGGER) return runtime.deserialize_cuda_engine(serialized_engine)