Go猜想录
大道至简,悟者天成
手搓 RPC 框架 -- #7 代码篇

客户端代码

package rpc

import (
	"context"
	"errors"
	"fmt"
	"net"
	"reflect"
	"strconv"
	"time"

	"github.com/luxpo/luxpo/go2nd/micro/rpc2/message"
	"github.com/luxpo/luxpo/go2nd/micro/rpc2/serialize"
	"github.com/luxpo/luxpo/go2nd/micro/rpc2/serialize/json"
	"github.com/silenceper/pool"
)

var (
	ErrServiceNil       = errors.New("rpc: nil service is not supported")
	ErrServiceWrongType = errors.New("rpc: only first-level pointer to struct is supported")
)

func (c *Client) InitService(service Service) error {
	return setFuncField(service, c, c.serializer)
}

func setFuncField(service Service, proxy Proxy, s serialize.Serializer) error {
	if service == nil {
		return ErrServiceNil
	}

	val := reflect.ValueOf(service)
	typ := val.Type()
	// 只支持指向结构体的一级指针
	if typ.Kind() != reflect.Pointer || typ.Elem().Kind() != reflect.Struct {
		return ErrServiceWrongType
	}

	val = val.Elem()
	typ = typ.Elem()

	numField := typ.NumField()
	for i := 0; i < numField; i++ {
		fieldTyp := typ.Field(i)
		fieldVal := val.Field(i)
		if !fieldVal.CanSet() || fieldVal.Kind() != reflect.Func {
			continue
		}
		fn := func(args []reflect.Value) (results []reflect.Value) {
			ctx := args[0].Interface().(context.Context)
			retVal := reflect.New(fieldTyp.Type.Out(0).Elem())

			reqArg, err := s.Encode(args[1].Interface())
			if err != nil {
				return []reflect.Value{
					retVal,
					reflect.ValueOf(err),
				}
			}
			meta := make(map[string]string, 2)
			// 设置了超时
			if deadline, ok := ctx.Deadline(); ok {
				meta["deadline"] = strconv.FormatInt(deadline.UnixMilli(), 10)
			}
			if isOneway(ctx) {
				meta = map[string]string{
					"one-way": "true",
				}
			}
			req := &message.Request{
				ServiceName: service.Name(),
				MethodName:  fieldTyp.Name,
				Data:        reqArg,
				Serializer:  s.Code(),
				Meta:        meta,
			}
			fmt.Println(req)

			req.CalculateHeaderLength()
			req.CalculateBodyLength()

			resp, err := proxy.Invoke(ctx, req)
			if err != nil {
				return []reflect.Value{
					retVal,
					reflect.ValueOf(err),
				}
			}
			fmt.Println(string(resp.Data))

			var retErr error
			if len(resp.Error) > 0 {
				// 服务端传来的 error
				retErr = errors.New(string(resp.Error))
			}

			if len(resp.Data) > 0 {
				err = s.Decode(resp.Data, retVal.Interface())
				if err != nil {
					// 反序列化的 error
					return []reflect.Value{
						retVal,
						reflect.ValueOf(err),
					}
				}
			}

			var retErrVal reflect.Value
			if retErr != nil {
				retErrVal = reflect.ValueOf(retErr)
			} else {
				retErrVal = reflect.Zero(reflect.TypeOf(new(error)).Elem())
			}

			return []reflect.Value{
				retVal,
				retErrVal,
			}
		}
		fnVal := reflect.MakeFunc(fieldTyp.Type, fn)
		fieldVal.Set(fnVal)
	}

	return nil
}

type Client struct {
	pool       pool.Pool
	serializer serialize.Serializer
}

type ClientOption func(client *Client)

func ClientWithSerializer(sl serialize.Serializer) ClientOption {
	return func(client *Client) {
		client.serializer = sl
	}
}

func NewClient(network, addr string, opts ...ClientOption) (*Client, error) {
	p, err := pool.NewChannelPool(
		&pool.Config{
			InitialCap:  1,
			MaxCap:      30,
			MaxIdle:     10,
			IdleTimeout: time.Minute,
			Factory: func() (interface{}, error) {
				return net.DialTimeout(network, addr, time.Second*3)
			},
			Close: func(i interface{}) error {
				return i.(net.Conn).Close()
			},
		},
	)
	if err != nil {
		return nil, err
	}

	c := &Client{
		pool:       p,
		serializer: &json.Serializer{},
	}

	for _, opt := range opts {
		opt(c)
	}

	return c, nil
}

