logo
0
0
WeChat Login
docs: 精简并重构 README 文档

stream/inference

AI 推理集成子包 — 统一的推理框架抽象接口,支持适配任意推理后端(ONNX、TensorFlow、NCNN、TFLite 等)。

特性

  • 统一接口规范Backend 接口定义了与具体框架无关的推理能力
  • 多后端支持 — 通过工厂注册表动态扩展推理后端
  • 异步/流式推理 — 支持同步、异步、批量、流式多种推理模式
  • 多模型管理 — 单后端同时加载和管理多个模型
  • 事件驱动 — 与 stream.EventBus 深度集成,异步发布推理事件
  • 混合数据模式 — Tensor 支持安全拷贝和零拷贝两种模式
  • YAML 配置 — 支持 YAML 格式模型配置,优先级高于 JSON
  • 加密模型 — 支持 AES-GCM 加密模型,自动检测 .enc 后缀

架构

┌─────────────────────────────────────────────────────────────┐
│                        Engine(引擎)                        │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐      │
│  │ Task Queue   │→ │ Worker Pool  │→ │ EventBus      │      │
│  │(任务队列)  │  │(工作池)    │  │(事件总线)   │      │
│  └──────────────┘  └──────┬───────┘  └──────────────┘      │
│                            │                                 │
│                            ▼                                 │
│                    ┌───────────────┐                        │
│                    │   Backend     │                        │
│                    │  (接口)     │                        │
│                    └───────┬───────┘                        │
│           ┌────────────────┼────────────────┐               │
│           ▼                ▼                ▼               │
│    ┌──────────┐     ┌──────────┐     ┌──────────┐          │
│    │   ONNX   │     │    TF    │     │  NCNN    │  ...     │
│    └──────────┘     └──────────┘     └──────────┘          │
└─────────────────────────────────────────────────────────────┘

核心概念

Backend 接口

所有推理后端必须实现的核心接口:

type Backend interface {
    // ========== 生命周期 ==========
    Init(ctx context.Context, cfg BackendConfig) error
    Close() error

    // ========== 模型管理 ==========
    LoadModel(ctx context.Context, modelID string, model Model) error
    UnloadModel(ctx context.Context, modelID string) error
    GetModel(ctx context.Context, modelID string) (*Model, error)
    ListModels(ctx context.Context) ([]string, error)

    // ========== 推理接口 ==========
    Inference(ctx context.Context, modelID string, input *Tensor) (*Tensor, error)
    InferenceAsync(ctx context.Context, modelID string, input *Tensor) *Future
    InferenceBatch(ctx context.Context, modelID string, inputs []*Tensor) ([]*Tensor, error)

    // ========== 扩展功能 ==========
    Extensions() *Extensions
}

Tensor 数据结构

张量支持混合模式,兼顾安全性和性能:

type Tensor struct {
    Shape   Shape           // 形状 [N, H, W, C]
    Dtype   Dtype           // 数据类型
    Data    []byte          // 拷贝模式(安全)
    Ptr     unsafe.Pointer  // 零拷贝模式(高性能)
    Owner   interface{}     // 内存所有者
}

Device 设备类型

// CPU 设备
device := CPUDevice()

// GPU 设备(指定设备 ID)
device := GPUDevice(0)

快速开始

基本推理

package main

import (
    "context"
    "fmt"
    "log"

    "cnb.cool/svn/stream/inference"
)

func main() {
    // 1. 创建后端配置
    cfg := inference.BackendConfig{
        Type:    inference.BackendONNX,
        Device:  inference.CPUDevice(),
        Workers: 4,
    }

    // 2. 从注册表创建后端
    backend, err := inference.CreateBackend(inference.BackendONNX, cfg)
    if err != nil {
        log.Fatal(err)
    }
    defer backend.Close()

    // 3. 初始化后端
    ctx := context.Background()
    if err := backend.Init(ctx, cfg); err != nil {
        log.Fatal(err)
    }

    // 4. 加载模型
    model := inference.Model{
        ID:          "yolov8n",
        Format:      inference.ModelFormatONNX,
        Path:        "/models/yolov8n.onnx",
        InputShape:  inference.Shape{1, 640, 640, 3},
        OutputShape: inference.Shape{1, 8400, 85},
    }

    if err := backend.LoadModel(ctx, "yolov8n", model); err != nil {
        log.Fatal(err)
    }

    // 5. 准备输入张量
    input, err := inference.NewTensor(model.InputShape, inference.DtypeFloat32)
    if err != nil {
        log.Fatal(err)
    }

    // 6. 执行推理
    output, err := backend.Inference(ctx, "yolov8n", input)
    if err != nil {
        log.Fatal(err)
    }

    fmt.Printf("输出形状: %s\n", output.Shape)
}

