diff --git a/cmd/lookup.go b/cmd/lookup.go index 7735f39..3182d07 100644 --- a/cmd/lookup.go +++ b/cmd/lookup.go @@ -73,6 +73,30 @@ func lookupAction(out io.Writer, ouiCsvFile io.Reader, s string) error { // Lookup the vendor in the OUI database vendor := db.FindOuiByAssignment(assignment) + // Check if the --include flag is set + include := viper.GetString("lookup.include") + if include != "" && vendor != nil { + // If the --include flag is set, check if the vendor name + // contains the specified string (case insensitive) + if !vendor.Contains(include) { + // If the vendor name does not contain the specified string, + // skip to the next MAC address + continue + } + } + + // Check if the --exclude flag is set + exclude := viper.GetString("lookup.exclude") + if exclude != "" && vendor != nil { + // If the --exclude flag is set, check if the vendor name + // contains the specified string (case insensitive) + if vendor.Contains(exclude) { + // If the vendor name contains the specified string, + // skip to the next MAC address + continue + } + } + if vendor != nil { // Write in CSV format if the --csv flag is set if viper.GetBool("lookup.csv") { @@ -252,4 +276,12 @@ func init() { // Set to the value of the --csv flag if set lookupCmd.PersistentFlags().BoolP("csv", "c", false, "write output in CSV format") viper.BindPFlag("lookup.csv", lookupCmd.PersistentFlags().Lookup("csv")) + + // Set to the value of the --include flag if set + lookupCmd.Flags().StringP("include", "I", "", "output only results that include this string (case insensitive)") + viper.BindPFlag("lookup.include", lookupCmd.Flags().Lookup("include")) + + // Set to the value of the --exclude flag if set + lookupCmd.Flags().StringP("exclude", "E", "", "filter out results that contain this string (case insensitive)") + viper.BindPFlag("lookup.exclude", lookupCmd.Flags().Lookup("exclude")) } diff --git a/oui/oui.go b/oui/oui.go index af6ad96..572fa6c 100644 --- a/oui/oui.go +++ b/oui/oui.go @@ -51,6 +51,18 @@ type Oui struct { Address string // The organization street address } +// Contains returns true if the OUI entry contains the specified string +// in any of the OUI fields. The search is case-insensitive. +func (o *Oui) Contains(s string) bool { + // Search is case-insensitive so convert the search string to lowercase + s = strings.ToLower(s) + + // Check if the search string is contained in any of the OUI fields + return strings.Contains(strings.ToLower(o.Assignment), s) || + strings.Contains(strings.ToLower(o.Organization), s) || + strings.Contains(strings.ToLower(o.Address), s) +} + // The OUI database type OuiDb struct { // The OUI database diff --git a/oui/oui_test.go b/oui/oui_test.go index e89b213..e87c338 100644 --- a/oui/oui_test.go +++ b/oui/oui_test.go @@ -261,3 +261,66 @@ MA-L,222222,Texas Instruments,12500 TI Blvd Dallas TX US 75243` } } } + +// TestOuiContains tests the Contains function of the Oui type. +func TestOuiContains(t *testing.T) { + // Create a test CSV database + entry := oui.Oui{ + Assignment: "1A2B3C", + Organization: "Banana, Inc.", + Address: "1 Infinite Fruity Loop CA US 12014", + } + + // Setup test cases + testCases := []struct { + name string + input string + expected bool + }{ + { + name: "FullAssignment", + input: "1A2B3C", expected: true, + }, + { + name: "PartialAssignment", + input: "1A2B3", expected: true, + }, + { + name: "PartialAssignmentLowercase", + input: "1a2b3", expected: true, + }, + { + name: "FullOrganization", + input: "Banana, Inc.", expected: true, + }, + { + name: "PartialOrganizationLowercase", + input: "banana", expected: true, + }, + { + name: "FullAddress", + input: "1 Infinite Fruity Loop CA US 12014", expected: true, + }, + { + name: "PartialAddressUppercase", + input: "LOOP", expected: true, + }, + { + name: "PartialAddressLowercase", + input: "fruit", expected: true, + }, + { + name: "NotFound", + input: "111222", expected: false}, + } + + // Loop through the test cases + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + // Verify that the database was loaded correctly + if entry.Contains(testCase.input) != testCase.expected { + t.Errorf("expected %v, got %v", testCase.expected, entry.Contains(testCase.input)) + } + }) + } +}