diff --git a/cmd/process-agent/subcommands/check/check.go b/cmd/process-agent/subcommands/check/check.go index d91447f346d19..be338e8df2252 100644 --- a/cmd/process-agent/subcommands/check/check.go +++ b/cmd/process-agent/subcommands/check/check.go @@ -182,10 +182,12 @@ func RunCheckCmd(deps Dependencies) error { names = append(names, ch.Name()) _, processModuleEnabled := deps.Syscfg.SysProbeObject().EnabledModules[sysconfig.ProcessModule] + _, networkTracerModuleEnabled := deps.Syscfg.SysProbeObject().EnabledModules[sysconfig.NetworkTracerModule] cfg := &checks.SysProbeConfig{ - MaxConnsPerMessage: deps.Syscfg.SysProbeObject().MaxConnsPerMessage, - SystemProbeAddress: deps.Syscfg.SysProbeObject().SocketAddress, - ProcessModuleEnabled: processModuleEnabled, + MaxConnsPerMessage: deps.Syscfg.SysProbeObject().MaxConnsPerMessage, + SystemProbeAddress: deps.Syscfg.SysProbeObject().SocketAddress, + ProcessModuleEnabled: processModuleEnabled, + NetworkTracerModuleEnabled: networkTracerModuleEnabled, } if !matchingCheck(deps.CliParams.checkName, ch) { diff --git a/cmd/system-probe/modules/network_tracer.go b/cmd/system-probe/modules/network_tracer.go index 4852575b36a75..44f2af55d9c78 100644 --- a/cmd/system-probe/modules/network_tracer.go +++ b/cmd/system-probe/modules/network_tracer.go @@ -12,6 +12,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "os" "runtime" @@ -108,6 +109,16 @@ func (nt *networkTracer) Register(httpMux *module.Router) error { logRequests(id, count, len(cs.Conns), start) })) + httpMux.HandleFunc("/network_id", utils.WithConcurrencyLimit(utils.DefaultMaxConcurrentRequests, func(w http.ResponseWriter, req *http.Request) { + id, err := nt.tracer.GetNetworkID(req.Context()) + if err != nil { + log.Errorf("unable to retrieve network_id: %s", err) + w.WriteHeader(500) + return + } + io.WriteString(w, id) + })) + httpMux.HandleFunc("/register", utils.WithConcurrencyLimit(utils.DefaultMaxConcurrentRequests, func(w http.ResponseWriter, req *http.Request) { id := getClientID(req) err := nt.tracer.RegisterClient(id) diff --git a/pkg/network/tracer/tracer.go b/pkg/network/tracer/tracer.go index fded4ce4d70c0..5017167091fca 100644 --- a/pkg/network/tracer/tracer.go +++ b/pkg/network/tracer/tracer.go @@ -40,6 +40,7 @@ import ( "github.com/DataDog/datadog-agent/pkg/process/util" timeresolver "github.com/DataDog/datadog-agent/pkg/security/resolvers/time" "github.com/DataDog/datadog-agent/pkg/telemetry" + "github.com/DataDog/datadog-agent/pkg/util/ec2" "github.com/DataDog/datadog-agent/pkg/util/kernel" "github.com/DataDog/datadog-agent/pkg/util/log" ) @@ -850,3 +851,17 @@ func newUSMMonitor(c *config.Config, tracer connection.Tracer) *usm.Monitor { return monitor } + +// GetNetworkID retrieves the vpc_id (network_id) from IMDS +func (t *Tracer) GetNetworkID(context context.Context) (string, error) { + id := "" + err := kernel.WithRootNS(kernel.ProcFSRoot(), func() error { + var err error + id, err = ec2.GetNetworkID(context) + return err + }) + if err != nil { + return "", err + } + return id, nil +} diff --git a/pkg/network/tracer/tracer_unsupported.go b/pkg/network/tracer/tracer_unsupported.go index bdb6abdf3dbf5..f3ef15179c0b7 100644 --- a/pkg/network/tracer/tracer_unsupported.go +++ b/pkg/network/tracer/tracer_unsupported.go @@ -34,6 +34,11 @@ func (t *Tracer) GetActiveConnections(_ string) (*network.Connections, error) { return nil, ebpf.ErrNotImplemented } +// GetNetworkID is not implemented on this OS for Tracer +func (t *Tracer) GetNetworkID(_ context.Context) (string, error) { + return "", ebpf.ErrNotImplemented +} + // RegisterClient registers the client func (t *Tracer) RegisterClient(_ string) error { return ebpf.ErrNotImplemented diff --git a/pkg/network/tracer/tracer_windows.go b/pkg/network/tracer/tracer_windows.go index a4677a19c501a..fba6ea78a95b0 100644 --- a/pkg/network/tracer/tracer_windows.go +++ b/pkg/network/tracer/tracer_windows.go @@ -309,6 +309,11 @@ func (t *Tracer) DebugDumpProcessCache(_ context.Context) (interface{}, error) { return nil, ebpf.ErrNotImplemented } +// GetNetworkID is not implemented on this OS for Tracer +func (t *Tracer) GetNetworkID(_ context.Context) (string, error) { + return "", ebpf.ErrNotImplemented +} + func newUSMMonitor(c *config.Config, dh driver.Handle) usm.Monitor { if !c.EnableHTTPMonitoring && !c.EnableNativeTLSMonitoring { return nil diff --git a/pkg/process/checks/checks.go b/pkg/process/checks/checks.go index 139abb92720dd..b7aabbd8c794e 100644 --- a/pkg/process/checks/checks.go +++ b/pkg/process/checks/checks.go @@ -35,6 +35,8 @@ type SysProbeConfig struct { SystemProbeAddress string // System probe process module on/off configuration ProcessModuleEnabled bool + // System probe network_tracer module on/off configuration + NetworkTracerModuleEnabled bool } // Check is an interface for Agent checks that collect data. Each check returns diff --git a/pkg/process/checks/container.go b/pkg/process/checks/container.go index de3e40fd00b43..1e4187d46391c 100644 --- a/pkg/process/checks/container.go +++ b/pkg/process/checks/container.go @@ -6,7 +6,6 @@ package checks import ( - "context" "fmt" "math" "sync" @@ -16,9 +15,9 @@ import ( workloadmeta "github.com/DataDog/datadog-agent/comp/core/workloadmeta/def" ddconfig "github.com/DataDog/datadog-agent/pkg/config" + "github.com/DataDog/datadog-agent/pkg/process/net" "github.com/DataDog/datadog-agent/pkg/process/statsd" proccontainers "github.com/DataDog/datadog-agent/pkg/process/util/containers" - "github.com/DataDog/datadog-agent/pkg/util/cloudproviders" "github.com/DataDog/datadog-agent/pkg/util/flavor" "github.com/DataDog/datadog-agent/pkg/util/log" ) @@ -53,11 +52,21 @@ type ContainerCheck struct { } // Init initializes a ContainerCheck instance. -func (c *ContainerCheck) Init(_ *SysProbeConfig, info *HostInfo, _ bool) error { +func (c *ContainerCheck) Init(syscfg *SysProbeConfig, info *HostInfo, _ bool) error { c.containerProvider = proccontainers.GetSharedContainerProvider(c.wmeta) c.hostInfo = info - networkID, err := cloudproviders.GetNetworkID(context.TODO()) + var tu *net.RemoteSysProbeUtil + var err error + if syscfg.NetworkTracerModuleEnabled { + // Calling the remote tracer will cause it to initialize and check connectivity + tu, err = net.GetRemoteSystemProbeUtil(syscfg.SystemProbeAddress) + if err != nil { + log.Warnf("could not initiate connection with system probe: %s", err) + } + } + + networkID, err := retryGetNetworkID(tu) if err != nil { log.Infof("no network ID detected: %s", err) } diff --git a/pkg/process/checks/net.go b/pkg/process/checks/net.go index 26e01d0677061..5396fdd26c10f 100644 --- a/pkg/process/checks/net.go +++ b/pkg/process/checks/net.go @@ -107,7 +107,7 @@ func (c *ConnectionsCheck) Init(syscfg *SysProbeConfig, hostInfo *HostInfo, _ bo } } - networkID, err := cloudproviders.GetNetworkID(context.TODO()) + networkID, err := retryGetNetworkID(tu) if err != nil { log.Infof("no network ID detected: %s", err) } @@ -503,3 +503,17 @@ func convertAndEnrichWithServiceCtx(tags []string, tagOffsets []uint32, serviceC return tagsStr } + +// fetches network_id from the current netNS or from the system probe if necessary, where the root netNS is used +func retryGetNetworkID(sysProbeUtil *net.RemoteSysProbeUtil) (string, error) { + networkID, err := cloudproviders.GetNetworkID(context.TODO()) + if err != nil && sysProbeUtil != nil { + log.Infof("no network ID detected. retrying via system-probe: %s", err) + networkID, err = sysProbeUtil.GetNetworkID() + if err != nil { + log.Infof("failed to get network ID from system-probe: %s", err) + return "", err + } + } + return networkID, err +} diff --git a/pkg/process/checks/process.go b/pkg/process/checks/process.go index f35e71704d1cb..26685da9e15aa 100644 --- a/pkg/process/checks/process.go +++ b/pkg/process/checks/process.go @@ -6,7 +6,6 @@ package checks import ( - "context" "errors" "fmt" "math" @@ -28,7 +27,6 @@ import ( "github.com/DataDog/datadog-agent/pkg/process/statsd" "github.com/DataDog/datadog-agent/pkg/process/util" proccontainers "github.com/DataDog/datadog-agent/pkg/process/util/containers" - "github.com/DataDog/datadog-agent/pkg/util/cloudproviders" "github.com/DataDog/datadog-agent/pkg/util/flavor" "github.com/DataDog/datadog-agent/pkg/util/log" "github.com/DataDog/datadog-agent/pkg/util/subscriptions" @@ -137,7 +135,17 @@ func (p *ProcessCheck) Init(syscfg *SysProbeConfig, info *HostInfo, oneShot bool p.notInitializedLogLimit = log.NewLogLimit(1, time.Minute*10) - networkID, err := cloudproviders.GetNetworkID(context.TODO()) + var tu *net.RemoteSysProbeUtil + var err error + if syscfg.NetworkTracerModuleEnabled { + // Calling the remote tracer will cause it to initialize and check connectivity + tu, err = net.GetRemoteSystemProbeUtil(syscfg.SystemProbeAddress) + if err != nil { + log.Warnf("could not initiate connection with system probe: %s", err) + } + } + + networkID, err := retryGetNetworkID(tu) if err != nil { log.Infof("no network ID detected: %s", err) } diff --git a/pkg/process/net/common.go b/pkg/process/net/common.go index a9b7a64430143..640c3e82dadd2 100644 --- a/pkg/process/net/common.go +++ b/pkg/process/net/common.go @@ -44,6 +44,7 @@ type Conn interface { const ( contentTypeProtobuf = "application/protobuf" + contentTypeJSON = "application/json" ) var ( @@ -166,6 +167,32 @@ func (r *RemoteSysProbeUtil) GetConnections(clientID string) (*model.Connections return conns, nil } +// GetNetworkID fetches the network_id (vpc_id) from system-probe +func (r *RemoteSysProbeUtil) GetNetworkID() (string, error) { + req, err := http.NewRequest("GET", networkIDURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "text/plain") + resp, err := r.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("network_id request failed: url: %s, status code: %d", networkIDURL, resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + return string(body), nil +} + // GetPing returns the results of a ping to a host func (r *RemoteSysProbeUtil) GetPing(clientID string, host string, count int, interval time.Duration, timeout time.Duration) ([]byte, error) { req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s?client_id=%s&count=%d&interval=%d&timeout=%d", pingURL, host, clientID, count, interval, timeout), nil) @@ -173,7 +200,7 @@ func (r *RemoteSysProbeUtil) GetPing(clientID string, host string, count int, in return nil, err } - req.Header.Set("Accept", "application/json") + req.Header.Set("Accept", contentTypeJSON) resp, err := r.httpClient.Do(req) if err != nil { return nil, err @@ -208,7 +235,7 @@ func (r *RemoteSysProbeUtil) GetTraceroute(clientID string, host string, port ui return nil, err } - req.Header.Set("Accept", "application/json") + req.Header.Set("Accept", contentTypeJSON) resp, err := r.tracerouteClient.Do(req) if err != nil { return nil, err diff --git a/pkg/process/net/common_linux.go b/pkg/process/net/common_linux.go index 2dc5c7db28c8d..7fee3ffdb1cb9 100644 --- a/pkg/process/net/common_linux.go +++ b/pkg/process/net/common_linux.go @@ -18,6 +18,7 @@ const ( pingURL = "http://unix/" + string(sysconfig.PingModule) + "/ping/" tracerouteURL = "http://unix/" + string(sysconfig.TracerouteModule) + "/traceroute/" connectionsURL = "http://unix/" + string(sysconfig.NetworkTracerModule) + "/connections" + networkIDURL = "http://unix/" + string(sysconfig.NetworkTracerModule) + "/network_id" procStatsURL = "http://unix/" + string(sysconfig.ProcessModule) + "/stats" registerURL = "http://unix/" + string(sysconfig.NetworkTracerModule) + "/register" statsURL = "http://unix/debug/stats" diff --git a/pkg/process/net/common_unsupported.go b/pkg/process/net/common_unsupported.go index 03a481a2de400..ebdea5968e5bb 100644 --- a/pkg/process/net/common_unsupported.go +++ b/pkg/process/net/common_unsupported.go @@ -40,6 +40,11 @@ func (r *RemoteSysProbeUtil) GetConnections(_ string) (*model.Connections, error return nil, ErrNotImplemented } +// GetNetworkID is not supported +func (r *RemoteSysProbeUtil) GetNetworkID() (string, error) { + return "", ErrNotImplemented +} + // GetStats is not supported func (r *RemoteSysProbeUtil) GetStats() (map[string]interface{}, error) { return nil, ErrNotImplemented diff --git a/pkg/process/net/common_windows.go b/pkg/process/net/common_windows.go index 4ad0d218e65f5..83d8440825e4a 100644 --- a/pkg/process/net/common_windows.go +++ b/pkg/process/net/common_windows.go @@ -15,6 +15,7 @@ import ( const ( connectionsURL = "http://localhost:3333/" + string(sysconfig.NetworkTracerModule) + "/connections" + networkIDURL = "http://unix/" + string(sysconfig.NetworkTracerModule) + "/network_id" registerURL = "http://localhost:3333/" + string(sysconfig.NetworkTracerModule) + "/register" languageDetectionURL = "http://localhost:3333/" + string(sysconfig.LanguageDetectionModule) + "/detect" statsURL = "http://localhost:3333/debug/stats" diff --git a/pkg/process/net/mocks/sys_probe_util.go b/pkg/process/net/mocks/sys_probe_util.go index 3bf0b2c1d7270..0d0af5300fa4f 100644 --- a/pkg/process/net/mocks/sys_probe_util.go +++ b/pkg/process/net/mocks/sys_probe_util.go @@ -43,6 +43,34 @@ func (_m *SysProbeUtil) GetConnections(clientID string) (*process.Connections, e return r0, r1 } +// GetNetworkID provides a mock function with given fields: +func (_m *SysProbeUtil) GetNetworkID() (string, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetNetworkID") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func() (string, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetProcStats provides a mock function with given fields: pids func (_m *SysProbeUtil) GetProcStats(pids []int32) (*process.ProcStatsWithPermByPID, error) { ret := _m.Called(pids) diff --git a/pkg/process/net/shared.go b/pkg/process/net/shared.go index 72a6e418865c6..a0a7aa18ae327 100644 --- a/pkg/process/net/shared.go +++ b/pkg/process/net/shared.go @@ -13,4 +13,5 @@ type SysProbeUtil interface { GetStats() (map[string]interface{}, error) GetProcStats(pids []int32) (*model.ProcStatsWithPermByPID, error) Register(clientID string) error + GetNetworkID() (string, error) } diff --git a/pkg/util/cloudproviders/network.go b/pkg/util/cloudproviders/network.go index 12c7496579c3f..c183ea96ce06f 100644 --- a/pkg/util/cloudproviders/network.go +++ b/pkg/util/cloudproviders/network.go @@ -30,7 +30,7 @@ func GetNetworkID(ctx context.Context) (string, error) { return cache.Get[string]( networkIDCacheKey, func() (string, error) { - // the the id from configuration + // the id from configuration if networkID := config.Datadog().GetString("network.id"); networkID != "" { log.Debugf("GetNetworkID: using configured network ID: %s", networkID) return networkID, nil diff --git a/pkg/util/ec2/ec2_test.go b/pkg/util/ec2/ec2_test.go index d92242dada46f..cbb660df50c8a 100644 --- a/pkg/util/ec2/ec2_test.go +++ b/pkg/util/ec2/ec2_test.go @@ -30,6 +30,8 @@ var ( initialTokenURL = tokenURL ) +const testIMDSToken = "AQAAAFKw7LyqwVmmBMkqXHpDBuDWw2GnfGswTHi2yiIOGvzD7OMaWw==" + func resetPackageVars() { config.Datadog().SetWithoutSource("ec2_metadata_timeout", initialTimeout) metadataURL = initialMetadataURL @@ -301,12 +303,11 @@ func TestExtractClusterName(t *testing.T) { func TestGetToken(t *testing.T) { ctx := context.Background() - originalToken := "AQAAAFKw7LyqwVmmBMkqXHpDBuDWw2GnfGswTHi2yiIOGvzD7OMaWw==" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") h := r.Header.Get("X-aws-ec2-metadata-token-ttl-seconds") if h != "" && r.Method == http.MethodPut { - io.WriteString(w, originalToken) + io.WriteString(w, testIMDSToken) } else { w.WriteHeader(http.StatusNotFound) } @@ -319,7 +320,7 @@ func TestGetToken(t *testing.T) { token, err := token.Get(ctx) require.NoError(t, err) - assert.Equal(t, originalToken, token) + assert.Equal(t, testIMDSToken, token) } func TestMetedataRequestWithToken(t *testing.T) { @@ -331,7 +332,6 @@ func TestMetedataRequestWithToken(t *testing.T) { ctx := context.Background() ipv4 := "198.51.100.1" - tok := "AQAAAFKw7LyqwVmmBMkqXHpDBuDWw2GnfGswTHi2yiIOGvzD7OMaWw==" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") @@ -345,11 +345,11 @@ func TestMetedataRequestWithToken(t *testing.T) { r.Header.Add("X-sequence", fmt.Sprintf("%v", seq)) seq++ requestForToken = r - io.WriteString(w, tok) + io.WriteString(w, testIMDSToken) case http.MethodGet: // Should be a metadata request t := r.Header.Get("X-aws-ec2-metadata-token") - if t != tok { + if t != testIMDSToken { r.Header.Add("X-sequence", fmt.Sprintf("%v", seq)) seq++ requestWithoutToken = r @@ -386,7 +386,7 @@ func TestMetedataRequestWithToken(t *testing.T) { assert.Equal(t, fmt.Sprint(config.Datadog().GetInt("ec2_metadata_token_lifetime")), requestForToken.Header.Get("X-aws-ec2-metadata-token-ttl-seconds")) assert.Equal(t, http.MethodPut, requestForToken.Method) assert.Equal(t, "/", requestForToken.RequestURI) - assert.Equal(t, tok, requestWithToken.Header.Get("X-aws-ec2-metadata-token")) + assert.Equal(t, testIMDSToken, requestWithToken.Header.Get("X-aws-ec2-metadata-token")) assert.Equal(t, "/public-ipv4", requestWithToken.RequestURI) assert.Equal(t, http.MethodGet, requestWithToken.Method) @@ -515,7 +515,7 @@ func TestMetadataSourceIMDS(t *testing.T) { w.Header().Set("Content-Type", "text/plain") switch r.Method { case http.MethodPut: // token request - io.WriteString(w, "AQAAAFKw7LyqwVmmBMkqXHpDBuDWw2GnfGswTHi2yiIOGvzD7OMaWw==") + io.WriteString(w, testIMDSToken) case http.MethodGet: // metadata request switch r.RequestURI { case "/hostname": diff --git a/pkg/util/ec2/imds_helpers.go b/pkg/util/ec2/imds_helpers.go index 510fad39f43c4..afc2ef22fffbd 100644 --- a/pkg/util/ec2/imds_helpers.go +++ b/pkg/util/ec2/imds_helpers.go @@ -77,7 +77,7 @@ func doHTTPRequest(ctx context.Context, url string, forceIMDSv2 bool) (string, e tokenValue, err := token.Get(ctx) if err != nil { if forceIMDSv2 { - return "", fmt.Errorf("Could not fetch token from IMDSv2") + return "", fmt.Errorf("could not fetch token from IMDSv2") } log.Warnf("ec2_prefer_imdsv2 is set to true in the configuration but the agent was unable to proceed: %s", err) } else { diff --git a/pkg/util/ec2/network.go b/pkg/util/ec2/network.go index 5fafa6bed62d7..a7fa4730513a7 100644 --- a/pkg/util/ec2/network.go +++ b/pkg/util/ec2/network.go @@ -30,9 +30,9 @@ func GetPublicIPv4(ctx context.Context) (string, error) { var networkIDFetcher = cachedfetch.Fetcher{ Name: "VPC IDs", Attempt: func(ctx context.Context) (interface{}, error) { - resp, err := getMetadataItem(ctx, imdsNetworkMacs, false) + resp, err := getMetadataItem(ctx, imdsNetworkMacs, true) if err != nil { - return "", err + return "", fmt.Errorf("EC2: GetNetworkID failed to get mac addresses: %w", err) } macs := strings.Split(strings.TrimSpace(resp), "\n") @@ -43,9 +43,9 @@ var networkIDFetcher = cachedfetch.Fetcher{ continue } mac = strings.TrimSuffix(mac, "/") - id, err := getMetadataItem(ctx, fmt.Sprintf("%s/%s/vpc-id", imdsNetworkMacs, mac), false) + id, err := getMetadataItem(ctx, fmt.Sprintf("%s/%s/vpc-id", imdsNetworkMacs, mac), true) if err != nil { - return "", err + return "", fmt.Errorf("EC2: GetNetworkID failed to get vpc id for mac %s: %w", mac, err) } vpcIDs.Add(id) } diff --git a/pkg/util/ec2/network_test.go b/pkg/util/ec2/network_test.go index 7fa773b41b888..1e4ca0bc36b42 100644 --- a/pkg/util/ec2/network_test.go +++ b/pkg/util/ec2/network_test.go @@ -23,11 +23,16 @@ func TestGetPublicIPv4(t *testing.T) { ip := "10.0.0.2" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") - switch r.RequestURI { - case "/public-ipv4": - io.WriteString(w, ip) - default: - w.WriteHeader(http.StatusNotFound) + switch r.Method { + case http.MethodPut: // token request + io.WriteString(w, testIMDSToken) + case http.MethodGet: // metadata request + switch r.RequestURI { + case "/public-ipv4": + io.WriteString(w, ip) + default: + w.WriteHeader(http.StatusNotFound) + } } })) @@ -47,18 +52,24 @@ func TestGetNetworkID(t *testing.T) { vpc := "vpc-12345" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") - switch r.RequestURI { - case "/network/interfaces/macs": - io.WriteString(w, mac+"/") - case "/network/interfaces/macs/00:00:00:00:00/vpc-id": - io.WriteString(w, vpc) - default: - w.WriteHeader(http.StatusNotFound) + switch r.Method { + case http.MethodPut: // token request + io.WriteString(w, testIMDSToken) + case http.MethodGet: // metadata request + switch r.RequestURI { + case "/network/interfaces/macs": + io.WriteString(w, mac+"/") + case "/network/interfaces/macs/00:00:00:00:00/vpc-id": + io.WriteString(w, vpc) + default: + w.WriteHeader(http.StatusNotFound) + } } })) defer ts.Close() metadataURL = ts.URL + tokenURL = ts.URL config.Datadog().SetWithoutSource("ec2_metadata_timeout", 1000) defer resetPackageVars() @@ -69,18 +80,25 @@ func TestGetNetworkID(t *testing.T) { func TestGetInstanceIDNoMac(t *testing.T) { ctx := context.Background() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - io.WriteString(w, "") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + switch r.Method { + case http.MethodPut: // token request + io.WriteString(w, testIMDSToken) + case http.MethodGet: // metadata request + io.WriteString(w, "") + } })) defer ts.Close() metadataURL = ts.URL + tokenURL = ts.URL config.Datadog().SetWithoutSource("ec2_metadata_timeout", 1000) defer resetPackageVars() _, err := GetNetworkID(ctx) require.Error(t, err) - assert.Contains(t, err.Error(), "no mac addresses returned") + assert.Contains(t, err.Error(), "EC2: GetNetworkID no mac addresses returned") } func TestGetInstanceIDMultipleVPC(t *testing.T) { @@ -91,21 +109,27 @@ func TestGetInstanceIDMultipleVPC(t *testing.T) { vpc2 := "vpc-6789" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") - switch r.RequestURI { - case "/network/interfaces/macs": - io.WriteString(w, mac+"/\n") - io.WriteString(w, mac2+"/\n") - case "/network/interfaces/macs/00:00:00:00:00/vpc-id": - io.WriteString(w, vpc) - case "/network/interfaces/macs/00:00:00:00:01/vpc-id": - io.WriteString(w, vpc2) - default: - w.WriteHeader(http.StatusNotFound) + switch r.Method { + case http.MethodPut: // token request + io.WriteString(w, testIMDSToken) + case http.MethodGet: // metadata request + switch r.RequestURI { + case "/network/interfaces/macs": + io.WriteString(w, mac+"/\n") + io.WriteString(w, mac2+"/\n") + case "/network/interfaces/macs/00:00:00:00:00/vpc-id": + io.WriteString(w, vpc) + case "/network/interfaces/macs/00:00:00:00:01/vpc-id": + io.WriteString(w, vpc2) + default: + w.WriteHeader(http.StatusNotFound) + } } })) defer ts.Close() metadataURL = ts.URL + tokenURL = ts.URL config.Datadog().SetWithoutSource("ec2_metadata_timeout", 1000) defer resetPackageVars()