stream 是 Yamux 的扩展包,提供基于流首包内容的多路协议分发能力。它允许你在同一个 Yamux 会话上同时服务多种协议(如 HTTP/1、HTTP/2、TLS、自定义协议等),实现 cmux 风格的多路复用。
Match 和 MatchWithWriters 接口AnyOf、AllOf、BytePrefixMatcher 组合与复用 matcherServeContext 与 AcceptContextSetHooks 监控关键运行时事件┌─────────────────────────────────────────────┐ │ Single TCP Port │ │ │ │ │ ┌──────▼──────┐ │ │ │ Yamux │ │ │ │ Session │ │ │ └──────┬──────┘ │ │ │ │ │ ┌────────────┼────────────┐ │ │ ▼ ▼ ▼ │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │ │ HTTP/1 │ │ HTTP/2 │ │ gRPC │ │ │ │ Handler │ │ Handler │ │ Handler │ │ │ └─────────┘ └─────────┘ └─────────┘ │ └─────────────────────────────────────────────┘
Client Server │ │ │ ────────── New Yamux Stream ───────────▶ │ │ │ │ ◀────────── sniff first bytes ────────── │ │ │ │ ◀────────── route to handler ────────── │ │ │ │ ◀────────── protocol handshake ────────▶ │
package main
import (
"log"
"net"
"net/http"
"cnb.cool/zishuo/yamux"
"cnb.cool/zishuo/yamux/stream"
)
func main() {
// 监听 TCP 端口
ln, err := net.Listen("tcp", ":8080")
if err != nil {
log.Fatal(err)
}
defer ln.Close()
for {
conn, err := ln.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go handleConnection(conn)
}
}
func handleConnection(conn net.Conn) {
// 创建 Yamux 会话;成功后由 session 接管 conn 的生命周期
session, err := yamux.Server(conn, nil)
if err != nil {
log.Printf("Yamux server error: %v", err)
conn.Close()
return
}
defer session.Close()
// 创建多路分发器
mux := stream.New(session)
// 配置协议路由
http1Listener := mux.Match(stream.HTTP1Fast())
http2Listener := mux.Match(stream.HTTP2())
defaultListener := mux.Match(stream.Any())
// 启动各协议服务
go func() {
http.Serve(http1Listener, http1Handler())
}()
go func() {
http.Serve(http2Listener, http2Handler())
}()
go func() {
serveRaw(defaultListener)
}()
// 开始服务(阻塞)
if err := mux.Serve(); err != nil {
log.Printf("Mux serve error: %v", err)
}
}
func http1Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello from HTTP/1!"))
})
}
func http2Handler() http.Handler {
// HTTP/2 处理器
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello from HTTP/2!"))
})
}
func serveRaw(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
// 处理非 HTTP 流量
}(conn)
}
}
// 基于 Yamux 会话创建多路分发器
func New(root Accepter) Mux
// 使用显式配置创建多路分发器
func NewWithConfig(root Accepter, cfg Config) Mux
Accepter 是 stream 运行所需的最小接口,*yamux.Session 已满足该接口:
type Accepter interface {
Accept() (net.Conn, error)
Close() error
Addr() net.Addr
}
type Listener interface {
net.Listener
AcceptContext(context.Context) (net.Conn, error)
Stats() ListenerStats
}
type Mux interface {
// Match 使用 Reader 匹配器创建监听器
Match(...Matcher) Listener
// MatchWithWriters 使用支持写入的匹配器
MatchWithWriters(...MatchWriter) Listener
// Serve 开始接受和处理连接(阻塞)
Serve() error
// ServeContext 在 context 生命周期内运行分发循环
ServeContext(context.Context) error
// Close 关闭多路分发器
Close() error
// HandleError 设置错误处理器
HandleError(ErrorHandler)
// SetHooks 设置生命周期钩子
SetHooks(Hooks)
// SetReadTimeout 设置匹配阶段读取超时
SetReadTimeout(time.Duration)
// Stats 返回 mux 与所有子 listener 的统计快照
Stats() MuxStats
}
// Matcher 根据流首包内容判断是否命中
type Matcher func(io.Reader) bool
// MatchWriter 可以在匹配阶段向连接写入握手数据
type MatchWriter func(io.Writer, io.Reader) bool
// ErrorHandler 决定遇到错误后是否继续 Serve
type ErrorHandler func(error) bool
type MuxStats struct {
Accepted uint64
Unmatched uint64
Errors uint64
Closed bool
Listeners []ListenerStats
}
type ListenerStats struct {
ListenerIndex int
Backlog int
BacklogCapacity int
Enqueued uint64
Accepted uint64
Dropped uint64
Closed bool
Matchers []MatcherStats
}
type MatcherStats struct {
ListenerIndex int
MatcherIndex int
Attempts uint64
Matches uint64
Misses uint64
}
type Hooks struct {
// OnServeStart: Serve() 开始时被调用
OnServeStart func(net.Addr)
// OnServeStop: Serve() 退出时被调用
OnServeStop func(error)
// OnAccept: 每次 Accept 成功后被调用
OnAccept func(net.Conn)
// OnMatch: 流成功匹配后被调用
OnMatch func(MatchEvent)
// OnUnmatched: 流未命中任何匹配器时被调用
OnUnmatched func(net.Conn)
// OnError: 发生错误时被调用
OnError func(error)
}
type MatchEvent struct {
Conn net.Conn // 被匹配的连接
ListenerIndex int // 命中的监听器索引
MatcherIndex int // 命中的匹配器索引
}
// HTTP1Fast - 通过 HTTP 方法前缀快速识别 HTTP/1.x
// 性能最优,适合大多数场景
matcher := stream.HTTP1Fast()
// 支持自定义 HTTP 方法
matcher := stream.HTTP1Fast("CUSTOM", "API")
// HTTP1 - 完整解析 HTTP/1.x 请求首行
// 更准确但开销稍大
matcher := stream.HTTP1()
// HTTP1HeaderField - 匹配特定请求头字段
matcher := stream.HTTP1HeaderField("Host", "api.example.com")
// HTTP1HeaderFieldPrefix - 匹配请求头字段前缀
matcher := stream.HTTP1HeaderFieldPrefix("User-Agent", "MyApp/")
// HTTP2 - 匹配 HTTP/2 客户端 preface
matcher := stream.HTTP2()
// HTTP2MatchSendSettings - 匹配后发送 SETTINGS 帧
// 对部分需要先收到 server preface 的客户端更友好
matcher := stream.HTTP2MatchSendSettings()
// MCPStreamableHTTP - 匹配 MCP Streamable HTTP POST/GET/DELETE 请求
// 推荐传入明确 endpoint path,避免与普通 JSON API 混淆
matcher := stream.MCPStreamableHTTP("/mcp")
// 未传 path 时仅依赖 Mcp-Method、MCP-Protocol-Version、
// MCP-Session-Id 等 MCP 专属 header 识别
matcher := stream.MCPStreamableHTTP()
// TLS - 匹配 TLS ClientHello(支持所有版本)
matcher := stream.TLS()
// 指定特定 TLS 版本
matcher := stream.TLS(tls.VersionTLS12, tls.VersionTLS13)
// 可用版本常量
// tls.VersionSSL30
// tls.VersionTLS10
// tls.VersionTLS11
// tls.VersionTLS12
// tls.VersionTLS13
// TLSSNI - 按 SNI 路由,支持精确域名和单级通配符
matcher := stream.TLSSNI("api.example.com", "*.internal.example.com")
// TLSALPN - 按 ALPN 路由 HTTP/2、HTTP/1.1、自定义协议
matcher := stream.TLSALPN("h2", "http/1.1")
// TLSClientHello - 自定义 ClientHello 信息路由
matcher := stream.TLSClientHello(func(info stream.TLSClientHelloInfo) bool {
return info.ServerName == "api.example.com" && len(info.ALPNProtocols) > 0
})
// Any - 匹配任意流(通常作为默认处理器)
matcher := stream.Any()
// AnyOf - 任一 matcher 命中即视为命中
matcher := stream.AnyOf(stream.HTTP1Fast(), stream.HTTP2())
// AllOf - 所有 matcher 都命中时才视为命中
matcher := stream.AllOf(
stream.HTTP1Fast(),
stream.HTTP1HeaderField("Host", "api.example.com"),
)
// PrefixMatcher - 按前缀匹配
matcher := stream.PrefixMatcher("GET ", "POST ", "PUT ")
// BytePrefixMatcher - 按二进制前缀匹配
matcher := stream.BytePrefixMatcher([]byte{0x16, 0x03, 0x03})
// 自定义协议检测
func MyProtocolMatcher() stream.Matcher {
return func(r io.Reader) bool {
buf := make([]byte, 4)
n, err := r.Read(buf)
if err != nil || n < 4 {
return false
}
// 检测魔数
return bytes.Equal(buf, []byte{0x42, 0x4D, 0x59, 0x50})
}
}
// 使用
myProtoListener := mux.Match(MyProtocolMatcher())
// 需要向客户端发送响应的匹配器
func HandshakeMatcher() stream.MatchWriter {
return func(w io.Writer, r io.Reader) bool {
buf := make([]byte, 1)
if _, err := r.Read(buf); err != nil {
return false
}
// 协议版本协商
if buf[0] == 0x01 {
// 发送接受响应
w.Write([]byte{0x01})
return true
}
// 拒绝并发送版本错误
w.Write([]byte{0xFF})
return false
}
}
// 使用
listener := mux.MatchWithWriters(HandshakeMatcher())
┌──────────────────────────────────────────────────────────────┐ │ Yamux Session │ │ (底层 Accepter) │ └─────────────────────┬────────────────────────────────────────┘ │ Accept() ▼ ┌─────────────────────┐ │ muxServer │ │ (serveConn) │ └──────────┬──────────┘ │ ┌─────────────┼─────────────┐ ▼ ▼ ▼ ┌─────────┐ ┌─────────┐ ┌─────────┐ │Matcher 1│ │Matcher 2│ │Matcher 3│ │HTTP1Fast│ │ HTTP2 │ │ Any │ └────┬────┘ └────┬────┘ └────┬────┘ │ │ │ └────────────┴────────────┘ │ ┌────────────┼────────────┐ ▼ ▼ ▼ ┌─────────┐ ┌─────────┐ ┌─────────┐ │Listener1│ │Listener2│ │Listener3│ │ (chan) │ │ (chan) │ │ (chan) │ └────┬────┘ └────┬────┘ └────┬────┘ │ │ │ ▼ ▼ ▼ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ HTTP/1 │ │ HTTP/2 │ │ Default │ │ Handler │ │ Handler │ │ Handler │ └─────────┘ └─────────┘ └─────────┘
stream 使用 muxConn 包装原始连接,实现透明的首包探测:
┌─────────────────────────────────────────────┐ │ muxConn │ │ ┌───────────────────────────────────────┐ │ │ │ sniffing buffer │ │ │ │ ┌─────────────────────────────┐ │ │ │ │ │ already read for matching │ │ │ │ │ └─────────────────────────────┘ │ │ │ │ ▲ │ │ │ │ │ startSniffing() │ │ │ │ ┌────┴────┐ │ │ │ │ │ io.TeeReader │ │ │ │ └────┬────┘ │ │ │ │ │ │ │ │ │ ┌─────────┴──────────┐ │ │ │ │ │ underlying conn │ │ │ │ │ └────────────────────┘ │ │ │ └─────────────────────────────────────────┘ │ └─────────────────────────────────────────────┘
// ✅ 正确的顺序:从具体到通用
mux := stream.New(session)
// 1. 最具体的协议
grpcListener := mux.Match(GRPCMatcher())
// 2. 标准协议
http2Listener := mux.Match(stream.HTTP2())
http1Listener := mux.Match(stream.HTTP1Fast())
// 3. 最后的兜底
defaultListener := mux.Match(stream.Any())
mux := stream.New(session)
// 设置匹配阶段读取超时
// 防止恶意连接长时间不发送数据
mux.SetReadTimeout(5 * time.Second)
mux := stream.New(session)
// 自定义错误处理器
mux.HandleError(func(err error) bool {
// 返回 true 表示继续服务,false 表示停止
if errors.Is(err, stream.ErrServerClosed) {
return false
}
// 记录错误
log.Printf("stream error: %v", err)
// 对于临时错误继续服务
if netErr, ok := err.(net.Error); ok {
return netErr.Temporary()
}
return true
})
func serveWithGracefulShutdown(session *yamux.Session) {
mux := stream.New(session)
// 监听系统信号
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
log.Println("Shutting down...")
mux.Close()
}()
if err := mux.Serve(); err != nil && !errors.Is(err, stream.ErrServerClosed) {
log.Printf("Serve error: %v", err)
}
}
mux.SetHooks(stream.Hooks{
OnServeStart: func(addr net.Addr) {
metrics.Gauge("stream.server.running", 1,
metrics.Tag("addr", addr.String()))
},
OnServeStop: func(err error) {
metrics.Gauge("stream.server.running", 0)
if err != nil {
metrics.Counter("stream.server.errors", 1)
}
},
OnAccept: func(conn net.Conn) {
metrics.Counter("stream.connections.accepted", 1)
},
OnMatch: func(event stream.MatchEvent) {
metrics.Counter("stream.streams.matched", 1,
metrics.Tag("listener", strconv.Itoa(event.ListenerIndex)),
metrics.Tag("matcher", strconv.Itoa(event.MatcherIndex)))
},
OnUnmatched: func(conn net.Conn) {
metrics.Counter("stream.streams.unmatched", 1)
},
OnError: func(err error) {
metrics.Counter("stream.errors", 1,
metrics.Tag("type", reflect.TypeOf(err).String()))
},
})
package main
import (
"context"
"crypto/tls"
"errors"
"log"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"cnb.cool/zishuo/yamux"
"cnb.cool/zishuo/yamux/stream"
)
type MultiProtocolServer struct {
session *yamux.Session
mux stream.Mux
}
func NewMultiProtocolServer(conn net.Conn) (*MultiProtocolServer, error) {
session, err := yamux.Server(conn, nil)
if err != nil {
conn.Close()
return nil, err
}
mux := stream.New(session)
// 配置路由
server := &MultiProtocolServer{session: session, mux: mux}
server.setupRoutes()
server.setupHooks()
return server, nil
}
func (s *MultiProtocolServer) setupRoutes() {
// HTTP/2 优先检查(HTTP/2 preface 更独特)
http2Ln := s.mux.Match(stream.HTTP2MatchSendSettings())
go s.serveHTTP2(http2Ln)
// HTTP/1.x
http1Ln := s.mux.Match(stream.HTTP1Fast())
go s.serveHTTP1(http1Ln)
// TLS(可能是其他协议)
tlsLn := s.mux.Match(stream.TLS())
go s.serveTLS(tlsLn)
// 兜底:原始 TCP
rawLn := s.mux.Match(stream.Any())
go s.serveRaw(rawLn)
}
func (s *MultiProtocolServer) setupHooks() {
s.mux.SetHooks(stream.Hooks{
OnServeStart: func(addr net.Addr) {
log.Printf("[stream] Server started on %s", addr)
},
OnServeStop: func(err error) {
if err != nil && !errors.Is(err, stream.ErrServerClosed) {
log.Printf("[stream] Server stopped with error: %v", err)
} else {
log.Println("[stream] Server stopped gracefully")
}
},
OnMatch: func(event stream.MatchEvent) {
log.Printf("[stream] Stream matched: listener=%d, matcher=%d",
event.ListenerIndex, event.MatcherIndex)
},
OnUnmatched: func(conn net.Conn) {
log.Printf("[stream] Unmatched connection from %s", conn.RemoteAddr())
},
})
// 错误处理
s.mux.HandleError(func(err error) bool {
if errors.Is(err, stream.ErrServerClosed) {
return false
}
log.Printf("[stream] Error: %v", err)
return true
})
// 设置超时
s.mux.SetReadTimeout(10 * time.Second)
}
func (s *MultiProtocolServer) serveHTTP1(ln net.Listener) {
server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello from HTTP/1.1!"))
}),
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
if err := server.Serve(ln); err != nil && !errors.Is(err, stream.ErrListenerClosed) {
log.Printf("[HTTP/1] Server error: %v", err)
}
}
func (s *MultiProtocolServer) serveHTTP2(ln net.Listener) {
server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello from HTTP/2!"))
}),
}
if err := server.Serve(ln); err != nil && !errors.Is(err, stream.ErrListenerClosed) {
log.Printf("[HTTP/2] Server error: %v", err)
}
}
func (s *MultiProtocolServer) serveTLS(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
// 处理 TLS 握手或代理到 TLS 终端
log.Printf("[TLS] Connection from %s", c.RemoteAddr())
}(conn)
}
}
func (s *MultiProtocolServer) serveRaw(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
// 处理原始协议
buf := make([]byte, 1024)
n, _ := c.Read(buf)
c.Write([]byte("Echo: "))
c.Write(buf[:n])
}(conn)
}
}
func (s *MultiProtocolServer) Run(ctx context.Context) error {
errCh := make(chan error, 1)
go func() {
errCh <- s.mux.Serve()
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
s.mux.Close()
return ctx.Err()
}
}
func main() {
ln, err := net.Listen("tcp", ":8080")
if err != nil {
log.Fatal(err)
}
defer ln.Close()
log.Println("Multi-protocol server listening on :8080")
for {
conn, err := ln.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go func(c net.Conn) {
server, err := NewMultiProtocolServer(c)
if err != nil {
log.Printf("Server creation error: %v", err)
c.Close()
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 信号处理
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
if err := server.Run(ctx); err != nil {
log.Printf("Server error: %v", err)
}
}(conn)
}
}
| 特性 | cmux | stream |
|---|---|---|
| 底层连接 | 直接 TCP | Yamux Session |
| 多路复用层 | 单一层级 | 两层(Yamux + Protocol) |
| 适用场景 | 单一端口多协议 | 单个连接多流多协议 |
| 协议嗅探 | ✅ | ✅ |
| HTTP/2 支持 | ✅ | ✅ |
本项目采用 Mozilla Public License 2.0 许可证。
Copyright IBM Corp. 2014, 2025 SPDX-License-Identifier: MPL-2.0