Go猜想录
大道至简,悟者天成
context 包源码分析

包核心方法

context 包的核心 API 有四个:

  • context. WithValue:设置键值对,并且返回一个新的 context 实例
  • context. WithCancel
  • context. WithDeadline
  • context. WithTimeout:三者都返回一个可取消的 context 实例,和取消函数

WithTimeout

func TestContext(t *testing.T) {
	ctx := context.Background()
	timeoutCtx, cancel := context.WithTimeout(ctx, time.Second)
	defer cancel()

	go handle(timeoutCtx, 500*time.Millisecond)
	// go handle(timeoutCtx, 1500*time.Millisecond)
	select {
	case <-timeoutCtx.Done():
		fmt.Println("main", timeoutCtx.Err())
	}
}

func handle(ctx context.Context, duration time.Duration) {
	select {
	case <-ctx.Done():
		fmt.Println("handle", ctx.Err())
	case <-time.After(duration):
		fmt.Println("process request for", duration)
	}
}

Deadline

func TestContext(t *testing.T) {
	ctx := context.Background()
	timeoutCtx, cancel := context.WithTimeout(ctx, time.Second)
	defer cancel()

	dl, ok := timeoutCtx.Deadline()
	fmt.Println(dl, ok)
}

WithValue

func TestContext(t *testing.T) {
	ctx := context.Background()
	type key string
	var k key = "abc"
	valCtx := context.WithValue(ctx, k, 123)

	val := valCtx.Value(k)
	fmt.Println(val)
}

Context 接口

Context 接口核心 API 有四个:

  • Deadline:返回过期时间,如果 ok 为 false,说明没有设置过期时间。不常用
  • Done:返回一个 channel,一般用于监听 Context 实例的信号,比如说过期,或者正常关闭。常用
  • Err:返回一个错误用于表达 Context 发生了什么。比较常用
    • Canceled => 正常关闭
    • DeadlineExceeded => 过期超时
  • Value:取值。非常常用

安全传递数据

context 包我们就用来做两件事:

  • 安全传递数据
    • 是指在请求执行上下文中线程安全地传递数据,依赖于 WithValue 方法
    • 因为 Go 本身没有 thread-local 机制,所以大部分类似的功能都是借助于 context 来实现的
  • 控制链路

例子:

  • 链路追踪的 trace id
  • AB 测试的标记位
  • 压力测试标记位
  • 分库分表中间件中传递 Sharding hint
  • ORM 中间件传递 SQL hint
  • Web 框架传递上下文

context-1.png

进程内传递就是依赖于 context. Context 传递的,也就是意味着所有的方法都必须有 context. Context 参数。

父子关系

特点:context 的实例之间存在父子关系

  • 当父亲取消或者超时,所有派生的子 context 都被取消或者超时
  • 当找 key 的时候,子 context 先看自己有没有,没有则去祖先里面找

控制是从上至下的,查找是从下至上的。

context-2.png

例子:父 context 取消,子 context 也被取消

context-3.png

func TestParentContext(t *testing.T) {
	ctx := context.Background()
	dlCtx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Minute))
	type key string
	var k key = "abc"
	childCtx := context.WithValue(dlCtx, k, 123)
	cancel()
	err := childCtx.Err()
	fmt.Println(err)
}

父无法访问子内容

  • 父 context 无法拿到子 context 设置的值
  • 子 context 可以拿到父 context 设置的值

context-4.png

func TestContext(t *testing.T) {
	ctx := context.Background()

	parentKey := key("parent")
	subKey := key("sub")

	parent := context.WithValue(ctx, parentKey, "parent val")
	sub := context.WithValue(parent, subKey, "sub val")

	fmt.Println(parent.Value(parentKey))
	fmt.Println(parent.Value(subKey))
	fmt.Println(sub.Value(parentKey))
	fmt.Println(sub.Value(subKey))
}

因为父 context 始终无法拿到子 context 设置的值,所以在逼不得已的时候我们可以在父 context 里面放一个 map,后续都是修改这个 map。

func TestParentValueContext(t *testing.T) {
	ctx := context.Background()

	parentKey := key("parent")
	subKey := key("sub")

	parent := context.WithValue(ctx, parentKey, map[string]string{})
	sub := context.WithValue(parent, subKey, "sub val")

	m := sub.Value(parentKey).(map[string]string)
	m["sub"] = "sub val"

	v1 := parent.Value(parentKey)
	fmt.Println(v1)
	v2 := parent.Value(subKey)
	fmt.Println(v2)
}

// output:
// map[sub:sub val]
// <nil>

valueCtx 实现

