diff --git a/bulkvalidate/bulkvalidate.go b/bulkvalidate/bulkvalidate.go deleted file mode 100644 index eea7e64..0000000 --- a/bulkvalidate/bulkvalidate.go +++ /dev/null @@ -1,239 +0,0 @@ -package bulkvalidate - -import ( - "bufio" - "encoding/csv" - "encoding/json" - "fmt" - "io" - "log" - "os" - "strconv" - "strings" - - "github.com/lucasepe/codename" - "github.com/schollz/progressbar/v3" - - "github.com/customeros/mailsherpa/internal/run" - "github.com/customeros/mailsherpa/internal/syntax" - "github.com/customeros/mailsherpa/mailvalidate" -) - -const ( - batchSize = 10 - checkpointFile = "validation_checkpoint.json" -) - -type Checkpoint struct { - ProcessedRows int `json:"processedRows"` -} - -func RunBulkValidation(inputFilePath, outputFilePath string) error { - checkpoint, err := loadCheckpoint() - if err != nil { - return fmt.Errorf("error loading checkpoint: %w", err) - } - - reader, file, err := read_csv(inputFilePath) - if err != nil { - return fmt.Errorf("error reading input file: %w", err) - } - defer file.Close() - - // Read and store the header - header, err := reader.Read() - if err != nil { - return fmt.Errorf("error reading header: %w", err) - } - - // Skip to the last processed row - for i := 0; i < checkpoint.ProcessedRows; i++ { - _, err := reader.Read() - if err != nil { - if err == io.EOF { - break - } - return fmt.Errorf("error skipping to checkpoint: %w", err) - } - } - - catchAllResults := make(map[string]bool) - bar := progressbar.Default(-1) - - outputFileExists := fileExists(outputFilePath) - - for { - batch, err := readBatch(reader, batchSize) - if err != nil { - return fmt.Errorf("error reading batch: %w", err) - } - if len(batch) == 0 { - break - } - - results := processBatch(batch, catchAllResults) - - err = writeResultsFile(results, outputFilePath, outputFileExists, header) - if err != nil { - return fmt.Errorf("error writing results: %w", err) - } - - checkpoint.ProcessedRows += len(batch) - err = saveCheckpoint(checkpoint) - if err != nil { - return fmt.Errorf("error saving checkpoint: %w", err) - } - - bar.Add(len(batch)) - outputFileExists = true - } - - return nil -} - -func read_csv(filePath string) (*csv.Reader, *os.File, error) { - file, err := os.Open(filePath) - if err != nil { - return nil, nil, fmt.Errorf("error opening file: %w", err) - } - - reader := csv.NewReader(bufio.NewReader(file)) - return reader, file, nil -} - -func writeResultsFile(results []run.VerifyEmailResponse, filePath string, append bool, header []string) error { - flag := os.O_CREATE | os.O_WRONLY - if append { - flag |= os.O_APPEND - } else { - flag |= os.O_TRUNC - } - - file, err := os.OpenFile(filePath, flag, 0644) - if err != nil { - return fmt.Errorf("error opening file: %w", err) - } - defer file.Close() - - writer := csv.NewWriter(file) - defer writer.Flush() - - if !append { - if err := writer.Write(header); err != nil { - return fmt.Errorf("error writing header: %w", err) - } - } - - for _, resp := range results { - row := []string{ - resp.Email, resp.Syntax.User, resp.Syntax.Domain, - strconv.FormatBool(resp.Syntax.IsValid), - resp.Deliverable, - resp.Provider, resp.SecureGatewayProvider, - strconv.FormatBool(resp.IsRisky), - strconv.FormatBool(resp.Risk.IsFirewalled), - strconv.FormatBool(resp.Risk.IsFreeAccount), - strconv.FormatBool(resp.Risk.IsRoleAccount), - strconv.FormatBool(resp.Risk.IsMailboxFull), - strconv.FormatBool(resp.IsCatchAll), - resp.Smtp.ResponseCode, resp.Smtp.ErrorCode, resp.Smtp.Description, - } - if err := writer.Write(row); err != nil { - return fmt.Errorf("error writing row: %w", err) - } - } - - return nil -} - -func readBatch(reader *csv.Reader, batchSize int) ([]string, error) { - var batch []string - for i := 0; i < batchSize; i++ { - record, err := reader.Read() - if err != nil { - if err == io.EOF { - // End of file reached - return batch, nil - } - if err == csv.ErrFieldCount { - // Skip records with incorrect field count - continue - } - // Return any other error - return batch, err - } - if len(record) > 0 { - batch = append(batch, record[0]) - } - } - return batch, nil -} - -func processBatch(batch []string, catchAllResults map[string]bool) []run.VerifyEmailResponse { - var results []run.VerifyEmailResponse - - for _, email := range batch { - request := run.BuildRequest(email) - _, domain, _ := syntax.GetEmailUserAndDomain(email) - validateCatchAll := false - if _, exists := catchAllResults[domain]; !exists { - validateCatchAll = true - } - syntaxResults := mailvalidate.ValidateEmailSyntax(email) - domainResults := mailvalidate.ValidateDomain(request) - if domainResults.Error != "" { - log.Println(domainResults.Error) - } - emailResults := mailvalidate.ValidateEmail(request) - if emailResults.Error != "" { - log.Println(domainResults.Error) - } - isCatchAll := domainResults.IsCatchAll - if validateCatchAll { - catchAllResults[domain] = isCatchAll - } - result := run.BuildResponse(email, syntaxResults, domainResults, emailResults) - results = append(results, result) - } - - return results -} - -func loadCheckpoint() (Checkpoint, error) { - var checkpoint Checkpoint - file, err := os.Open(checkpointFile) - if os.IsNotExist(err) { - return checkpoint, nil - } - if err != nil { - return checkpoint, err - } - defer file.Close() - - err = json.NewDecoder(file).Decode(&checkpoint) - return checkpoint, err -} - -func saveCheckpoint(checkpoint Checkpoint) error { - file, err := os.Create(checkpointFile) - if err != nil { - return err - } - defer file.Close() - - return json.NewEncoder(file).Encode(checkpoint) -} - -func generateCatchAllUsername() string { - rng, err := codename.DefaultRNG() - if err != nil { - panic(err) - } - name := codename.Generate(rng, 0) - return strings.ReplaceAll(name, "-", "") -} - -func fileExists(filename string) bool { - _, err := os.Stat(filename) - return !os.IsNotExist(err) -} diff --git a/domaincheck/domain.go b/domaincheck/domain.go new file mode 100644 index 0000000..45d33d5 --- /dev/null +++ b/domaincheck/domain.go @@ -0,0 +1,181 @@ +package domaincheck + +import ( + "fmt" + "net" + "net/http" + "net/url" + "sort" + "strings" + "time" +) + +type DNS struct { + MX []string + SPF string + CNAME string + HasA bool + Errors []string +} + +func PrimaryDomainCheck(domain string) (bool, string) { + dns := CheckDNS(domain) + redirects, primaryDomain := CheckRedirects(domain) + + if !redirects && dns.CNAME == "" && len(dns.MX) > 0 && dns.HasA { + return true, "" + } + return false, primaryDomain +} + +func CheckDNS(domain string) DNS { + var dns DNS + var mxErr, spfErr error + + dns.HasA = hasAorAAAARecord(domain) + + dns.MX, mxErr = getMXRecordsForDomain(domain) + dns.SPF, spfErr = getSPFRecord(domain) + if mxErr != nil { + dns.Errors = append(dns.Errors, mxErr.Error()) + } + if spfErr != nil { + dns.Errors = append(dns.Errors, spfErr.Error()) + } + + exists, cname := getCNAMERecord(domain) + if exists { + dns.CNAME = cname + } + return dns +} + +func CheckRedirects(domain string) (bool, string) { + // Check for HTTP/HTTPS redirects + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Timeout: 10 * time.Second, + } + + for _, protocol := range []string{"http", "https"} { + url := fmt.Sprintf("%s://%s", protocol, domain) + resp, err := client.Get(url) + if err != nil { + continue + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 && resp.StatusCode < 400 { + location := resp.Header.Get("Location") + if location != "" { + location = extractDomain(location) + if location != domain { + return true, location + } + } + } + } + + return false, "" +} + +func extractDomain(urlStr string) string { + u, err := url.Parse(urlStr) + if err != nil { + return urlStr // Return as-is if parsing fails + } + + // Remove 'www.' prefix if present + domain := strings.TrimPrefix(u.Hostname(), "www.") + + // Split the domain and get the last two parts (or just one if it's a TLD) + parts := strings.Split(domain, ".") + if len(parts) > 2 { + return strings.Join(parts[len(parts)-2:], ".") + } + return domain +} + +func getMXRecordsForDomain(domain string) ([]string, error) { + mxRecords, err := getRawMXRecords(domain) + if err != nil { + return nil, err + } + + // Sort MX records by priority (lower number = higher priority) + sort.Slice(mxRecords, func(i, j int) bool { + return mxRecords[i].Pref < mxRecords[j].Pref + }) + + stripDot := func(s string) string { + return strings.ToLower(strings.TrimSuffix(s, ".")) + } + + // Extract hostnames into a string array + result := make([]string, len(mxRecords)) + for i, mx := range mxRecords { + result[i] = stripDot(mx.Host) + } + + return result, nil +} + +func getRawMXRecords(domain string) ([]*net.MX, error) { + mxRecords, err := net.LookupMX(domain) + if err != nil { + return nil, err + } + + return mxRecords, nil +} + +func getSPFRecord(domain string) (string, error) { + records, err := net.LookupTXT(domain) + if err != nil { + return "", fmt.Errorf("error looking up TXT records: %w", err) + } + for _, record := range records { + spfRecord := parseTXTRecord(record) + if strings.HasPrefix(spfRecord, "v=spf1") { + return spfRecord, nil + } + } + return "", fmt.Errorf("no SPF record found for domain %s", domain) +} + +func getCNAMERecord(domain string) (bool, string) { + cname, err := net.LookupCNAME(domain) + if err != nil { + return false, "" + } + + // Remove the trailing dot from the CNAME if present + cname = strings.TrimSuffix(cname, ".") + + // Check if the CNAME is different from the input domain + if cname != domain && cname != domain+"." { + return true, cname + } + + return false, "" +} + +func hasAorAAAARecord(domain string) bool { + ips, err := net.LookupIP(domain) + if err != nil { + return false + } + return len(ips) > 0 +} + +func parseTXTRecord(record string) string { + // Remove surrounding quotes if present + record = strings.Trim(record, "\"") + + // Replace multiple spaces with a single space + record = strings.Join(strings.Fields(record), " ") + + return record +} diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index fa7fe40..a36b8d8 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" - "github.com/customeros/mailsherpa/bulkvalidate" "github.com/customeros/mailsherpa/internal/run" "github.com/customeros/mailsherpa/mailvalidate" ) @@ -15,16 +14,11 @@ func PrintUsage() { fmt.Println("Usage: mailsherpa [arguments]") fmt.Println("Commands:") fmt.Println(" ") - fmt.Println(" bulk ") fmt.Println(" domain ") fmt.Println(" syntax ") fmt.Println(" version") } -func BulkVerify(inputFilePath, outputFilePath string) error { - return bulkvalidate.RunBulkValidation(inputFilePath, outputFilePath) -} - func VerifyDomain(domain string, printResults bool) mailvalidate.DomainValidation { request := run.BuildRequest(fmt.Sprintf("user@%s", domain)) domainResults := mailvalidate.ValidateDomain(request) diff --git a/internal/dns/authorizedSenders.go b/internal/dns/authorizedSenders.go index 5307dea..84b671c 100644 --- a/internal/dns/authorizedSenders.go +++ b/internal/dns/authorizedSenders.go @@ -3,6 +3,8 @@ package dns import ( "regexp" "strings" + + "github.com/customeros/mailsherpa/domaincheck" ) type AuthorizedSenders struct { @@ -13,7 +15,7 @@ type AuthorizedSenders struct { Other []string } -func GetAuthorizedSenders(dns DNS, knownProviders *KnownProviders) AuthorizedSenders { +func GetAuthorizedSenders(dns domaincheck.DNS, knownProviders *KnownProviders) AuthorizedSenders { if dns.SPF == "" { return AuthorizedSenders{} } diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 921ab2c..790e79d 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -2,129 +2,22 @@ package dns import ( "fmt" - "net" - "sort" - "strings" + "github.com/customeros/mailsherpa/domaincheck" "github.com/customeros/mailsherpa/internal/syntax" ) -type DNS struct { - MX []string - SPF string - CNAME string - HasA bool - Errors []string -} - -func GetDNS(email string) DNS { - var dns DNS - var mxErr error - var spfErr error +func GetDNS(email string) domaincheck.DNS { + var dns domaincheck.DNS _, domain, ok := syntax.GetEmailUserAndDomain(email) if !ok { - mxErr = fmt.Errorf("No MX Records: Invalid email address") + mxErr := fmt.Errorf("No MX Records: Invalid email address") dns.Errors = append(dns.Errors, mxErr.Error()) return dns } - dns.HasA = hasAorAAAARecord(domain) - - dns.MX, mxErr = getMXRecordsForDomain(domain) - dns.SPF, spfErr = getSPFRecord(domain) - if mxErr != nil { - dns.Errors = append(dns.Errors, mxErr.Error()) - } - if spfErr != nil { - dns.Errors = append(dns.Errors, spfErr.Error()) - } + dns = domaincheck.CheckDNS(domain) - exists, cname := getCNAMERecord(domain) - if exists { - dns.CNAME = cname - } return dns } - -func getMXRecordsForDomain(domain string) ([]string, error) { - mxRecords, err := getRawMXRecords(domain) - if err != nil { - return nil, err - } - - // Sort MX records by priority (lower number = higher priority) - sort.Slice(mxRecords, func(i, j int) bool { - return mxRecords[i].Pref < mxRecords[j].Pref - }) - - stripDot := func(s string) string { - return strings.ToLower(strings.TrimSuffix(s, ".")) - } - - // Extract hostnames into a string array - result := make([]string, len(mxRecords)) - for i, mx := range mxRecords { - result[i] = stripDot(mx.Host) - } - - return result, nil -} - -func getRawMXRecords(domain string) ([]*net.MX, error) { - mxRecords, err := net.LookupMX(domain) - if err != nil { - return nil, err - } - - return mxRecords, nil -} - -func getSPFRecord(domain string) (string, error) { - records, err := net.LookupTXT(domain) - if err != nil { - return "", fmt.Errorf("error looking up TXT records: %w", err) - } - for _, record := range records { - spfRecord := parseTXTRecord(record) - if strings.HasPrefix(spfRecord, "v=spf1") { - return spfRecord, nil - } - } - return "", fmt.Errorf("no SPF record found for domain %s", domain) -} - -func parseTXTRecord(record string) string { - // Remove surrounding quotes if present - record = strings.Trim(record, "\"") - - // Replace multiple spaces with a single space - record = strings.Join(strings.Fields(record), " ") - - return record -} - -func getCNAMERecord(domain string) (bool, string) { - cname, err := net.LookupCNAME(domain) - if err != nil { - return false, "" - } - - // Remove the trailing dot from the CNAME if present - cname = strings.TrimSuffix(cname, ".") - - // Check if the CNAME is different from the input domain - if cname != domain && cname != domain+"." { - return true, cname - } - - return false, "" -} - -func hasAorAAAARecord(domain string) bool { - ips, err := net.LookupIP(domain) - if err != nil { - return false - } - return len(ips) > 0 -} diff --git a/internal/dns/emailProvider.go b/internal/dns/emailProvider.go index 1b37df6..046ba08 100644 --- a/internal/dns/emailProvider.go +++ b/internal/dns/emailProvider.go @@ -1,6 +1,8 @@ package dns -func GetEmailProviderFromMx(dns DNS, knownProviders KnownProviders) (emailProvider, firewall string) { +import "github.com/customeros/mailsherpa/domaincheck" + +func GetEmailProviderFromMx(dns domaincheck.DNS, knownProviders KnownProviders) (emailProvider, firewall string) { if len(dns.MX) == 0 { return "", "" } diff --git a/internal/mailserver/mailserver.go b/internal/mailserver/mailserver.go index c490086..def046d 100644 --- a/internal/mailserver/mailserver.go +++ b/internal/mailserver/mailserver.go @@ -11,7 +11,7 @@ import ( "github.com/pkg/errors" - "github.com/customeros/mailsherpa/internal/dns" + "github.com/customeros/mailsherpa/domaincheck" ) type SMPTValidation struct { @@ -23,7 +23,7 @@ type SMPTValidation struct { SmtpResponse string } -func VerifyEmailAddress(email, fromDomain, fromEmail string, dnsRecords dns.DNS) SMPTValidation { +func VerifyEmailAddress(email, fromDomain, fromEmail string, dnsRecords domaincheck.DNS) SMPTValidation { results := SMPTValidation{} if len(dnsRecords.MX) == 0 { diff --git a/mailvalidate/validation.go b/mailvalidate/validation.go index 84d12b0..4543203 100644 --- a/mailvalidate/validation.go +++ b/mailvalidate/validation.go @@ -3,14 +3,13 @@ package mailvalidate import ( "fmt" "log" - "net/http" - "net/url" "strings" "time" "github.com/pkg/errors" "github.com/rdegges/go-ipify" + "github.com/customeros/mailsherpa/domaincheck" "github.com/customeros/mailsherpa/internal/dns" "github.com/customeros/mailsherpa/internal/mailserver" "github.com/customeros/mailsherpa/internal/syntax" @@ -21,7 +20,7 @@ type EmailValidationRequest struct { FromDomain string FromEmail string CatchAllTestUser string - Dns *dns.DNS + Dns *domaincheck.DNS // applicable only for email validation. Pass results from domain validation DomainValidationParams *DomainValidationParams } @@ -127,7 +126,13 @@ func ValidateDomainWithCustomKnownProviders(validationRequest EmailValidationReq evaluateDnsRecords(&validationRequest, &knownProviders, &results) - redirects, primaryDomain := checkRedirects(validationRequest.Email) + _, domain, ok := syntax.GetEmailUserAndDomain(validationRequest.Email) + if !ok { + results.Error = fmt.Sprintf("Invalid Email Address") + return results + } + + redirects, primaryDomain := domaincheck.CheckRedirects(domain) if !redirects && validationRequest.Dns.CNAME == "" && results.HasMXRecord && validationRequest.Dns.HasA { results.IsPrimaryDomain = true } else { @@ -473,54 +478,3 @@ func getRetryTimestamp(minutesDelay int) int { retryTimestamp := time.Unix(currentEpochTime, 0).Add(time.Duration(minutesDelay) * time.Minute).Unix() return int(retryTimestamp) } - -func checkRedirects(email string) (bool, string) { - - _, domain, _ := syntax.GetEmailUserAndDomain(email) - - // Check for HTTP/HTTPS redirects - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Timeout: 10 * time.Second, - } - - for _, protocol := range []string{"http", "https"} { - url := fmt.Sprintf("%s://%s", protocol, domain) - resp, err := client.Get(url) - if err != nil { - continue - } - defer resp.Body.Close() - - if resp.StatusCode >= 300 && resp.StatusCode < 400 { - location := resp.Header.Get("Location") - if location != "" { - location = extractDomain(location) - if location != domain { - return true, location - } - } - } - } - - return false, "" -} - -func extractDomain(urlStr string) string { - u, err := url.Parse(urlStr) - if err != nil { - return urlStr // Return as-is if parsing fails - } - - // Remove 'www.' prefix if present - domain := strings.TrimPrefix(u.Hostname(), "www.") - - // Split the domain and get the last two parts (or just one if it's a TLD) - parts := strings.Split(domain, ".") - if len(parts) > 2 { - return strings.Join(parts[len(parts)-2:], ".") - } - return domain -} diff --git a/main.go b/main.go index a47040b..c655cfa 100644 --- a/main.go +++ b/main.go @@ -12,12 +12,6 @@ func main() { args := flag.Args() switch args[0] { - case "bulk": - if len(args) != 3 { - fmt.Println("Usage: mailsherpa bulk ") - return - } - cmd.BulkVerify(args[1], args[2]) case "domain": if len(args) != 2 { fmt.Println("Usage: mailsherpa domain ")