diff --git a/cmd/radserver/main.go b/cmd/radserver/main.go index 2b5d514..84b2c11 100644 --- a/cmd/radserver/main.go +++ b/cmd/radserver/main.go @@ -16,6 +16,10 @@ var secret = flag.String("secret", "", "shared RADIUS secret between clients and var command string var arguments []string +func error_handler(e error, p *radius.Packet) { + log.Printf("radserver err: %v", e) +} + func handler(w radius.ResponseWriter, p *radius.Packet) { username, password, ok := p.PAP() @@ -140,10 +144,11 @@ func main() { } server := radius.Server{ Handler: radius.HandlerFunc(handler), + ErrorHandler: radius.ErrorFunc(error_handler), Secret: []byte(*secret), Dictionary: radius.Builtin, ClientsMap: map[string]string{ - "127.0.0.1": "abc", + "127.0.0.51/24": "abc", }, } diff --git a/server.go b/server.go index e55db09..0e92dbc 100644 --- a/server.go +++ b/server.go @@ -13,15 +13,26 @@ type Handler interface { ServeRadius(w ResponseWriter, p *Packet) } +type ErrorHandler interface { + Error(e error, p *Packet) +} + // HandlerFunc is a wrapper that allows ordinary functions to be used as a // handler. type HandlerFunc func(w ResponseWriter, p *Packet) +type ErrorFunc func(e error, p *Packet) + // ServeRadius calls h(w, p). func (h HandlerFunc) ServeRadius(w ResponseWriter, p *Packet) { h(w, p) } +// Handle any errors +func (h ErrorFunc) Error(e error, p *Packet) { + h(e, p) +} + // ResponseWriter is used by Handler when replying to a RADIUS packet. type ResponseWriter interface { @@ -129,8 +140,8 @@ type Server struct { // Client->Secret mapping ClientsMap map[string]string clientIP []string - ClientNets []net.IPNet - ClientSecrets []string + ClientIPMap map[string]string + ClientNetMap map[string]string // Dictionary used when decoding incoming packets. Dictionary *Dictionary @@ -138,6 +149,9 @@ type Server struct { // The packet handler that handles incoming, valid packets. Handler Handler + // Error handler for any errors outside the handler + ErrorHandler ErrorHandler + // Listener listener *net.UDPConn @@ -145,50 +159,46 @@ type Server struct { CloseChan chan bool } -func (s *Server) ResetClientNets() error { +func (s *Server) ResetClientNetMap() error { - s.ClientNets = nil - s.ClientSecrets = nil + s.ClientNetMap = make(map[string]string, 0) + s.ClientIPMap = make(map[string]string, 0) + ipParseErrors := make([]string, 0) + // return errors.New("Unable to parse CIDR or IP " + k) if s.ClientsMap != nil { for k, v := range s.ClientsMap { - _, subnet, err := net.ParseCIDR(k) - if err != nil { - return errors.New("Unable to parse CIDR or IP " + k) + ip, subnet, err := net.ParseCIDR(k) + if err == nil { + s.ClientNetMap[subnet.String()] = v + } else { + ip = net.ParseIP(k) + if ip != nil { + s.ClientIPMap[string(ip)] = v + } else { + ipParseErrors = append(ipParseErrors, k) + } } - - s.ClientNets = append(s.ClientNets, *subnet) - s.ClientSecrets = append(s.ClientSecrets, v) } } - - return nil -} - - -func (s *Server) CheckClientsMap() error { - - if s.ClientsMap != nil { - for k, _ := range s.ClientsMap { - ip := net.ParseIP(k) - if ip == nil { - return errors.New("Not legal Clients IP address") - } - } + if len(ipParseErrors) > 0{ + return errors.New("Unable to parse CIDR or IP " + fmt.Sprintf("%v", ipParseErrors)) } - return nil } func (s *Server) AddClientsMap(m map[string]string ) { if s.ClientsMap == nil && len(m) > 0 { s.ClientsMap = m - // s.ResetClientNets() + s.ResetClientNetMap() } } +func defaultErrorHandler(e error, p *Packet) { + log.Printf("Radius Server Error %v", e) +} // ListenAndServe starts a RADIUS server on the address given in s. func (s *Server) ListenAndServe() error { @@ -196,8 +206,14 @@ func (s *Server) ListenAndServe() error { return errors.New("radius: server already started") } + if s.ErrorHandler == nil { + s.ErrorHandler = ErrorFunc(defaultErrorHandler) + } + if s.Handler == nil { - return errors.New("radius: nil Handler") + err := errors.New("radius: nil Handler") + s.ErrorHandler.Error(err, nil) + return err } if s.CloseChan == nil { @@ -216,21 +232,20 @@ func (s *Server) ListenAndServe() error { addr, err := net.ResolveUDPAddr(network, addrStr) if err != nil { + s.ErrorHandler.Error(err, nil) return err } s.listener, err = net.ListenUDP(network, addr) if err != nil { + s.ErrorHandler.Error(err, nil) return err } if s.ClientsMap != nil { // double check, either IP or IPNet range - err = s.ResetClientNets() + err = s.ResetClientNetMap() if err != nil { - err = s.CheckClientsMap() - if err != nil { - return err - } + s.ErrorHandler.Error(err, nil) } } @@ -252,6 +267,7 @@ func (s *Server) ListenAndServe() error { buff := make([]byte, 4096) n, remoteAddr, err := s.listener.ReadFromUDP(buff) if err != nil && !err.(*net.OpError).Temporary() { + s.ErrorHandler.Error(err, nil) break } @@ -263,34 +279,37 @@ func (s *Server) ListenAndServe() error { go func(conn *net.UDPConn, buff []byte, remoteAddr *net.UDPAddr) { secret := s.Secret - log.Println("Remote IP: ",remoteAddr.IP) - legal := false + inClientIPMap := false + inClientNetMap := false - if s.ClientsMap[ fmt.Sprintf("%v",remoteAddr.IP)] != "" { - legal = true - secret = []byte( s.ClientsMap[ fmt.Sprintf("%v",remoteAddr.IP) ] ) + if s.ClientIPMap[string(remoteAddr.IP)] != "" { + secret = []byte( s.ClientIPMap[string(remoteAddr.IP)] ) + inClientIPMap = true } else { - log.Println( s.ClientsMap ,fmt.Sprintf("%v",remoteAddr.IP)) - defer conn.Close() + if s.ClientNetMap != nil { + for k, v := range s.ClientNetMap { + + _, subnet, err := net.ParseCIDR(k) + if err != nil { + s.ErrorHandler.Error(err, nil) + } + if subnet.Contains(remoteAddr.IP) { + secret = []byte(v) + inClientNetMap = true + break + } + } + } } - - if legal == false { - log.Println(remoteAddr.IP," not in clients map") + if !inClientIPMap && !inClientNetMap { + err := errors.New(fmt.Sprintf("%v", remoteAddr.IP) + " is not configured") + s.ErrorHandler.Error(err, nil) return } - if s.ClientNets != nil { - log.Println(remoteAddr.IP) - for k, v := range s.ClientNets { - if v.Contains(remoteAddr.IP) { - log.Println(remoteAddr.IP) - secret = []byte(s.ClientSecrets[k]) - } - } - } - packet, err := Parse(buff, secret, s.Dictionary) if err != nil { + s.ErrorHandler.Error(err, nil) return } @@ -302,6 +321,8 @@ func (s *Server) ListenAndServe() error { activeLock.Lock() if _, ok := active[key]; ok { activeLock.Unlock() + err = errors.New(remoteAddr.String() + " busy") + s.ErrorHandler.Error(err, nil) return } active[key] = true