valueCtx 用于存储 key-value 数据,特点:

  • 典型的装饰器模式:在已有 Context 的基础上附加一个存储 key-value 的功能
  • 只能存储一个 key, val:为什么不用 map?
    • context 包的设计理念就是将 Context 设计成不可变
// A valueCtx carries a key-value pair. It implements Value for that key and
// delegates all other calls to the embedded Context.
type valueCtx struct {
	Context
	key, val any
}

func (c *valueCtx) Value(key any) any {
	if c.key == key {
		return c.val
	}
	return value(c.Context, key)
}

控制

context 包提供了三个控制方法, WithCancel、WithDeadline 和 WithTimeout。

三者用法大同小异:

  • 没有过期时间,但是又需要在必要的时候取消,使用 WithCancel
  • 在固定时间点过期,使用 WithDeadline
  • 在一段时间后过期,使用 WithTimeout

而后便是监听 Done () 返回的 channel,不管是主动调用 cancel () 还是超时,都能从这个 channel 里面取出来数据。后面可以用 Err () 方法来判断究竟是哪种情况。

父亲可以控制儿子,但是儿子控制不了父亲

func TestParentValueContext(t *testing.T) {
	ctx := context.Background()
	parentCtx, cancel1 := context.WithTimeout(ctx, time.Second)
	subCtx, cancel2 := context.WithTimeout(parentCtx, 3*time.Second)
	go func() {
		<-subCtx.Done()
		fmt.Println("subCtx done:", subCtx.Err())
	}()

	time.Sleep(2 * time.Second)
	cancel2()
	cancel1()
}

// output:
// subCtx done: context deadline exceeded

子 context 试图重新设置超时时间,然而并没有成功,它依旧受到了父亲的控制

控制超时

WithTimeout

最经典用法是利用 context 来控制超时。

控制超时,相当于我们同时监听两个 channel,一个是正常业务结束的 channel,一个是 Done () 返回的。

func TestBusinessTimeout(t *testing.T) {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	done := make(chan struct{})
	go slowBusiness(done)
	select {
	case <-ctx.Done():
		fmt.Println(ctx.Err())
	case <-done:
		fmt.Println("business end")
	}
}

func slowBusiness(done chan<- struct{}) {
	time.Sleep(2 * time.Second)
	done <- struct{}{}
}

// output:
// context deadline exceeded

time. AfterFunc

另外一种超时控制是采用 time.AfterFunc: 一般这种用法我们会认为是定时任务,而不是超时控制。

这种超时控制有两个弊端:

  • 如果不主动取消,那么 AfterFunc 是必然会执行的
  • 如果主动取消,那么在业务正常结束到主动取消之间,有一个短时间的时间差
func TestTimeoutAfter(t *testing.T) {
	bsChan := make(chan struct{})
	go func() {
		slowBusiness(bsChan)
	}()

	timer := time.AfterFunc(time.Second, func() {
		fmt.Println("timer timeout")
	})
	<-bsChan
	fmt.Println("business end")
	timer.Stop()
}

func slowBusiness(done chan<- struct{}) {
	time.Sleep(2 * time.Second)
	done <- struct{}{}
}

// output:
// timer timeout
// business end

例子:DB.conn 控制超时

首先直接检查一次 context.Context 有没有超时。

这种提前检测一下的用法还是比较常见的。比如说 RPC 链路超时控制就可以先看看 context 有没有超时。

如果超时则可以不发送请求,直接返回超时响应。

// /usr/local/go/src/database/sql/sql.go:1294

// conn returns a newly-opened or cached *driverConn.
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
	db.mu.Lock()
	if db.closed {
		db.mu.Unlock()
		return nil, errDBClosed
	}
	// Check if the context is expired.
	select {
	default:
	case <-ctx.Done():
		db.mu.Unlock()
		return nil, ctx.Err()
	}
	lifetime := db.maxLifetime

	// Prefer a free connection, if possible.
	last := len(db.freeConn) - 1
	if strategy == cachedOrNewConn && last >= 0 {
...

超时控制至少两个分支:

  • 超时分支
  • 正常业务分支

所以普遍来说 context.Context 会和 select - case 一起使用。

// /usr/local/go/src/database/sql/sql.go:1340

// Timeout the connection request with the context.
		select {
		case <-ctx.Done():
			// Remove the connection request and ensure no value has been sent
			// on it after removing.
			db.mu.Lock()
			delete(db.connRequests, reqKey)
			db.mu.Unlock()

			atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))

			select {
			default:
			case ret, ok := <-req:
				if ok && ret.conn != nil {
					db.putConn(ret.conn, ret.err, false)
				}
			}
			return nil, ctx.Err()
		case ret, ok := <-req:

