# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Provides object representation for the model that is conducive to code generation using templates. """ from typing import Dict, List, Optional, Sequence import string import textwrap from tflite_micro.codegen.operators import factory from tflite_micro.codegen.operators import operator from tflite_micro.codegen import tensor from tflite_micro.codegen import utils from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb from tflite_micro.tensorflow.lite.tools import visualize class OpCode(object): def __init__(self, op_code: schema_fb.OperatorCodeT): self._op_code: schema_fb.OperatorCodeT = op_code @property def name(self) -> str: if self._op_code.customCode: return self._op_code.customCode return visualize.BuiltinCodeToName(self._op_code.builtinCode) @property def register_function(self) -> str: return "tflite::RegisterInference_{}".format(self.name) @property def enum_name(self) -> str: return "k{}".format(utils.to_pascal_case(self.name)) @property def full_enum_name(self) -> str: return "OpCode::" + self.enum_name class Subgraph(object): def __init__(self, model: schema_fb.ModelT, buffers: Sequence[tensor.Buffer], subgraph_idx: int, subgraph: schema_fb.SubGraphT): self._subgraph_idx: int = subgraph_idx self._subgraph: schema_fb.SubGraphT = subgraph self._op_codes: List[OpCode] = [ OpCode(op_code) for op_code in model.operatorCodes ] self._tensors: List[Tensor] = [] for t in subgraph.tensors: self._tensors.append(tensor.Tensor(buffers[t.buffer], t)) self._operators: List[operator.Operator] = [] for op in subgraph.operators: op_code = model.operatorCodes[op.opcodeIndex] self._operators.append(factory.create_operator(op_code, op)) @property def index(self) -> int: return self._subgraph_idx @property def inputs(self) -> Sequence[int]: return self._subgraph.inputs @property def outputs(self) -> Sequence[int]: return self._subgraph.outputs @property def operators(self) -> Sequence[operator.Operator]: return self._operators @property def tensors(self) -> Sequence[tensor.Tensor]: return self._tensors @property def needs_zero_length_int_array(self) -> bool: return any(t.needs_zero_length_int_array for t in self.tensors) @property def invoke_fn_name(self) -> str: return f"InvokeSubgraph{self.index}" @property def inputs_array_name(self) -> str: return f"kSubgraph{self.index}Inputs" @property def outputs_array_name(self) -> str: return f"kSubgraph{self.index}Outputs" @property def nodes_array(self) -> str: return f"subgraph{self.index}_nodes_" def nodes_element(self, operator_idx: int) -> str: return self.nodes_array + f"[{operator_idx}]" def node_data_type(self, operator_idx: int) -> str: return f"Node{self.index}_{operator_idx}" def node_data_name(self, operator_idx: int) -> str: return f"node_{self.index}_{operator_idx}" def generate_c_node_data(self, indent: str) -> str: node_data_strs: List[str] = [] for op_idx, op in enumerate(self.operators): type_name = self.node_data_type(op_idx) node_name = self.node_data_name(op_idx) node_data_strs.append(op.generate_c_node_data(type_name, node_name)) return textwrap.indent("\n\n".join(node_data_strs), indent) def generate_c_node_init(self, indent: str) -> str: node_init_strs: List[str] = [] for op_idx, op in enumerate(self.operators): tflite_node_name = self.nodes_element(op_idx) node_data_name = self.node_data_name(op_idx) node_init_strs.append( op.generate_c_node_init(tflite_node_name, node_data_name)) return textwrap.indent("\n".join(node_init_strs), indent) def generate_c_invoke(self, indent: str) -> str: function_template = string.Template( "TfLiteStatus ${function_name}(TfLiteContext* context,\n" " tflite::Span nodes) {\n" " TFLITE_DCHECK(nodes.size() == ${num_nodes});\n" "${body}\n" " return kTfLiteOk;\n" "}") body_template = string.Template( " TF_LITE_ENSURE_OK(\n" " context, op_table[${op_code}].invoke(context, &${node}));\n") invoke_strs: List[str] = [] for op_idx, op in enumerate(self.operators): invoke_strs.append( body_template.substitute( op_code=self._op_codes[op.op_code_index].full_enum_name, node=f"nodes[{op_idx}]")) invoke = function_template.substitute(function_name=self.invoke_fn_name, num_nodes=len(self.operators), body="".join(invoke_strs)) return textwrap.indent(invoke, indent) def generate_c_input_array(self, indent: str) -> str: return utils.generate_c_int_array(indent, "size_t", self.inputs_array_name, self.inputs) def generate_c_output_array(self, indent: str) -> str: return utils.generate_c_int_array(indent, "size_t", self.outputs_array_name, self.outputs) def generate_c_subgraph_init(self, indent: str) -> str: init_template = string.Template( "{.inputs = {&${input_array}[0], ${input_size}},\n" " .outputs = {&${output_array}[0], ${output_size}},\n" " .nodes = {&${node_array}[0], ${node_size}},\n" " .tensors = {&${tensor_array}[0], ${tensor_size}},\n" " .invoke = &${invoke}},") return textwrap.indent( init_template.substitute(input_array=self.inputs_array_name, input_size=len(self.inputs), output_array=self.outputs_array_name, output_size=len(self.outputs), node_array=self.nodes_array, node_size=len(self.operators), tensor_array=self.tensors_array, tensor_size=len(self.tensors), invoke=self.invoke_fn_name), indent) @property def tensors_array(self) -> str: return f"subgraph{self.index}_tensors_" def tensors_element(self, tensor_idx: int) -> str: return self.tensors_array + f"[{tensor_idx}]" def tensor_data_type(self, tensor_idx: int) -> str: return f"Tensor{self.index}_{tensor_idx}" def tensor_data_name(self, tensor_idx: int) -> str: return f"tensor{self.index}_{tensor_idx}" def generate_c_tensor_data(self, indent: str) -> str: tensor_dims_strs: List[str] = [] for tensor_idx, tensor in enumerate(self.tensors): type_name = self.tensor_data_type(tensor_idx) tensor_name = self.tensor_data_name(tensor_idx) tensor_dims_strs.append( tensor.generate_c_tensor_dims(type_name, tensor_name)) return textwrap.indent("\n\n".join(tensor_dims_strs), indent) def generate_c_tensor_init(self, indent: str) -> str: tensor_init_strs: List[str] = [] for tensor_idx, tensor in enumerate(self.tensors): tflite_tensor_name = self.tensors_element(tensor_idx) tensor_data_name = self.tensor_data_name(tensor_idx) tensor_init_strs.append( tensor.generate_c_tensor_init(tflite_tensor_name, tensor_data_name)) return textwrap.indent("\n".join(tensor_init_strs), indent) class Graph(object): def __init__(self, model: schema_fb.ModelT): buffers: List[tensor.Buffer] = [ tensor.Buffer("buffer_{}".format(idx), buffer) for idx, buffer in enumerate(model.buffers) ] self._subgraphs: List[SubGraph] = [ Subgraph(model, buffers, idx, subgraph) for idx, subgraph in enumerate(model.subgraphs) ] @property def subgraphs(self) -> Sequence[Subgraph]: return self._subgraphs @property def buffers(self) -> Sequence[tensor.Buffer]: buffers: List[tensor.Buffer] = [] for subgraph in self.subgraphs: for t in subgraph.tensors: buffers.append(t.buffer) return buffers @property def needs_zero_length_int_array(self) -> bool: return any(subgraph.needs_zero_length_int_array for subgraph in self.subgraphs) class OpCodeTable(object): def __init__(self, models: Sequence[schema_fb.ModelT]): op_codes = [] for model in models: for op_code in model.operatorCodes: op_codes.append(OpCode(op_code)) self._op_codes: List([OpCode]) = list(set(op_codes)) @property def op_codes(self) -> Sequence[OpCode]: return self._op_codes