diff --git a/Ipacker.go b/Ipacker.go index d845599..01d98e3 100644 --- a/Ipacker.go +++ b/Ipacker.go @@ -1,8 +1,7 @@ package network -import "io" - type IPacker interface { - Pack(message *Message) ([]byte, error) - Unpack(reader io.Reader) (*Message, error) + Pack(msgID uint16, msg interface{}) ([]byte, error) + Read(*TcpConnX) ([]byte, error) + Unpack([]byte) (*Message, error) } diff --git a/buffer_packer.go b/buffer_packer.go new file mode 100644 index 0000000..1e837ac --- /dev/null +++ b/buffer_packer.go @@ -0,0 +1,180 @@ +package network + +import ( + "encoding/binary" + "errors" + "fmt" + "github.com/golang/protobuf/proto" + "io" + "math" +) + +type BufferPacker struct { + lenMsgLen int32 + minMsgLen uint32 + maxMsgLen uint32 + recvBuff *ByteBuffer + sendBuff *ByteBuffer + byteOrder binary.ByteOrder +} + +func newInActionPacker() *BufferPacker { + msgParser := &BufferPacker{ + lenMsgLen: 4, + minMsgLen: 2, + maxMsgLen: 2 * 1024 * 1024, + recvBuff: NewByteBuffer(), + sendBuff: NewByteBuffer(), + byteOrder: binary.LittleEndian, + } + return msgParser +} + +// SetMsgLen It's dangerous to call the method on reading or writing +func (p *BufferPacker) SetMsgLen(lenMsgLen int32, minMsgLen uint32, maxMsgLen uint32) { + if lenMsgLen == 1 || lenMsgLen == 2 || lenMsgLen == 4 { + p.lenMsgLen = lenMsgLen + } + if minMsgLen != 0 { + p.minMsgLen = minMsgLen + } + if maxMsgLen != 0 { + p.maxMsgLen = maxMsgLen + } + + var max uint32 + switch p.lenMsgLen { + case 1: + max = math.MaxUint8 + case 2: + max = math.MaxUint16 + case 4: + max = math.MaxUint32 + } + if p.minMsgLen > max { + p.minMsgLen = max + } + if p.maxMsgLen > max { + p.maxMsgLen = max + } +} + +// Read goroutine safe +func (p *BufferPacker) Read(conn *TcpConnX) ([]byte, error) { + + p.recvBuff.EnsureWritableBytes(p.lenMsgLen) + + readLen, err := io.ReadFull(conn, p.recvBuff.WriteBuff()[:p.lenMsgLen]) + // read len + if err != nil { + return nil, fmt.Errorf("%v readLen:%v", err, readLen) + } + p.recvBuff.WriteBytes(int32(readLen)) + + // parse len + var msgLen uint32 + switch p.lenMsgLen { + case 2: + msgLen = uint32(p.recvBuff.ReadInt16()) + case 4: + msgLen = uint32(p.recvBuff.ReadInt32()) + } + + // check len + if msgLen > p.maxMsgLen { + return nil, errors.New("message too long") + } else if msgLen < p.minMsgLen { + return nil, errors.New("message too short") + } + + p.recvBuff.EnsureWritableBytes(int32(msgLen)) + + rLen, err := io.ReadFull(conn, p.recvBuff.WriteBuff()[:msgLen]) + if err != nil { + return nil, fmt.Errorf("%v msgLen:%v readLen:%v", err, msgLen, rLen) + } + p.recvBuff.WriteBytes(int32(rLen)) + + /* + // 保留了2字节flag 暂时未处理 + var flag uint16 + flag = uint16(p.recvBuff.ReadInt16()) + */ + p.recvBuff.Skip(2) // 跳过2字节保留字段 + + // 减去2字节的保留字段长度 + return p.recvBuff.NextBytes(int32(msgLen - 2)), nil + +} + +// goroutine safe +func (p *BufferPacker) Write(conn *TcpConnX, buff ...byte) error { + // get len + msgLen := uint32(len(buff)) + + // check len + if msgLen > p.maxMsgLen { + return errors.New("message too long") + } else if msgLen < p.minMsgLen { + return errors.New("message too short") + } + + // write len + switch p.lenMsgLen { + case 2: + p.sendBuff.AppendInt16(int16(msgLen)) + case 4: + p.sendBuff.AppendInt32(int32(msgLen)) + } + + p.sendBuff.Append(buff) + // write data + writeBuff := p.sendBuff.ReadBuff()[:p.sendBuff.Length()] + + _, err := conn.Write(writeBuff) + + p.sendBuff.Reset() + + return err +} + +func (p *BufferPacker) reset() { + p.recvBuff = NewByteBuffer() + p.sendBuff = NewByteBuffer() +} + +func (p *BufferPacker) Pack(msgID uint16, msg interface{}) ([]byte, error) { + pbMsg, ok := msg.(proto.Message) + if !ok { + return []byte{}, fmt.Errorf("msg is not protobuf message") + } + // data + data, err := proto.Marshal(pbMsg) + if err != nil { + return data, err + } + // 4byte = len(flag)[2byte] + len(msgID)[2byte] + buf := make([]byte, 4+len(data)) + if p.byteOrder == binary.LittleEndian { + binary.LittleEndian.PutUint16(buf[0:2], 0) + binary.LittleEndian.PutUint16(buf[2:], msgID) + } else { + binary.BigEndian.PutUint16(buf[0:2], 0) + binary.BigEndian.PutUint16(buf[2:], msgID) + } + copy(buf[4:], data) + return buf, err +} + +// Unpack id + protobuf data +func (p *BufferPacker) Unpack(data []byte) (*Message, error) { + if len(data) < 2 { + return nil, errors.New("protobuf data too short") + } + msgID := p.byteOrder.Uint16(data[:2]) + msg := &Message{ + ID: uint64(msgID), + Data: data[2:], + } + return msg, nil +} diff --git a/byte_buffer.go b/byte_buffer.go new file mode 100644 index 0000000..cdf17fc --- /dev/null +++ b/byte_buffer.go @@ -0,0 +1,230 @@ +package network + +import ( + "encoding/binary" +) + +const ( + cheapPrependSize = 8 + initialSize = 1024 +) + +// ByteBuffer 字节buff +type ByteBuffer struct { + mBuffer []byte + mCapacity int32 + readIndex int32 + writeIndex int32 + reservedPrependSize int32 + littleEndian bool +} + +// NewByteBuffer 创建一个字节buffer +func NewByteBuffer() *ByteBuffer { + return &ByteBuffer{ + mBuffer: make([]byte, cheapPrependSize+initialSize), + mCapacity: cheapPrependSize + initialSize, + readIndex: cheapPrependSize, + writeIndex: cheapPrependSize, + reservedPrependSize: cheapPrependSize, + littleEndian: true, + } +} + +// SetByteOrder It's dangerous to call the method on reading or writing +func (bf *ByteBuffer) SetByteOrder(littleEndian bool) { + bf.littleEndian = littleEndian +} + +// Length ... +func (bf *ByteBuffer) Length() int32 { + return bf.writeIndex - bf.readIndex +} + +// Swap ... +func (bf *ByteBuffer) Swap(other *ByteBuffer) { +} + +// Skip advances the reading index of the buffer +func (bf *ByteBuffer) Skip(len int32) { + if len < bf.Length() { + bf.readIndex = bf.readIndex + len + } else { + bf.Reset() + } +} + +// Retrieve ... +func (bf *ByteBuffer) Retrieve(len int32) { + bf.Skip(len) +} + +// Reset ... +func (bf *ByteBuffer) Reset() { + bf.Truncate(0) +} + +// Truncate ... +func (bf *ByteBuffer) Truncate(n int32) { + if n == 0 { + bf.readIndex = bf.reservedPrependSize + bf.writeIndex = bf.reservedPrependSize + } else if bf.writeIndex > (bf.readIndex + n) { + bf.writeIndex = bf.readIndex + n + } +} + +// Reserve ... +func (bf *ByteBuffer) Reserve(len int32) { + if bf.mCapacity >= len+bf.reservedPrependSize { + return + } + bf.grow(len + bf.reservedPrependSize) +} + +// Append ... +func (bf *ByteBuffer) Append(buff []byte) { + size := len(buff) + if size == 0 { + return + } + bf.write(buff, int32(size)) +} + +// AppendInt64 ... +func (bf *ByteBuffer) AppendInt64(x int64) { + buff := make([]byte, 8) + if bf.littleEndian { + binary.LittleEndian.PutUint64(buff, uint64(x)) + } else { + binary.BigEndian.PutUint64(buff, uint64(x)) + } + bf.write(buff, 8) +} + +// AppendInt32 ... +func (bf *ByteBuffer) AppendInt32(x int32) { + buff := make([]byte, 4) + if bf.littleEndian { + binary.LittleEndian.PutUint32(buff, uint32(x)) + } else { + binary.BigEndian.PutUint32(buff, uint32(x)) + } + bf.write(buff, 4) +} + +// AppendInt16 ... +func (bf *ByteBuffer) AppendInt16(x int16) { + buff := make([]byte, 2) + if bf.littleEndian { + binary.LittleEndian.PutUint16(buff, uint16(x)) + } else { + binary.BigEndian.PutUint16(buff, uint16(x)) + } + bf.write(buff, 2) +} + +// ReadInt64 ... +func (bf *ByteBuffer) ReadInt64() int64 { + buff := bf.mBuffer[bf.readIndex : bf.readIndex+8] + var result uint64 + if bf.littleEndian { + result = binary.LittleEndian.Uint64(buff) + } else { + result = binary.BigEndian.Uint64(buff) + } + bf.Skip(8) + return int64(result) +} + +// ReadInt32 ... +func (bf *ByteBuffer) ReadInt32() int32 { + buff := bf.mBuffer[bf.readIndex : bf.readIndex+4] + var result uint32 + if bf.littleEndian { + result = binary.LittleEndian.Uint32(buff) + } else { + result = binary.BigEndian.Uint32(buff) + } + bf.Skip(4) + return int32(result) +} + +// ReadInt16 ... +func (bf *ByteBuffer) ReadInt16() int16 { + buff := bf.mBuffer[bf.readIndex : bf.readIndex+2] + var result uint16 + if bf.littleEndian { + result = binary.LittleEndian.Uint16(buff) + } else { + result = binary.BigEndian.Uint16(buff) + } + bf.Skip(2) + return int16(result) +} + +// NextBytes 读取N字节 +func (bf *ByteBuffer) NextBytes(len int32) []byte { + msgData := bf.mBuffer[bf.readIndex : bf.readIndex+len] + bf.readIndex += len + return msgData +} + +// EnsureWritableBytes ... +func (bf *ByteBuffer) EnsureWritableBytes(len int32) { + if bf.writableBytes() < len { + bf.grow(len) + } +} + +func (bf *ByteBuffer) grow(len int32) { + if bf.writableBytes()+bf.prependableBytes() < len+bf.reservedPrependSize { + newCap := (bf.mCapacity << 1) + len + buff := make([]byte, newCap) + copy(buff, bf.mBuffer) + bf.mCapacity = newCap + bf.mBuffer = buff + } else { + readable := bf.Length() + copy(bf.mBuffer[bf.reservedPrependSize:], bf.mBuffer[bf.readIndex:bf.writeIndex]) + bf.readIndex = bf.reservedPrependSize + bf.writeIndex = bf.readIndex + readable + } +} + +func (bf *ByteBuffer) write(buff []byte, len int32) { + bf.EnsureWritableBytes(len) + copy(bf.mBuffer[bf.writeIndex:], buff) + bf.writeIndex = bf.writeIndex + len +} + +func (bf *ByteBuffer) writableBytes() int32 { + return bf.mCapacity - bf.writeIndex +} + +func (bf *ByteBuffer) prependableBytes() int32 { + return bf.readIndex +} + +// WriteBuff ... +func (bf *ByteBuffer) WriteBuff() []byte { + buffLen := int32(len(bf.mBuffer)) + if bf.writeIndex >= buffLen { + return nil + } + return bf.mBuffer[bf.writeIndex:] +} + +// ReadBuff ... +func (bf *ByteBuffer) ReadBuff() []byte { + buffLen := int32(len(bf.mBuffer)) + if bf.readIndex >= buffLen { + return nil + } + return bf.mBuffer[bf.readIndex:] +} + +// WriteBytes 写入n字节 +func (bf *ByteBuffer) WriteBytes(n int32) { + bf.writeIndex += n +} diff --git a/client.go b/client.go index ac91821..b32acd9 100644 --- a/client.go +++ b/client.go @@ -1,68 +1,97 @@ package network import ( - "encoding/binary" - "fmt" + "github.com/phuhao00/spoor" "net" + "runtime/debug" + "sync/atomic" ) type Client struct { - Address string - packer IPacker - ChMsg chan *Message - OnMessage func(message *ClientPacket) + *TcpConnX + Address string + ChMsg chan *Message + OnMessageCb func(message *Packet) + logger *spoor.Spoor + bufferSize int + running atomic.Value + OnCloseCallBack func() + closed int32 } -func NewClient(address string) *Client { - return &Client{ - Address: address, - packer: &NormalPacker{ - ByteOrder: binary.BigEndian, - }, - ChMsg: make(chan *Message, 1), +func NewClient(address string, connBuffSize int, logger *spoor.Spoor) *Client { + client := &Client{ + bufferSize: connBuffSize, + Address: address, + logger: logger, + TcpConnX: nil, } + client.running.Store(false) + return client } -func (c *Client) Run() { - conn, err := net.Dial("tcp6", c.Address) +func (c *Client) Dial() (*net.TCPConn, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp4", c.Address) + if err != nil { - fmt.Println(err) - return + return nil, err } - go c.Read(conn) - go c.Write(conn) -} + conn, err := net.DialTCP("tcp6", nil, tcpAddr) -func (c *Client) Write(conn net.Conn) { - for { - select { - case msg := <-c.ChMsg: - c.Send(conn, msg) - } + if err != nil { + return nil, err } + + return conn, nil } -func (c *Client) Send(conn net.Conn, message *Message) { - pack, err := c.packer.Pack(message) +func (c *Client) Run() { + conn, err := c.Dial() if err != nil { - //fmt.Println(err) + c.logger.ErrorF("%v", err) return } - conn.Write(pack) + tcpConnX, err := NewTcpConnX(conn, c.bufferSize, c.logger) + if err != nil { + c.logger.ErrorF("%v", err) + return + } + c.TcpConnX = tcpConnX + c.Impl = c + c.Reset() + c.running.Store(true) + go c.Connect() } -func (c *Client) Read(conn net.Conn) { - for { - message, err := c.packer.Unpack(conn) - if err != nil { - //fmt.Println(err) - continue +func (c *Client) OnClose() { + if c.OnCloseCallBack != nil { + c.OnCloseCallBack() + } + c.running.Store(false) + c.TcpConnX.OnClose() +} + +func (c *Client) OnMessage(data *Message, conn *TcpConnX) { + + c.Verify() + + defer func() { + if err := recover(); err != nil { + c.logger.ErrorF("[OnMessage] panic ", err, "\n", string(debug.Stack())) } - c.OnMessage(&ClientPacket{ - Msg: message, - Conn: conn, - }) - fmt.Println("read msg:", string(message.Data)) + }() + + c.OnMessageCb(&Packet{ + Msg: data, + Conn: conn, + }) +} + +// Close 关闭连接 +func (c *Client) Close() { + if atomic.CompareAndSwapInt32(&c.closed, 0, 1) { + c.Conn.Close() + close(c.stopped) } } diff --git a/example/client/mian.go b/example/client/mian.go new file mode 100644 index 0000000..b6693f1 --- /dev/null +++ b/example/client/mian.go @@ -0,0 +1,37 @@ +package main + +import ( + "fmt" + "github.com/phuhao00/greatestworks-proto/gen/player" + "github.com/phuhao00/network" + "github.com/phuhao00/network/example/logger" + "time" +) + +func main() { + client := network.NewClient(":8023", 200, logger.Logger) + client.OnMessageCb = OnClientMessage + client.Run() + go Tick(client) + select {} +} + +func OnClientMessage(packet *network.Packet) { + if packet.Msg.ID == 1 { + fmt.Println("hello world") + } + time.Sleep(time.Second) + packet.Conn.AsyncSend(3, &player.ChatMessage{ + Content: "abd", + Extra: nil, + }) +} + +func Tick(c *network.Client) { + + c.TcpConnX.AsyncSend(3, &player.ChatMessage{ + Content: "abd", + Extra: nil, + }) + +} diff --git a/example/logger/log.go b/example/logger/log.go new file mode 100644 index 0000000..a2720e3 --- /dev/null +++ b/example/logger/log.go @@ -0,0 +1,21 @@ +package logger + +import ( + "log" + "sync" + + "github.com/phuhao00/spoor" +) + +var ( + Logger *spoor.Spoor + onceInitLogger sync.Once +) + +func init() { + onceInitLogger.Do(func() { + fileWriter := spoor.NewFileWriter("log", 0, 0, 0) + l := spoor.NewSpoor(spoor.DEBUG, "", log.Ldate|log.Ltime|log.Lmicroseconds|log.Llongfile, spoor.WithFileWriter(fileWriter)) + Logger = l + }) +} diff --git a/example/server/main.go b/example/server/main.go new file mode 100644 index 0000000..c807f85 --- /dev/null +++ b/example/server/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "fmt" + "github.com/phuhao00/greatestworks-proto/gen/player" + "github.com/phuhao00/network" + "github.com/phuhao00/network/example/logger" + "time" +) + +func main() { + server := network.NewServer(":8023", 100, 200, logger.Logger) + server.MessageHandler = OnMessage + go server.Run() + select {} +} + +func OnMessage(packet *network.Packet) { + time.Sleep(time.Second) + + if packet.Msg.ID == 3 { + fmt.Println("hello hao") + } + packet.Conn.AsyncSend( + 1, + &player.SCSendChatMsg{}, + ) +} diff --git a/go.mod b/go.mod index 2674044..8c7c120 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module github.com/phuhao00/network go 1.18 + +require ( + github.com/golang/protobuf v1.5.2 + github.com/phuhao00/greatestworks-proto v1.0.1 + github.com/phuhao00/spoor v1.0.2 +) + +require google.golang.org/protobuf v1.28.1 // indirect diff --git a/normal_packer.go b/normal_packer.go new file mode 100644 index 0000000..78b33dc --- /dev/null +++ b/normal_packer.go @@ -0,0 +1,64 @@ +package network + +import ( + "encoding/binary" + "errors" + "fmt" + "github.com/golang/protobuf/proto" + "io" + "net" + "time" +) + +type NormalPacker struct { + ByteOrder binary.ByteOrder +} + +func (p *NormalPacker) Pack(msgID uint16, msg interface{}) ([]byte, error) { + pbMsg, ok := msg.(proto.Message) + if !ok { + return []byte{}, fmt.Errorf("msg is not protobuf message") + } + data, err := proto.Marshal(pbMsg) + if err != nil { + return data, err + } + buffer := make([]byte, 8+8+len(data)) + p.ByteOrder.PutUint64(buffer[0:8], uint64(len(buffer))) + p.ByteOrder.PutUint64(buffer[8:16], uint64(msgID)) + copy(buffer[16:], data) + return buffer, nil +} + +func (p *NormalPacker) Read(conn *TcpConnX) ([]byte, error) { + err := conn.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Second * 10)) + if err != nil { + return nil, err + } + buffer := make([]byte, 8+8) + + if _, err := io.ReadFull(conn.Conn, buffer); err != nil { + return nil, err + } + totalSize := p.ByteOrder.Uint64(buffer[:8]) + dataSize := totalSize - 8 - 8 + data := make([]byte, 8+dataSize) + copy(data[:8], buffer[8:]) + if _, err := io.ReadFull(conn.Conn, data[8:]); err != nil { + return nil, err + } + return data, nil +} + +// Unpack id + protobuf data +func (p *NormalPacker) Unpack(data []byte) (*Message, error) { + if len(data) < 2 { + return nil, errors.New("protobuf data too short") + } + msgID := p.ByteOrder.Uint16(data[:2]) + msg := &Message{ + ID: uint64(msgID), + Data: data[2:], + } + return msg, nil +} diff --git a/packer.go b/packer.go deleted file mode 100644 index a83af5f..0000000 --- a/packer.go +++ /dev/null @@ -1,45 +0,0 @@ -package network - -import ( - "encoding/binary" - "io" - "net" - "time" -) - -type NormalPacker struct { - ByteOrder binary.ByteOrder -} - -func (p *NormalPacker) Pack(message *Message) ([]byte, error) { - buffer := make([]byte, 8+8+len(message.Data)) - p.ByteOrder.PutUint64(buffer[0:8], uint64(len(buffer))) - p.ByteOrder.PutUint64(buffer[8:16], message.ID) - copy(buffer[16:], message.Data) - return buffer, nil -} - -func (p *NormalPacker) Unpack(reader io.Reader) (*Message, error) { - err := reader.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Second * 10)) - if err != nil { - return nil, err - } - buffer := make([]byte, 8+8) - - if _, err := io.ReadFull(reader, buffer); err != nil { - return nil, err - } - totalSize := p.ByteOrder.Uint64(buffer[:8]) - Id := p.ByteOrder.Uint64(buffer[8:]) - dataSize := totalSize - 8 - 8 - data := make([]byte, dataSize) - - if _, err := io.ReadFull(reader, data); err != nil { - return nil, err - } - msg := &Message{ - ID: Id, - Data: data, - } - return msg, nil -} diff --git a/packet.go b/packet.go index ef4e459..14e5bf7 100644 --- a/packet.go +++ b/packet.go @@ -1,13 +1,6 @@ package network -import "net" - -type ClientPacket struct { - Msg *Message - Conn net.Conn -} - -type SessionPacket struct { +type Packet struct { Msg *Message - Sess *Session + Conn *TcpConnX } diff --git a/server.go b/server.go index fb4b878..59e81f7 100644 --- a/server.go +++ b/server.go @@ -1,47 +1,165 @@ package network import ( - "fmt" + "github.com/phuhao00/spoor" "net" + "os" + "runtime/debug" + "sync" + "sync/atomic" + "time" ) type Server struct { - tcpListener net.Listener - OnSessionPacket func(packet *SessionPacket) - Address string + pid int64 + Addr string + MaxConnNum int + ln *net.TCPListener + connSet map[net.Conn]interface{} + counter int64 + idCounter int64 + mutexConn sync.Mutex + wgLn sync.WaitGroup + wgConn sync.WaitGroup + connBuffSize int + logger *spoor.Spoor + MessageHandler func(packet *Packet) } -func NewServer(address string) *Server { - - s := &Server{Address: address} - +func NewServer(addr string, maxConnNum int, buffSize int, logger *spoor.Spoor) *Server { + s := &Server{ + Addr: addr, + MaxConnNum: maxConnNum, + connBuffSize: buffSize, + logger: logger, + } + s.Init() return s - } -func (s *Server) Run() { - resolveTCPAddr, err := net.ResolveTCPAddr("tcp6", s.Address) +func (s *Server) Init() { + tcpAddr, err := net.ResolveTCPAddr("tcp4", s.Addr) + if err != nil { - panic(err) + s.logger.FatalF("[net] addr resolve error", tcpAddr, err) } - tcpListener, err := net.ListenTCP("tcp6", resolveTCPAddr) + + ln, err := net.ListenTCP("tcp6", tcpAddr) + if err != nil { - panic(err) + s.logger.FatalF("%v", err) + } + + if s.MaxConnNum <= 0 { + s.MaxConnNum = 100 + s.logger.InfoF("invalid MaxConnNum, reset to %v", s.MaxConnNum) } - s.tcpListener = tcpListener + + s.ln = ln + s.connSet = make(map[net.Conn]interface{}) + s.counter = 1 + s.idCounter = 1 + s.pid = int64(os.Getpid()) + s.logger.InfoF("Server Listen %s", s.ln.Addr().String()) +} + +func (s *Server) Run() { + defer func() { + if err := recover(); err != nil { + s.logger.ErrorF("[net] panic", err, "\n", string(debug.Stack())) + } + }() + + s.wgLn.Add(1) + defer s.wgLn.Done() + + var tempDelay time.Duration for { - conn, err := s.tcpListener.Accept() + conn, err := s.ln.AcceptTCP() + if err != nil { if _, ok := err.(net.Error); ok { - fmt.Println(err) + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + s.logger.InfoF("accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) continue } + return } + tempDelay = 0 - newSession := NewSession(conn) - newSession.MessageHandler = s.OnSessionPacket - SessionMgrInstance.AddSession(newSession) - newSession.Run() - SessionMgrInstance.DelSession(newSession.UId) + if atomic.LoadInt64(&s.counter) >= int64(s.MaxConnNum) { + conn.Close() + s.logger.InfoF("too many connections %v", atomic.LoadInt64(&s.counter)) + continue + } + tcpConnX, err := NewTcpConnX(conn, s.connBuffSize, s.logger) + if err != nil { + s.logger.ErrorF("%v", err) + return + } + s.addConn(conn, tcpConnX) + tcpConnX.Impl = s + s.wgConn.Add(1) + go func() { + tcpConnX.Connect() + s.removeConn(conn, tcpConnX) + s.wgConn.Done() + }() } } + +func (s *Server) Close() { + s.ln.Close() + s.wgLn.Wait() + + s.mutexConn.Lock() + for conn := range s.connSet { + conn.Close() + } + s.connSet = nil + s.mutexConn.Unlock() + s.wgConn.Wait() +} + +func (s *Server) addConn(conn net.Conn, tcpConnX *TcpConnX) { + s.mutexConn.Lock() + atomic.AddInt64(&s.counter, 1) + s.connSet[conn] = conn + nowTime := time.Now().Unix() + idCounter := atomic.AddInt64(&s.idCounter, 1) + connId := (nowTime << 32) | (s.pid << 24) | idCounter + tcpConnX.ConnID = connId + s.mutexConn.Unlock() + tcpConnX.OnConnect() +} + +func (s *Server) removeConn(conn net.Conn, tcpConn *TcpConnX) { + tcpConn.Close() + s.mutexConn.Lock() + atomic.AddInt64(&s.counter, -1) + delete(s.connSet, conn) + s.mutexConn.Unlock() +} + +func (s *Server) OnMessage(message *Message, conn *TcpConnX) { + s.MessageHandler(&Packet{ + Msg: message, + Conn: conn, + }) +} + +func (s *Server) OnClose() { + +} + +func (s *Server) OnConnect() { + +} diff --git a/session.go b/session.go deleted file mode 100644 index 35ed047..0000000 --- a/session.go +++ /dev/null @@ -1,76 +0,0 @@ -package network - -import ( - "encoding/binary" - "fmt" - "net" - "time" -) - -type Session struct { - UId uint64 - Conn net.Conn - IsClose bool - packer IPacker - WriteCh chan *Message - IsPlayerOnline bool - MessageHandler func(packet *SessionPacket) - // -} - -func NewSession(conn net.Conn) *Session { - return &Session{Conn: conn, packer: &NormalPacker{ByteOrder: binary.BigEndian}, WriteCh: make(chan *Message, 10)} -} - -func (s *Session) Run() { - go s.Read() - go s.Write() - -} - -func (s *Session) Read() { - for { - err := s.Conn.SetReadDeadline(time.Now().Add(time.Second)) - if err != nil { - fmt.Println(err) - continue - } - message, err := s.packer.Unpack(s.Conn) - if _, ok := err.(net.Error); ok { - continue - } - fmt.Println("receive message:", string(message.Data)) - s.MessageHandler(&SessionPacket{ - Msg: message, - Sess: s, - }) - } -} - -func (s *Session) Write() { - for { - select { - case resp := <-s.WriteCh: - s.send(resp) - } - } -} - -func (s *Session) send(message *Message) { - err := s.Conn.SetWriteDeadline(time.Now().Add(time.Second)) - if err != nil { - fmt.Println(err) - return - } - bytes, err := s.packer.Pack(message) - if err != nil { - fmt.Println(err) - return - } - s.Conn.Write(bytes) - -} - -func (s *Session) SendMsg(msg *Message) { - s.WriteCh <- msg -} diff --git a/session_mgr.go b/session_mgr.go deleted file mode 100644 index 13e4047..0000000 --- a/session_mgr.go +++ /dev/null @@ -1,43 +0,0 @@ -package network - -import "sync" - -type SessionMgr struct { - Sessions map[uint64]*Session - Counter int64 //计数器 - Mutex sync.Mutex - Pid int64 -} - -var ( - SessionMgrInstance SessionMgr - onceInitSessionMgr sync.Once -) - -func init() { - onceInitSessionMgr.Do(func() { - SessionMgrInstance = SessionMgr{ - Sessions: make(map[uint64]*Session), - Counter: 0, - Mutex: sync.Mutex{}, - } - }) -} - -//AddSession ... -func (sm *SessionMgr) AddSession(s *Session) { - sm.Mutex.Lock() - defer sm.Mutex.Unlock() - if val := sm.Sessions[s.UId]; val != nil { - if val.IsClose { - sm.Sessions[s.UId] = s - } else { - return - } - } -} - -//DelSession ... -func (sm *SessionMgr) DelSession(UId uint64) { - delete(sm.Sessions, UId) -} diff --git a/tcp_connx.go b/tcp_connx.go new file mode 100644 index 0000000..0a3aaef --- /dev/null +++ b/tcp_connx.go @@ -0,0 +1,320 @@ +package network + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "reflect" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "github.com/phuhao00/spoor" +) + +type IConn interface { + OnConnect() + OnClose() + OnMessage(*Message, *TcpConnX) +} + +const timeoutTime = 30 // 连接通过验证的超时时间 + +type TcpConnX struct { + Conn net.Conn + Impl IConn + ConnID int64 + verify int32 + closed int32 + stopped chan bool + signal chan interface{} + lastSignal chan interface{} + wgRW sync.WaitGroup + msgParser *BufferPacker + msgBuffSize int + logger *spoor.Spoor +} + +func NewTcpConnX(conn *net.TCPConn, msgBuffSize int, logger *spoor.Spoor) (*TcpConnX, error) { + tcpConn := &TcpConnX{ + closed: -1, + verify: 0, + stopped: make(chan bool, 1), + signal: make(chan interface{}, 100), + lastSignal: make(chan interface{}, 1), + Conn: conn, + msgBuffSize: msgBuffSize, + msgParser: newInActionPacker(), + logger: logger, + } + // Try to open keepalive for tcp. + err := conn.SetKeepAlive(true) + if err != nil { + return nil, err + } + err = conn.SetKeepAlivePeriod(1 * time.Minute) + if err != nil { + return nil, err + } + // disable Nagle algorithm. + err = conn.SetNoDelay(true) + if err != nil { + return nil, err + } + err = conn.SetWriteBuffer(msgBuffSize) + if err != nil { + return nil, err + } + err = conn.SetReadBuffer(msgBuffSize) + if err != nil { + return nil, err + } + return tcpConn, nil +} + +func (c *TcpConnX) Connect() { + if atomic.CompareAndSwapInt32(&c.closed, -1, 0) { + c.wgRW.Add(1) + go c.HandleRead() + c.wgRW.Add(1) + go c.HandleWrite() + } + timeout := time.NewTimer(time.Second * timeoutTime) +L: + for { + select { + // 等待通到返回 返回后检查连接是否验证完成 如果没有验证 则关闭连接 + case <-timeout.C: + if !c.Verified() { + c.logger.ErrorF("[Connect] 验证超时 ip addr %s", c.RemoteAddr()) + c.Close() + break L + } + case <-c.stopped: + break L + } + } + timeout.Stop() + c.wgRW.Wait() + c.Impl.OnClose() +} + +func (c *TcpConnX) HandleRead() { + defer func() { + if err := recover(); err != nil { + c.logger.ErrorF("[HandleRead] panic ", err, "\n", string(debug.Stack())) + } + }() + defer c.Close() + + defer c.wgRW.Done() + + for { + data, err := c.msgParser.Read(c) + if err != nil { + if err != io.EOF { + c.logger.ErrorF("read message error: %v", err) + } + break + } + message, err := c.msgParser.Unpack(data) + c.Impl.OnMessage(message, c) + } +} + +func (c *TcpConnX) HandleWrite() { + defer func() { + if err := recover(); err != nil { + c.logger.ErrorF("[HandleWrite] panic", err, "\n", string(debug.Stack())) + } + }() + defer c.Close() + defer c.wgRW.Done() + for { + select { + case signal := <-c.signal: // 普通消息 + data, ok := signal.([]byte) + if !ok { + c.logger.ErrorF("write message %v error: msg is not bytes", reflect.TypeOf(signal)) + return + } + err := c.msgParser.Write(c, data...) + if err != nil { + c.logger.ErrorF("write message %v error: %v", reflect.TypeOf(signal), err) + return + } + case signal := <-c.lastSignal: // 最后一个通知消息 + data, ok := signal.([]byte) + if !ok { + c.logger.ErrorF("write message %v error: msg is not bytes", reflect.TypeOf(signal)) + return + } + err := c.msgParser.Write(c, data...) + if err != nil { + c.logger.ErrorF("write message %v error: %v", reflect.TypeOf(signal), err) + return + } + time.Sleep(2 * time.Second) + return + case <-c.stopped: // 连接关闭通知 + return + } + } +} + +func (c *TcpConnX) AsyncSend(msgID uint16, msg interface{}) bool { + + if c.IsShutdown() { + return false + } + + data, err := c.msgParser.Pack(msgID, msg) + if err != nil { + c.logger.ErrorF("[AsyncSend] Pack msgID:%v and msg to bytes error:%v", msgID, err) + return false + } + + if uint32(len(data)) > c.msgParser.maxMsgLen { + c.logger.ErrorF("[AsyncSend] 发送的消息包体过长 msgID:%v", msgID) + return false + } + + err = c.Signal(data) + if err != nil { + c.Close() + c.logger.ErrorF("%v", err) + return false + } + + return true +} + +func (c *TcpConnX) AsyncSendRowMsg(data []byte) bool { + + if c.IsShutdown() { + return false + } + + if uint32(len(data)) > c.msgParser.maxMsgLen { + c.logger.ErrorF("[AsyncSendRowMsg] 发送的消息包体过长 AsyncSendRowMsg") + return false + } + + err := c.Signal(data) + if err != nil { + c.Close() + c.logger.ErrorF("%v", err) + return false + } + + return true +} + +// AsyncSendLastPacket 缓存在发送队列里等待发送goroutine取出 (发送最后一个消息 发送会关闭tcp连接 终止tcp goroutine) +func (c *TcpConnX) AsyncSendLastPacket(msgID uint16, msg interface{}) bool { + data, err := c.msgParser.Pack(msgID, msg) + if err != nil { + c.logger.ErrorF("[AsyncSendLastPacket] Pack msgID:%v and msg to bytes error:%v", msgID, err) + return false + } + + if uint32(len(data)) > c.msgParser.maxMsgLen { + c.logger.ErrorF("[AsyncSendLastPacket] 发送的消息包体过长 msgID:%v", msgID) + return false + } + + err = c.LastSignal(data) + if err != nil { + c.Close() + c.logger.ErrorF("%v", err) + return false + } + + return true +} + +func (c *TcpConnX) Signal(signal []byte) error { + select { + case c.signal <- signal: + return nil + default: + { + cmd := binary.LittleEndian.Uint16(signal[2:4]) + return fmt.Errorf("[Signal] buffer full blocking connID:%v cmd:%v", c.ConnID, cmd) + } + } +} + +func (c *TcpConnX) LastSignal(signal []byte) error { + select { + case c.lastSignal <- signal: + return nil + default: + { + cmd := binary.LittleEndian.Uint16(signal[2:4]) + return fmt.Errorf("[LastSignal] buffer full blocking connID:%v cmd:%v", c.ConnID, cmd) + } + } +} + +func (c *TcpConnX) Verified() bool { + return atomic.LoadInt32(&c.verify) != 0 +} + +func (c *TcpConnX) Verify() { + atomic.CompareAndSwapInt32(&c.verify, 0, 1) +} + +func (c *TcpConnX) IsClosed() bool { + return atomic.LoadInt32(&c.closed) != 0 +} + +func (c *TcpConnX) IsShutdown() bool { + return atomic.LoadInt32(&c.closed) == 1 +} + +func (c *TcpConnX) Close() { + if atomic.CompareAndSwapInt32(&c.closed, 0, 1) { + c.Conn.Close() + close(c.stopped) + } +} + +func (c *TcpConnX) Read(b []byte) (int, error) { + return c.Conn.Read(b) +} + +func (c *TcpConnX) Write(b []byte) (int, error) { + return c.Conn.Write(b) +} + +func (c *TcpConnX) LocalAddr() net.Addr { + return c.Conn.LocalAddr() +} + +func (c *TcpConnX) RemoteAddr() net.Addr { + return c.Conn.RemoteAddr() +} + +func (c *TcpConnX) Reset() { + if atomic.LoadInt32(&c.closed) == -1 { + return + } + c.closed = -1 + c.verify = 0 + c.stopped = make(chan bool, 1) + c.signal = make(chan interface{}, c.msgBuffSize) + c.lastSignal = make(chan interface{}, 1) + c.msgParser.reset() +} + +// OnConnect ... +func (c *TcpConnX) OnConnect() { + c.logger.DebugF("[OnConnect] 建立连接 local:%s remote:%s", c.LocalAddr(), c.RemoteAddr()) +} + +func (c *TcpConnX) OnClose() { + c.logger.InfoF("[OnConnect] 断开连接 local:%s remote:%s", c.LocalAddr(), c.RemoteAddr()) +}