例子:http.Request 使用 context 作为字段类型

  • http.Request 本身就是 request-scope 的。
  • http.Request 里面的 ctx 依旧设计为不可变的,我们只能创建一个新的 http.Request。
Response *Response

	// ctx is either the client or server context. It should only
	// be modified via copying the whole Request using WithContext.
	// It is unexported to prevent people from using Context wrong
	// and mutating the contexts held by callers of the same request.
	ctx context.Context
}

所以实际上我们没有办法修改一个已有的 http.Request 里面的 ctx。

即便我们要把 context.Context 做成字段,也要遵循类似的用法。

// WithContext returns a shallow copy of r with its context changed
// to ctx. The provided ctx must be non-nil.
//
// For outgoing client request, the context controls the entire
// lifetime of a request and its response: obtaining a connection,
// sending the request, and reading the response headers and body.
//
// To create a new request with a context, use NewRequestWithContext.
// To change the context of a request, such as an incoming request you
// want to modify before sending back out, use Request.Clone. Between
// those two uses, it's rare to need WithContext.
func (r *Request) WithContext(ctx context.Context) *Request {
	if ctx == nil {
		panic("nil context")
	}
	r2 := new(Request)
	*r2 = *r
	r2.ctx = ctx
	return r2
}

例子:errgroup.WithContext 利用 context 来传递信号

  • errgroup.WithContext 会返回一个 context.Context 实例
  • 如果 errgroup.Group 的 Wait 返回,或者任何一个 Group 执行的函数返回 error,context.Context 实例都会被取消(一损俱损)
  • 所以用户可以通过监听 ctx.Done() 来判断 errgroup.Group 的执行情况

这是典型的将 context.Context 作为信号载体的用法,本质是依赖于 channel 的特性。

下边是 Kratos 利用这个特性来优雅启动服务实例,并且监听服务实例启动情况的代码片段。

  • 如果 30 行返回,说明是启动有问题,那么其它启动没有问题的 Server 也会退出,确保要么全部成功,要么全部失败
  • 如果 54 行返回,说明监听到了退出信号,比如说 ctrl+C,Server 都会退出。

注意:所有的 Server 调用 Stop 都是创建了一个新的 context,这是因为关闭的时候需要摆脱启动时候的 context 的控制。

// Run executes all OnStart hooks registered with the application's Lifecycle.
func (a *App) Run() error {
	instance, err := a.buildInstance()
	if err != nil {
		return err
	}
	a.mu.Lock()
	a.instance = instance
	a.mu.Unlock()
	sctx := NewContext(a.ctx, a)
	eg, ctx := errgroup.WithContext(sctx)
	wg := sync.WaitGroup{}

	for _, fn := range a.opts.beforeStart {
		if err = fn(sctx); err != nil {
			return err
		}
	}
	for _, srv := range a.opts.servers {
		srv := srv
		eg.Go(func() error {
			<-ctx.Done() // wait for stop signal
			stopCtx, cancel := context.WithTimeout(NewContext(a.opts.ctx, a), a.opts.stopTimeout)
			defer cancel()
			return srv.Stop(stopCtx)
		})
		wg.Add(1)
		eg.Go(func() error {
			wg.Done() // here is to ensure server start has begun running before register, so defer is not needed
			return srv.Start(sctx)
		})
	}
	wg.Wait()
	if a.opts.registrar != nil {
		rctx, rcancel := context.WithTimeout(ctx, a.opts.registrarTimeout)
		defer rcancel()
		if err = a.opts.registrar.Register(rctx, instance); err != nil {
			return err
		}
	}
	for _, fn := range a.opts.afterStart {
		if err = fn(sctx); err != nil {
			return err
		}
	}

	c := make(chan os.Signal, 1)
	signal.Notify(c, a.opts.sigs...)
	eg.Go(func() error {
		select {
		case <-ctx.Done():
			return nil
		case <-c:
			return a.Stop()
		}
	})
	if err = eg.Wait(); err != nil && !errors.Is(err, context.Canceled) {
		return err
	}
	for _, fn := range a.opts.afterStop {
		err = fn(sctx)
	}
	return err
}

开源实例 —— Kratos 的启动过程

这是一个综合的用例:

  1. errgroup + context.Context 协调 server 启动过程,以及关闭
  2. channel 监听系统信号
  3. WaitGroup 协调所有 server 启动
  4. Context 设置超时

作为启动过程要考虑:

  1. 监听系统关闭信号
  2. 监控 server 启动过程。如果有一个启动失败,那么应该全部直接失败,退出
  3. 监控 server 异常退出

cancelCtx 实现