使用 Engine(带事件)

import "cnb.cool/svn/stream"

// 创建事件总线
eventBus := stream.NewEventBus(stream.EventBusConfig{})
defer eventBus.Close()

// 订阅推理事件
eventBus.Subscribe("inference-handler", stream.EventHandlerFunc(func(e stream.Event) {
    switch e.Type {
    case stream.EventInferenceCompleted:
        fmt.Printf("推理完成: %v\n", e.Data)
    case stream.EventInferenceFailed:
        fmt.Printf("推理失败: %v\n", e.Error)
    }
}))

// 创建引擎
cfg := inference.BackendConfig{
    Type:     inference.BackendONNX,
    Device:   inference.CPUDevice(),
    EventBus: eventBus,
}

engine, err := inference.NewEngine(cfg, eventBus)
if err != nil {
    log.Fatal(err)
}
defer engine.Stop()

if err := engine.Start(); err != nil {
    log.Fatal(err)
}

// 加载模型
model := inference.Model{...}
if err := engine.LoadModel(ctx, model); err != nil {
    log.Fatal(err)
}

// 提交推理任务
task := &inference.Task{
    SessionID: "session-001",
    ModelID:   "yolov8n",
    Tensor:    inputTensor,
}

if err := engine.Submit(task); err != nil {
    log.Fatal(err)
}

异步推理

// 异步推理,立即返回 Future
future := backend.InferenceAsync(ctx, "yolov8n", input)

// 可以做其他事情...

// 获取结果
output, err := future.Get()
if err != nil {
    log.Fatal(err)
}

// 或带超时
output, err := future.GetTimeout(5 * time.Second)

// 或使用回调
future.Then(func(output *Tensor, err error) {
    if err != nil {
        log.Printf("推理失败: %v", err)
        return
    }
    fmt.Printf("输出: %v\n", output.Shape)
})

批量推理

// 准备批量输入
inputs := []*inference.Tensor{input1, input2, input3}

// 批量推理(通常比单独调用更高效)
outputs, err := backend.InferenceBatch(ctx, "yolov8n", inputs)
if err != nil {
    log.Fatal(err)
}

fmt.Printf("批量处理 %d 个输入\n", len(outputs))

Frame 与 Tensor 互转

import "cnb.cool/svn/stream"

// Frame 转 Tensor
frame := &stream.Frame{
    Width:   1920,
    Height:  1080,
    Payload: imageData,
}

tensor, err := inference.FrameToTensor(frame, inference.Shape{1, 1080, 1920, 3})
if err != nil {
    log.Fatal(err)
}

// Tensor 转 Frame
frame2, err := inference.TensorToFrame(tensor, "session-001")
if err != nil {
    log.Fatal(err)
}

后端注册

使用内置后端

// MockBackend 自动注册,可直接使用
backend, err := inference.CreateBackend(inference.BackendMock, cfg)

注册自定义后端

// 1. 实现 Backend 接口
type MyCustomBackend struct {
    // ...
}

func (b *MyCustomBackend) Init(ctx context.Context, cfg inference.BackendConfig) error {
    // 初始化逻辑
    return nil
}

// ... 实现其他方法

// 2. 注册到全局注册表
func init() {
    inference.RegisterBackend("my-backend", func(cfg inference.BackendConfig) (inference.Backend, error) {
        return &MyCustomBackend{}, nil
    })
}

// 3. 使用
backend, err := inference.CreateBackend("my-backend", cfg)

事件类型

推理模块通过 EventBus 发布以下事件:

事件类型说明
EventInferenceBackendReady后端初始化完成
EventInferenceModelLoaded模型加载成功
EventInferenceModelUnloaded模型已卸载
EventInferenceStarted推理任务开始
EventInferenceCompleted推理任务完成
EventInferenceFailed推理任务失败
EventInferenceQueueFull推理队列已满

扩展接口

扩展功能通过 Extensions() 方法访问,后端选择性实现:

ext := backend.Extensions()
if ext == nil {
    log.Println("该后端不支持任何扩展功能")
    return
}

元数据查询

if ext := backend.Extensions(); ext != nil && ext.Metadata != nil {
    metadata, err := ext.Metadata.Metadata(ctx)
    fmt.Printf("后端: %s, 版本: %s\n", metadata.Name, metadata.Version)
}

流式推理

if ext := backend.Extensions(); ext != nil && ext.Stream != nil {
    inputCh := make(chan inference.Tensor, 10)

    go func() {
        for _, frame := range frames {
            inputCh <- tensor
        }
        close(inputCh)
    }()

    outputCh := ext.Stream.InferenceStream(ctx, "model-id", inputCh)

    for result := range outputCh {
        if result.Error != nil {
            log.Printf("错误: %v", result.Error)
            continue
        }
        processResult(result.Output)
    }
}

