diff --git a/udp/README.md b/udp/README.md index eae70b78..86959cb5 100644 --- a/udp/README.md +++ b/udp/README.md @@ -7,7 +7,6 @@ package main import ( "log" - "net" "github.com/go-kratos/kratos/v2" @@ -17,11 +16,11 @@ import ( func main() { err := kratos.New( kratos.Server( - udp.NewServer(":12190", udp.WithHandler(func(conn net.PacketConn, buf []byte, addr net.Addr) { - log.Println(string(buf)) - }), udp.WithRecoveryHandler(func(conn net.PacketConn, buf []byte, addr net.Addr, err interface{}) { + udp.NewServer(":12190", udp.WithHandler(func(msg *udp.Message) { + log.Printf("receive message: %s", msg.Body) + }), udp.WithRecoveryHandler(func(msg *udp.Message, err interface{}) { log.Println(err) - })), + }), udp.WithReadChanSize(10240)), ), ).Run() diff --git a/udp/server.go b/udp/server.go index c5467fcd..3ccc7e68 100644 --- a/udp/server.go +++ b/udp/server.go @@ -4,8 +4,15 @@ import ( "context" "log" "net" + "sync" ) +type Message struct { + Conn net.PacketConn + Addr net.Addr + Body []byte +} + type Server struct { address string @@ -13,9 +20,15 @@ type Server struct { conn net.PacketConn - handler func(conn net.PacketConn, buf []byte, addr net.Addr) + handler func(message *Message) + + recoveryHandler func(message *Message, err interface{}) - recoveryHandler func(conn net.PacketConn, buf []byte, addr net.Addr, err interface{}) + readChan chan *Message + readChanSize int // readChan size + + stoped chan struct{} + stopedOnce sync.Once } type Option func(*Server) @@ -28,7 +41,7 @@ func WithBufSize(bufSize int) Option { } } -func WithHandler(handler func(conn net.PacketConn, buf []byte, addr net.Addr)) Option { +func WithHandler(handler func(message *Message)) Option { return func(s *Server) { if handler != nil { s.handler = handler @@ -36,7 +49,7 @@ func WithHandler(handler func(conn net.PacketConn, buf []byte, addr net.Addr)) O } } -func WithRecoveryHandler(handler func(conn net.PacketConn, buf []byte, addr net.Addr, err interface{})) Option { +func WithRecoveryHandler(handler func(message *Message, err interface{})) Option { return func(s *Server) { if handler != nil { s.recoveryHandler = handler @@ -44,16 +57,28 @@ func WithRecoveryHandler(handler func(conn net.PacketConn, buf []byte, addr net. } } +func WithReadChanSize(readChanSize int) Option { + return func(s *Server) { + if readChanSize > 0 { + s.readChanSize = readChanSize + } + } +} + func NewServer(address string, opts ...Option) *Server { s := &Server{ - address: address, - bufSize: 1024, + address: address, + bufSize: 1024, + readChanSize: 1024, + stoped: make(chan struct{}), } for _, opt := range opts { opt(s) } + s.readChan = make(chan *Message, s.readChanSize) + return s } @@ -65,38 +90,65 @@ func (s *Server) Start(ctx context.Context) (err error) { log.Printf("udp server: listening on %s\n", s.address) + go s.start() + buf := make([]byte, s.bufSize) for { n, addr, err := s.conn.ReadFrom(buf) if err != nil { + s.stop() return err } - if s.handler == nil { - log.Printf("udp server: receive from %s: %s\n", addr.String(), string(buf)) - continue + s.readChan <- &Message{ + Conn: s.conn, + Addr: addr, + Body: buf[:n], } - - go s.handle(buf[:n], addr) } } -func (s *Server) handle(buf []byte, addr net.Addr) { +func (s *Server) start() { + for { + select { + case <-s.stoped: + return + case message := <-s.readChan: + if s.handler != nil { + s.handle(message) + } + } + } +} + +func (s *Server) handle(message *Message) { if s.recoveryHandler != nil { defer func() { if err := recover(); err != nil { - s.recoveryHandler(s.conn, buf, addr, err) + s.recoveryHandler(message, err) } }() } - s.handler(s.conn, buf, addr) + s.handler(message) } func (s *Server) Stop(ctx context.Context) error { log.Println("udp server: stopping") + s.stop() + + if s.conn == nil { + return nil + } + return s.conn.Close() } + +func (s *Server) stop() { + s.stopedOnce.Do(func() { + close(s.stoped) + }) +} diff --git a/udp/server_test.go b/udp/server_test.go index ee2ac75e..b59636f7 100644 --- a/udp/server_test.go +++ b/udp/server_test.go @@ -20,9 +20,9 @@ func TestServer(t *testing.T) { go func() { defer wg.Done() - server = NewServer(":12190", WithHandler(func(conn net.PacketConn, buf []byte, addr net.Addr) { - done <- buf - }), WithRecoveryHandler(func(conn net.PacketConn, buf []byte, addr net.Addr, err interface{}) { + server = NewServer(":12190", WithHandler(func(msg *Message) { + done <- msg.Body + }), WithRecoveryHandler(func(msg *Message, err interface{}) { t.Log(err) }), WithBufSize(1024))