logo
0
0
WeChat Login
feat(stream): add matcher combinators

stream

stream 是 Yamux 的扩展包,提供基于流首包内容的多路协议分发能力。它允许你在同一个 Yamux 会话上同时服务多种协议(如 HTTP/1、HTTP/2、TLS、自定义协议等),实现 cmux 风格的多路复用。

目录

核心特性

  • cmux 风格 API:熟悉的 MatchMatchWithWriters 接口
  • 零拷贝 sniffing:缓冲探测机制确保处理器仍能读取原始前缀字节
  • 丰富的内置匹配器:支持 HTTP/1.x、HTTP/2、MCP Streamable HTTP、TLS SNI/ALPN 等常见协议
  • 轻量组合件:支持 AnyOfAllOfBytePrefixMatcher 组合与复用 matcher
  • 可配置超时:匹配阶段可设置读取超时
  • context-aware 控制面:支持 ServeContextAcceptContext
  • 运行时统计:暴露 mux / listener / matcher 分层统计快照
  • 生命周期钩子:通过 SetHooks 监控关键运行时事件
  • 无协议侵入:不修改 Yamux 的 wire 格式

使用场景

多协议服务

┌─────────────────────────────────────────────┐ │ Single TCP Port │ │ │ │ │ ┌──────▼──────┐ │ │ │ Yamux │ │ │ │ Session │ │ │ └──────┬──────┘ │ │ │ │ │ ┌────────────┼────────────┐ │ │ ▼ ▼ ▼ │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │ │ HTTP/1 │ │ HTTP/2 │ │ gRPC │ │ │ │ Handler │ │ Handler │ │ Handler │ │ │ └─────────┘ └─────────┘ └─────────┘ │ └─────────────────────────────────────────────┘

协议版本协商

Client Server │ │ │ ────────── New Yamux Stream ───────────▶ │ │ │ │ ◀────────── sniff first bytes ────────── │ │ │ │ ◀────────── route to handler ────────── │ │ │ │ ◀────────── protocol handshake ────────▶ │

快速开始

