手搓 RPC 框架 -- #7 代码篇
客户端代码
package rpc
import (
"context"
"errors"
"fmt"
"net"
"reflect"
"strconv"
"time"
"github.com/luxpo/luxpo/go2nd/micro/rpc2/message"
"github.com/luxpo/luxpo/go2nd/micro/rpc2/serialize"
"github.com/luxpo/luxpo/go2nd/micro/rpc2/serialize/json"
"github.com/silenceper/pool"
)
var (
ErrServiceNil = errors.New("rpc: nil service is not supported")
ErrServiceWrongType = errors.New("rpc: only first-level pointer to struct is supported")
)
func (c *Client) InitService(service Service) error {
return setFuncField(service, c, c.serializer)
}
func setFuncField(service Service, proxy Proxy, s serialize.Serializer) error {
if service == nil {
return ErrServiceNil
}
val := reflect.ValueOf(service)
typ := val.Type()
// 只支持指向结构体的一级指针
if typ.Kind() != reflect.Pointer || typ.Elem().Kind() != reflect.Struct {
return ErrServiceWrongType
}
val = val.Elem()
typ = typ.Elem()
numField := typ.NumField()
for i := 0; i < numField; i++ {
fieldTyp := typ.Field(i)
fieldVal := val.Field(i)
if !fieldVal.CanSet() || fieldVal.Kind() != reflect.Func {
continue
}
fn := func(args []reflect.Value) (results []reflect.Value) {
ctx := args[0].Interface().(context.Context)
retVal := reflect.New(fieldTyp.Type.Out(0).Elem())
reqArg, err := s.Encode(args[1].Interface())
if err != nil {
return []reflect.Value{
retVal,
reflect.ValueOf(err),
}
}
meta := make(map[string]string, 2)
// 设置了超时
if deadline, ok := ctx.Deadline(); ok {
meta["deadline"] = strconv.FormatInt(deadline.UnixMilli(), 10)
}
if isOneway(ctx) {
meta = map[string]string{
"one-way": "true",
}
}
req := &message.Request{
ServiceName: service.Name(),
MethodName: fieldTyp.Name,
Data: reqArg,
Serializer: s.Code(),
Meta: meta,
}
fmt.Println(req)
req.CalculateHeaderLength()
req.CalculateBodyLength()
resp, err := proxy.Invoke(ctx, req)
if err != nil {
return []reflect.Value{
retVal,
reflect.ValueOf(err),
}
}
fmt.Println(string(resp.Data))
var retErr error
if len(resp.Error) > 0 {
// 服务端传来的 error
retErr = errors.New(string(resp.Error))
}
if len(resp.Data) > 0 {
err = s.Decode(resp.Data, retVal.Interface())
if err != nil {
// 反序列化的 error
return []reflect.Value{
retVal,
reflect.ValueOf(err),
}
}
}
var retErrVal reflect.Value
if retErr != nil {
retErrVal = reflect.ValueOf(retErr)
} else {
retErrVal = reflect.Zero(reflect.TypeOf(new(error)).Elem())
}
return []reflect.Value{
retVal,
retErrVal,
}
}
fnVal := reflect.MakeFunc(fieldTyp.Type, fn)
fieldVal.Set(fnVal)
}
return nil
}
type Client struct {
pool pool.Pool
serializer serialize.Serializer
}
type ClientOption func(client *Client)
func ClientWithSerializer(sl serialize.Serializer) ClientOption {
return func(client *Client) {
client.serializer = sl
}
}
func NewClient(network, addr string, opts ...ClientOption) (*Client, error) {
p, err := pool.NewChannelPool(
&pool.Config{
InitialCap: 1,
MaxCap: 30,
MaxIdle: 10,
IdleTimeout: time.Minute,
Factory: func() (interface{}, error) {
return net.DialTimeout(network, addr, time.Second*3)
},
Close: func(i interface{}) error {
return i.(net.Conn).Close()
},
},
)
if err != nil {
return nil, err
}
c := &Client{
pool: p,
serializer: &json.Serializer{},
}
for _, opt := range opts {
opt(c)
}
return c, nil
}
func (c *Client) Invoke(ctx context.Context, req *message.Request) (*message.Response, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
ch := make(chan struct{})
defer func() {
close(ch)
}()
var (
resp *message.Response
err error
)
go func() {
resp, err = c.doInvoke(ctx, req)
ch <- struct{}{}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ch:
return resp, err
}
}
func (c *Client) doInvoke(ctx context.Context, req *message.Request) (*message.Response, error) {
data := message.EncodeReq(req)
resp, err := c.Send(ctx, data)
if err != nil {
return nil, err
}
return message.DecodeResp(resp), nil
}
func (c *Client) Send(ctx context.Context, data []byte) ([]byte, error) {
val, err := c.pool.Get()
if err != nil {
return nil, err
}
conn := val.(net.Conn)
defer func() {
c.pool.Put(val)
}()
_, err = conn.Write(data)
if err != nil {
return nil, err
}
if isOneway(ctx) {
// oneway 调用直接返回
return nil, errors.New("micro: oneway")
}
return ReadMsg(conn)
}
服务端代码
package rpc
import (
"context"
"errors"
"log"
"net"
"reflect"
"strconv"
"time"
"github.com/luxpo/luxpo/go2nd/micro/rpc2/message"
"github.com/luxpo/luxpo/go2nd/micro/rpc2/serialize"
"github.com/luxpo/luxpo/go2nd/micro/rpc2/serialize/json"
)
// 长度字段使用的字节数量
const numOfLengthBytes = 8
type Server struct {
stubs map[string]reflectionStub
serializers map[uint8]serialize.Serializer
}
func NewServer() *Server {
s := &Server{
stubs: make(map[string]reflectionStub, 16),
serializers: make(map[uint8]serialize.Serializer, 4),
}
s.RegisterSerializer(&json.Serializer{})
return s
}
func (s *Server) RegisterSerializer(sl serialize.Serializer) {
s.serializers[sl.Code()] = sl
}
func (s *Server) RegisterService(service Service) {
s.stubs[service.Name()] = reflectionStub{
service: service,
value: reflect.ValueOf(service),
serializers: s.serializers,
}
}
func (s *Server) Start(network, address string) error {
listener, err := net.Listen(network, address)
if err != nil {
return err
}
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go func() {
if herr := s.handleConn(conn); herr != nil {
_ = conn.Close()
}
}()
}
}
func (s *Server) handleConn(conn net.Conn) error {
for {
reqBs, err := ReadMsg(conn)
if err != nil {
return err
}
// 还原调用信息
req := message.DecodeReq(reqBs)
ctx := context.Background()
cancel := func() {}
if deadlineStr, ok := req.Meta["deadline"]; ok {
log.Println(deadlineStr)
if deadline, er := strconv.ParseInt(deadlineStr, 10, 64); er == nil {
log.Println(deadline)
ctx, cancel = context.WithDeadline(ctx, time.UnixMilli(deadline))
}
}
oneway, ok := req.Meta["one-way"]
if ok && oneway == "true" {
ctx = CtxWithOneway(ctx)
}
resp, err := s.Invoke(ctx, req)
cancel()
if err != nil {
// 处理业务 error
resp.Error = []byte(err.Error())
}
resp.CalculateHeaderLength()
resp.CalculateBodyLength()
_, err = conn.Write(message.EncodeResp(resp))
if err != nil {
return err
}
}
}
func (s *Server) Invoke(ctx context.Context, req *message.Request) (*message.Response, error) {
// 发起业务调用
stub, ok := s.stubs[req.ServiceName]
resp := &message.Response{
RequestID: req.RequestID,
Version: req.Version,
Compressor: req.Compressor,
Serializer: req.Serializer,
}
if !ok {
// 即使是 oneway 调用,也返回这个错误。
return resp, errors.New("service not available")
}
if isOneway(ctx) {
go func() {
_, _ = stub.invoke(ctx, req)
}()
return nil, nil
}
respData, err := stub.invoke(ctx, req)
resp.Data = respData
if err != nil {
return resp, err
}
return resp, nil
}
type reflectionStub struct {
service Service
value reflect.Value
serializers map[uint8]serialize.Serializer
}
func (s *reflectionStub) invoke(ctx context.Context, req *message.Request) ([]byte, error) {
// 反射找到方法,并且执行调用
method := s.value.MethodByName(req.MethodName)
inReq := reflect.New(method.Type().In(1).Elem())
serializer, ok := s.serializers[req.Serializer]
if !ok {
return nil, errors.New("micro: not supported serializer " + strconv.FormatUint(uint64(req.Serializer), 10))
}
err := serializer.Decode(req.Data, inReq.Interface())
if err != nil {
return nil, err
}
in := []reflect.Value{
reflect.ValueOf(context.Background()),
inReq,
}
result := method.Call(in)
if result[1].Interface() != nil {
err = result[1].Interface().(error)
}
var res []byte
if result[0].IsNil() {
return nil, err
} else {
var er error
res, er = serializer.Encode(result[0].Interface())
if er != nil {
return nil, er
}
}
return res, err
}
context
package rpc
import "context"
type onewayKey struct {
}
func CtxWithOneway(ctx context.Context) context.Context {
return context.WithValue(ctx, onewayKey{}, true)
}
func isOneway(ctx context.Context) bool {
val := ctx.Value(onewayKey{})
oneway, ok := val.(bool)
return ok && oneway
}
tcp
package rpc
import (
"encoding/binary"
"net"
)
func ReadMsg(conn net.Conn) ([]byte, error) {
lenBs := make([]byte, numOfLengthBytes)
_, err := conn.Read(lenBs)
if err != nil {
return nil, err
}
headerLength := binary.BigEndian.Uint32(lenBs[:4])
bodyLength := binary.BigEndian.Uint32(lenBs[4:8])
length := headerLength + bodyLength
data := make([]byte, length)
copy(data[:8], lenBs)
_, err = conn.Read(data[8:])
return data, err
}
types
type Service interface {
Name() string
}
type Proxy interface {
Invoke(ctx context.Context, req *message.Request) (resp *message.Response, err error)
}
service
`
type UserService struct {
// 用反射来赋值
// 类型是函数的字段,它不是方法(它不是定义在 UserService 上的方法)
// 本质上是一个字段
GetById func(ctx context.Context, req *GetByIdReq) (*GetByIdResp, error)
GetByIdProto func(ctx context.Context, req *gen.GetByIdReq) (*gen.GetByIdResp, error)
}
func (u UserService) Name() string {
return "user-service"
}
type GetByIdReq struct {
Id int
}
type GetByIdResp struct {
Msg string
}
type UserServiceServer struct {
Err error
Msg string
}
func (u *UserServiceServer) GetById(ctx context.Context, req *GetByIdReq) (*GetByIdResp, error) {
log.Println(req)
return &GetByIdResp{
Msg: u.Msg,
}, u.Err
}
func (u *UserServiceServer) GetByIdProto(ctx context.Context, req *gen.GetByIdReq) (*gen.GetByIdResp, error) {
log.Println(req)
return &gen.GetByIdResp{
User: &gen.User{
Name: u.Msg,
},
}, u.Err
}
func (u *UserServiceServer) Name() string {
return "user-service"
}
使用
server := NewServer()
service := &UserServiceServer{}
server.RegisterService(service)
server.RegisterSerializer(&proto.Serializer{})
go func() {
err := server.Start("tcp", ":8081")
t.Log(err)
}()
time.Sleep(time.Second * 3)
usClient := &UserService{}
client, err := NewClient(":8081", ClientWithSerializer(&proto.Serializer{}))
require.NoError(t, err)
err = client.InitService(usClient)
require.NoError(t, err)
resp, er := usClient.GetByIdProto(context.Background(), &gen.GetByIdReq{Id: 123})
log.Println(resp, er)
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。