Skip to content

Commit

Permalink
- fixed client map / net mess (you can now use CIDR notation in addit…
Browse files Browse the repository at this point in the history
…ion to ip addresses, will attempt to match ip address first and then subnets. Will reject if not found, will not use default secret)

- added error handler
  • Loading branch information
foodforarabbit committed Jun 12, 2017
1 parent 2d9a04f commit 8f9afd6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 54 deletions.
7 changes: 6 additions & 1 deletion cmd/radserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ var secret = flag.String("secret", "", "shared RADIUS secret between clients and
var command string
var arguments []string

func error_handler(e error, p *radius.Packet) {
log.Printf("radserver err: %v", e)
}

func handler(w radius.ResponseWriter, p *radius.Packet) {

username, password, ok := p.PAP()
Expand Down Expand Up @@ -140,10 +144,11 @@ func main() {
}
server := radius.Server{
Handler: radius.HandlerFunc(handler),
ErrorHandler: radius.ErrorFunc(error_handler),
Secret: []byte(*secret),
Dictionary: radius.Builtin,
ClientsMap: map[string]string{
"127.0.0.1": "abc",
"127.0.0.51/24": "abc",
},
}

Expand Down
127 changes: 74 additions & 53 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,26 @@ type Handler interface {
ServeRadius(w ResponseWriter, p *Packet)
}

type ErrorHandler interface {
Error(e error, p *Packet)
}

// HandlerFunc is a wrapper that allows ordinary functions to be used as a
// handler.
type HandlerFunc func(w ResponseWriter, p *Packet)

type ErrorFunc func(e error, p *Packet)

// ServeRadius calls h(w, p).
func (h HandlerFunc) ServeRadius(w ResponseWriter, p *Packet) {
h(w, p)
}

// Handle any errors
func (h ErrorFunc) Error(e error, p *Packet) {
h(e, p)
}


// ResponseWriter is used by Handler when replying to a RADIUS packet.
type ResponseWriter interface {
Expand Down Expand Up @@ -129,75 +140,80 @@ type Server struct {
// Client->Secret mapping
ClientsMap map[string]string
clientIP []string
ClientNets []net.IPNet
ClientSecrets []string
ClientIPMap map[string]string
ClientNetMap map[string]string

// Dictionary used when decoding incoming packets.
Dictionary *Dictionary

// The packet handler that handles incoming, valid packets.
Handler Handler

// Error handler for any errors outside the handler
ErrorHandler ErrorHandler

// Listener
listener *net.UDPConn

// quit channel
CloseChan chan bool
}

func (s *Server) ResetClientNets() error {
func (s *Server) ResetClientNetMap() error {

s.ClientNets = nil
s.ClientSecrets = nil
s.ClientNetMap = make(map[string]string, 0)
s.ClientIPMap = make(map[string]string, 0)
ipParseErrors := make([]string, 0)

// return errors.New("Unable to parse CIDR or IP " + k)
if s.ClientsMap != nil {
for k, v := range s.ClientsMap {

_, subnet, err := net.ParseCIDR(k)
if err != nil {
return errors.New("Unable to parse CIDR or IP " + k)
ip, subnet, err := net.ParseCIDR(k)
if err == nil {
s.ClientNetMap[subnet.String()] = v
} else {
ip = net.ParseIP(k)
if ip != nil {
s.ClientIPMap[string(ip)] = v
} else {
ipParseErrors = append(ipParseErrors, k)
}
}

s.ClientNets = append(s.ClientNets, *subnet)
s.ClientSecrets = append(s.ClientSecrets, v)
}
}

return nil
}


func (s *Server) CheckClientsMap() error {

if s.ClientsMap != nil {
for k, _ := range s.ClientsMap {
ip := net.ParseIP(k)
if ip == nil {
return errors.New("Not legal Clients IP address")
}
}
if len(ipParseErrors) > 0{
return errors.New("Unable to parse CIDR or IP " + fmt.Sprintf("%v", ipParseErrors))
}

return nil
}

func (s *Server) AddClientsMap(m map[string]string ) {
if s.ClientsMap == nil && len(m) > 0 {
s.ClientsMap = m
// s.ResetClientNets()
s.ResetClientNetMap()
}

}

func defaultErrorHandler(e error, p *Packet) {
log.Printf("Radius Server Error %v", e)
}

// ListenAndServe starts a RADIUS server on the address given in s.
func (s *Server) ListenAndServe() error {
if s.listener != nil {
return errors.New("radius: server already started")
}

if s.ErrorHandler == nil {
s.ErrorHandler = ErrorFunc(defaultErrorHandler)
}

if s.Handler == nil {
return errors.New("radius: nil Handler")
err := errors.New("radius: nil Handler")
s.ErrorHandler.Error(err, nil)
return err
}

if s.CloseChan == nil {
Expand All @@ -216,21 +232,20 @@ func (s *Server) ListenAndServe() error {

addr, err := net.ResolveUDPAddr(network, addrStr)
if err != nil {
s.ErrorHandler.Error(err, nil)
return err
}
s.listener, err = net.ListenUDP(network, addr)
if err != nil {
s.ErrorHandler.Error(err, nil)
return err
}

if s.ClientsMap != nil {
// double check, either IP or IPNet range
err = s.ResetClientNets()
err = s.ResetClientNetMap()
if err != nil {
err = s.CheckClientsMap()
if err != nil {
return err
}
s.ErrorHandler.Error(err, nil)
}
}

Expand All @@ -252,6 +267,7 @@ func (s *Server) ListenAndServe() error {
buff := make([]byte, 4096)
n, remoteAddr, err := s.listener.ReadFromUDP(buff)
if err != nil && !err.(*net.OpError).Temporary() {
s.ErrorHandler.Error(err, nil)
break
}

Expand All @@ -263,34 +279,37 @@ func (s *Server) ListenAndServe() error {
go func(conn *net.UDPConn, buff []byte, remoteAddr *net.UDPAddr) {
secret := s.Secret

log.Println("Remote IP: ",remoteAddr.IP)
legal := false
inClientIPMap := false
inClientNetMap := false

if s.ClientsMap[ fmt.Sprintf("%v",remoteAddr.IP)] != "" {
legal = true
secret = []byte( s.ClientsMap[ fmt.Sprintf("%v",remoteAddr.IP) ] )
if s.ClientIPMap[string(remoteAddr.IP)] != "" {
secret = []byte( s.ClientIPMap[string(remoteAddr.IP)] )
inClientIPMap = true
} else {
log.Println( s.ClientsMap ,fmt.Sprintf("%v",remoteAddr.IP))
defer conn.Close()
if s.ClientNetMap != nil {
for k, v := range s.ClientNetMap {

_, subnet, err := net.ParseCIDR(k)
if err != nil {
s.ErrorHandler.Error(err, nil)
}
if subnet.Contains(remoteAddr.IP) {
secret = []byte(v)
inClientNetMap = true
break
}
}
}
}

if legal == false {
log.Println(remoteAddr.IP," not in clients map")
if !inClientIPMap && !inClientNetMap {
err := errors.New(fmt.Sprintf("%v", remoteAddr.IP) + " is not configured")
s.ErrorHandler.Error(err, nil)
return
}

if s.ClientNets != nil {
log.Println(remoteAddr.IP)
for k, v := range s.ClientNets {
if v.Contains(remoteAddr.IP) {
log.Println(remoteAddr.IP)
secret = []byte(s.ClientSecrets[k])
}
}
}

packet, err := Parse(buff, secret, s.Dictionary)
if err != nil {
s.ErrorHandler.Error(err, nil)
return
}

Expand All @@ -302,6 +321,8 @@ func (s *Server) ListenAndServe() error {
activeLock.Lock()
if _, ok := active[key]; ok {
activeLock.Unlock()
err = errors.New(remoteAddr.String() + " busy")
s.ErrorHandler.Error(err, nil)
return
}
active[key] = true
Expand Down

0 comments on commit 8f9afd6

Please sign in to comment.