package main import ( "log" "net" "net/http" "cnb.cool/zishuo/yamux" "cnb.cool/zishuo/yamux/stream" ) func main() { // 监听 TCP 端口 ln, err := net.Listen("tcp", ":8080") if err != nil { log.Fatal(err) } defer ln.Close() for { conn, err := ln.Accept() if err != nil { log.Printf("Accept error: %v", err) continue } go handleConnection(conn) } } func handleConnection(conn net.Conn) { // 创建 Yamux 会话;成功后由 session 接管 conn 的生命周期 session, err := yamux.Server(conn, nil) if err != nil { log.Printf("Yamux server error: %v", err) conn.Close() return } defer session.Close() // 创建多路分发器 mux := stream.New(session) // 配置协议路由 http1Listener := mux.Match(stream.HTTP1Fast()) http2Listener := mux.Match(stream.HTTP2()) defaultListener := mux.Match(stream.Any()) // 启动各协议服务 go func() { http.Serve(http1Listener, http1Handler()) }() go func() { http.Serve(http2Listener, http2Handler()) }() go func() { serveRaw(defaultListener) }() // 开始服务(阻塞) if err := mux.Serve(); err != nil { log.Printf("Mux serve error: %v", err) } } func http1Handler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello from HTTP/1!")) }) } func http2Handler() http.Handler { // HTTP/2 处理器 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello from HTTP/2!")) }) } func serveRaw(ln net.Listener) { for { conn, err := ln.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() // 处理非 HTTP 流量 }(conn) } }

API 文档

创建多路分发器

// 基于 Yamux 会话创建多路分发器 func New(root Accepter) Mux // 使用显式配置创建多路分发器 func NewWithConfig(root Accepter, cfg Config) Mux

Accepter 是 stream 运行所需的最小接口,*yamux.Session 已满足该接口:

type Accepter interface { Accept() (net.Conn, error) Close() error Addr() net.Addr }

协议匹配

type Listener interface { net.Listener AcceptContext(context.Context) (net.Conn, error) Stats() ListenerStats } type Mux interface { // Match 使用 Reader 匹配器创建监听器 Match(...Matcher) Listener // MatchWithWriters 使用支持写入的匹配器 MatchWithWriters(...MatchWriter) Listener // Serve 开始接受和处理连接(阻塞) Serve() error // ServeContext 在 context 生命周期内运行分发循环 ServeContext(context.Context) error // Close 关闭多路分发器 Close() error // HandleError 设置错误处理器 HandleError(ErrorHandler) // SetHooks 设置生命周期钩子 SetHooks(Hooks) // SetReadTimeout 设置匹配阶段读取超时 SetReadTimeout(time.Duration) // Stats 返回 mux 与所有子 listener 的统计快照 Stats() MuxStats }

匹配器类型

// Matcher 根据流首包内容判断是否命中 type Matcher func(io.Reader) bool // MatchWriter 可以在匹配阶段向连接写入握手数据 type MatchWriter func(io.Writer, io.Reader) bool // ErrorHandler 决定遇到错误后是否继续 Serve type ErrorHandler func(error) bool

统计快照

type MuxStats struct { Accepted uint64 Unmatched uint64 Errors uint64 Closed bool Listeners []ListenerStats } type ListenerStats struct { ListenerIndex int Backlog int BacklogCapacity int Enqueued uint64 Accepted uint64 Dropped uint64 Closed bool Matchers []MatcherStats } type MatcherStats struct { ListenerIndex int MatcherIndex int Attempts uint64 Matches uint64 Misses uint64 }

生命周期钩子

type Hooks struct { // OnServeStart: Serve() 开始时被调用 OnServeStart func(net.Addr) // OnServeStop: Serve() 退出时被调用 OnServeStop func(error) // OnAccept: 每次 Accept 成功后被调用 OnAccept func(net.Conn) // OnMatch: 流成功匹配后被调用 OnMatch func(MatchEvent) // OnUnmatched: 流未命中任何匹配器时被调用 OnUnmatched func(net.Conn) // OnError: 发生错误时被调用 OnError func(error) } type MatchEvent struct { Conn net.Conn // 被匹配的连接 ListenerIndex int // 命中的监听器索引 MatcherIndex int // 命中的匹配器索引 }

内置匹配器

HTTP 协议

// HTTP1Fast - 通过 HTTP 方法前缀快速识别 HTTP/1.x // 性能最优,适合大多数场景 matcher := stream.HTTP1Fast() // 支持自定义 HTTP 方法 matcher := stream.HTTP1Fast("CUSTOM", "API") // HTTP1 - 完整解析 HTTP/1.x 请求首行 // 更准确但开销稍大 matcher := stream.HTTP1() // HTTP1HeaderField - 匹配特定请求头字段 matcher := stream.HTTP1HeaderField("Host", "api.example.com") // HTTP1HeaderFieldPrefix - 匹配请求头字段前缀 matcher := stream.HTTP1HeaderFieldPrefix("User-Agent", "MyApp/") // HTTP2 - 匹配 HTTP/2 客户端 preface matcher := stream.HTTP2() // HTTP2MatchSendSettings - 匹配后发送 SETTINGS 帧 // 对部分需要先收到 server preface 的客户端更友好 matcher := stream.HTTP2MatchSendSettings()

MCP Streamable HTTP

// MCPStreamableHTTP - 匹配 MCP Streamable HTTP POST/GET/DELETE 请求 // 推荐传入明确 endpoint path,避免与普通 JSON API 混淆 matcher := stream.MCPStreamableHTTP("/mcp") // 未传 path 时仅依赖 Mcp-Method、MCP-Protocol-Version、 // MCP-Session-Id 等 MCP 专属 header 识别 matcher := stream.MCPStreamableHTTP()

TLS 协议

// TLS - 匹配 TLS ClientHello(支持所有版本) matcher := stream.TLS() // 指定特定 TLS 版本 matcher := stream.TLS(tls.VersionTLS12, tls.VersionTLS13) // 可用版本常量 // tls.VersionSSL30 // tls.VersionTLS10 // tls.VersionTLS11 // tls.VersionTLS12 // tls.VersionTLS13 // TLSSNI - 按 SNI 路由,支持精确域名和单级通配符 matcher := stream.TLSSNI("api.example.com", "*.internal.example.com") // TLSALPN - 按 ALPN 路由 HTTP/2、HTTP/1.1、自定义协议 matcher := stream.TLSALPN("h2", "http/1.1") // TLSClientHello - 自定义 ClientHello 信息路由 matcher := stream.TLSClientHello(func(info stream.TLSClientHelloInfo) bool { return info.ServerName == "api.example.com" && len(info.ALPNProtocols) > 0 })

通用匹配器

// Any - 匹配任意流(通常作为默认处理器) matcher := stream.Any() // AnyOf - 任一 matcher 命中即视为命中 matcher := stream.AnyOf(stream.HTTP1Fast(), stream.HTTP2()) // AllOf - 所有 matcher 都命中时才视为命中 matcher := stream.AllOf( stream.HTTP1Fast(), stream.HTTP1HeaderField("Host", "api.example.com"), ) // PrefixMatcher - 按前缀匹配 matcher := stream.PrefixMatcher("GET ", "POST ", "PUT ") // BytePrefixMatcher - 按二进制前缀匹配 matcher := stream.BytePrefixMatcher([]byte{0x16, 0x03, 0x03})

自定义匹配器

简单 Matcher

// 自定义协议检测 func MyProtocolMatcher() stream.Matcher { return func(r io.Reader) bool { buf := make([]byte, 4) n, err := r.Read(buf) if err != nil || n < 4 { return false } // 检测魔数 return bytes.Equal(buf, []byte{0x42, 0x4D, 0x59, 0x50}) } } // 使用 myProtoListener := mux.Match(MyProtocolMatcher())

MatchWriter(需要握手的情况)

// 需要向客户端发送响应的匹配器 func HandshakeMatcher() stream.MatchWriter { return func(w io.Writer, r io.Reader) bool { buf := make([]byte, 1) if _, err := r.Read(buf); err != nil { return false } // 协议版本协商 if buf[0] == 0x01 { // 发送接受响应 w.Write([]byte{0x01}) return true } // 拒绝并发送版本错误 w.Write([]byte{0xFF}) return false } } // 使用 listener := mux.MatchWithWriters(HandshakeMatcher())

架构设计

数据流

┌──────────────────────────────────────────────────────────────┐ │ Yamux Session │ │ (底层 Accepter) │ └─────────────────────┬────────────────────────────────────────┘ │ Accept() ▼ ┌─────────────────────┐ │ muxServer │ │ (serveConn) │ └──────────┬──────────┘ │ ┌─────────────┼─────────────┐ ▼ ▼ ▼ ┌─────────┐ ┌─────────┐ ┌─────────┐ │Matcher 1│ │Matcher 2│ │Matcher 3│ │HTTP1Fast│ │ HTTP2 │ │ Any │ └────┬────┘ └────┬────┘ └────┬────┘ │ │ │ └────────────┴────────────┘ │ ┌────────────┼────────────┐ ▼ ▼ ▼ ┌─────────┐ ┌─────────┐ ┌─────────┐ │Listener1│ │Listener2│ │Listener3│ │ (chan) │ │ (chan) │ │ (chan) │ └────┬────┘ └────┬────┘ └────┬────┘ │ │ │ ▼ ▼ ▼ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ HTTP/1 │ │ HTTP/2 │ │ Default │ │ Handler │ │ Handler │ │ Handler │ └─────────┘ └─────────┘ └─────────┘

缓冲 Sniffing

stream 使用 muxConn 包装原始连接,实现透明的首包探测:

┌─────────────────────────────────────────────┐ │ muxConn │ │ ┌───────────────────────────────────────┐ │ │ │ sniffing buffer │ │ │ │ ┌─────────────────────────────┐ │ │ │ │ │ already read for matching │ │ │ │ │ └─────────────────────────────┘ │ │ │ │ ▲ │ │ │ │ │ startSniffing() │ │ │ │ ┌────┴────┐ │ │ │ │ │ io.TeeReader │ │ │ │ └────┬────┘ │ │ │ │ │ │ │ │ │ ┌─────────┴──────────┐ │ │ │ │ │ underlying conn │ │ │ │ │ └────────────────────┘ │ │ │ └─────────────────────────────────────────┘ │ └─────────────────────────────────────────────┘

最佳实践

1. 匹配器顺序

// ✅ 正确的顺序:从具体到通用 mux := stream.New(session) // 1. 最具体的协议 grpcListener := mux.Match(GRPCMatcher()) // 2. 标准协议 http2Listener := mux.Match(stream.HTTP2()) http1Listener := mux.Match(stream.HTTP1Fast()) // 3. 最后的兜底 defaultListener := mux.Match(stream.Any())

2. 超时设置

mux := stream.New(session) // 设置匹配阶段读取超时 // 防止恶意连接长时间不发送数据 mux.SetReadTimeout(5 * time.Second)

3. 错误处理

mux := stream.New(session) // 自定义错误处理器 mux.HandleError(func(err error) bool { // 返回 true 表示继续服务,false 表示停止 if errors.Is(err, stream.ErrServerClosed) { return false } // 记录错误 log.Printf("stream error: %v", err) // 对于临时错误继续服务 if netErr, ok := err.(net.Error); ok { return netErr.Temporary() } return true })

4. 优雅关闭

func serveWithGracefulShutdown(session *yamux.Session) { mux := stream.New(session) // 监听系统信号 sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigCh log.Println("Shutting down...") mux.Close() }() if err := mux.Serve(); err != nil && !errors.Is(err, stream.ErrServerClosed) { log.Printf("Serve error: %v", err) } }

5. 指标收集

mux.SetHooks(stream.Hooks{ OnServeStart: func(addr net.Addr) { metrics.Gauge("stream.server.running", 1, metrics.Tag("addr", addr.String())) }, OnServeStop: func(err error) { metrics.Gauge("stream.server.running", 0) if err != nil { metrics.Counter("stream.server.errors", 1) } }, OnAccept: func(conn net.Conn) { metrics.Counter("stream.connections.accepted", 1) }, OnMatch: func(event stream.MatchEvent) { metrics.Counter("stream.streams.matched", 1, metrics.Tag("listener", strconv.Itoa(event.ListenerIndex)), metrics.Tag("matcher", strconv.Itoa(event.MatcherIndex))) }, OnUnmatched: func(conn net.Conn) { metrics.Counter("stream.streams.unmatched", 1) }, OnError: func(err error) { metrics.Counter("stream.errors", 1, metrics.Tag("type", reflect.TypeOf(err).String())) }, })

完整示例

多协议网关

package main import ( "context" "crypto/tls" "errors" "log" "net" "net/http" "os" "os/signal" "syscall" "time" "cnb.cool/zishuo/yamux" "cnb.cool/zishuo/yamux/stream" ) type MultiProtocolServer struct { session *yamux.Session mux stream.Mux } func NewMultiProtocolServer(conn net.Conn) (*MultiProtocolServer, error) { session, err := yamux.Server(conn, nil) if err != nil { conn.Close() return nil, err } mux := stream.New(session) // 配置路由 server := &MultiProtocolServer{session: session, mux: mux} server.setupRoutes() server.setupHooks() return server, nil } func (s *MultiProtocolServer) setupRoutes() { // HTTP/2 优先检查(HTTP/2 preface 更独特) http2Ln := s.mux.Match(stream.HTTP2MatchSendSettings()) go s.serveHTTP2(http2Ln) // HTTP/1.x http1Ln := s.mux.Match(stream.HTTP1Fast()) go s.serveHTTP1(http1Ln) // TLS(可能是其他协议) tlsLn := s.mux.Match(stream.TLS()) go s.serveTLS(tlsLn) // 兜底:原始 TCP rawLn := s.mux.Match(stream.Any()) go s.serveRaw(rawLn) } func (s *MultiProtocolServer) setupHooks() { s.mux.SetHooks(stream.Hooks{ OnServeStart: func(addr net.Addr) { log.Printf("[stream] Server started on %s", addr) }, OnServeStop: func(err error) { if err != nil && !errors.Is(err, stream.ErrServerClosed) { log.Printf("[stream] Server stopped with error: %v", err) } else { log.Println("[stream] Server stopped gracefully") } }, OnMatch: func(event stream.MatchEvent) { log.Printf("[stream] Stream matched: listener=%d, matcher=%d", event.ListenerIndex, event.MatcherIndex) }, OnUnmatched: func(conn net.Conn) { log.Printf("[stream] Unmatched connection from %s", conn.RemoteAddr()) }, }) // 错误处理 s.mux.HandleError(func(err error) bool { if errors.Is(err, stream.ErrServerClosed) { return false } log.Printf("[stream] Error: %v", err) return true }) // 设置超时 s.mux.SetReadTimeout(10 * time.Second) } func (s *MultiProtocolServer) serveHTTP1(ln net.Listener) { server := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello from HTTP/1.1!")) }), ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, } if err := server.Serve(ln); err != nil && !errors.Is(err, stream.ErrListenerClosed) { log.Printf("[HTTP/1] Server error: %v", err) } } func (s *MultiProtocolServer) serveHTTP2(ln net.Listener) { server := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello from HTTP/2!")) }), } if err := server.Serve(ln); err != nil && !errors.Is(err, stream.ErrListenerClosed) { log.Printf("[HTTP/2] Server error: %v", err) } } func (s *MultiProtocolServer) serveTLS(ln net.Listener) { for { conn, err := ln.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() // 处理 TLS 握手或代理到 TLS 终端 log.Printf("[TLS] Connection from %s", c.RemoteAddr()) }(conn) } } func (s *MultiProtocolServer) serveRaw(ln net.Listener) { for { conn, err := ln.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() // 处理原始协议 buf := make([]byte, 1024) n, _ := c.Read(buf) c.Write([]byte("Echo: ")) c.Write(buf[:n]) }(conn) } } func (s *MultiProtocolServer) Run(ctx context.Context) error { errCh := make(chan error, 1) go func() { errCh <- s.mux.Serve() }() select { case err := <-errCh: return err case <-ctx.Done(): s.mux.Close() return ctx.Err() } } func main() { ln, err := net.Listen("tcp", ":8080") if err != nil { log.Fatal(err) } defer ln.Close() log.Println("Multi-protocol server listening on :8080") for { conn, err := ln.Accept() if err != nil { log.Printf("Accept error: %v", err) continue } go func(c net.Conn) { server, err := NewMultiProtocolServer(c) if err != nil { log.Printf("Server creation error: %v", err) c.Close() return } ctx, cancel := context.WithCancel(context.Background()) defer cancel() // 信号处理 sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigCh cancel() }() if err := server.Run(ctx); err != nil { log.Printf("Server error: %v", err) } }(conn) } }

与 cmux 的区别

特性cmuxstream
底层连接直接 TCPYamux Session
多路复用层单一层级两层(Yamux + Protocol)
适用场景单一端口多协议单个连接多流多协议
协议嗅探
HTTP/2 支持

许可证

本项目采用 Mozilla Public License 2.0 许可证。

Copyright IBM Corp. 2014, 2025 SPDX-License-Identifier: MPL-2.0