Skip to content

Commit

Permalink
Merge pull request #12 from nberlee/fix-concurrency
Browse files Browse the repository at this point in the history
fix: potentential conurrency issue when running as a module
  • Loading branch information
nberlee authored Apr 8, 2023
2 parents a71ff7c + 1e6ac37 commit 411373f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ jobs:
run: go build -v ./...

- name: Test
run: go test -v ./...
run: go test -race -v ./...
6 changes: 4 additions & 2 deletions netstat/export_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package netstat

var pd = newProcessData()

var (
// Exported for testing
ParseAddr = parseAddr
ParseIPv4 = parseIPv4
ParseIPv6 = parseIPv6
ParseSockTab = parseSockTab
OpenFileStream = openFileStream
ParseSockTab = pd.parseSockTab
OpenFileStream = pd.openFileStream
)
40 changes: 25 additions & 15 deletions netstat/netstat_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,17 @@ const (
Closing
)

var fdProcess = make(map[uint64]*common.Process)
var pidNetNS map[uint32]string
type processData struct {
fdProcess map[uint64]*common.Process
pidNetNS map[uint32]string
}

func newProcessData() *processData {
return &processData{
fdProcess: make(map[uint64]*common.Process),
pidNetNS: make(map[uint32]string),
}
}

var skStates = [...]string{
"UNKNOWN",
Expand Down Expand Up @@ -137,7 +146,7 @@ func parseAddr(s string) (*SockEndpoint, error) {
return &SockEndpoint{IP: ip, Port: uint16(v)}, nil
}

func parseSockTab(reader io.Reader, accept AcceptFn, transport string, podPid uint32) ([]SockTabEntry, error) {
func (pd *processData) parseSockTab(reader io.Reader, accept AcceptFn, transport string, podPid uint32) ([]SockTabEntry, error) {
scanner := bufio.NewScanner(reader)
scanner.Scan()

Expand Down Expand Up @@ -177,13 +186,13 @@ func parseSockTab(reader io.Reader, accept AcceptFn, transport string, podPid ui
}
entry.Transport = transport
if podPid != 0 {
if netNsName, ok := pidNetNS[podPid]; ok {
if netNsName, ok := pd.pidNetNS[podPid]; ok {
entry.NetNS = netNsName
} else {
entry.NetNS = strconv.Itoa(int(podPid))
}
}
entry.Process = fdProcess[entry.Inode]
entry.Process = pd.fdProcess[entry.Inode]
if accept(&entry) {
sockTab = append(sockTab, entry)
}
Expand All @@ -198,15 +207,17 @@ func parseSockTab(reader io.Reader, accept AcceptFn, transport string, podPid ui
func Netstat(ctx context.Context, feature EnableFeatures, fn AcceptFn) ([]SockTabEntry, error) {
var err error

pids, err := mergePids(feature)
pd := newProcessData()

pids, err := pd.mergePids(feature)
if err != nil {
return nil, err
}

files := procFiles(feature, pids)

if feature.PID {
fdProcess, err = processes.GetProcessFDs(ctx)
pd.fdProcess, err = processes.GetProcessFDs(ctx)
if err != nil {
return nil, err
}
Expand All @@ -226,7 +237,7 @@ func Netstat(ctx context.Context, feature EnableFeatures, fn AcceptFn) ([]SockTa
chs[i] <- []SockTabEntry{}
return
default:
tabs, err := openFileStream(file, fn)
tabs, err := pd.openFileStream(file, fn)
if err != nil {
// Send an empty slice if there was an error.
chs[i] <- []SockTabEntry{}
Expand All @@ -252,7 +263,7 @@ func Netstat(ctx context.Context, feature EnableFeatures, fn AcceptFn) ([]SockTa
return combinedTabs, nil
}

func mergePids(feature EnableFeatures) (pids []string, err error) {
func (pd *processData) mergePids(feature EnableFeatures) (pids []string, err error) {
if feature.AllNetNs {
netNsName, err := netns.GetNetNSNames()
if err != nil {
Expand All @@ -267,17 +278,16 @@ func mergePids(feature EnableFeatures) (pids []string, err error) {
feature.NetNsPids = []uint32{}
}

pidNetNS = map[uint32]string{}
if len(feature.NetNsName) > 0 {
pidNetNS = *netns.GetNetNsPids(feature.NetNsName)
pd.pidNetNS = *netns.GetNetNsPids(feature.NetNsName)
}

hostNetNsIndex := 0
if !feature.NoHostNetwork {
hostNetNsIndex = 1
}

lengthPids := len(pidNetNS) + len(feature.NetNsPids) + hostNetNsIndex
lengthPids := len(pd.pidNetNS) + len(feature.NetNsPids) + hostNetNsIndex
pids = make([]string, lengthPids)

if !feature.NoHostNetwork {
Expand All @@ -286,7 +296,7 @@ func mergePids(feature EnableFeatures) (pids []string, err error) {

netNsNameIndex := 0

for pid := range pidNetNS {
for pid := range pd.pidNetNS {
pids[netNsNameIndex+hostNetNsIndex] = strconv.Itoa(int(pid))
netNsNameIndex++
}
Expand Down Expand Up @@ -329,7 +339,7 @@ func procFiles(feature EnableFeatures, pids []string) (files []string) {
return files
}

func openFileStream(file string, fn AcceptFn) ([]SockTabEntry, error) {
func (pd *processData) openFileStream(file string, fn AcceptFn) ([]SockTabEntry, error) {
f, err := os.Open(file)
if err != nil {
return nil, err
Expand All @@ -338,7 +348,7 @@ func openFileStream(file string, fn AcceptFn) ([]SockTabEntry, error) {
_, transport := path.Split(file)
podPid, _ := strconv.ParseUint(strings.Split(file, "/")[2], 10, 32)

tabs, err := parseSockTab(f, fn, transport, uint32(podPid))
tabs, err := pd.parseSockTab(f, fn, transport, uint32(podPid))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 411373f

Please sign in to comment.