模型配置

从 YAML 配置文件加载

配置文件命名为模型路径加 .yaml 扩展名,如 /models/yolov8n.onnx.yaml

优先级:.yaml > .yml > .json(向后兼容)。

id: yolov8n
format: onnx
input_shape: [1, 3, 640, 640]
output_shape: [1, 8400, 85]
labels:
  - person
  - car
  - dog
preprocess:
  normalize: true
  mean: [0.485, 0.456, 0.406]
  std: [0.229, 0.224, 0.225]
  resize_to: [640, 640]
  color_space: RGB
postprocess:
  threshold: 0.5
  nms_threshold: 0.45
  max_detections: 100
metadata:
  author: "your-team"
  version: "1.0.0"
loader := inference.NewModelLoader()
model, err := loader.LoadFromFile("/models/yolov8n.onnx")

加密模型

支持通过 .enc.encrypted 扩展名自动检测加密模型。

import "cnb.cool/svn/stream/inference"

// 创建 AES 解密器(256 位密钥)
key := []byte{ /* 32 字节密钥 */ }
decryptor := inference.NewAESDecryptor(key)

// 模型配置
model := inference.Model{
    ID:          "encrypted-model",
    Format:      inference.ModelFormatONNX,
    Path:        "/models/yolov8n.onnx.enc",  // .enc 后缀触发解密
    InputShape:  inference.Shape{1, 3, 640, 640},
    OutputShape: inference.Shape{1, 8400, 85},
    Decryptor:   decryptor,  // 提供解密器
}

// 加载时自动解密(200M 以下模型使用内存解密)
err := backend.LoadModel(ctx, "encrypted-model", model)

注意:加密接口仅供测试和开发使用。生产环境应使用专门的加密工具和密钥管理服务(如 HSM、KMS)。

Dtype 类型

类型字节数说明
DtypeFloat32432 位浮点数
DtypeFloat16216 位浮点数
DtypeInt818 位整数
DtypeUint818 位无符号整数
DtypeInt32432 位整数
DtypeInt64864 位整数
DtypeBool1布尔值

错误处理

import "cnb.cool/svn/stream/inference"

// 推理失败
_, err := backend.Inference(ctx, "model-id", input)
if err != nil {
    // 检查是否是后端错误
    if inference.IsBackendError(err) {
        code := inference.GetErrorCode(err)
        switch code {
        case inference.ErrModelNotFound:
            // 处理模型未找到
        case inference.ErrTimeout:
            // 处理超时
        case inference.ErrInferenceFailed:
            // 处理推理失败
        }
    }
}

测试

# 运行所有测试
cd /workspace
go test -v ./inference/...

# 运行短测试(跳过耗时测试)
go test -v -short ./inference/...

# 运行基准测试
go test -bench=. ./inference/...

目录结构

inference/
├── types.go              # 核心类型(Tensor、Dtype、Device、Shape)
├── errors.go             # 结构化错误类型
├── future.go             # Future/Promise 实现
├── backend.go            # Backend 核心接口
├── backend_ext.go        # 扩展接口(Metadata、Stream、Preprocessor)
├── config.go             # 配置结构
├── registry.go           # 后端工厂注册表
├── converter.go          # Frame/Tensor 互转
├── backend_stubs.go      # MockBackend 实现
├── model_loader.go       # 模型配置加载器
├── inference_engine.go   # 推理引擎(工作池 + 事件集成)
├── types_test.go         # 类型测试
├── backend_test.go       # 接口测试
└── model_loader_test.go  # 加载器测试

已知限制

  1. 当前实现MockBackend 为模拟实现,仅返回固定形状的空数据
  2. 生产部署:需要实现具体的后端(ONNX Runtime、TensorFlow、NCNN、TFLite)的 CGO 绑定
  3. 预处理:当前 converter.go 中的预处理为简化实现,生产环境需要完整的图像处理库
  4. 零拷贝:零拷贝模式需要后端实现支持,当前仅支持拷贝模式

后续适配

适配新的推理框架只需实现 Backend 接口:

  1. 创建后端文件(如 backend_onnx.go
  2. 实现 Backend 接口的所有方法
  3. 实现 BackendFactory 工厂函数
  4. 调用 RegisterBackend() 注册

示例:

package inference

type ONNXBackend struct {
    // ONNX Runtime C API 绑定
}

func (b *ONNXBackend) Init(ctx context.Context, cfg BackendConfig) error {
    // 初始化 ONNX Runtime
    return nil
}

// ... 实现其他方法

func NewONNXBackend(cfg BackendConfig) (Backend, error) {
    return &ONNXBackend{}, nil
}

func init() {
    RegisterBackend(BackendONNX, NewONNXBackend)
}

许可证

MIT License