diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/azure/azure.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/azure/azure.go index d34f937c8f..acd3651695 100644 --- a/internal/app/machined/pkg/runtime/v1alpha1/platform/azure/azure.go +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/azure/azure.go @@ -293,16 +293,16 @@ func (a *Azure) configFromCD() ([]byte, error) { // //nolint:gocyclo func (a *Azure) NetworkConfiguration(ctx context.Context, _ state.State, ch chan<- *runtime.PlatformNetworkConfig) error { - log.Printf("fetching azure instance config from: %q", AzureMetadataEndpoint) - - metadata, err := a.getMetadata(ctx) + metadata, apiVersion, err := a.getMetadata(ctx) if err != nil { return err } - log.Printf("fetching network config from %q", AzureInterfacesEndpoint) + interfacesEndpoint := fmt.Sprintf(AzureInterfacesEndpoint, apiVersion) + + log.Printf("fetching network config from %q", interfacesEndpoint) - metadataNetworkConfig, err := download.Download(ctx, AzureInterfacesEndpoint, + metadataNetworkConfig, err := download.Download(ctx, interfacesEndpoint, download.WithHeaders(map[string]string{"Metadata": "true"})) if err != nil { return fmt.Errorf("failed to fetch network config from metadata service: %w", err) @@ -319,11 +319,13 @@ func (a *Azure) NetworkConfiguration(ctx context.Context, _ state.State, ch chan return fmt.Errorf("failed to parse network metadata: %w", err) } - log.Printf("fetching load balancer metadata from: %q", AzureLoadbalancerEndpoint) + loadbalancerEndpoint := fmt.Sprintf(AzureLoadbalancerEndpoint, apiVersion) + + log.Printf("fetching load balancer metadata from: %q", loadbalancerEndpoint) var loadBalancerAddresses LoadBalancerMetadata - lbConfig, err := download.Download(ctx, AzureLoadbalancerEndpoint, + lbConfig, err := download.Download(ctx, loadbalancerEndpoint, download.WithHeaders(map[string]string{"Metadata": "true"}), download.WithErrorOnNotFound(errors.ErrNoConfigSource), download.WithErrorOnEmptyResponse(errors.ErrNoConfigSource)) diff --git a/internal/app/machined/pkg/runtime/v1alpha1/platform/azure/metadata.go b/internal/app/machined/pkg/runtime/v1alpha1/platform/azure/metadata.go index 2c944f12cb..7fb848b8c5 100644 --- a/internal/app/machined/pkg/runtime/v1alpha1/platform/azure/metadata.go +++ b/internal/app/machined/pkg/runtime/v1alpha1/platform/azure/metadata.go @@ -9,8 +9,8 @@ import ( "encoding/json" stderrors "errors" "fmt" + "log" - "github.com/siderolabs/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/errors" "github.com/siderolabs/talos/pkg/download" ) @@ -19,15 +19,21 @@ const ( // ref: https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service // ref: https://github.com/Azure/azure-rest-api-specs/blob/main/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2023-07-01/examples/GetInstanceMetadata.json + // AzureVersion is the version of the Azure metadata service. + AzureVersion = "2021-12-13" + + // AzureVersionFallback is the fallback version of the Azure metadata service (e.g. Azure Stack Hub). + AzureVersionFallback = "2019-06-01" + // AzureInternalEndpoint is the Azure Internal Channel IP // https://blogs.msdn.microsoft.com/mast/2015/05/18/what-is-the-ip-address-168-63-129-16/ AzureInternalEndpoint = "http://168.63.129.16" // AzureMetadataEndpoint is the local endpoint for the metadata. - AzureMetadataEndpoint = "http://169.254.169.254/metadata/instance/compute?api-version=2021-12-13&format=json" + AzureMetadataEndpoint = "http://169.254.169.254/metadata/instance/compute?api-version=%s&format=json" // AzureInterfacesEndpoint is the local endpoint to get external IPs. - AzureInterfacesEndpoint = "http://169.254.169.254/metadata/instance/network/interface?api-version=2021-12-13&format=json" + AzureInterfacesEndpoint = "http://169.254.169.254/metadata/instance/network/interface?api-version=%s&format=json" // AzureLoadbalancerEndpoint is the local endpoint for load balancer config. - AzureLoadbalancerEndpoint = "http://169.254.169.254/metadata/loadbalancer?api-version=2021-05-01&format=json" + AzureLoadbalancerEndpoint = "http://169.254.169.254/metadata/loadbalancer?api-version=%s&format=json" mnt = "/mnt" ) @@ -54,18 +60,38 @@ type ComputeMetadata struct { EvictionPolicy string `json:"evictionPolicy,omitempty"` } -func (a *Azure) getMetadata(ctx context.Context) (*ComputeMetadata, error) { - metadataDl, err := download.Download(ctx, AzureMetadataEndpoint, - download.WithHeaders(map[string]string{"Metadata": "true"})) - if err != nil && !stderrors.Is(err, errors.ErrNoHostname) { - return nil, fmt.Errorf("error fetching metadata: %w", err) +func (a *Azure) getMetadata(ctx context.Context) (*ComputeMetadata, string, error) { + apiVersion := AzureVersion + errBadRequest := stderrors.New("bad request") + + metadataEndpoint := fmt.Sprintf(AzureMetadataEndpoint, apiVersion) + + log.Printf("fetching azure instance config from: %q", metadataEndpoint) + + metadataDl, err := download.Download(ctx, metadataEndpoint, + download.WithHeaders(map[string]string{"Metadata": "true"}), + download.WithErrorOnBadRequest(errBadRequest), + ) + if err != nil && stderrors.Is(err, errBadRequest) { + apiVersion = AzureVersionFallback + metadataEndpoint = fmt.Sprintf(AzureMetadataEndpoint, apiVersion) + + log.Printf("fetching azure instance config from: %q", metadataEndpoint) + + metadataDl, err = download.Download(ctx, metadataEndpoint, + download.WithHeaders(map[string]string{"Metadata": "true"}), + ) + } + + if err != nil { + return nil, "", fmt.Errorf("error fetching metadata: %w", err) } var metadata ComputeMetadata if err = json.Unmarshal(metadataDl, &metadata); err != nil { - return nil, fmt.Errorf("failed to parse compute metadata: %w", err) + return nil, "", fmt.Errorf("failed to parse compute metadata: %w", err) } - return &metadata, nil + return &metadata, apiVersion, nil } diff --git a/pkg/download/download.go b/pkg/download/download.go index d2e5eb7e9d..79125e88ba 100644 --- a/pkg/download/download.go +++ b/pkg/download/download.go @@ -35,6 +35,7 @@ type downloadOptions struct { EndpointFunc func(context.Context) (string, error) ErrorOnNotFound error + ErrorOnBadRequest error ErrorOnEmptyResponse error Timeout time.Duration @@ -108,6 +109,13 @@ func WithErrorOnEmptyResponse(e error) Option { } } +// WithErrorOnBadRequest provides specific error to return when response has HTTP 400 error. +func WithErrorOnBadRequest(e error) Option { + return func(d *downloadOptions) { + d.ErrorOnBadRequest = e + } +} + // WithEndpointFunc provides a function that sets the endpoint of the download options. func WithEndpointFunc(endpointFunc func(context.Context) (string, error)) Option { return func(d *downloadOptions) { @@ -212,6 +220,7 @@ func Download(ctx context.Context, endpoint string, opts ...Option) (b []byte, e return b, nil } +//nolint:gocyclo func download(req *http.Request, options *downloadOptions) (data []byte, err error) { transport := httpdefaults.PatchTransport(cleanhttp.DefaultTransport()) transport.RegisterProtocol("tftp", NewTFTPTransport()) @@ -249,6 +258,10 @@ func download(req *http.Request, options *downloadOptions) (data []byte, err err return data, options.ErrorOnNotFound } + if resp.StatusCode == http.StatusBadRequest && options.ErrorOnBadRequest != nil { + return data, options.ErrorOnBadRequest + } + if resp.StatusCode != http.StatusOK { // try to read first 32 bytes of the response body // to provide more context in case of error diff --git a/pkg/download/download_test.go b/pkg/download/download_test.go index bf53064d48..3f2d48a445 100644 --- a/pkg/download/download_test.go +++ b/pkg/download/download_test.go @@ -53,6 +53,9 @@ func TestDownload(t *testing.T) { case "/base64": w.WriteHeader(http.StatusOK) w.Write([]byte("ZGF0YQ==")) //nolint:errcheck + case "/400": + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, "bad request") case "/404": w.WriteHeader(http.StatusNotFound) fmt.Fprintln(w, "not found") @@ -107,12 +110,24 @@ func TestDownload(t *testing.T) { opts: []download.Option{download.WithErrorOnNotFound(errors.New("gone forever"))}, expectedError: "gone forever", }, + { + name: "bad request error", + path: "/400", + opts: []download.Option{download.WithErrorOnBadRequest(errors.New("bad req"))}, + expectedError: "bad req", + }, { name: "failure 404", path: "/404", opts: []download.Option{download.WithTimeout(2 * time.Second)}, expectedError: "failed to download config, status code 404, body \"not found\\n\"", }, + { + name: "failure 400", + path: "/400", + opts: []download.Option{download.WithTimeout(2 * time.Second)}, + expectedError: "failed to download config, status code 400, body \"bad request\\n\"", + }, { name: "retry endpoint change", opts: []download.Option{