# -*- coding: utf-8 -*-
from ctypes import *
import numpy as np
from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
from .struct import *
from .tensor import *
[文档]class LiteOptions(Structure):
"""
the inference options which can optimize the network forwarding
performance
Attributes:
weight_preprocess: is the option which optimize the inference performance
with processing the weights of the network ahead
fuse_preprocess: fuse preprocess patten, like astype + pad_channel +
dimshuffle
fake_next_exec: whether only to perform non-computing tasks (like
memory allocation and queue initialization) for next exec. This will be
reset to false when the graph is executed.
var_sanity_check_first_run: Disable var sanity check on the first run.
Var sanity check is enabled on the first-time execution by default, and can
be used to find some potential memory access errors in the operator
const_shape: used to reduce memory usage and improve performance since some
static inference data structures can be omitted and some operators can be
compute before forwarding
force_dynamic_alloc: force dynamic allocate memory for all vars
force_output_dynamic_alloc: force dynamic allocate memory for output tensor
which are used as the input of CallbackCaller Operator
no_profiling_on_shape_change: do not re-profile to select best implement
algo when input shape changes (use previous algo)
jit_level: Execute supported operators with JIT, please check with MGB_JIT_BACKEND
for more details, this value indicates JIT level:
level 1: for JIT execute with basic elemwise operator
level 2: for JIT execute elemwise and reduce operators
record_level: flags to optimize the inference performance with record the
kernel tasks in first run, hereafter the inference all need is to execute the
recorded tasks.
level = 0 means the normal inference
level = 1 means use record inference
level = 2 means record inference with free the extra memory
graph_opt_level: network optimization level:
0: disable
1: level-1: inplace arith transformations during graph construction
2: level-2: level-1, plus global optimization before graph compiling
3: also enable JIT
async_exec_level: level of dispatch on separate threads for different comp_node.
0: do not perform async dispatch
1: dispatch async if there are more than one comp node with limited queue
mask 0b10: async if there are multiple comp nodes with
mask 0b100: always async
Examples:
.. code-block::
from megenginelite import *
options = LiteOptions()
options.weight_preprocess = true
options.record_level = 1
options.fuse_preprocess = true
"""
_fields_ = [
("weight_preprocess", c_int),
("fuse_preprocess", c_int),
("fake_next_exec", c_int),
("var_sanity_check_first_run", c_int),
("const_shape", c_int),
("force_dynamic_alloc", c_int),
("force_output_dynamic_alloc", c_int),
("force_output_use_user_specified_memory", c_int),
("no_profiling_on_shape_change", c_int),
("jit_level", c_int),
("comp_node_seq_record_level", c_int),
("graph_opt_level", c_int),
("async_exec_level", c_int),
# layout transform options
("enable_nchw44", c_int),
("enable_nchw44_dot", c_int),
("enable_nchw88", c_int),
("enable_nhwcd4", c_int),
("enable_nchw4", c_int),
("enable_nchw32", c_int),
("enable_nchw64", c_int),
]
def __init__(self):
self.weight_preprocess = False
self.fuse_preprocess = False
self.fake_next_exec = False
self.var_sanity_check_first_run = True
self.const_shape = False
self.force_dynamic_alloc = False
self.force_output_dynamic_alloc = False
self.force_output_use_user_specified_memory = False
self.no_profiling_on_shape_change = False
self.jit_level = 0
self.comp_node_seq_record_level = 0
self.graph_opt_level = 2
self.async_exec_level = 1
def __repr__(self):
data = {
"weight_preprocess": bool(self.weight_preprocess),
"fuse_preprocess": bool(self.fuse_preprocess),
"fake_next_exec": bool(self.fake_next_exec),
"var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
"const_shape": bool(self.const_shape),
"force_dynamic_alloc": bool(self.force_dynamic_alloc),
"force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
"force_output_use_user_specified_memory": bool(
self.force_output_use_user_specified_memory
),
"no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
"jit_level": self.jit_level,
"comp_node_seq_record_level": self.comp_node_seq_record_level,
"graph_opt_level": self.graph_opt_level,
"async_exec_level": self.async_exec_level,
}
return data.__repr__()
[文档]class LiteConfig(Structure):
"""
Configuration when load and compile a network
Attributes:
has_compression: flag whether the model is compressed, the compress
method is stored in the model
device_id: configure the device id of a network
device_type: configure the device type of a network
backend: configure the inference backend of a network, now only support
megengine
bare_model_cryption_name: is the bare model encryption method name, bare
model is not packed with json information, this encryption method name is
useful to decrypt the encrypted bare model
options: configuration of Options
auto_optimize_inference: lite will detect the device information add set the options heuristically
discrete_input_name: configure which input is composed of discrete multiple tensors
Examples:
.. code-block::
from megenginelite import *
config = LiteConfig()
config.has_compression = False
config.device_type = LiteDeviceType.LITE_CPU
config.backend = LiteBackend.LITE_DEFAULT
config.bare_model_cryption_name = "AES_default".encode("utf-8")
config.auto_optimize_inference = False
"""
_fields_ = [
("has_compression", c_int),
("device_id", c_int),
("device_type", c_int),
("backend", c_int),
("_bare_model_cryption_name", c_char_p),
("options", LiteOptions),
("auto_optimize_inference", c_int),
("discrete_input_name", c_char_p),
]
def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
self.device_type = device_type
if option:
self.options = option
else:
self.options = LiteOptions()
self._bare_model_cryption_name = c_char_p(b"")
self.use_loader_dynamic_param = 0
self.has_compression = 0
self.backend = LiteBackend.LITE_DEFAULT
self.auto_optimize_inference = 0
self.discrete_input_name = c_char_p(b"")
@property
def bare_model_cryption_name(self):
return self._bare_model_cryption_name.decode("utf-8")
@bare_model_cryption_name.setter
def bare_model_cryption_name(self, name):
if isinstance(name, str):
self._bare_model_cryption_name = name.encode("utf-8")
else:
assert isinstance(name, bytes), "name should be str or bytes type."
self._bare_model_cryption_name = name
def __repr__(self):
data = {
"has_compression": bool(self.has_compression),
"device_id": LiteDeviceType(self.device_id),
"device_type": LiteDeviceType(self.device_type),
"backend": LiteBackend(self.backend),
"bare_model_cryption_name": self.bare_model_cryption_name,
"options": self.options,
"auto_optimize_inference": self.auto_optimize_inference,
"discrete_input_name": self.discrete_input_name,
}
return data.__repr__()
class LiteExtraConfig(Structure):
"""
Extra configuration when load and compile the graph
disable_configure_by_model_info: disable the configuration dumped with
model, if set true, all configuration in the model will not apply, users
should configure the network.
"""
_fields_ = [
("disable_configure_by_model_info", c_int),
]
def __init__(self, disable_model_config=False):
self.disable_configure_by_model_info = disable_model_config
def __repr__(self):
data = {
"disable_configure_by_model_info": bool(
self.disable_configure_by_model_info
),
}
return data.__repr__()
[文档]class LiteIO(Structure):
"""
config the network input and output item, the input and output tensor
information will describe there
Attributes:
name: the tensor name in the graph corresponding to the IO
is_host: Used to mark where the input tensor comes from and where the output
tensor will copy to, if is_host is true, the input is from host and output copy
to host, otherwise in device. Sometimes the input is from device and output no need
copy to host, default is true.
io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
output tensor value is invaid, only shape will be set, default is VALUE
config_layout: The layout of the config from user, if other layout is set before
forward or get after forward, this layout will by pass. if no other
layout is set before forward, this layout will work. if this layout is
no set, the model will forward with its origin layout. if in output, it
will used to check.
Note:
if other layout is set to input tensor before forwarding, this layout will not work
if no layout is set before forwarding, the model will forward with its origin layout
if layout is set in output tensor, it will used to check whether the layout computed from the network is correct
Examples:
.. code-block::
from megenginelite import *
io = LiteIO(
"data2",
is_host=True,
io_type=LiteIOType.LITE_IO_SHAPE,
layout=LiteLayout([2, 4, 4]),
)
"""
_fields_ = [
("_name", c_char_p),
("is_host", c_int),
("io_type", c_int),
("config_layout", LiteLayout),
]
def __init__(
self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
):
if type(name) == str:
self._name = c_char_p(name.encode("utf-8"))
else:
self._name = c_char_p(name)
if layout:
self.config_layout = layout
else:
self.config_layout = LiteLayout()
self.is_host = is_host
self.io_type = io_type
@property
def name(self):
"""
get the name of IO item
"""
return self._name.decode("utf-8")
@name.setter
def name(self, name):
"""
set the name of IO item
"""
if isinstance(name, str):
self._name = name.encode("utf-8")
else:
assert isinstance(name, bytes), "name should be str or bytes type."
self._name = name
def __repr__(self):
data = {
"name": self.name,
"is_host": bool(self.is_host),
"io_type": LiteIOType(self.io_type),
"config_layout": self.config_layout,
}
return data.__repr__()
def __hash__(self):
return hash(self.name)
class _LiteNetworkIO(Structure):
_fields_ = [
("inputs", POINTER(LiteIO)),
("outputs", POINTER(LiteIO)),
("input_size", c_size_t),
("output_size", c_size_t),
]
def __init__(self):
self.inputs = POINTER(LiteIO)()
self.outputs = POINTER(LiteIO)()
self.input_size = 0
self.output_size = 0
[文档]class LiteNetworkIO(object):
"""
the input and output information when load the network for user
the NetworkIO will remain in the network until the network is destroyed.
Attributes:
inputs: The all input tensors information that will configure to the network
outputs: The all output tensors information that will configure to the network
Examples:
.. code-block::
from megenginelite import *
input_io = LiteIO("data", is_host=False, io_type=LiteIOType.LITE_IO_VALUE)
io = LiteNetworkIO()
io.add_input(input_io)
output_io = LiteIO("out", is_host=True, layout=LiteLayout([1, 1000]))
io.add_output(output_io)
"""
def __init__(self, inputs=None, outputs=None):
self.inputs = []
self.outputs = []
if inputs:
for i in inputs:
if isinstance(i, list):
self.inputs.append(LiteIO(*i))
else:
assert isinstance(
i, LiteIO
), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
self.inputs.append(i)
if outputs:
for i in outputs:
if isinstance(i, list):
self.outputs.append(LiteIO(*i))
else:
assert isinstance(
i, LiteIO
), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
self.outputs.append(i)
[文档] def add_output(
self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
):
"""
add output information into LiteNetworkIO
"""
if isinstance(obj, LiteIO):
self.outputs.append(obj)
else:
name = obj
self.add_output(LiteIO(name, is_host, io_type, layout))
def _create_network_io(self):
network_io = _LiteNetworkIO()
length = 1 if len(self.inputs) == 0 else len(self.inputs)
self.c_inputs = (LiteIO * length)(*self.inputs)
length = 1 if len(self.outputs) == 0 else len(self.outputs)
self.c_outputs = (LiteIO * length)(*self.outputs)
network_io.inputs = pointer(self.c_inputs[0])
network_io.outputs = pointer(self.c_outputs[0])
network_io.input_size = len(self.inputs)
network_io.output_size = len(self.outputs)
return network_io
def __repr__(self):
data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
return data.__repr__()
LiteAsyncCallback = CFUNCTYPE(c_int)
LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
def wrap_async_callback(func):
global wrapper
@CFUNCTYPE(c_int)
def wrapper():
return func()
return wrapper
def start_finish_callback(func):
global wrapper
@CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
def wrapper(c_ios, c_tensors, size):
ios = {}
for i in range(size):
tensor = LiteTensor(physic_construct=False)
tensor._tensor = c_void_p(c_tensors[i])
tensor.update()
io = c_ios[i]
ios[io] = tensor
return func(ios)
return wrapper
class _NetworkAPI(_LiteCObjBase):
"""
get the network api from the lib
"""
_api_ = [
("LITE_make_default_network", [POINTER(_Cnetwork)]),
("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
("LITE_destroy_network", [_Cnetwork]),
("LITE_forward", [_Cnetwork]),
("LITE_wait", [_Cnetwork]),
("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
("LITE_set_device_id", [_Cnetwork, c_int]),
("LITE_set_cpu_inplace_mode", [_Cnetwork]),
("LITE_use_tensorrt", [_Cnetwork]),
("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
("LITE_set_stream_id", [_Cnetwork, c_int]),
("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
("LITE_enable_global_layout_transform", [_Cnetwork]),
("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
(
"LITE_get_model_io_info_by_path",
[c_char_p, LiteConfig, POINTER(_LiteNetworkIO)],
),
(
"LITE_get_model_io_info_by_memory",
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
),
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
(
"LITE_get_discrete_tensor",
[_Cnetwork, c_char_p, c_size_t, c_int, POINTER(_Ctensor)],
),
]
[文档]class LiteNetwork(object):
"""
the network to load a model and forward
Examples:
.. code-block::
from megenginelite import *
config = LiteConfig()
config.device_type = LiteDeviceType.LITE_CPU
network = LiteNetwork(config)
network.load("model_path")
input_name = network.get_input_name(0)
input_tensor = network.get_io_tensor(input_name)
output_name = network.get_output_name(0)
output_tensor = network.get_io_tensor(output_name)
input_tensor.set_data_by_copy(input_data)
network.forward()
network.wait()
"""
_api = _NetworkAPI()._lib
def __init__(self, config=None, io=None):
"""
create a network with config and networkio
"""
self._network = _Cnetwork()
if config:
self.config = config
else:
self.config = LiteConfig()
if io:
self.network_io = io
else:
self.network_io = LiteNetworkIO()
c_network_io = self.network_io._create_network_io()
self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
def __repr__(self):
data = {"config": self.config, "IOs": self.network_io}
return data.__repr__()
def __del__(self):
self._api.LITE_destroy_network(self._network)
[文档] def load(self, path):
"""
load network from given path
"""
c_path = c_char_p(path.encode("utf-8"))
self._api.LITE_load_model_from_path(self._network, c_path)
[文档] def forward(self):
"""
forward the network with filled input data and fill the output data
to the output tensor
"""
self._api.LITE_forward(self._network)
[文档] def wait(self):
"""
wait until forward finish in sync model
"""
self._api.LITE_wait(self._network)
[文档] def is_cpu_inplace_mode(self):
"""
whether the network run in cpu inpalce mode
Returns:
if use inpalce mode return True, else return False
"""
inplace = c_int()
self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
return bool(inplace.value)
[文档] def enable_cpu_inplace_mode(self):
"""
set cpu forward in inplace mode with which cpu forward only create one
thread
Note:
this must be set before the network loaded
"""
self._api.LITE_set_cpu_inplace_mode(self._network)
[文档] def use_tensorrt(self):
"""
use TensorRT
Note:
this must be set before the network loaded
"""
self._api.LITE_use_tensorrt(self._network)
@property
def device_id(self):
"""
get the device id
Returns:
the device id of current network used
"""
device_id = c_int()
self._api.LITE_get_device_id(self._network, byref(device_id))
return device_id.value
@device_id.setter
def device_id(self, device_id):
"""
set the device id
Note:
this must be set before the network loaded
"""
self._api.LITE_set_device_id(self._network, device_id)
@property
def stream_id(self):
"""
get the stream id
Returns:
the value of stream id set for detwork
"""
stream_id = c_int()
self._api.LITE_get_stream_id(self._network, byref(stream_id))
return stream_id.value
@stream_id.setter
def stream_id(self, stream_id):
"""
set the stream id
Note:
this must be set before the network loaded
"""
self._api.LITE_set_stream_id(self._network, stream_id)
@property
def threads_number(self):
"""
get the thread number of the netwrok
Returns:
the number of thread set in the network
"""
nr_thread = c_size_t()
self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
return nr_thread.value
@threads_number.setter
def threads_number(self, nr_threads):
"""
set the network forward in multithread mode, and the thread number
Note:
this must be set before the network loaded
"""
self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
[文档] def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
"""
get input or output tensor by its name
Args:
name: the name of io tensor
phase: the type of LiteTensor, this is useful to separate input or output tensor with the same name
Returns:
the tensor with given name and type
"""
if type(name) == str:
c_name = c_char_p(name.encode("utf-8"))
else:
c_name = c_char_p(name)
tensor = LiteTensor(physic_construct=False)
self._api.LITE_get_io_tensor(
self._network, c_name, phase, byref(tensor._tensor)
)
tensor.update()
return tensor
[文档] def get_discrete_tensor(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT):
"""
get the n_idx'th tensor in the network input tensors whose
input consists of discrete multiple tensors and tensor name is name
Args:
name: the name of input tensor
n_idx: the tensor index
phase: the type of LiteTensor, this is useful to separate input tensor with the same name
Returns:
the tensors with given name and type
"""
if type(name) == str:
c_name = c_char_p(name.encode("utf-8"))
else:
c_name = c_char_p(name)
tensor = LiteTensor(physic_construct=False)
self._api.LITE_get_discrete_tensor(
self._network, c_name, n_idx, phase, byref(tensor._tensor)
)
tensor.update()
return tensor
[文档] def get_output_name(self, index):
"""
get the output name by the index in the network
Args:
index: the index of the output name
Returns:
the name of output tesor with given index
"""
c_name = c_char_p()
self._api.LITE_get_output_name(self._network, index, byref(c_name))
return c_name.value.decode("utf-8")
[文档] def get_all_output_name(self):
"""
get all the output tensor name in the network
Returns:
the names of all output tesor in the network
"""
nr_output = c_size_t()
self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
if nr_output.value > 0:
names = (c_char_p * nr_output.value)()
self._api.LITE_get_all_output_name(self._network, None, names)
ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
return ret_name
[文档] def share_weights_with(self, src_network):
"""
share weights with the loaded network
Args:
src_network: the network to share weights
"""
assert isinstance(src_network, LiteNetwork)
self._api.LITE_shared_weight_with_network(self._network, src_network._network)
[文档] def share_runtime_memroy(self, src_network):
"""
share runtime memory with the srouce network
Args:
src_network: the network to share runtime memory
"""
assert isinstance(src_network, LiteNetwork)
self._api.LITE_share_runtime_memroy(self._network, src_network._network)
[文档] def async_with_callback(self, async_callback):
"""
set the network forwarding in async mode and set the AsyncCallback callback
function
Args:
async_callback: the callback to set for network
"""
callback = wrap_async_callback(async_callback)
self._api.LITE_set_async_callback(self._network, callback)
[文档] def set_start_callback(self, start_callback):
"""
when the network start forward, the callback will be called,
the start_callback with param mapping from LiteIO to the corresponding
LiteTensor
Args:
start_callback: the callback to set for network
"""
callback = start_finish_callback(start_callback)
self._api.LITE_set_start_callback(self._network, callback)
[文档] def set_finish_callback(self, finish_callback):
"""
when the network finish forward, the callback will be called,
the finish_callback with param mapping from LiteIO to the corresponding
LiteTensor
Args:
finish_callback: the callback to set for network
"""
callback = start_finish_callback(finish_callback)
self._api.LITE_set_finish_callback(self._network, callback)
[文档] def set_network_algo_workspace_limit(self, size_limit):
"""
set the opr workspace limitation in the target network, some opr
maybe use large of workspace to get good performance, set workspace limitation
can save memory but may influence the performance
Args:
size_limit: the byte size of workspace limitation
"""
self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
[文档] def set_network_algo_policy(
self, policy, shared_batch_size=0, binary_equal_between_batch=False
):
"""
set the network algorithm search policy for fast-run
Args:
shared_batch_size: the batch size used by fastrun,
Non-zero value means that fastrun use this batch size
regardless of the batch size of the model. Zero means
fastrun use batch size of the model
binary_equal_between_batch: if the content of each input batch is
binary equal,whether the content of each output batch is
promised to be equal
"""
self._api.LITE_set_network_algo_policy(self._network, policy)
self._api.LITE_set_network_algo_fastrun_config(
self._network, shared_batch_size, binary_equal_between_batch
)
[文档] def io_txt_dump(self, txt_file):
"""
dump all input/output tensor of all operators to the output file, in txt
format, user can use this function to debug compute error
Args:
txt_file: the txt file
"""
c_file = txt_file.encode("utf-8")
self._api.LITE_enable_io_txt_dump(self._network, c_file)
[文档] def io_bin_dump(self, bin_dir):
"""
dump all input/output tensor of all operators to the output file, in
binary format, user can use this function to debug compute error
Args:
bin_dir: the binary file directory
"""
c_dir = bin_dir.encode("utf-8")
self._api.LITE_enable_io_bin_dump(self._network, c_dir)
[文档] def get_static_memory_alloc_info(self, log_dir="logs/test"):
"""
get static peak memory info showed by Graph visualization
Args:
log_dir: the directory to save information log
"""
c_log_dir = log_dir.encode("utf-8")
self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)
def get_model_io_info(model_path, config=None):
"""
get the model io information before model loaded by model path.
Args:
model_path: the model path to get the model IO information
config the model configuration
Returns:
the input and output information in the network configuration
"""
api = _NetworkAPI()._lib
c_path = c_char_p(model_path.encode("utf-8"))
ios = _LiteNetworkIO()
if config is not None:
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
else:
config = LiteConfig()
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
ret_ios = LiteNetworkIO()
for i in range(ios.input_size):
ret_ios.add_input(ios.inputs[i])
for i in range(ios.output_size):
ret_ios.add_output(ios.outputs[i])
return ret_ios