diff --git a/hls4ml/backends/oneapi/oneapi_backend.py b/hls4ml/backends/oneapi/oneapi_backend.py index 0c11c16d09..94f26c9f1c 100644 --- a/hls4ml/backends/oneapi/oneapi_backend.py +++ b/hls4ml/backends/oneapi/oneapi_backend.py @@ -19,7 +19,6 @@ Embedding, Layer, SimpleRNN, - Softmax, ) from hls4ml.model.optimizer import get_backend_passes, layer_optimizer from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType @@ -257,13 +256,6 @@ def init_activation(self, layer): if layer.get_attr('recurrent_activation') == 'tanh': layer.set_attr('recurrent_activation', 'dense_tanh') - @layer_optimizer(Softmax) - def init_softmax(self, layer): - if layer.model.config.get_config_value('IOType') == 'io_parallel': - assert len(layer.get_input_variable().shape) == 1, ( - 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.' - ) - @layer_optimizer(Embedding) def init_embed(self, layer): if layer.attributes['n_in'] is None: diff --git a/hls4ml/backends/oneapi/passes/core_templates.py b/hls4ml/backends/oneapi/passes/core_templates.py index 9602b2d0fc..c6050dfb57 100644 --- a/hls4ml/backends/oneapi/passes/core_templates.py +++ b/hls4ml/backends/oneapi/passes/core_templates.py @@ -1,7 +1,10 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.oneapi.oneapi_template import StreamFunctionCallTemplate, TaskSequenceTemplate from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax +from hls4ml.utils.fixed_point_utils import FixedPointEmulator, ceil_log2, uint_to_binary +import numpy as np # Dense templates @@ -194,12 +197,26 @@ def format(self, node): softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{ static constexpr unsigned n_in = {n_in}; - static constexpr unsigned table_size = {table_size}; + static constexpr unsigned exp_table_size = {exp_table_size}; + static constexpr unsigned inv_table_size = {inv_table_size}; static constexpr unsigned io_type = nnet::{iotype}; static constexpr unsigned reuse_factor = {reuse}; static constexpr nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation}; typedef {exp_table_t.name} exp_table_t; - typedef {inv_table_t.name} inv_table_t; + typedef {inv_table_t.name} inv_table_t;""" + +softmax_config_table_template = """ + + static constexpr const exp_table_t *exp_table = &{exp_table_name}[0]; + static constexpr const inv_table_t *invert_table = &{inv_table_name}[0]; +}};\n""" + +softmax_config_table_template_stable = """ + typedef {inv_inp_t.name} inv_inp_t; + typedef {inp_norm_t.name} inp_norm_t; + + static constexpr const exp_table_t *exp_table = &{exp_table_name}[0]; + static constexpr const inv_table_t *invert_table = &{inv_table_name}[0]; }};\n""" activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});' @@ -220,7 +237,58 @@ def __init__(self): def format(self, node): params = self._default_config_params(node) params['type'] = node.get_attr('activation') - + + if params['type'] == 'softmax': + + if 'exp_table_size' in params: + params['exp_table_size'] //= 2 + else: + params['exp_table_size'] = 1024 + + params['exp_table_t'].precision.width = ceil_log2(params['exp_table_size']) + params['exp_table_t'].precision.integer = 3 + params['exp_table_t'].precision.signed = False + + if 'inp_norm_t' not in params: + input_t = node.get_input_variable().type.precision + width, iwidth, signed = input_t.width, input_t.integer, input_t.signed # noqa: F841 + width, iwidth = width - signed, iwidth - signed + import copy + params['inp_norm_t'] = copy.deepcopy(params['exp_table_t']) #assign type,later override + + #this checks if table sizes will be default, if it is just use the table size to derive precision + if 'inv_table_size' not in params: + params['inp_norm_t'].precision.width = params['exp_table_t'].precision.width + 1 + params['inp_norm_t'].precision.integer = params['exp_table_t'].precision.integer + 1 + params['inp_norm_t'].precision.signed = True + params['inp_norm_t'].name = f'{node.name}_inp_norm_t' + else: + params['inp_norm_t'].name = f'ac_fixed<{width},{iwidth},{'true' if signed else 'false'},AC_RND,AC_SAT_SYM>' + + node.set_attr('inp_norm_t', params['inp_norm_t']) + + if 'inv_table_size' in params: + params['inv_table_size'] //= 2 + else: + params['inv_table_size'] = 1024 + + params['inv_table_t'].precision.width = ceil_log2(params['inv_table_size']) + params['inv_table_t'].precision.integer = 3 + params['inv_table_t'].precision.signed = False + + params['inv_inp_t'].precision.width = params['inv_table_t'].precision.width + 1 + params['inv_inp_t'].precision.integer = params['inv_table_t'].precision.integer + 1 + params['inv_inp_t'].precision.signed = True + + + if params['implementation'] == 'stable': + self.template += softmax_config_table_template_stable + else: + self.template += softmax_config_table_template + + params['exp_table_name'] = node.name + '_exp_table' + params['inv_table_name'] = node.name + '_inv_table' + return self.template.format(**params) diff --git a/hls4ml/converters/keras_v3/hgq2/multi_head_attention.py b/hls4ml/converters/keras_v3/hgq2/multi_head_attention.py index 7154d0c9ca..09723f5336 100644 --- a/hls4ml/converters/keras_v3/hgq2/multi_head_attention.py +++ b/hls4ml/converters/keras_v3/hgq2/multi_head_attention.py @@ -14,7 +14,7 @@ class QMultiHeadAttentionHandler(QLayerHandler): - handles = ('hgq.layers.multi_head_attention.QMultiHeadAttention',) + handles = ('hgq.layers.attn.mha.QMultiHeadAttention',) def handle( self, @@ -127,7 +127,7 @@ def _handle(self, layer, tensor_q, tensor_O, node_index, tensor_k, tensor_v): class QLinformerAttentionHandler(QMultiHeadAttentionHandler): - handles = ('hgq.layers.linformer_attention.QLinformerAttention',) + handles = ('hgq.layers.attn.linformer.QLinformerAttention',) def handle( self, diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation.h index f118ecb05c..385457204d 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation.h @@ -99,18 +99,21 @@ template void sigmoid(const data_ enum class softmax_implementation { latency = 0, legacy = 1, stable = 2, argmax = 3 }; -template inline unsigned softmax_stable_idx_from_real_val(const data_T x) { + +template inline unsigned softmax_stable_idx_from_real_val(const data_T x) { // Number of address bits for table - static constexpr int N = ceillog2::val; + static constexpr int N = ceillog2::val; // Slice the top N bits of the input [[intel::fpga_register]] ac_int y = x.template slc(x.width - N - 1); + // If x is the most negative value, the slice will be 0, so we need to set the 0-th bit to ensure correctness if (x != 0 && y == 0) y[0] = 1; return y.to_uint(); } + template inline unsigned softmax_latency_idx_from_real_val(const data_T x) { // Number of address bits for table static constexpr int N = ceillog2::val; @@ -120,10 +123,8 @@ template inline unsigned softmax_latency_idx_f return y.to_uint(); } + template void softmax_stable(const data_T &data, res_T &res) { -// Look-up tables -#include "activation_tables/exp_table.tb" -#include "activation_tables/invert_table.tb" // Find maximum Op_max op_max; @@ -131,8 +132,8 @@ template void softmax_stable(cons reduce>(data.data(), op_max); // For the diffs, use the same type as the input but force rounding and saturation - [[intel::fpga_register]] ac_fixed - d_xi_xmax[CONFIG_T::n_in]; + [[intel::fpga_register]] + typename CONFIG_T::inp_norm_t d_xi_xmax[CONFIG_T::n_in]; #pragma unroll for (unsigned i = 0; i < CONFIG_T::n_in; i++) { d_xi_xmax[i] = data[i] - x_max; @@ -142,23 +143,25 @@ template void softmax_stable(cons [[intel::fpga_register]] typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; #pragma unroll for (unsigned i = 0; i < CONFIG_T::n_in; i++) { - exp_res[i] = exp_table[softmax_stable_idx_from_real_val(d_xi_xmax[i])]; + exp_res[i] = CONFIG_T::exp_table[softmax_stable_idx_from_real_val(d_xi_xmax[i])]; //input_t, CONFIG_T } // Explicitly sum previously calculated exponentials with an adder tree Op_add op_add; - [[intel::fpga_register]] typename CONFIG_T::exp_table_t exp_sum = + [[intel::fpga_register]] typename CONFIG_T::inv_inp_t exp_sum = reduce>(exp_res, op_add); // Multiply previously calculated exponetials with the reciprocal of the sum [[intel::fpga_register]] typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_stable_idx_from_real_val(exp_sum)]; + CONFIG_T::invert_table[softmax_stable_idx_from_real_val(exp_sum)]; + #pragma unroll for (unsigned i = 0; i < CONFIG_T::n_in; i++) { res[i] = exp_res[i] * inv_exp_sum; } } + // TODO - Improve accuracy template void softmax_latency(const data_T &data, res_T &res) { #include "activation_tables/exp_table_latency.tb" @@ -265,6 +268,45 @@ template inline void softmax(cons } } +// ************************************************* +// Multidimensional Softmax +// ************************************************* + +// Helper to remap the config for the core softmax function +template struct softmax_multidim_slice_config : CONFIG_T { + static constexpr unsigned n_in = CONFIG_T::n_slice; +}; + +template inline void softmax_multidim(const data_T &data, res_T &res) { + using buffer_data_t = std::array; + using buffer_res_t = std::array; + using slice_config = softmax_multidim_slice_config; + + #pragma unroll + for (unsigned i = 0; i < CONFIG_T::n_outer; i++) { + #pragma unroll + for (unsigned k = 0; k < CONFIG_T::n_inner; k++) { + + [[intel::fpga_register]] buffer_data_t buffer_in; + [[intel::fpga_register]] buffer_res_t buffer_out; + + // Gather Phase + #pragma unroll + for (unsigned j = 0; j < CONFIG_T::n_slice; j++) { + unsigned idx = (i * CONFIG_T::n_slice * CONFIG_T::n_inner) + (j * CONFIG_T::n_inner) + k; + buffer_in[j] = data[idx]; + } + + nnet::softmax(buffer_in, buffer_out); + + #pragma unroll + for (unsigned j = 0; j < CONFIG_T::n_slice; j++) { + unsigned idx = (i * CONFIG_T::n_slice * CONFIG_T::n_inner) + (j * CONFIG_T::n_inner) + k; + res[idx] = buffer_out[j]; + } + } + } +} // ************************************************* // TanH Activation // ************************************************* diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation_stream.h index e860c38988..d640f89f7e 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation_stream.h @@ -271,64 +271,63 @@ template void softsign_stre // ************************************************* template void softmax_stable_stream() { -#include "activation_tables/exp_table.tb" -#include "activation_tables/invert_table.tb" + + using input_arr_t = typename ExtractPipeType::value_type; + using input_t = typename ExtractPipeType::value_type::value_type; + constexpr unsigned input_arr_size = std::tuple_size{}; + constexpr unsigned multiplier_limit = - DIV_ROUNDUP(std::tuple_size::value_type>{}, CONFIG_T::reuse_factor); - constexpr unsigned pipeline = std::tuple_size::value_type>{} / multiplier_limit; + DIV_ROUNDUP(input_arr_size, CONFIG_T::reuse_factor); + constexpr unsigned pipeline = input_arr_size / multiplier_limit; - [[intel::fpga_register]] typename ExtractPipeType::value_type::value_type - data_array[std::tuple_size::value_type>{}]; + + [[intel::fpga_register]] input_t data_array[input_arr_size]; SoftmaxArrayLoop: - [[intel::initiation_interval(pipeline)]] for (unsigned i = 0; - i < CONFIG_T::n_in / - std::tuple_size::value_type>{}; - i++) { + [[intel::initiation_interval(pipeline)]] + for (unsigned i = 0; i < CONFIG_T::n_in / input_arr_size; i++) { auto in_pack = data_pipe::read(); SoftmaxArrayPackLoop: #pragma unroll - for (unsigned j = 0; j < std::tuple_size::value_type>{}; j++) { + for (unsigned j = 0; j < input_arr_size; j++) { data_array[j] = in_pack[j]; } // Find the max and compute all delta(x_i, x_max) - Op_max::value_type::value_type> op_max; - [[intel::fpga_register]] typename ExtractPipeType::value_type::value_type x_max = - reduce::value_type::value_type, - std::tuple_size::value_type>{}, - Op_max::value_type::value_type>>(data_array, op_max); - - // For the diffs, use the same type as the input but force rounding and saturation - [[intel::fpga_register]] ac_fixed::value_type::value_type::width, - ExtractPipeType::value_type::value_type::i_width, true, AC_RND, AC_SAT> - d_xi_xmax[std::tuple_size::value_type>{}]; + Op_max op_max; + [[intel::fpga_register]] + input_t x_max = reduce>(data_array, op_max); + + [[intel::fpga_register]] + typename CONFIG_T::inp_norm_t d_xi_xmax[input_arr_size]; + #pragma unroll - for (unsigned j = 0; j < std::tuple_size::value_type>{}; j++) { + for (unsigned j = 0; j < input_arr_size; j++) { d_xi_xmax[j] = data_array[j] - x_max; } // Calculate all the e^x's [[intel::fpga_register]] - typename CONFIG_T::exp_table_t exp_res[std::tuple_size::value_type>{}]; + typename CONFIG_T::exp_table_t exp_res[input_arr_size]; + #pragma unroll - for (unsigned j = 0; j < std::tuple_size::value_type>{}; j++) { + for (unsigned j = 0; j < input_arr_size; j++) { exp_res[j] = - exp_table[softmax_stable_idx_from_real_val::value_type::value_type, - CONFIG_T>(d_xi_xmax[j])]; + CONFIG_T::exp_table[softmax_stable_idx_from_real_val(d_xi_xmax[j])]; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing Op_add op_add; - [[intel::fpga_register]] typename CONFIG_T::exp_table_t exp_sum = - reduce::value_type>{}, + [[intel::fpga_register]] typename CONFIG_T::inv_inp_t exp_sum = + reduce>(exp_res, op_add); [[intel::fpga_register]] typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_stable_idx_from_real_val(exp_sum)]; + CONFIG_T::invert_table[softmax_stable_idx_from_real_val(exp_sum)]; + typename ExtractPipeType::value_type out_pack; SoftmaxInvPackLoop: diff --git a/hls4ml/templates/oneapi/firmware/parameters.h b/hls4ml/templates/oneapi/firmware/parameters.h index 717059f1e8..ef4e5d26b9 100644 --- a/hls4ml/templates/oneapi/firmware/parameters.h +++ b/hls4ml/templates/oneapi/firmware/parameters.h @@ -6,6 +6,8 @@ #include "nnet_utils/nnet_code_gen.h" #include "nnet_utils/nnet_helpers.h" +// hls-fpga-machine-learning insert softmax tables + // hls-fpga-machine-learning insert includes // hls-fpga-machine-learning insert layer-config diff --git a/hls4ml/writer/oneapi_writer.py b/hls4ml/writer/oneapi_writer.py index 3c0a778c50..320afa74db 100644 --- a/hls4ml/writer/oneapi_writer.py +++ b/hls4ml/writer/oneapi_writer.py @@ -302,6 +302,14 @@ def write_parameters(self, model): config = layer.get_attr('config_cpp', None) if config: newline += config + '\n' + + elif '// hls-fpga-machine-learning insert softmax tables' in line: + newline = line + for layer in model.get_layers(): + if 'softmax' in layer.name: + newline += f'#include "nnet_utils/activation_tables/{layer.name}_exp_table.h"\n' + newline += f'#include "nnet_utils/activation_tables/{layer.name}_inv_table.h"\n' + else: newline = line fout.write(newline) @@ -549,16 +557,16 @@ def write_nnet_utils(self, model): dstpath = f'{model.config.get_output_dir()}/src/firmware/{dst}' copyfile(srcpath, dstpath) - def __get_table_size(self, model, activation): + def __get_table_size(self, model, activation, table_name='table_size'): for layer in model.get_layers(): if ( layer.get_attr('activation') == activation or layer.get_attr('recurrent_activation') == activation - ) and layer.get_attr('table_size') is not None: - return int(layer.get_attr('table_size')) + ) and layer.get_attr(table_name) is not None: + return int(layer.get_attr(table_name)) return 1024 - def __get_table_header(self, table_name, table_size): - table_header = f'static const typename CONFIG_T::table_t {table_name}[{table_size}] = {{' + def __get_table_header(self, table_name, table_size, table_type='table_t'): + table_header = f'static const typename CONFIG_T::{table_type} {table_name}[{table_size}] = {{' return table_header def __write_elu_table(self, model, path): @@ -687,94 +695,125 @@ def __write_selu_table(self, model, path): h_file.write('};\n') h_file.close() - def __write_exp_table(self, model, path): - table_name = 'exp_table' - table_size = self.__get_table_size(model, 'softmax') - - h_file = open(f'{path}/{table_name}.tb', 'w') - h_file.write(self.__get_table_header(table_name, table_size)) - - # Default fixed point precision - # 6 bits for integer part, 10 bits for decimal - total, 16 - fp_bits = 16 - fp_integer = 6 - fp_signed = True - - # Exp table should use the same precision as exp_table, as seen in Vivado code - # init_exp_table(exp_table); + def __get_table_precision(self, model, activation, table_name='table_precision'): for layer in model.get_layers(): - if layer.name == 'softmax': - ac_type = layer.get_input_variable().type - if ac_type is not None: - try: - fp_bits = ac_type.precision.integer + ac_type.precision.fractional - fp_integer = ac_type.precision.integer - fp_signed = ac_type.precision.signed - except Exception: - # FixedPrecisionType wasn't correctly stored in layer attributes, use default values - pass - if fp_signed is False: - raise Exception('Softmax types need to be signed') + if layer.get_attr('activation') == activation and layer.get_attr(table_name) is not None: + precision = layer.get_attr(table_name) + return precision.precision - sep = '' - N = ceil_log2(table_size) - for i in range(table_size): - f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed) - b = uint_to_binary(i, N) - if i == 0: - b.insert(0, 0) - else: - b.insert(0, 1) - f.set_msb_bits(b) - real_val = f.exp_float() - h_file.write(sep + str(real_val)) - sep = ', ' + return None # fp_bits, fp_integer, fp_signed - h_file.write('};\n') - h_file.close() - def __write_invert_table(self, model, path): - table_name = 'invert_table' - table_size = self.__get_table_size(model, 'softmax') + def __write_exp_table(self, model, path): - h_file = open(f'{path}/{table_name}.tb', 'w') - h_file.write(self.__get_table_header(table_name, table_size)) + for layer in model.get_layers(): + + if 'softmax' in layer.name: + + table_name = layer.name + '_exp_table' + table_size = int(layer.get_attr('exp_table_size'))//2 if ( + layer.get_attr('activation') == 'softmax' or layer.get_attr('recurrent_activation') == 'softmax' + ) and layer.get_attr('exp_table_size') is not None else 1024 + + with open(f'{path}/{table_name}.h', 'w') as h_file: + + header_name = table_name + h_file.write(f'#ifndef {header_name.upper()}_H_\n') + h_file.write(f'#define {header_name.upper()}_H_\n\n') + + h_file.write(f'static constexpr {table_name}_t {table_name}[{table_size}] = {{') + + ac_type = layer.get_attr('inp_norm_t') + + if ac_type is not None: + try: + fp_bits = ac_type.precision.integer + ac_type.precision.fractional + fp_integer = ac_type.precision.integer + fp_signed = ac_type.precision.signed + except Exception: + # FixedPrecisionType wasn't correctly stored in layer attributes, use default values + fp_bits = 16 + fp_integer = 6 + fp_signed = True + + if fp_signed is False: + raise Exception('Softmax types need to be signed') + + else: + fp_bits = 16 + fp_integer = 6 + fp_signed = True + + sep = '' + N = ceil_log2(table_size) + for i in range(table_size): + f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed) + b = uint_to_binary(i, N) + if i == 0: + b.insert(0, 0) + else: + b.insert(0, 1) + f.set_msb_bits(b) + real_val = f.exp_float() + h_file.write(sep + str(real_val)) + sep = ', ' + + h_file.write('};\n\n') + h_file.write('#endif') - # Default fixed point precision, in case values from layer attributes cannot be extracted - # 8 bits for integer part, 10 bits for decimal - total, 18 - fp_bits = 18 - fp_integer = 8 - fp_signed = True - # Invert table should use the same precision as exp_table, as seen in Vivado code - # init_invert_table(invert_table); + def __write_invert_table(self, model, path): for layer in model.get_layers(): - if layer.name == 'softmax': - ac_type = layer.get_attr('exp_table_t') - if ac_type is not None: - try: - fp_bits = ac_type.precision.integer + ac_type.precision.fractional - fp_integer = ac_type.precision.integer - fp_signed = ac_type.precision.signed - except Exception: - # FixedPrecisionType wasn't correctly stored in layer attributes, use default values - pass - if fp_signed is False: - raise Exception('Softmax types need to be signed') - - sep = '' - N = ceil_log2(table_size) - for i in range(table_size): - f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed) - b = uint_to_binary(i, N) - b.insert(0, 0) - f.set_msb_bits(b) - real_val = f.inv_float() - h_file.write(sep + str(real_val)) - sep = ', ' + if 'softmax' in layer.name: + + table_name = layer.name + '_inv_table' + table_size = int(layer.get_attr('inv_table_size')) //2 if ( + layer.get_attr('activation') == 'softmax' or layer.get_attr('recurrent_activation') == 'softmax' + ) and layer.get_attr('inv_table_size') is not None else 1024 + + with open(f'{path}/{table_name}.h', 'w') as h_file: + + header_name = table_name + h_file.write(f'#ifndef {header_name.upper()}_H_\n') + h_file.write(f'#define {header_name.upper()}_H_\n\n') + + h_file.write(f'static constexpr {table_name}_t {table_name}[{table_size}] = {{') + + ac_type = layer.get_attr('inv_inp_t') + + if ac_type is not None: + try: + fp_bits = ac_type.precision.integer + ac_type.precision.fractional + fp_integer = ac_type.precision.integer + fp_signed = ac_type.precision.signed + except Exception: + # FixedPrecisionType wasn't correctly stored in layer attributes, use default values + fp_bits = 18 + fp_integer = 8 + fp_signed = True + + if fp_signed is False: + raise Exception('Softmax types need to be signed') + + else: + fp_bits = 18 + fp_integer = 8 + fp_signed = True + + sep = '' + N = ceil_log2(table_size) + for i in range(table_size): + f = FixedPointEmulator(fp_bits, fp_integer, signed=fp_signed) + b = uint_to_binary(i, N) + b.insert(0, 0) + f.set_msb_bits(b) + real_val = f.inv_float() + h_file.write(sep + str(real_val)) + sep = ', ' + + h_file.write('};\n\n') + h_file.write('#endif') - h_file.write('};\n') - h_file.close() def __write_exp_table_latency(self, model, path): table_name = 'exp_table_latency' @@ -994,3 +1033,6 @@ def write_hls(self, model): self.write_generated_code(model) self.write_yml(model) self.write_tar(model) + + +