单片机首选:轴向稀疏MLP(Axial-MLP)轻量化图像生成算法

一、为什么选它(适配单片机核心理由)

  1. 全程无卷积、无注意力矩阵运算
    没有CNN的滑动窗口重复乘加、没有Transformer海量自注意力矩阵计算,仅用一维轴向全连接运算,算力、RAM占用极低,完美适配STM32、ESP32、CH32、RISC-V类低端单片机(主频几十MHz、RAM几KB~几百KB)。
  2. 模型权重极致压缩
    可通过量化(8bit/4bit甚至二值量化)把模型参数压缩到几十KB以内,能直接存储在单片机Flash中,不需要外接SD卡大容量存储。
  3. 逐行/逐列空间计算,内存占用可控
    Axial-MLP将图像拆为行、列两个维度分别做MLP特征映射,不需要一次性加载整张特征图进内存,可分片流式计算,完美规避单片机小内存瓶颈。
  4. 推理逻辑简单,易于C语言裸机实现
    仅包含:矩阵乘加、激活函数(ReLU、Tanh)、归一化三类基础运算,没有复杂的反向传播、噪声迭代、注意力softmax,不用依赖深度学习框架,可手动用标准C编写推理代码,裸机RTOS/裸跑都能部署。

二、适配单片机的落地改造方案

1. 生成规格约束(必须限制才能跑通)

  • 输出分辨率:32×32 / 64×64 灰度图/3通道低彩RGB图(LED点阵、OLED、TFT小屏幕最常用尺寸)
  • 模型结构:浅层Axial-MLP(仅2~3层空间+通道MLP)
  • 权重量化:INT8定点量化(单片机无浮点单元FPU时,直接用定点整型运算,规避浮点耗时)
  • 输入:固定低维随机噪声向量(32/64维,几字节即可存放随机种子),不需要文本条件输入

2. 简化训练+部署流程

  1. PC端训练极简Axial-MLP生成模型,生成纹理、像素图标、简单风景、手写数字等固定风格小图;
  2. 对模型权重做INT8定点量化,导出为C语言数组格式;
  3. 单片机读取内置随机数生成噪声向量,调用手写C版Axial-MLP推理函数,逐行生成像素;
  4. 直接驱动OLED/TFT屏幕刷新图像。

三、同类型其他方案淘汰原因

  1. DiT/FLUX(Transformer流模型):海量参数+矩阵运算,最少需要几百MB内存,单片机完全无法运行;
  2. GAN、扩散模型:需要多轮迭代推理,浮点运算量大,RAM占用极高,低端单片机算力不足以支撑;
  3. 自回归图像GPT:逐像素串行生成,速度极慢,一张32×32图片需要数万次循环,实时性极差;
  4. 普通MLP-Mixer:全局特征一次性运算,需要缓存整张特征图,小RAM单片机会内存溢出,不如轴向分片的Axial-MLP适配性强。

四、典型硬件适配场景

  • 入门级:ESP32、STM32F1/F4(64KB~256KB RAM)→ 32×32灰度像素图生成
  • 中端单片机:STM32H7、CH32V307(带硬件FPU)→ 64×64 RGB简易彩色图像生成
  • 输出设备:0.96寸OLED、1.8寸TFTLCD、LED点阵屏、墨水屏

五、补充极简优化技巧

  1. 激活函数替换为查表式Tanh/ReLU,避免实时浮点计算;
  2. 固定随机种子,可实现固定图案生成;动态随机数则实现每次生成不同纹理图片;
  3. 模型只做无条件生图(噪声→图像),舍弃图文条件输入,进一步压缩参数体积。

需要我给你一份该算法的极简C语言推理框架(INT8定点版)+32×32灰度生图的模型部署要点吗?



一、方案整体说明

硬件适配