func (c *Client) Invoke(ctx context.Context, req *message.Request) (*message.Response, error) {
	if ctx.Err() != nil {
		return nil, ctx.Err()
	}

	ch := make(chan struct{})
	defer func() {
		close(ch)
	}()
	var (
		resp *message.Response
		err  error
	)
	go func() {
		resp, err = c.doInvoke(ctx, req)
		ch <- struct{}{}
	}()

	select {
	case <-ctx.Done():
		return nil, ctx.Err()
	case <-ch:
		return resp, err
	}
}

func (c *Client) doInvoke(ctx context.Context, req *message.Request) (*message.Response, error) {
	data := message.EncodeReq(req)
	resp, err := c.Send(ctx, data)
	if err != nil {
		return nil, err
	}
	return message.DecodeResp(resp), nil
}

func (c *Client) Send(ctx context.Context, data []byte) ([]byte, error) {
	val, err := c.pool.Get()
	if err != nil {
		return nil, err
	}
	conn := val.(net.Conn)
	defer func() {
		c.pool.Put(val)
	}()
	_, err = conn.Write(data)
	if err != nil {
		return nil, err
	}
	if isOneway(ctx) {
		// oneway 调用直接返回
		return nil, errors.New("micro: oneway")
	}
	return ReadMsg(conn)
}

服务端代码

package rpc

import (
	"context"
	"errors"
	"log"
	"net"
	"reflect"
	"strconv"
	"time"

	"github.com/luxpo/luxpo/go2nd/micro/rpc2/message"
	"github.com/luxpo/luxpo/go2nd/micro/rpc2/serialize"
	"github.com/luxpo/luxpo/go2nd/micro/rpc2/serialize/json"
)

// 长度字段使用的字节数量
const numOfLengthBytes = 8

type Server struct {
	stubs       map[string]reflectionStub
	serializers map[uint8]serialize.Serializer
}

func NewServer() *Server {
	s := &Server{
		stubs:       make(map[string]reflectionStub, 16),
		serializers: make(map[uint8]serialize.Serializer, 4),
	}
	s.RegisterSerializer(&json.Serializer{})
	return s
}

func (s *Server) RegisterSerializer(sl serialize.Serializer) {
	s.serializers[sl.Code()] = sl
}

func (s *Server) RegisterService(service Service) {
	s.stubs[service.Name()] = reflectionStub{
		service:     service,
		value:       reflect.ValueOf(service),
		serializers: s.serializers,
	}
}

func (s *Server) Start(network, address string) error {
	listener, err := net.Listen(network, address)
	if err != nil {
		return err
	}

	for {
		conn, err := listener.Accept()
		if err != nil {
			return err
		}
		go func() {
			if herr := s.handleConn(conn); herr != nil {
				_ = conn.Close()
			}
		}()
	}
}

func (s *Server) handleConn(conn net.Conn) error {
	for {
		reqBs, err := ReadMsg(conn)
		if err != nil {
			return err
		}

		// 还原调用信息
		req := message.DecodeReq(reqBs)

		ctx := context.Background()
		cancel := func() {}
		if deadlineStr, ok := req.Meta["deadline"]; ok {
			log.Println(deadlineStr)
			if deadline, er := strconv.ParseInt(deadlineStr, 10, 64); er == nil {
				log.Println(deadline)
				ctx, cancel = context.WithDeadline(ctx, time.UnixMilli(deadline))
			}
		}
		oneway, ok := req.Meta["one-way"]
		if ok && oneway == "true" {
			ctx = CtxWithOneway(ctx)
		}
		resp, err := s.Invoke(ctx, req)
		cancel()
		if err != nil {
			// 处理业务 error
			resp.Error = []byte(err.Error())
		}

		resp.CalculateHeaderLength()
		resp.CalculateBodyLength()

		_, err = conn.Write(message.EncodeResp(resp))
		if err != nil {
			return err
		}
	}
}

func (s *Server) Invoke(ctx context.Context, req *message.Request) (*message.Response, error) {
	// 发起业务调用
	stub, ok := s.stubs[req.ServiceName]

	resp := &message.Response{
		RequestID:  req.RequestID,
		Version:    req.Version,
		Compressor: req.Compressor,
		Serializer: req.Serializer,
	}

	if !ok {
		// 即使是 oneway 调用,也返回这个错误。
		return resp, errors.New("service not available")
	}

	if isOneway(ctx) {
		go func() {
			_, _ = stub.invoke(ctx, req)
		}()
		return nil, nil
	}

	respData, err := stub.invoke(ctx, req)
	resp.Data = respData
	if err != nil {
		return resp, err
	}

	return resp, nil
}

