diff --git a/src/main_test.go b/src/main_test.go index 34d4a74..ea27ce1 100644 --- a/src/main_test.go +++ b/src/main_test.go @@ -2,6 +2,9 @@ package main import ( "encoding/hex" + "fmt" + "os" + "runtime/trace" "testing" "github.com/google/gopacket" @@ -40,16 +43,15 @@ func DummySink() { select { case <-testResultChannel: cnt++ - if cnt%1000000 == 0 { - println(cnt) - } } } } -var encoderList []packetEncoder - -func benchmarkUdpPacketProcessingWorker(i uint, b *testing.B) { +func benchmarkUdpPacketProcessingWorker(workers uint, b *testing.B) { + fmt.Println("Benchmarking UDP Packet Processing Worker") + file, _ := os.Create("trace.out") + trace.Start(file) + defer trace.Stop() e := packetEncoder{ 53, testInputChannel, @@ -60,11 +62,10 @@ func benchmarkUdpPacketProcessingWorker(i uint, b *testing.B) { testTcpChannels, testReturnChannel, testResultChannel, - i, + workers, testDoneChannel, false, } - encoderList = append(encoderList, e) go e.run() testTcpChannels = append(testTcpChannels, make(chan tcpPacket, TCPAssemblyChannelSize)) go DummySink() @@ -74,9 +75,10 @@ func benchmarkUdpPacketProcessingWorker(i uint, b *testing.B) { } } -func BenchmarkUdpPacketProcessingWorker1(b *testing.B) { benchmarkUdpPacketProcessingWorker(1, b) } -func BenchmarkUdpPacketProcessingWorker2(b *testing.B) { benchmarkUdpPacketProcessingWorker(2, b) } -func BenchmarkUdpPacketProcessingWorker4(b *testing.B) { benchmarkUdpPacketProcessingWorker(4, b) } -func BenchmarkUdpPacketProcessingWorker6(b *testing.B) { benchmarkUdpPacketProcessingWorker(6, b) } -func BenchmarkUdpPacketProcessingWorker8(b *testing.B) { benchmarkUdpPacketProcessingWorker(8, b) } -func BenchmarkUdpPacketProcessingWorker16(b *testing.B) { benchmarkUdpPacketProcessingWorker(16, b) } +func BenchmarkUdpPacketProcessingWorker1(b *testing.B) { benchmarkUdpPacketProcessingWorker(1, b) } + +// func BenchmarkUdpPacketProcessingWorker2(b *testing.B) { benchmarkUdpPacketProcessingWorker(2, b) } +// func BenchmarkUdpPacketProcessingWorker4(b *testing.B) { benchmarkUdpPacketProcessingWorker(4, b) } +// func BenchmarkUdpPacketProcessingWorker6(b *testing.B) { benchmarkUdpPacketProcessingWorker(6, b) } +// func BenchmarkUdpPacketProcessingWorker8(b *testing.B) { benchmarkUdpPacketProcessingWorker(8, b) } +// func BenchmarkUdpPacketProcessingWorker16(b *testing.B) { benchmarkUdpPacketProcessingWorker(16, b) } diff --git a/src/packet.go b/src/packet.go index 73361e0..82e448a 100644 --- a/src/packet.go +++ b/src/packet.go @@ -9,6 +9,7 @@ import ( "github.com/google/gopacket/layers" mkdns "github.com/miekg/dns" "github.com/mosajjal/dnsmonster/types" + log "github.com/sirupsen/logrus" ) func (encoder *packetEncoder) processTransport(foundLayerTypes *[]gopacket.LayerType, udp *layers.UDP, tcp *layers.TCP, flow gopacket.Flow, timestamp time.Time, IPVersion uint8, SrcIP, DstIP net.IP) { @@ -62,48 +63,44 @@ func (encoder *packetEncoder) inputHandlerWorker(p chan gopacket.Packet) { } parser := gopacket.NewDecodingLayerParser(startLayer, decodeLayers...) foundLayerTypes := []gopacket.LayerType{} - for { - select { - case packet := <-p: - timestamp := packet.Metadata().Timestamp - if timestamp.IsZero() { - timestamp = time.Now() - } - _ = parser.DecodeLayers(packet.Data(), &foundLayerTypes) - // first parse the ip layer, so we can find fragmented packets - for _, layerType := range foundLayerTypes { - switch layerType { - case layers.LayerTypeIPv4: - // Check for fragmentation - if ip4.Flags&layers.IPv4DontFragment == 0 && (ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0) { - // Packet is fragmented, send it to the defragger - encoder.ip4Defrgger <- ipv4ToDefrag{ - ip4, - timestamp, - } - break + for packet := range p { + timestamp := packet.Metadata().Timestamp + if timestamp.IsZero() { + timestamp = time.Now() + } + _ = parser.DecodeLayers(packet.Data(), &foundLayerTypes) + // first parse the ip layer, so we can find fragmented packets + for _, layerType := range foundLayerTypes { + switch layerType { + case layers.LayerTypeIPv4: + // Check for fragmentation + if ip4.Flags&layers.IPv4DontFragment == 0 && (ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0) { + // Packet is fragmented, send it to the defragger + encoder.ip4Defrgger <- ipv4ToDefrag{ + ip4, + timestamp, } + } else { // log.Infof("packet %v coming to %p\n", timestamp, &encoder) encoder.processTransport(&foundLayerTypes, &udp, &tcp, ip4.NetworkFlow(), timestamp, 4, ip4.SrcIP, ip4.DstIP) - continue - case layers.LayerTypeIPv6: - // Store the packet metadata - if ip6.NextHeader == layers.IPProtocolIPv6Fragment { - // TODO: Move the parsing to DecodingLayer when gopacket support it - if frag := packet.Layer(layers.LayerTypeIPv6Fragment).(*layers.IPv6Fragment); frag != nil { - encoder.ip6Defrgger <- ipv6FragmentInfo{ - ip6, - *frag, - timestamp, - } + } + case layers.LayerTypeIPv6: + // Store the packet metadata + if ip6.NextHeader == layers.IPProtocolIPv6Fragment { + // TODO: Move the parsing to DecodingLayer when gopacket support it + if frag := packet.Layer(layers.LayerTypeIPv6Fragment).(*layers.IPv6Fragment); frag != nil { + encoder.ip6Defrgger <- ipv6FragmentInfo{ + ip6, + *frag, + timestamp, } - } else { - encoder.processTransport(&foundLayerTypes, &udp, &tcp, ip6.NetworkFlow(), timestamp, 6, ip6.SrcIP, ip6.DstIP) } + } else { + encoder.processTransport(&foundLayerTypes, &udp, &tcp, ip6.NetworkFlow(), timestamp, 6, ip6.SrcIP, ip6.DstIP) } } - break } + } } @@ -127,6 +124,7 @@ func (encoder *packetEncoder) run() { var handlerChanList []chan gopacket.Packet for i := 0; i < int(encoder.handlerCount); i++ { + log.Infof("Creating handler #%d\n", i) handlerChanList = append(handlerChanList, make(chan gopacket.Packet, 10000)) //todo: parameter for size of this channel needs to be defined go encoder.inputHandlerWorker(handlerChanList[i]) } @@ -167,7 +165,7 @@ func (encoder *packetEncoder) run() { } encoder.processTransport(&foundLayerTypes, &udp, &tcp, packet.ip.NetworkFlow(), packet.timestamp, 6, packet.ip.SrcIP, packet.ip.DstIP) case packet := <-encoder.input: - handlerChanList[packet.TransportLayer().TransportFlow().FastHash()%uint64(encoder.handlerCount)] <- packet + handlerChanList[packet.NetworkLayer().NetworkFlow().FastHash()%uint64(encoder.handlerCount)] <- packet case <-encoder.done: continue }