我是用vscode+platformIO开发,用的platformIO里导入的库,算子里没有REDUCE_PROD,这个算子是在我把模型转换成tflite格式后产生的,没有办法直接在模型上解决,所以决定自己引入这个算子的操作

1. 算子定义和注册

tensorflow/lite/micro/micro_mutable_op_resolver.h(在libdeps目录下找,在platform.ini里导入,编译之后才会在.pio目录下生成这个目录),直接复制AddReduceMax()的粘贴在它后面修改成AddReduceProd()


  TfLiteStatus AddReduceProd() {
    return AddBuiltin(BuiltinOperator_REDUCE_PROD,
                      tflite::ops::micro::Register_REDUCE_PROD(), ParseReducer);
  }

2. REDUCE_PROD算子实现

实现的代码是AI写的我测了一下暂时没问题,建议自己对着TFlite的算子实现,我这里是为了临时解bug记的笔记并不是稳定版本
AddReduceMax()右键跳转到定义,在 namespace reduce 内,添加新的 EvalProd 实现

TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
  const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
  TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
  TfLiteReducerParams* params =
      static_cast<TfLiteReducerParams*>(node->builtin_data);
  OpData* op_data = static_cast<OpData*>(node->user_data);

  int num_axis = static_cast<int>(ElementCount(*axis->dims));
  int* temp_buffer = static_cast<int*>(
      context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  int* resolved_axis = static_cast<int*>(
      context->GetScratchBuffer(context, op_data->resolved_axis_idx));

  switch (input->type) {
    case kTfLiteFloat32:
      TF_LITE_ENSURE(
          context,
          reference_ops::ReduceGeneric<float>(
              tflite::micro::GetTensorData<float>(input), input->dims->data,
              input->dims->size, tflite::micro::GetTensorData<float>(output),
              output->dims->data, output->dims->size,
              tflite::micro::GetTensorData<int>(axis), num_axis,
              params->keep_dims, temp_buffer, resolved_axis,
              1.0f,  // identity for product
              [](const float current, const float in) -> float {
                return current * in;
              }));
      break;
    case kTfLiteInt8:
      // 简化量化版本:仅在 scale 和 zero point 相等的前提下做整数乘积(有溢出/范围问题)。
      TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
                        static_cast<double>(op_data->output_scale));
      TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
      TF_LITE_ENSURE(
          context,
          reference_ops::ReduceGeneric<int8_t>(
              tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
              input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
              output->dims->data, output->dims->size,
              tflite::micro::GetTensorData<int>(axis), num_axis,
              params->keep_dims, temp_buffer, resolved_axis,
              static_cast<int8_t>(1),  // identity
              [](const int8_t current, const int8_t in) -> int8_t {
                // 注意:直接乘可能 overflow,且量化语义不精确。
                int32_t prod = static_cast<int32_t>(current) *
                               static_cast<int32_t>(in);
                // 简单饱和到 int8_t
                if (prod > std::numeric_limits<int8_t>::max()) {
                  prod = std::numeric_limits<int8_t>::max();
                } else if (prod < std::numeric_limits<int8_t>::lowest()) {
                  prod = std::numeric_limits<int8_t>::lowest();
                }
                return static_cast<int8_t>(prod);
              }));
      break;
    default:
      TF_LITE_KERNEL_LOG(context,
                         "Only float32 and int8 types are supported for PROD.\n");
      return kTfLiteError;
  }
  return kTfLiteOk;
}

Register_REDUCE_MAX()后添加注册函数

TfLiteRegistration Register_REDUCE_PROD() {
  return tflite::micro::RegisterOp(reduce::InitReduce, reduce::PrepareMax,
                                   reduce::EvalProd);
}

micro_ops.hTfLiteRegistration Register_REDUCE_MAX();后加上

TfLiteRegistration Register_REDUCE_PROD();

最后在all_ops_resolver.cpp里加上,加在 AddReduceMax();后一行方便快速定位

  AddReduceProd();

3. 编译选项

最后我们把整个TensorFlowLite_ESP32包保存备用,后续可以直接在放在lib目录下不会被编译覆盖掉
然后修改platformIO.ini,把lib_deps引入的原先的包注释掉,build_flags里添加

build_flags = 
	-I lib/TensorFlowLite_ESP32
Logo

智能硬件社区聚焦AI智能硬件技术生态,汇聚嵌入式AI、物联网硬件开发者,打造交流分享平台,同步全国赛事资讯、开展 OPC 核心人才招募,助力技术落地与开发者成长。

更多推荐