Skip to content

Commit

Permalink
moved back to zerolog and pretty logging
Browse files Browse the repository at this point in the history
  • Loading branch information
mosajjal committed May 11, 2023
1 parent a75e2c4 commit 46de3a0
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 155 deletions.
13 changes: 6 additions & 7 deletions acl/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"sort"

"github.com/knadh/koanf"
"golang.org/x/exp/slog"
"github.com/rs/zerolog"
)

// Decision is the type of decision that an ACL can make for each connection info
Expand Down Expand Up @@ -45,11 +45,11 @@ type ACL interface {
Decide(*ConnInfo) error
Name() string
Priority() uint
ConfigAndStart(*slog.Logger, *koanf.Koanf) error
ConfigAndStart(zerolog.Logger, *koanf.Koanf) error
}

// StartACLs starts all the ACLs that have been configured and registered
func StartACLs(log *slog.Logger, k *koanf.Koanf) ([]ACL, error) {
func StartACLs(log zerolog.Logger, k *koanf.Koanf) ([]ACL, error) {
var a []ACL
aclK := k.Cut("acl")
for _, acl := range availableACLs {
Expand All @@ -58,16 +58,15 @@ func StartACLs(log *slog.Logger, k *koanf.Koanf) ([]ACL, error) {
if !aclK.Bool(fmt.Sprintf("%s.enabled", (acl).Name())) {
continue
}
l := slog.New(log.Handler().WithAttrs([]slog.Attr{{Key: "service", Value: slog.StringValue((acl).Name())}}))
var l = log.With().Str("acl", (acl).Name()).Logger()
// we pass the full config to each ACL so that they can cut it themselves. it's needed for some ACLs that need
// to read the config of other ACLs or the global config
if err := acl.ConfigAndStart(l, k); err != nil {
log.Warn("failed to start ACL", "name", (acl).Name(), "err", err)
log.Warn().Msgf("failed to start ACL %s with error %s", (acl).Name(), err)
return a, err
}
a = append(a, acl)
log.Info("started ACL", "name", (acl).Name())
fmt.Printf("%+v\n", a)
log.Info().Msgf("started ACL: '%s'", (acl).Name())

}
return a, nil
Expand Down
24 changes: 12 additions & 12 deletions acl/cidr.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"time"

"github.com/knadh/koanf"
"github.com/rs/zerolog"
"github.com/yl2chen/cidranger"
"golang.org/x/exp/slog"
)

// CIDR acl allows sniproxy to use a list of CIDR to allow or reject connections
Expand All @@ -23,18 +23,18 @@ type cidr struct {
RefreshInterval time.Duration `yaml:"refresh_interval"`
AllowRanger cidranger.Ranger
RejectRanger cidranger.Ranger
logger *slog.Logger
logger zerolog.Logger
priority uint
}

func (d *cidr) LoadCIDRCSV(path string) error {
d.AllowRanger = cidranger.NewPCTrieRanger()
d.RejectRanger = cidranger.NewPCTrieRanger()

d.logger.Info("Loading the CIDR from file/url")
d.logger.Info().Msg("Loading the CIDR from file/url")
var scanner *bufio.Scanner
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
d.logger.Info("CIDR list is a URL, trying to fetch")
d.logger.Info().Msg("CIDR list is a URL, trying to fetch")
client := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
r.URL.Opaque = r.URL.Path
Expand All @@ -43,10 +43,10 @@ func (d *cidr) LoadCIDRCSV(path string) error {
}
resp, err := client.Get(path)
if err != nil {
d.logger.Error(err.Error())
d.logger.Err(err)
return err
}
d.logger.Info("(re)fetching URL", "path", path)
d.logger.Info().Msgf("(re)fetching URL: ", path)
defer resp.Body.Close()
scanner = bufio.NewScanner(resp.Body)

Expand All @@ -55,7 +55,7 @@ func (d *cidr) LoadCIDRCSV(path string) error {
if err != nil {
return err
}
d.logger.Info("(re)loading File", "path", path)
d.logger.Info().Msgf("(re)loading file: ", path)
defer file.Close()
scanner = bufio.NewScanner(file)
}
Expand All @@ -65,7 +65,7 @@ func (d *cidr) LoadCIDRCSV(path string) error {
// cut the line at the first comma
cidr, policy, found := strings.Cut(row, ",")
if !found {
d.logger.Info(cidr + " is not a valid csv line, assuming reject")
d.logger.Info().Msg(cidr + " is not a valid csv line, assuming reject")
}
if policy == "allow" {
if _, netw, err := net.ParseCIDR(cidr); err == nil {
Expand All @@ -74,7 +74,7 @@ func (d *cidr) LoadCIDRCSV(path string) error {
if _, netw, err := net.ParseCIDR(cidr + "/32"); err == nil {
_ = d.AllowRanger.Insert(cidranger.NewBasicRangerEntry(*netw))
} else {
d.logger.Error(err.Error())
d.logger.Err(err)
}
}
} else {
Expand All @@ -84,12 +84,12 @@ func (d *cidr) LoadCIDRCSV(path string) error {
if _, netw, err := net.ParseCIDR(cidr + "/32"); err == nil {
_ = d.RejectRanger.Insert(cidranger.NewBasicRangerEntry(*netw))
} else {
d.logger.Error(err.Error())
d.logger.Err(err)
}
}
}
}
d.logger.Info("cidrs loaded", "len", d.AllowRanger.Len())
d.logger.Info().Msgf("%d cidr(s) loaded", d.AllowRanger.Len())

return nil
}
Expand Down Expand Up @@ -128,7 +128,7 @@ func (d cidr) Priority() uint {
}

// Config function is what starts the ACL
func (d *cidr) ConfigAndStart(logger *slog.Logger, c *koanf.Koanf) error {
func (d *cidr) ConfigAndStart(logger zerolog.Logger, c *koanf.Koanf) error {
c = c.Cut(fmt.Sprintf("acl.%s", d.Name()))
d.logger = logger
d.Path = c.String("path")
Expand Down
26 changes: 13 additions & 13 deletions acl/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

"github.com/golang-collections/collections/tst"
"github.com/knadh/koanf"
"golang.org/x/exp/slog"
"github.com/rs/zerolog"
)

// domain ACL makes a decision on a connection based on the domain name derived
Expand All @@ -22,7 +22,7 @@ type domain struct {
routePrefixes *tst.TernarySearchTree
routeSuffixes *tst.TernarySearchTree
routeFQDNs map[string]uint8
logger *slog.Logger
logger zerolog.Logger
priority uint
}

Expand Down Expand Up @@ -72,10 +72,10 @@ func reverse(s string) string {
// 2. a TST for all the suffixes (type 2)
// 3. a hashtable for all the full match fqdn (type 3)
func (d *domain) LoadDomainsCsv(Filename string) error {
d.logger.Info("Loading the domain from file/url")
d.logger.Info().Msg("Loading the domain from file/url")
var scanner *bufio.Scanner
if strings.HasPrefix(Filename, "http://") || strings.HasPrefix(Filename, "https://") {
d.logger.Info("domain list is a URL, trying to fetch")
d.logger.Info().Msg("domain list is a URL, trying to fetch")
client := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
r.URL.Opaque = r.URL.Path
Expand All @@ -84,10 +84,10 @@ func (d *domain) LoadDomainsCsv(Filename string) error {
}
resp, err := client.Get(Filename)
if err != nil {
d.logger.Error(err.Error())
d.logger.Err(err)
return err
}
d.logger.Info("(re)fetching URL", "url", Filename)
d.logger.Info().Msgf("(re)fetching URL: %s", Filename)
defer resp.Body.Close()
scanner = bufio.NewScanner(resp.Body)

Expand All @@ -96,7 +96,7 @@ func (d *domain) LoadDomainsCsv(Filename string) error {
if err != nil {
return err
}
d.logger.Info("(re)loading File", "file", Filename)
d.logger.Info().Msgf("(re)loading file: %s", Filename)
defer file.Close()
scanner = bufio.NewScanner(file)
}
Expand All @@ -105,7 +105,7 @@ func (d *domain) LoadDomainsCsv(Filename string) error {
// split the line by comma to understand thed.logger.c
fqdn := strings.Split(lowerCaseLine, ",")
if len(fqdn) != 2 {
d.logger.Info(lowerCaseLine + " is not a valid line, assuming FQDN")
d.logger.Info().Msg(lowerCaseLine + " is not a valid line, assuming FQDN")
fqdn = []string{lowerCaseLine, "fqdn"}
}
// add the fqdn to the hashtable with its type
Expand All @@ -121,11 +121,11 @@ func (d *domain) LoadDomainsCsv(Filename string) error {
d.routeFQDNs[fqdn[0]] = matchFQDN
default:
//d.logger.Warnf("%s is not a valid line, assuming fqdn", lowerCaseLine)
d.logger.Info(lowerCaseLine + " is not a valid line, assuming FQDN")
d.logger.Info().Msg(lowerCaseLine + " is not a valid line, assuming FQDN")
d.routeFQDNs[fqdn[0]] = matchFQDN
}
}
d.logger.Info(fmt.Sprintf("%s loaded with %d prefix, %d suffix and %d fqdn", Filename, d.routePrefixes.Len(), d.routeSuffixes.Len(), len(d.routeFQDNs)-d.routePrefixes.Len()-d.routeSuffixes.Len()))
d.logger.Info().Msgf("%s loaded with %d prefix, %d suffix and %d fqdn", Filename, d.routePrefixes.Len(), d.routeSuffixes.Len(), len(d.routeFQDNs)-d.routePrefixes.Len()-d.routeSuffixes.Len())

return nil
}
Expand All @@ -145,10 +145,10 @@ func (d domain) Decide(c *ConnInfo) error {
return nil
}
if d.inDomainList(c.Domain) {
d.logger.Debug("domain not going through proxy", "domain", c.Domain)
d.logger.Debug().Msgf("domain not going through proxy: %s", c.Domain)
c.Decision = OriginIP
} else {
d.logger.Debug("domain going through proxy", "domain", c.Domain)
d.logger.Debug().Msgf("domain going through proxy: %s", c.Domain)
c.Decision = ProxyIP
}
return nil
Expand All @@ -160,7 +160,7 @@ func (d domain) Priority() uint {
return d.priority
}

func (d *domain) ConfigAndStart(logger *slog.Logger, c *koanf.Koanf) error {
func (d *domain) ConfigAndStart(logger zerolog.Logger, c *koanf.Koanf) error {
c = c.Cut(fmt.Sprintf("acl.%s", d.Name()))
d.logger = logger
d.routePrefixes = tst.New()
Expand Down
28 changes: 14 additions & 14 deletions acl/geoip.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (

"github.com/knadh/koanf"
"github.com/oschwald/maxminddb-golang"
"github.com/rs/zerolog"
"golang.org/x/exp/slices"
"golang.org/x/exp/slog"
)

// geoIP is an ACL that checks the geolocation of incoming connections and
Expand All @@ -29,7 +29,7 @@ type geoIP struct {
BlockedCountries []string
Refresh time.Duration
mmdb *maxminddb.Reader
logger *slog.Logger
logger zerolog.Logger
priority uint
}

Expand Down Expand Up @@ -59,15 +59,15 @@ func (g geoIP) getCountry(ipAddr string) (string, error) {
// initializeGeoIP loads the geolocation database from the specified g.Path.
func (g *geoIP) initializeGeoIP() error {

g.logger.Info("Loading the domain from file/url")
g.logger.Info().Msg("Loading the domain from file/url")
var scanner []byte
if strings.HasPrefix(g.Path, "http://") || strings.HasPrefix(g.Path, "https://") {
g.logger.Info("domain list is a URL, trying to fetch")
g.logger.Info().Msg("domain list is a URL, trying to fetch")
resp, err := http.Get(g.Path)
if err != nil {
return err
}
g.logger.Info("(re)fetching", "Path", g.Path)
g.logger.Info().Msgf("(re)fetching %s", g.Path)
defer resp.Body.Close()
scanner, err = io.ReadAll(resp.Body)
if err != nil {
Expand All @@ -79,27 +79,27 @@ func (g *geoIP) initializeGeoIP() error {
if err != nil {
return err
}
g.logger.Info("(re)loading File: ", g.Path)
g.logger.Info().Msgf("(re)loading File: %s", g.Path)
defer file.Close()
n, err := file.Read(scanner)
if err != nil {
return err
}
g.logger.Info("geolocation database loaded", n)
g.logger.Info().Msgf("geolocation database with %d bytes loaded", n)

}
var err error
if g.mmdb, err = maxminddb.FromBytes(scanner); err != nil {
//g.logger.Warn("%d bytes read, %s", len(scanner), err)
return err
}
g.logger.Info("Loaded MMDB")
g.logger.Info().Msg("Loaded MMDB")
for range time.NewTicker(g.Refresh).C {
if g.mmdb, err = maxminddb.FromBytes(scanner); err != nil {
//g.logger.Warn("%d bytes read, %s", len(scanner), err)
return err
}
g.logger.Info("Loaded MMDB %v", g.mmdb)
g.logger.Info().Msgf("Loaded MMDB %v", g.mmdb)
}
return nil
}
Expand Down Expand Up @@ -127,10 +127,10 @@ func (g geoIP) checkGeoIPSkip(addr net.Addr) bool {
var country string
country, err := g.getCountry(ip)
country = strings.ToLower(country)
g.logger.Debug("incoming tcp connection", "ip", ip, "country", country)
g.logger.Debug().Msgf("incoming tcp connection from ip %s and country %s", ip, country)

if err != nil {
g.logger.Info("Failed to get the geolocation", "ip", ip, "country", country)
g.logger.Info().Msgf("Failed to get the geolocation of ip %s", ip)
return false
}
if slices.Contains(g.BlockedCountries, country) {
Expand All @@ -153,10 +153,10 @@ func (g geoIP) checkGeoIPSkip(addr net.Addr) bool {
func (g geoIP) Decide(c *ConnInfo) error {
// in checkGeoIPSkip, false is reject
if !g.checkGeoIPSkip(c.SrcIP) {
g.logger.Info("Rejecting connection from", "ip", c.SrcIP)
g.logger.Info().Msgf("rejecting connection from ip %s", c.SrcIP)
c.Decision = Reject
}
g.logger.Debug("GeoIP decision", "ip", c.SrcIP, "decision", c.Decision)
g.logger.Debug().Msgf("GeoIP decision for ip %s is %#v", c.SrcIP, c.Decision)
return nil
}
func (g geoIP) Name() string {
Expand All @@ -166,7 +166,7 @@ func (g geoIP) Priority() uint {
return g.priority
}

func (g *geoIP) ConfigAndStart(logger *slog.Logger, c *koanf.Koanf) error {
func (g *geoIP) ConfigAndStart(logger zerolog.Logger, c *koanf.Koanf) error {
c = c.Cut(fmt.Sprintf("acl.%s", g.Name()))
g.logger = logger
g.Path = c.String("path")
Expand Down
16 changes: 8 additions & 8 deletions acl/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/knadh/koanf"
doh "github.com/mosajjal/sniproxy/dohserver"
dohserver "github.com/mosajjal/sniproxy/dohserver"
"golang.org/x/exp/slog"
"github.com/rs/zerolog"
"inet.af/tcpproxy"
)

Expand All @@ -25,7 +25,7 @@ type override struct {
tcpproxyport int
tlsCert string
tlsKey string
logger *slog.Logger
logger zerolog.Logger
}

// GetFreePort returns a random open port
Expand All @@ -49,15 +49,15 @@ func (o *override) startProxy() {
var err error
o.tcpproxyport, err = GetFreePort()
if err != nil {
o.logger.Error("failed to get a free port for tcpproxy: %s", err)
o.logger.Error().Msgf("failed to get a free port for tcpproxy: %s", err)
return
}
for k, v := range o.rules {
o.logger.Info("adding overide rule", k, v)
o.logger.Info().Msgf("adding overide rule %s -> %s", k, v)
// TODO: create a regex matcher for SNIRoute
o.tcpproxy.AddSNIRoute(fmt.Sprintf("127.0.0.1:%d", o.tcpproxyport), k, tcpproxy.To(v))
}
o.logger.Info("starting tcpproxy", "port", o.tcpproxyport)
o.logger.Info().Msgf("starting tcpproxy on port %d", o.tcpproxyport)
o.tcpproxy.Run()
}

Expand All @@ -81,7 +81,7 @@ func (o override) Priority() uint {
return o.priority
}

func (o *override) ConfigAndStart(logger *slog.Logger, c *koanf.Koanf) error {
func (o *override) ConfigAndStart(logger zerolog.Logger, c *koanf.Koanf) error {
DNSBind := c.String("general.bind_dns_over_udp")
c = c.Cut(fmt.Sprintf("acl.%s", o.Name()))
tmpRules := c.StringMap("rules")
Expand All @@ -100,9 +100,9 @@ func (o *override) ConfigAndStart(logger *slog.Logger, c *koanf.Koanf) error {
dohConfig.Listen = []string{fmt.Sprintf("127.0.0.1:%d", o.dohPort)}
if o.tlsCert == "" || o.tlsKey == "" {
_, _, err := doh.GenerateSelfSignedCertKey(dohSNI, nil, nil, os.TempDir())
o.logger.Info("certificate was not provided, generating a self signed cert in temp directory")
o.logger.Info().Msg("certificate was not provided, generating a self signed cert in temp directory")
if err != nil {
o.logger.Error("error while generating self-signed cert: ", "error", err)
o.logger.Error().Msgf("error while generating self-signed cert: %s", err)
}
o.tlsCert = filepath.Join(os.TempDir(), dohSNI+".crt")
o.tlsKey = filepath.Join(os.TempDir(), dohSNI+".key")
Expand Down
Loading

0 comments on commit 46de3a0

Please sign in to comment.