type reflectionStub struct {
	service     Service
	value       reflect.Value
	serializers map[uint8]serialize.Serializer
}

func (s *reflectionStub) invoke(ctx context.Context, req *message.Request) ([]byte, error) {
	// 反射找到方法,并且执行调用
	method := s.value.MethodByName(req.MethodName)

	inReq := reflect.New(method.Type().In(1).Elem())
	serializer, ok := s.serializers[req.Serializer]
	if !ok {
		return nil, errors.New("micro: not supported serializer " + strconv.FormatUint(uint64(req.Serializer), 10))
	}
	err := serializer.Decode(req.Data, inReq.Interface())
	if err != nil {
		return nil, err
	}

	in := []reflect.Value{
		reflect.ValueOf(context.Background()),
		inReq,
	}
	result := method.Call(in)

	if result[1].Interface() != nil {
		err = result[1].Interface().(error)
	}

	var res []byte
	if result[0].IsNil() {
		return nil, err
	} else {
		var er error
		res, er = serializer.Encode(result[0].Interface())
		if er != nil {
			return nil, er
		}
	}

	return res, err
}

context

package rpc

import "context"

type onewayKey struct {
}

func CtxWithOneway(ctx context.Context) context.Context {
	return context.WithValue(ctx, onewayKey{}, true)
}

func isOneway(ctx context.Context) bool {
	val := ctx.Value(onewayKey{})
	oneway, ok := val.(bool)
	return ok && oneway
}

tcp

package rpc

import (
	"encoding/binary"
	"net"
)

func ReadMsg(conn net.Conn) ([]byte, error) {
	lenBs := make([]byte, numOfLengthBytes)
	_, err := conn.Read(lenBs)
	if err != nil {
		return nil, err
	}

	headerLength := binary.BigEndian.Uint32(lenBs[:4])
	bodyLength := binary.BigEndian.Uint32(lenBs[4:8])
	length := headerLength + bodyLength

	data := make([]byte, length)
	copy(data[:8], lenBs)
	_, err = conn.Read(data[8:])

	return data, err
}

types

type Service interface {
	Name() string
}

type Proxy interface {
	Invoke(ctx context.Context, req *message.Request) (resp *message.Response, err error)
}

service

`

type UserService struct {
	// 用反射来赋值
	// 类型是函数的字段,它不是方法(它不是定义在 UserService 上的方法)
	// 本质上是一个字段
	GetById func(ctx context.Context, req *GetByIdReq) (*GetByIdResp, error)

	GetByIdProto func(ctx context.Context, req *gen.GetByIdReq) (*gen.GetByIdResp, error)
}

func (u UserService) Name() string {
	return "user-service"
}

type GetByIdReq struct {
	Id int
}

type GetByIdResp struct {
	Msg string
}

type UserServiceServer struct {
	Err error
	Msg string
}

func (u *UserServiceServer) GetById(ctx context.Context, req *GetByIdReq) (*GetByIdResp, error) {
	log.Println(req)
	return &GetByIdResp{
		Msg: u.Msg,
	}, u.Err
}

func (u *UserServiceServer) GetByIdProto(ctx context.Context, req *gen.GetByIdReq) (*gen.GetByIdResp, error) {
	log.Println(req)
	return &gen.GetByIdResp{
		User: &gen.User{
			Name: u.Msg,
		},
	}, u.Err
}

func (u *UserServiceServer) Name() string {
	return "user-service"
}

使用

server := NewServer()
service := &UserServiceServer{}
server.RegisterService(service)
server.RegisterSerializer(&proto.Serializer{})
go func() {
	err := server.Start("tcp", ":8081")
	t.Log(err)
}()
time.Sleep(time.Second * 3)

usClient := &UserService{}
client, err := NewClient(":8081", ClientWithSerializer(&proto.Serializer{}))
require.NoError(t, err)
err = client.InitService(usClient)
require.NoError(t, err)
resp, er := usClient.GetByIdProto(context.Background(), &gen.GetByIdReq{Id: 123})
log.Println(resp, er)

知识共享许可协议

本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。