cancelCtx 也是典型的装饰器模式:在已有 Context 的基础上,加上取消的功能。

核心实现:

  • Done 方法是通过类似于 double-check 的机制写的。这种原子操作和锁结合的用法比较罕见。(思考: 能不能换成读写锁?)
  • 利用 children 来维护了所有的衍生节点,难点就在于它是如何维护这个衍生节点。
// A cancelCtx can be canceled. When canceled, it also cancels any children
// that implement canceler.
type cancelCtx struct {
	Context

	mu       sync.Mutex            // protects following fields
	done     atomic.Value          // of chan struct{}, created lazily, closed by first cancel call
	children map[canceler]struct{} // set to nil by the first cancel call
	err      error                 // set to non-nil by the first cancel call
}
func (c *cancelCtx) Done() <-chan struct{} {
	d := c.done.Load()
	if d != nil {
		return d.(chan struct{})
	}
	c.mu.Lock()
	defer c.mu.Unlock()
	d = c.done.Load()
	if d == nil {
		d = make(chan struct{})
		c.done.Store(d)
	}
	return d.(chan struct{})
}

children:核心是儿子把自己加进去父亲的 children 字段里面。

但是因为 Context 里面存在非常多的层级,所以父亲不一定是 cancelCtx,因此本质上是找最近属于 cancelCtx 类型的祖先,然后儿子把自己加进去。

cancel 就是遍历 children,挨个调用 cancel。然后儿子调用孙子的 cancel,子子孙孙无穷匮也。

// propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent Context, child canceler) {
	done := parent.Done()
	if done == nil {
		return // parent is never canceled
	}

	select {
	case <-done:
		// parent is already canceled
		child.cancel(false, parent.Err())
		return
	default:
	}

// 找到最近的是 cancelCtx 类型的祖先,然后将 child 加进去祖先的 children 里面
	if p, ok := parentCancelCtx(parent); ok {
		p.mu.Lock()
		if p.err != nil {
			// parent has already been canceled
			child.cancel(false, p.err)
		} else {
			if p.children == nil {
				p.children = make(map[canceler]struct{})
			}
			p.children[child] = struct{}{}
		}
		p.mu.Unlock()
	} else {
// 找不到就只需要监听到 parent 的信号,或者自己的信号。这些信号源自 cancel 或者超时
		atomic.AddInt32(&goroutines, +1)
		go func() {
			select {
			case <-parent.Done():
				child.cancel(false, parent.Err())
			case <-child.Done():
			}
		}()
	}
}
// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
	if err == nil {
		panic("context: internal error: missing cancel error")
	}
	c.mu.Lock()
	if c.err != nil {
		c.mu.Unlock()
		return // already canceled
	}
	c.err = err
	d, _ := c.done.Load().(chan struct{})
	if d == nil {
		c.done.Store(closedchan)
	} else {
		close(d)
	}
	for child := range c.children {
		// NOTE: acquiring the child's lock while holding parent's lock.
		child.cancel(false, err)
	}
	c.children = nil
	c.mu.Unlock()

	if removeFromParent {
		removeChild(c.Context, c)
	}
}

timerCtx 实现

timerCtx 也是装饰器模式:在已有 cancelCtx 的基础上增加了超时的功能。

实现要点:

  • WithTimeout 和 WithDeadline 本质一样
  • WithDeadline 里面,在创建 timerCtx 的时候利用 time.AfterFunc 来实现超时
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
	cancelCtx
	timer *time.Timer // Under cancelCtx.mu.

	deadline time.Time
}
}
// cancel 关系依旧要建立起来
	propagateCancel(parent, c)
	dur := time.Until(d)
	if dur <= 0 {
		c.cancel(true, DeadlineExceeded) // deadline has already passed
		return c, func() { c.cancel(false, Canceled) }
	}
	c.mu.Lock()
	defer c.mu.Unlock()
	if c.err == nil {
		c.timer = time.AfterFunc(dur, func() {
// 超时就 cancel
			c.cancel(true, DeadlineExceeded)
		})
	}
	return c, func() { c.cancel(true, Canceled) }
}

使用注意事项

  • 一般只用做方法参数,而且是作为第一个参数
  • 所有公共方法,除非是 util,helper 之类的方法,否则都加上 context 参数
  • 不要用作结构体字段,除非你的结构体本身也是表达一个上下文的概念

总结

  • context.Context 使用场景:上下文传递和超时控制
  • context.Context 原理:
    • 父亲如何控制儿子:通过儿子主动加入到父亲的 children 里面,父亲只需要遍历就可以
    • valueCtx 和 timeCtx 的原理

知识共享许可协议

本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。