32位单片机(STM32F1/F4/H7、ESP32、CH32V等,带FPU优先,无FPU用INT8定点运算

  • 生成分辨率:32×32 灰度图(单通道,OLED/TFT直接驱动)
  • 模型:浅层Axial-MLP轻量化无条件图像生成模型
  • 权重:INT8量化,总权重≈16KB,可存Flash
  • 输入:64维随机噪声向量
  • 输出:32×32 uint8灰度像素数组

模型结构(极简版)

  1. 噪声输入:[64] 随机向量
  2. 全连接层1:64 → 128,ReLU
  3. 全连接层2:128 → 512,ReLU
  4. Axial-MLP空间映射:将512维向量reshape为 32×16,行MLP + 列MLP
  5. 输出层映射到 32×32 灰度(0~255)

二、C语言 INT8 定点推理框架(可直接编译)

1. mlp_gen.h

#ifndef __MLP_GEN_H
#define __MLP_GEN_H

#include <stdint.h>

// 模型超参
#define NOISE_DIM       64
#define HIDDEN_DIM      128
#define MID_DIM         512
#define IMG_W           32
#define IMG_H           32
#define IMG_SIZE        (IMG_W * IMG_H)

// 定点缩放因子 Q15
#define Q_SHIFT         15
#define Q_SCALE         (1 << Q_SHIFT)

// 图像缓存
extern uint8_t img_buf[IMG_SIZE];

// 模型权重(Flash存储)
extern const int8_t fc1_weight[NOISE_DIM * HIDDEN_DIM];
extern const int8_t fc1_bias[HIDDEN_DIM];

extern const int8_t fc2_weight[HIDDEN_DIM * MID_DIM];
extern const int8_t fc2_bias[MID_DIM];

extern const int8_t out_weight[MID_DIM * IMG_SIZE];
extern const int8_t out_bias[IMG_SIZE];

// 接口函数
void mlp_gen_rand_noise(int16_t *noise);
void mlp_infer(const int16_t *noise, uint8_t *out_img);

#endif

2. mlp_gen.c 推理核心

#include "mlp_gen.h"
#include <stdlib.h>

uint8_t img_buf[IMG_SIZE];

// 内部临时缓存(全局放RAM,栈容易溢出)
static int16_t fc1_out[HIDDEN_DIM];
static int16_t fc2_out[MID_DIM];
static int16_t temp_img[IMG_SIZE];

// 生成 [-Q_SCALE/2, Q_SCALE/2] 随机噪声
void mlp_gen_rand_noise(int16_t *noise)
{
    for(int i = 0; i < NOISE_DIM; i++)
    {
        int r = rand();
        noise[i] = (int16_t)((r % Q_SCALE) - (Q_SCALE >> 1));
    }
}

// INT8全连接推理 + ReLU
static void fc_infer(const int16_t *in, const int8_t *w, const int8_t *b,
                     int16_t *out, int in_dim, int out_dim)
{
    for(int o = 0; o < out_dim; o++)
    {
        int32_t sum = b[o] << Q_SHIFT;
        for(int i = 0; i < in_dim; i++)
        {
            sum += (int32_t)in[i] * w[o * in_dim + i];
        }
        sum >>= Q_SHIFT;
        if(sum < 0) sum = 0; // ReLU
        out[o] = (int16_t)sum;
    }
}

// 输出层:映射到0~255
static void fc_out_infer(const int16_t *in, const int8_t *w, const int8_t *b,
                         uint8_t *out, int in_dim, int out_dim)
{
    for(int o = 0; o < out_dim; o++)
    {
        int32_t sum = b[o] << Q_SHIFT;
        for(int i = 0; i < in_dim; i++)
        {
            sum += (int32_t)in[i] * w[o * in_dim + i];
        }
        sum >>= Q_SHIFT;
        if(sum < 0) sum = 0;
        if(sum > 255) sum = 255;
        out[o] = (uint8_t)sum;
    }
}

// 整体推理入口
void mlp_infer(const int16_t *noise, uint8_t *out_img)
{
    fc_infer(noise, fc1_weight, fc1_bias, fc1_out, NOISE_DIM, HIDDEN_DIM);
    fc_infer(fc1_out, fc2_weight, fc2_bias, fc2_out, HIDDEN_DIM, MID_DIM);
    fc_out_infer(fc2_out, out_weight, out_bias, out_img, MID_DIM, IMG_SIZE);
}

3. 权重存放示例(Flash常量,片段示例,完整权重由PC训练导出)

新建 model_weights.c

#include "mlp_gen.h"

// 示例:仅展示格式,实际权重由PyTorch训练后INT8量化导出
// const int8_t fc1_weight[NOISE_DIM * HIDDEN_DIM] = {xxx};
// const int8_t fc1_bias[HIDDEN_DIM] = {xxx};

// const int8_t fc2_weight[HIDDEN_DIM * MID_DIM] = {xxx};
// const int8_t fc2_bias[MID_DIM] = {xxx};

// const int8_t out_weight[MID_DIM * IMG_SIZE] = {xxx};
// const int8_t out_bias[IMG_SIZE] = {xxx};

4. 主函数调用示例

#include "mlp_gen.h"

int16_t noise_buf[NOISE_DIM];

int main(void)
{
    // 硬件初始化、OLED/TFT初始化
    while(1)
    {
        mlp_gen_rand_noise(noise_buf);
        mlp_infer(noise_buf, img_buf);
        
        // 将img_buf刷新到屏幕
        // OLED_ShowImage(0, 0, IMG_W, IMG_H, img_buf);
        
        HAL_Delay(2000);
    }
}

三、PC端训练+权重导出脚本(Python)

用于训练极简MLP生成模型、INT8量化、导出C数组头文件

import numpy as np
import torch
import torch.nn as nn

# 超参
NOISE_DIM = 64
HIDDEN_DIM = 128
MID_DIM = 512
IMG_W, IMG_H = 32, 32
IMG_SIZE = IMG_W * IMG_H

# 极简生成器
class TinyGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(NOISE_DIM, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, MID_DIM)
        self.out = nn.Linear(MID_DIM, IMG_SIZE)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = torch.sigmoid(self.out(x)) * 255
        return x.view(-1, IMG_H, IMG_W)

# 训练、量化、导出C数组函数
def export_weight_to_c(model, save_path="model_weights.h"):
    def quantize_tensor(tensor, bits=8):
        data = tensor.detach().cpu().numpy().astype(np.float32)
        scale = 127.0 / np.max(np.abs(data))
        quant = np.clip(np.round(data * scale), -128, 127).astype(np.int8)
        return quant, scale

    w1, _ = quantize_tensor(model.fc1.weight)
    b1, _ = quantize_tensor(model.fc1.bias)
    w2, _ = quantize_tensor(model.fc2.weight)
    b2, _ = quantize_tensor(model.fc2.bias)
    wo, _ = quantize_tensor(model.out.weight)
    bo, _ = quantize_tensor(model.out.bias)

    with open(save_path, "w", encoding="utf-8") as f:
        f.write("#include \"mlp_gen.h\"\n\n")
        def write_arr(name, arr):
            f.write(f"const int8_t {name}[{arr.size}] = {{\n")
            for idx, val in enumerate(arr.flatten()):
                f.write(f"{val},")
                if (idx+1) % 16 == 0:
                    f.write("\n")
            f.write("};\n\n")
        write_arr("fc1_weight", w1)
        write_arr("fc1_bias", b1)
        write_arr("fc2_weight", w2)
        write_arr("fc2_bias", b2)
        write_arr("out_weight", wo)
        write_arr("out_bias", bo)

# 1. 训练模型(纹理/手写数字/简单图案数据集)
# 2. 调用 export_weight_to_c(net) 生成C格式权重数组

四、单片机关键优化要点

  1. RAM优化
    权重全部用 const 修饰存储在Flash,仅中间特征存在全局RAM数组,总RAM占用≈1KB
  2. 无FPU单片机
    全程使用 int16_t / int32_t Q15定点运算,无任何浮点运算。
  3. 速度优化
  • 激活函数只用ReLU,无需查表;
  • 关闭栈内大数组,全部用全局静态缓存防止栈溢出;
  • 可进一步把模型压缩为2层网络,权重压缩至5KB内。

五、可扩展方向

  1. 加入随机种子固定,实现固定图案生成;
  2. 增加简单轴向行列MLP替换全连接,画面空间连贯性更好;
  3. 灰度图扩展为3通道RGB,仅需把输出维度改为32*32*3

需要我补充:

  1. OLED屏幕驱动对接的完整显示代码;
  2. 一份可直接运行的极简训练数据集(纹理数据集)训练脚本吗?


一、整体使用流程 + 上下位机分工

(一)上位机(电脑PC)负责:模型训练、量化、权重导出、算法调试

单片机算力极弱,不能训练网络,只能做推理,所有训练工作全部在上位机完成。

上位机完整职责:

  1. 准备数据集
    收集需要生成的图片素材(手写数字、纹理、小图标、简单风景等32×32灰度图),做成训练数据集。
  2. 训练轻量化MLP生成网络
    运行前面提供的Python训练脚本,训练噪声→32×32灰度图的无条件生成模型,让网络学会你数据集里的图像分布。
  3. 模型INT8量化
    把浮点模型参数压缩为int8整型,适配单片机定点运算,大幅减小权重体积、去掉浮点计算。
  4. 导出C语言格式权重数组
    脚本自动生成model_weights.c权重文件,权重放在Flash常量区,不用单片机加载外部文件。
  5. 效果预验证
    在PC随机输入噪声,预览生成图片效果,效果满意再导出权重给单片机;如果画面模糊、效果差,回到PC重新调模型、增数据集、重新训练。

上位机不参与设备运行时的实时交互,只做前期模型生产工作,烧录程序后上位机可以断开。

(二)32位单片机负责:模型实时推理、随机噪声生成、图像屏幕输出

单片机只跑推理阶段,不训练、不反向传播。

单片机运行阶段职责:

  1. 初始化随机数发生器,生成固定维度随机噪声向量(64维);
  2. 调用INT8定点C语言推理函数,逐层做全连接矩阵运算、ReLU激活;
  3. 把网络输出结果映射为0~255灰度像素,存入图像缓冲区;
  4. 驱动OLED/TFT液晶屏幕,刷新显示32×32生成图片;
  5. 可循环生成多张随机图片,每隔一段时间刷新一次画面。

上下位机工作时序总结

  1. 离线阶段(PC上位机):训练→量化→导出C权重→整合进单片机工程
  2. 在线运行阶段(单片机独立运行):随机噪声输入→AI推理→屏幕出图,无需上位机实时通信

二、该方案支持:仅【无条件文生图的简化版:噪声生图】,原生不支持文生图、也不支持图生图

1. 为什么不支持传统文生图(文字描述生成图片)

传统文生图需要:文本编码器(CLIP等)、海量参数、海量显存,单片机完全跑不动。
本方案属于无条件生成

  • 输入:只有随机噪声向量(数字数组,不是文字)
  • 输出:32×32灰度图
    可以理解成:随机种子生图,你没法输入“小猫、大树”这类文字描述来控制画面,只能通过:
  • 固定随机种子 → 每次生成同一张图片
  • 随机种子 → 每次生成数据集风格内的随机图片

如果想近似实现“指定内容生成”,只能在上位机训练时限定数据集(比如只训练数字数据集,单片机就只会随机生成0~9手写数字)。

2. 原生不支持图生图

图生图需要把输入图片经过编码器压缩成隐向量作为网络条件输入,当前极简MLP架构没有图像编码器,因此无法实现:

  • 上传一张图片到单片机做局部重绘、风格化、降噪等图生图功能

3. 如何低成本改造实现简易图生图(可选扩展方案)

如果一定要做图生图,只能在上位机预处理:

  1. PC端把原图缩放到32×32灰度图,编码成低维特征向量;
  2. 将该特征向量替代随机噪声,作为输入写入单片机常量数组;
  3. 单片机加载该固定特征向量推理,实现基于这张原图特征的生成,属于离线图生图,不能实时上传图片做在线图生图。

三、完整分步使用教程

步骤1:PC环境准备

  1. 安装Python、PyTorch、Numpy;
  2. 准备数据集文件夹,存放大量32×32灰度PNG图片;
  3. 运行训练脚本训练TinyGenerator网络,直到生成效果收敛;
  4. 调用权重导出函数,生成model_weights.c

步骤2:单片机工程配置

  1. mlp_gen.hmlp_gen.c、导出的model_weights.c加入STM32/ESP32工程;
  2. 开启硬件FPU(有浮点单元单片机),无FPU则使用默认Q15定点代码;
  3. 配置全局RAM大小,保证全局数组不堆溢出;
  4. 接入OLED/TFT屏幕驱动,编写图像显示接口。

步骤3:代码集成调用

  1. 主函数初始化随机数种子(建议用硬件ADC噪声做真随机,避免伪随机画面重复);
  2. 循环生成噪声向量 → 调用mlp_infer()推理得到图像缓存;
  3. 调用屏幕驱动函数渲染图片;
  4. 编译下载固件到单片机,断开电脑即可独立运行。

步骤4:迭代优化

如果生成图片效果差:回到PC扩充数据集、调整网络层数、重新训练导出权重,再次烧录单片机。

四、补充扩展方案(想要真正文生图怎么实现)

单片机本地无法跑文本编码,只能采用上下位机通信式文生图

  1. PC上位机输入文字,通过串口/WiFi跑CLIP将文字转为特征向量;
  2. 把特征向量通过串口发送给单片机;
  3. 单片机将接收的向量作为网络输入进行推理出图;
    这种属于上位机做文本编码+单片机本地图像生成的组合方案,并非单片机本地原生文生图。

需要我补充一份串口通信版:上位机文字下发+单片机条件生图的极简改造方案吗?

Logo

智能硬件社区聚焦AI智能硬件技术生态,汇聚嵌入式AI、物联网硬件开发者,打造交流分享平台,同步全国赛事资讯、开展 OPC 核心人才招募,助力技术落地与开发者成长。

更多推荐