115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
import os
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import pycuda.driver as cuda
|
|
import pycuda.autoinit # noqa
|
|
import tensorrt as trt
|
|
|
|
from .utils import do_inference, allocate_buffers, GiB
|
|
from ..utils import _read_spec_value
|
|
|
|
LOGGER = trt.Logger(trt.Logger.Severity.ERROR)
|
|
DEFAULT_MAX_BATCH_SIZE = 1
|
|
|
|
|
|
class Engine(object):
|
|
def __init__(self, model_path, settings):
|
|
self.stream = cuda.Stream()
|
|
self.input_shape = _read_spec_value(model_path, "input_shape")
|
|
self.engine = self._load_engine(model_path, settings)
|
|
self.context = self.engine.create_execution_context()
|
|
(self.inputs, self.outputs, self.data, self.bindings) = allocate_buffers(
|
|
self.engine
|
|
)
|
|
|
|
def _load_engine(self, model_path, settings):
|
|
engine_dirpath = model_path
|
|
try:
|
|
import paravision_models.liveness
|
|
|
|
if paravision_models.liveness.location() == model_path:
|
|
engine_dirpath = paravision_models.liveness.TRT_ENGINE_PATH
|
|
except (ImportError, NameError, AttributeError):
|
|
pass
|
|
|
|
runtime = trt.Runtime(LOGGER)
|
|
|
|
engine_path = "{}/liveness.engine".format(engine_dirpath)
|
|
if os.path.isfile(engine_path) is False:
|
|
return self._build_engine(model_path, engine_path, settings)
|
|
|
|
with open(engine_path, "rb") as f:
|
|
return runtime.deserialize_cuda_engine(f.read())
|
|
|
|
def _build_engine(self, model_path, engine_path, settings):
|
|
model_file = "{}/liveness.onnx".format(model_path)
|
|
|
|
max_batch_size = settings.get("max_batch_size", DEFAULT_MAX_BATCH_SIZE)
|
|
|
|
trt_version = int(trt.__version__.split(".")[0])
|
|
if trt_version >= 7:
|
|
input_shape = [max_batch_size, 3] + list(self.input_shape)
|
|
net_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
elif trt_version == 6:
|
|
input_shape = [3] + list(self.input_shape)
|
|
net_flags = 0
|
|
else:
|
|
raise Exception("TensorRT version 6 or higher required to build engine")
|
|
|
|
if os.path.isfile(model_file) is False:
|
|
raise Exception("No model found at {}".format(model_file))
|
|
|
|
with open(model_file, "rb") as f:
|
|
model = f.read()
|
|
|
|
with trt.Builder(LOGGER) as builder, builder.create_network(
|
|
net_flags
|
|
) as network, trt.OnnxParser(network, LOGGER) as parser:
|
|
|
|
builder.max_workspace_size = GiB(1)
|
|
builder.max_batch_size = max_batch_size
|
|
|
|
if not parser.parse(model):
|
|
raise Exception("Cannot parse liveness model.")
|
|
|
|
network.get_input(0).shape = input_shape
|
|
engine = builder.build_cuda_engine(network)
|
|
|
|
serialized = engine.serialize()
|
|
if serialized is None:
|
|
raise Exception("Cannot serialize engine")
|
|
|
|
with open(engine_path, "wb") as f:
|
|
f.write(serialized)
|
|
|
|
return engine
|
|
|
|
def predict(self, exp_bb_depth_imgs):
|
|
max_batch_size = self.engine.max_batch_size
|
|
live_probs = []
|
|
for i in range(0, len(exp_bb_depth_imgs), max_batch_size):
|
|
batch = exp_bb_depth_imgs[
|
|
i : min(len(exp_bb_depth_imgs), i + max_batch_size)
|
|
]
|
|
probs_batch = self._batch_predict(batch)
|
|
live_probs.extend(probs_batch)
|
|
|
|
return live_probs
|
|
|
|
def _batch_predict(self, np_imgs):
|
|
stacked = [np.stack([np_img for _ in range(3)], axis=0) for np_img in np_imgs]
|
|
np_imgs = np.asarray(stacked, dtype=np.float32)
|
|
results = do_inference(
|
|
self.context,
|
|
bindings=self.bindings,
|
|
inputs=self.inputs,
|
|
input_data=[np_imgs.ravel()],
|
|
outputs=self.outputs,
|
|
output_data=self.data,
|
|
stream=self.stream,
|
|
)
|
|
|
|
# grab every other value to return the live probabilities
|
|
return results[0][0 : 2 * len(np_imgs) : 2]
|