diff --git a/.gitignore b/.gitignore index dd353cf6..89d2c2f8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,9 @@ coverage.txt *.test *.out + +# Build artifacts +/acr-cli +/acr/acr +/cmd/acr/acr *.html \ No newline at end of file diff --git a/README.md b/README.md index 32e9a1ed..45215d5e 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,10 @@ acr purge \ --include-locked ``` +#### ABAC batch size (environment variable) +For registries with ABAC enabled, repositories are processed in batches. The batch size controls how many repositories share a single token scope. Token refresh happens dynamically when API calls detect token expiration, using the current batch's repository scope. The batch size can be configured via the `ABAC_BATCH_SIZE` environment variable (default=10) + + ### Integration with ACR Tasks To run a locally built version of the ACR-CLI using ACR Tasks follow these steps: diff --git a/cmd/acr/purge.go b/cmd/acr/purge.go index 837b1518..7fd8bf2a 100644 --- a/cmd/acr/purge.go +++ b/cmd/acr/purge.go @@ -7,8 +7,10 @@ import ( "context" "fmt" "net/http" + "os" "runtime" "sort" + "strconv" "strings" "time" @@ -89,6 +91,7 @@ type purgeParameters struct { includeLocked bool concurrency int repoPageSize int32 + verbose bool } // newPurgeCmd defines the purge command. @@ -178,7 +181,7 @@ func newPurgeCmd(rootParams *rootParameters) *cobra.Command { // Combine flags for clarity - these are mutually exclusive supportUntaggedCleanup := purgeParams.untagged || purgeParams.untaggedOnly - deletedTagsCount, deletedManifestsCount, err := purge(ctx, acrClient, loginURL, repoParallelism, agoDuration, purgeParams.keep, purgeParams.filterTimeout, supportUntaggedCleanup, purgeParams.untaggedOnly, tagFilters, purgeParams.dryRun, purgeParams.includeLocked) + deletedTagsCount, deletedManifestsCount, err := purge(ctx, acrClient, loginURL, repoParallelism, agoDuration, purgeParams.keep, purgeParams.filterTimeout, supportUntaggedCleanup, purgeParams.untaggedOnly, tagFilters, purgeParams.dryRun, purgeParams.includeLocked, purgeParams.verbose) if err != nil { fmt.Printf("Failed to complete purge: %v \n", err) @@ -208,6 +211,7 @@ func newPurgeCmd(rootParams *rootParameters) *cobra.Command { cmd.Flags().Int64Var(&purgeParams.filterTimeout, "filter-timeout-seconds", defaultRegexpMatchTimeoutSeconds, "This limits the evaluation of the regex filter, and will return a timeout error if this duration is exceeded during a single evaluation. If written incorrectly a regexp filter with backtracking can result in an infinite loop.") cmd.Flags().IntVar(&purgeParams.concurrency, "concurrency", defaultPoolSize, concurrencyDescription) cmd.Flags().Int32Var(&purgeParams.repoPageSize, "repository-page-size", defaultRepoPageSize, repoPageSizeDescription) + cmd.Flags().BoolVar(&purgeParams.verbose, "verbose", false, "Enable verbose output including detailed repository names during ABAC token operations") cmd.Flags().BoolP("help", "h", false, "Print usage") // Make filter and ago conditionally required based on untagged-only flag cmd.MarkFlagsOneRequired("filter", "untagged-only") @@ -226,36 +230,75 @@ func purge(ctx context.Context, untaggedOnly bool, tagFilters map[string]string, dryRun bool, - includeLocked bool) (deletedTagsCount int, deletedManifestsCount int, err error) { - - // In order to print a summary of the deleted tags/manifests the counters get updated everytime a repo is purged. - for repoName, tagRegex := range tagFilters { - var singleDeletedTagsCount int - var manifestToTagsCountMap map[string]int - - // Handle tag deletion based on mode - if untaggedOnly { - // Initialize empty map for untagged-only mode (no tag deletion) - manifestToTagsCountMap = make(map[string]int) - } else { - // Standard mode: delete matching tags first - singleDeletedTagsCount, manifestToTagsCountMap, err = purgeTags(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, tagRegex, keep, filterTimeout, dryRun, includeLocked) - if err != nil { - return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge tags: %w", err) + includeLocked bool, + verbose bool) (deletedTagsCount int, deletedManifestsCount int, err error) { + + // Load ABAC batch size from environment variable + abacBatchSize := 10 // default + if envVal, exists := os.LookupEnv("ABAC_BATCH_SIZE"); exists { + if parsed, err := strconv.Atoi(envVal); err == nil && parsed > 0 { + abacBatchSize = parsed + } + } + + // Collect all repository names into a slice for batching + repos := make([]string, 0, len(tagFilters)) + for repoName := range tagFilters { + repos = append(repos, repoName) + } + + // Process repositories in batches of abacBatchSize. + // For ABAC-enabled registries, we set the current repositories for the batch so that + // token refresh happens dynamically when needed (on API calls that detect token expiration). + // For non-ABAC registries, the batching loop is harmless (no special token handling needed). + for i := 0; i < len(repos); i += abacBatchSize { + end := i + abacBatchSize + if end > len(repos) { + end = len(repos) + } + batch := repos[i:end] + + // For ABAC registries, set the current repositories for this batch. + // Token refresh will happen dynamically when API calls detect token expiration. + if acrClient.IsAbac() { + acrClient.SetCurrentRepositories(batch) + if verbose { + fmt.Printf("ABAC: Setting token scope for %d repositories: %v\n", len(batch), batch) + } else { + fmt.Printf("ABAC: Setting token scope for %d repositories\n", len(batch)) } } - singleDeletedManifestsCount := 0 - // If the untagged flag is set or untagged-only mode is enabled, delete manifests - if removeUntaggedManifests { - singleDeletedManifestsCount, err = purgeDanglingManifests(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, keep, manifestToTagsCountMap, dryRun, includeLocked) - if err != nil { - return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge manifests: %w", err) + // Process all repositories in this batch + for _, repoName := range batch { + tagRegex := tagFilters[repoName] + var singleDeletedTagsCount int + var manifestToTagsCountMap map[string]int + + // Handle tag deletion based on mode + if untaggedOnly { + // Initialize empty map for untagged-only mode (no tag deletion) + manifestToTagsCountMap = make(map[string]int) + } else { + // Standard mode: delete matching tags first + singleDeletedTagsCount, manifestToTagsCountMap, err = purgeTags(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, tagRegex, keep, filterTimeout, dryRun, includeLocked) + if err != nil { + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge tags: %w", err) + } + } + + singleDeletedManifestsCount := 0 + // If the untagged flag is set or untagged-only mode is enabled, delete manifests + if removeUntaggedManifests { + singleDeletedManifestsCount, err = purgeDanglingManifests(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, keep, manifestToTagsCountMap, dryRun, includeLocked) + if err != nil { + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge manifests: %w", err) + } } + // After every repository is purged the counters are updated. + deletedTagsCount += singleDeletedTagsCount + deletedManifestsCount += singleDeletedManifestsCount } - // After every repository is purged the counters are updated. - deletedTagsCount += singleDeletedTagsCount - deletedManifestsCount += singleDeletedManifestsCount } return deletedTagsCount, deletedManifestsCount, nil diff --git a/cmd/acr/purge_test.go b/cmd/acr/purge_test.go index 912d3d7b..20fb3b05 100644 --- a/cmd/acr/purge_test.go +++ b/cmd/acr/purge_test.go @@ -552,9 +552,13 @@ func TestDryRun(t *testing.T) { t.Run("RepositoryNotFoundTest", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + // Mock IsAbac to return false (non-ABAC registry) to use standard wildcard token flow + mockClient.On("IsAbac").Return(false) + // Need a .Maybe() since it's only called for ABAC registries (this test mocks IsAbac to return false) + mockClient.On("IsTokenExpired").Return(false).Maybe() mockClient.On("GetAcrManifests", mock.Anything, testRepo, "", "").Return(notFoundManifestResponse, errors.New("testRepo not found")).Once() mockClient.On("GetAcrTags", mock.Anything, testRepo, "timedesc", "").Return(notFoundTagResponse, errors.New("testRepo not found")).Once() - deletedTags, deletedManifests, err := purge(testCtx, mockClient, testLoginURL, 60, -24*time.Hour, 0, 1, true, false, map[string]string{testRepo: "[\\s\\S]*"}, true, false) + deletedTags, deletedManifests, err := purge(testCtx, mockClient, testLoginURL, 60, -24*time.Hour, 0, 1, true, false, map[string]string{testRepo: "[\\s\\S]*"}, true, false, false) assert.Equal(0, deletedTags, "Number of deleted elements should be 0") assert.Equal(0, deletedManifests, "Number of deleted elements should be 0") assert.Equal(nil, err, "Error should be nil") diff --git a/cmd/acr/purge_untagged_only_test.go b/cmd/acr/purge_untagged_only_test.go index 4d59f23c..8c41531f 100644 --- a/cmd/acr/purge_untagged_only_test.go +++ b/cmd/acr/purge_untagged_only_test.go @@ -3,8 +3,11 @@ package main import ( + "bytes" "context" + "io" "net/http" + "os" "testing" "github.com/Azure/acr-cli/acr" @@ -25,6 +28,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyPurgeManifestsOnly", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Setup mock response for manifests without tags manifestDigest := "sha256:abc123" @@ -85,6 +90,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{testRepo: ".*"}, false, // dryRun false, // includeLocked + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted in untagged-only mode") @@ -97,6 +103,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyNoFilterAllRepos", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // We won't test GetRepositories here since the purge function is called // with already-created tagFilters. Instead test that all repos are processed. @@ -137,6 +145,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { tagFilters, false, // dryRun false, // includeLocked + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted") @@ -149,6 +158,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithFilter", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() manifestDigest := "sha256:def456" mediaType := "application/vnd.docker.distribution.manifest.v2+json" @@ -206,6 +217,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{"specific-repo": ".*"}, false, // dryRun false, // includeLocked + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted in untagged-only mode") @@ -218,6 +230,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyDryRun", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() manifestDigest := "sha256:ghi789" mediaType := "application/vnd.docker.distribution.manifest.v2+json" @@ -270,6 +284,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{testRepo: ".*"}, true, // dryRun false, // includeLocked + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted in dry-run") @@ -282,6 +297,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithLockedManifests", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Create locked and unlocked untagged manifests lockedDigest := "sha256:locked123" @@ -351,6 +368,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{testRepo: ".*"}, false, // dryRun false, // includeLocked = false + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted") @@ -363,6 +381,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithIncludeLocked", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Create locked untagged manifest lockedDigest := "sha256:locked789" @@ -429,6 +449,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{testRepo: ".*"}, false, // dryRun true, // includeLocked = true + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted") @@ -750,3 +771,194 @@ func TestPurgeDanglingManifestsWithAgoAndKeep(t *testing.T) { mockClient.AssertExpectations(t) }) } + +// TestPurgeAbacVerboseMode tests the verbose output for ABAC registries +func TestPurgeAbacVerboseMode(t *testing.T) { + testCtx := context.Background() + testLoginURL := "registry.azurecr.io" + defaultPoolSize := 1 + + // Test: ABAC verbose mode should output repository names + t.Run("AbacVerboseModeOutputsRepoNames", func(t *testing.T) { + assert := assert.New(t) + mockClient := &mocks.AcrCLIClientInterface{} + + // Mock ABAC registry + mockClient.On("IsAbac").Return(true) + mockClient.On("SetCurrentRepositories", mock.Anything).Return() + + // Empty manifests result for each repo + emptyManifestsResult := &acr.Manifests{ + Response: autorest.Response{ + Response: &http.Response{ + StatusCode: 200, + }, + }, + ManifestsAttributes: &[]acr.ManifestAttributesBase{}, + } + + repos := []string{"repo1", "repo2", "repo3"} + for _, repo := range repos { + mockClient.On("GetAcrManifests", mock.Anything, repo, "", "").Return(emptyManifestsResult, nil).Once() + } + + tagFilters := make(map[string]string) + for _, repo := range repos { + tagFilters[repo] = ".*" + } + + // Capture stdout to verify verbose output + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // Call purge with verbose=true and ABAC enabled + deletedTagsCount, deletedManifestsCount, err := purge( + testCtx, + mockClient, + testLoginURL, + defaultPoolSize, + 0, // ago + 0, // keep + 60, // filterTimeout + true, // removeUntaggedManifests + true, // untaggedOnly + tagFilters, + false, // dryRun + false, // includeLocked + true, // verbose = true + ) + + // Restore stdout and read captured output + w.Close() + os.Stdout = oldStdout + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + assert.Equal(0, deletedTagsCount, "No tags should be deleted") + assert.Equal(0, deletedManifestsCount, "No manifests deleted when none are untagged") + assert.Nil(err, "Error should be nil") + // Verify verbose output contains repository names + assert.Contains(output, "ABAC: Setting token scope for 3 repositories:", "Should output repo count") + assert.Contains(output, "repo1", "Should output repo1 in verbose mode") + assert.Contains(output, "repo2", "Should output repo2 in verbose mode") + assert.Contains(output, "repo3", "Should output repo3 in verbose mode") + mockClient.AssertCalled(t, "SetCurrentRepositories", mock.Anything) + mockClient.AssertExpectations(t) + }) + + // Test: ABAC non-verbose mode should only output count, not names + t.Run("AbacNonVerboseModeOutputsCountOnly", func(t *testing.T) { + assert := assert.New(t) + mockClient := &mocks.AcrCLIClientInterface{} + + // Mock ABAC registry + mockClient.On("IsAbac").Return(true) + mockClient.On("SetCurrentRepositories", mock.Anything).Return() + + // Empty manifests result for each repo + emptyManifestsResult := &acr.Manifests{ + Response: autorest.Response{ + Response: &http.Response{ + StatusCode: 200, + }, + }, + ManifestsAttributes: &[]acr.ManifestAttributesBase{}, + } + + repos := []string{"repo1", "repo2", "repo3"} + for _, repo := range repos { + mockClient.On("GetAcrManifests", mock.Anything, repo, "", "").Return(emptyManifestsResult, nil).Once() + } + + tagFilters := make(map[string]string) + for _, repo := range repos { + tagFilters[repo] = ".*" + } + + // Capture stdout to verify non-verbose output + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // Call purge with verbose=false and ABAC enabled + deletedTagsCount, deletedManifestsCount, err := purge( + testCtx, + mockClient, + testLoginURL, + defaultPoolSize, + 0, // ago + 0, // keep + 60, // filterTimeout + true, // removeUntaggedManifests + true, // untaggedOnly + tagFilters, + false, // dryRun + false, // includeLocked + false, // verbose = false + ) + + // Restore stdout and read captured output + w.Close() + os.Stdout = oldStdout + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + assert.Equal(0, deletedTagsCount, "No tags should be deleted") + assert.Equal(0, deletedManifestsCount, "No manifests deleted when none are untagged") + assert.Nil(err, "Error should be nil") + // Verify non-verbose output contains count but NOT the repository list + assert.Contains(output, "ABAC: Setting token scope for 3 repositories", "Should output repo count") + // The non-verbose output should NOT contain the bracketed list of repos + assert.NotContains(output, "[repo1", "Should NOT output repo list in non-verbose mode") + mockClient.AssertCalled(t, "SetCurrentRepositories", mock.Anything) + mockClient.AssertExpectations(t) + }) + + // Test: Non-ABAC registry should not call SetCurrentRepositories + t.Run("NonAbacDoesNotCallSetCurrentRepositories", func(t *testing.T) { + assert := assert.New(t) + mockClient := &mocks.AcrCLIClientInterface{} + + // Mock non-ABAC registry + mockClient.On("IsAbac").Return(false) + + // Empty manifests result + emptyManifestsResult := &acr.Manifests{ + Response: autorest.Response{ + Response: &http.Response{ + StatusCode: 200, + }, + }, + ManifestsAttributes: &[]acr.ManifestAttributesBase{}, + } + + mockClient.On("GetAcrManifests", mock.Anything, "test-repo", "", "").Return(emptyManifestsResult, nil).Once() + + // Call purge with verbose=true but non-ABAC registry + deletedTagsCount, deletedManifestsCount, err := purge( + testCtx, + mockClient, + testLoginURL, + defaultPoolSize, + 0, // ago + 0, // keep + 60, // filterTimeout + true, // removeUntaggedManifests + true, // untaggedOnly + map[string]string{"test-repo": ".*"}, + false, // dryRun + false, // includeLocked + true, // verbose = true + ) + + assert.Equal(0, deletedTagsCount, "No tags should be deleted") + assert.Equal(0, deletedManifestsCount, "No manifests deleted") + assert.Nil(err, "Error should be nil") + // Verify SetCurrentRepositories was NOT called for non-ABAC + mockClient.AssertNotCalled(t, "SetCurrentRepositories", mock.Anything) + mockClient.AssertExpectations(t) + }) +} diff --git a/cmd/mocks/AcrCLIClientInterface.go b/cmd/mocks/AcrCLIClientInterface.go index d553d8d5..bac2214d 100644 --- a/cmd/mocks/AcrCLIClientInterface.go +++ b/cmd/mocks/AcrCLIClientInterface.go @@ -2,11 +2,15 @@ package mocks -import acr "github.com/Azure/acr-cli/acr" +import ( + acr "github.com/Azure/acr-cli/acr" -import autorest "github.com/Azure/go-autorest/autorest" -import context "context" -import mock "github.com/stretchr/testify/mock" + autorest "github.com/Azure/go-autorest/autorest" + + context "context" + + mock "github.com/stretchr/testify/mock" +) // AcrCLIClientInterface is an autogenerated mock type for the AcrCLIClientInterface type type AcrCLIClientInterface struct { @@ -196,3 +200,50 @@ func (_m *AcrCLIClientInterface) UpdateAcrManifestAttributes(ctx context.Context return r0, r1 } + +// IsAbac provides a mock function that returns whether the registry is ABAC-enabled +func (_m *AcrCLIClientInterface) IsAbac() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsTokenExpired provides a mock function for checking if token is expired +func (_m *AcrCLIClientInterface) IsTokenExpired() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// RefreshTokenForAbac provides a mock function for refreshing tokens with specific repository scopes +func (_m *AcrCLIClientInterface) RefreshTokenForAbac(ctx context.Context, repositories []string) error { + ret := _m.Called(ctx, repositories) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []string) error); ok { + r0 = rf(ctx, repositories) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetCurrentRepositories provides a mock function for setting the current repositories for ABAC token scope +func (_m *AcrCLIClientInterface) SetCurrentRepositories(repositories []string) { + _m.Called(repositories) +} diff --git a/internal/api/acrsdk.go b/internal/api/acrsdk.go index 0b3a7c11..fe9438c0 100644 --- a/internal/api/acrsdk.go +++ b/internal/api/acrsdk.go @@ -7,6 +7,7 @@ package api import ( "bytes" "context" + "fmt" "io/ioutil" "strings" "time" @@ -52,6 +53,13 @@ type AcrCLIClient struct { // accessTokenExp refers to the expiration time for the access token, it is in a unix time format represented by a // 64 bit integer. accessTokenExp int64 + // isAbac indicates whether this registry uses Attribute-Based Access Control (ABAC). + // ABAC registries require repository-level permissions instead of registry-wide wildcards. + // This is detected by checking if the refresh token contains the "aad_identity" claim. + isAbac bool + // currentRepositories holds the repository names for which the current ABAC token has permissions. + // This is used for dynamic token refresh when the token expires during operations. + currentRepositories []string } // LoginURL returns the FQDN for a registry. @@ -91,10 +99,29 @@ func newAcrCLIClientWithBasicAuth(loginURL string, username string, password str } // newAcrCLIClientWithBearerAuth creates a client that uses bearer token authentication. +// It detects if the registry is ABAC-enabled by checking for the "aad_identity" claim in the refresh token. +// For ABAC registries, it only requests catalog access initially; repository access is requested per-batch. +// For non-ABAC registries, it requests the traditional wildcard scope for all repositories. func newAcrCLIClientWithBearerAuth(loginURL string, refreshToken string) (AcrCLIClient, error) { + // Detect if this is an ABAC-enabled registry by checking for aad_identity claim + isAbac := hasAadIdentityClaim(refreshToken) + newAcrCLIClient := newAcrCLIClient(loginURL) + newAcrCLIClient.isAbac = isAbac + ctx := context.Background() - accessTokenResponse, err := newAcrCLIClient.AutorestClient.GetAcrAccessToken(ctx, loginURL, "registry:catalog:* repository:*:*", refreshToken) + var scope string + if isAbac { + // For ABAC registries, only request catalog access initially. + // Repository-level access will be requested on-demand per repository or batch. + // This is because ABAC registries cannot grant wildcard repository access. + scope = "registry:catalog:*" + } else { + // For non-ABAC registries, request full wildcard access to all repositories. + scope = "registry:catalog:* repository:*:*" + } + + accessTokenResponse, err := newAcrCLIClient.AutorestClient.GetAcrAccessToken(ctx, loginURL, scope, refreshToken) if err != nil { return newAcrCLIClient, err } @@ -154,25 +181,129 @@ func GetAcrCLIClientWithAuth(loginURL string, username string, password string, } // refreshAcrCLIClientToken obtains a new token and gets its expiration time. -func refreshAcrCLIClientToken(ctx context.Context, c *AcrCLIClient) error { - accessTokenResponse, err := c.AutorestClient.GetAcrAccessToken(ctx, c.loginURL, "repository:*:*", c.token.RefreshToken) +// For non-ABAC registries, this uses the wildcard scope. +// For ABAC registries, this uses the currentRepositories to refresh with the appropriate scope. +func refreshAcrCLIClientToken(ctx context.Context, c *AcrCLIClient, repoName string) error { + var scope string + if c.isAbac { + // For ABAC registries, build scope from currentRepositories and ensure repoName is included + repoSet := make(map[string]bool) + for _, repo := range c.currentRepositories { + repoSet[repo] = true + } + // Ensure the current repoName is in the set + if repoName != "" { + repoSet[repoName] = true + } + var scopeParts []string + for repo := range repoSet { + scopeParts = append(scopeParts, fmt.Sprintf("repository:%s:pull,delete,metadata_read,metadata_write", repo)) + } + if len(scopeParts) == 0 { + // Fallback: if no repositories specified, return error for ABAC + return errors.New("ABAC registry requires repository scope but none specified") + } + scope = strings.Join(scopeParts, " ") + } else { + // For non-ABAC registries, use the wildcard scope + scope = "repository:*:*" + } + + accessTokenResponse, err := c.AutorestClient.GetAcrAccessToken(ctx, c.loginURL, scope, c.token.RefreshToken) + if err != nil { + return err + } + token := &adal.Token{ + AccessToken: *accessTokenResponse.AccessToken, + RefreshToken: c.token.RefreshToken, + } + c.token = token + c.AutorestClient.Authorizer = autorest.NewBearerAuthorizer(token) + exp, err := getExpiration(token.AccessToken) if err != nil { return err } + c.accessTokenExp = exp + return nil +} + +// hasAadIdentityClaim checks if a JWT token contains the "aad_identity" claim. +// The presence of this claim indicates that the registry is ABAC-enabled. +// ABAC (Attribute-Based Access Control) registries grant permissions at the repository level, +// not at the registry level, so wildcard scopes like "repository:*:*" will not work. +func hasAadIdentityClaim(tokenString string) bool { + parser := jwt.Parser{SkipClaimsValidation: true} + mapC := jwt.MapClaims{} + // We only need to check for the claim, not verify the signature + _, _, err := parser.ParseUnverified(tokenString, mapC) + if err != nil { + return false + } + _, ok := mapC["aad_identity"] + return ok +} + +// SetCurrentRepositories sets the repositories for which ABAC token refresh should request permissions. +// This should be called before performing operations on repositories in ABAC-enabled registries. +// When the token expires, the refresh will automatically request permissions for these repositories. +func (c *AcrCLIClient) SetCurrentRepositories(repositories []string) { + c.currentRepositories = repositories +} + +// RefreshTokenForAbac obtains a new access token scoped to specific repositories. +// This is used for ABAC-enabled registries where wildcard repository access is not allowed. +// The token will include permissions for all specified repositories. +// It also updates currentRepositories so subsequent automatic refreshes use the same scope. +// +// Parameters: +// - repositories: list of repository names to request access for +// +// The scope format is: "registry:catalog:* repository::pull repository::delete ..." +// This allows batching multiple repositories into a single token request for efficiency. +func (c *AcrCLIClient) RefreshTokenForAbac(ctx context.Context, repositories []string) error { + if c.token == nil { + return errors.New("no refresh token available for ABAC token refresh") + } + + // Update the current repositories so automatic refreshes use the same scope + c.currentRepositories = repositories + + // Build the scope string for all requested repositories. + // Each repository needs pull, delete, and metadata permissions for purge operations. + // Format: "repository:repo1:pull,delete,metadata_read,metadata_write repository:repo2:pull,delete,metadata_read,metadata_write ..." + var scopeParts []string + for _, repo := range repositories { + scopeParts = append(scopeParts, fmt.Sprintf("repository:%s:pull,delete,metadata_read,metadata_write", repo)) + } + scope := strings.Join(scopeParts, " ") + + accessTokenResponse, err := c.AutorestClient.GetAcrAccessToken(ctx, c.loginURL, scope, c.token.RefreshToken) + if err != nil { + return errors.Wrap(err, "failed to refresh token for ABAC repositories") + } + token := &adal.Token{ AccessToken: *accessTokenResponse.AccessToken, RefreshToken: c.token.RefreshToken, } c.token = token c.AutorestClient.Authorizer = autorest.NewBearerAuthorizer(token) + exp, err := getExpiration(token.AccessToken) if err != nil { return err } c.accessTokenExp = exp + return nil } +// IsAbac returns true if this client is connected to an ABAC-enabled registry. +// ABAC registries require repository-level token scopes instead of wildcard scopes. +func (c *AcrCLIClient) IsAbac() bool { + return c.isAbac +} + // getExpiration is used to obtain the expiration out of a jwt token. func getExpiration(token string) (int64, error) { parser := jwt.Parser{SkipClaimsValidation: true} @@ -198,10 +329,17 @@ func (c *AcrCLIClient) isExpired() bool { return (time.Now().Add(5 * time.Minute)).Unix() > c.accessTokenExp } +// IsTokenExpired returns true when the token is expired or close to expiring. +// This is the public version of isExpired for use by callers that need to check +// token expiration before making batched ABAC token refresh requests. +func (c *AcrCLIClient) IsTokenExpired() bool { + return c.isExpired() +} + // GetAcrTags list the tags of a repository with their attributes. func (c *AcrCLIClient) GetAcrTags(ctx context.Context, repoName string, orderBy string, last string) (*acrapi.RepositoryTagsType, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -216,7 +354,7 @@ func (c *AcrCLIClient) GetAcrTags(ctx context.Context, repoName string, orderBy // DeleteAcrTag deletes the tag by reference. func (c *AcrCLIClient) DeleteAcrTag(ctx context.Context, repoName string, reference string) (*autorest.Response, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -230,7 +368,7 @@ func (c *AcrCLIClient) DeleteAcrTag(ctx context.Context, repoName string, refere // GetAcrManifests list all the manifest in a repository with their attributes. func (c *AcrCLIClient) GetAcrManifests(ctx context.Context, repoName string, orderBy string, last string) (*acrapi.Manifests, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -244,7 +382,7 @@ func (c *AcrCLIClient) GetAcrManifests(ctx context.Context, repoName string, ord // DeleteManifest deletes a manifest using the digest as a reference. func (c *AcrCLIClient) DeleteManifest(ctx context.Context, repoName string, reference string) (*autorest.Response, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -259,7 +397,7 @@ func (c *AcrCLIClient) DeleteManifest(ctx context.Context, repoName string, refe // This is used when a manifest list is wanted, first the bytes are obtained and then unmarshalled into a new struct. func (c *AcrCLIClient) GetManifest(ctx context.Context, repoName string, reference string) ([]byte, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -299,7 +437,7 @@ func (c *AcrCLIClient) GetManifest(ctx context.Context, repoName string, referen // GetAcrManifestAttributes gets the attributes of a manifest. func (c *AcrCLIClient) GetAcrManifestAttributes(ctx context.Context, repoName string, reference string) (*acrapi.ManifestAttributes, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -313,7 +451,7 @@ func (c *AcrCLIClient) GetAcrManifestAttributes(ctx context.Context, repoName st // UpdateAcrTagAttributes updates tag attributes to enable/disable deletion and writing. func (c *AcrCLIClient) UpdateAcrTagAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -327,7 +465,7 @@ func (c *AcrCLIClient) UpdateAcrTagAttributes(ctx context.Context, repoName stri // UpdateAcrManifestAttributes updates manifest attributes to enable/disable deletion and writing. func (c *AcrCLIClient) UpdateAcrManifestAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -348,4 +486,13 @@ type AcrCLIClientInterface interface { GetAcrManifestAttributes(ctx context.Context, repoName string, reference string) (*acrapi.ManifestAttributes, error) UpdateAcrTagAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) UpdateAcrManifestAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) + + // IsAbac returns true if the registry uses Attribute-Based Access Control. + IsAbac() bool + // IsTokenExpired returns true if the access token is expired or close to expiring. + IsTokenExpired() bool + // RefreshTokenForAbac refreshes the access token with scopes for specific repositories. + RefreshTokenForAbac(ctx context.Context, repositories []string) error + // SetCurrentRepositories sets the repositories for ABAC token refresh scope. + SetCurrentRepositories(repositories []string) } diff --git a/internal/api/acrsdk_test.go b/internal/api/acrsdk_test.go index 60224873..30d9c266 100644 --- a/internal/api/acrsdk_test.go +++ b/internal/api/acrsdk_test.go @@ -4,6 +4,7 @@ package api import ( + "context" "encoding/base64" "fmt" "net/http" @@ -234,3 +235,209 @@ func TestGetAcrCLIClientWithAuth(t *testing.T) { }) } } + +// TestHasAadIdentityClaim tests the ABAC detection function +func TestHasAadIdentityClaim(t *testing.T) { + tests := []struct { + name string + token string + expected bool + }{ + { + name: "token with aad_identity claim - ABAC enabled", + // JWT with {"aad_identity": "user@example.com"} in payload + token: strings.Join([]string{ + base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)), + base64.RawURLEncoding.EncodeToString([]byte(`{"exp":1563910981,"aad_identity":"user@example.com"}`)), + "", + }, "."), + expected: true, + }, + { + name: "token without aad_identity claim - non-ABAC", + // JWT without aad_identity + token: strings.Join([]string{ + base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)), + base64.RawURLEncoding.EncodeToString([]byte(`{"exp":1563910981}`)), + "", + }, "."), + expected: false, + }, + { + name: "invalid token", + token: "not-a-valid-jwt", + expected: false, + }, + { + name: "empty token", + token: "", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := hasAadIdentityClaim(tt.token) + if result != tt.expected { + t.Errorf("hasAadIdentityClaim() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// TestAcrCLIClientIsAbac tests the IsAbac method +func TestAcrCLIClientIsAbac(t *testing.T) { + tests := []struct { + name string + isAbac bool + expected bool + }{ + { + name: "ABAC enabled client", + isAbac: true, + expected: true, + }, + { + name: "non-ABAC client", + isAbac: false, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := AcrCLIClient{ + isAbac: tt.isAbac, + } + result := client.IsAbac() + if result != tt.expected { + t.Errorf("IsAbac() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// TestRefreshAcrCLIClientTokenAbac tests the ABAC-aware token refresh path. +// This ensures that when SDK methods (GetAcrTags, DeleteAcrTag, etc.) detect token expiry, +// the refresh uses repository-scoped tokens for ABAC registries instead of wildcard scope. +func TestRefreshAcrCLIClientTokenAbac(t *testing.T) { + testAccessToken := strings.Join([]string{ + base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)), + base64.RawURLEncoding.EncodeToString([]byte(`{"exp":9999999999}`)), // Far future expiry + "", + }, ".") + testRefreshToken := "test/refresh/token" + + tests := []struct { + name string + isAbac bool + currentRepositories []string + repoName string + expectedScopePrefix string // What the scope should start with or contain + shouldContainRepo string // Specific repo that must be in scope + wantErr bool + }{ + { + name: "ABAC with currentRepositories and repoName - includes both", + isAbac: true, + currentRepositories: []string{"repo1", "repo2"}, + repoName: "repo3", + expectedScopePrefix: "repository:", + shouldContainRepo: "repo3", + wantErr: false, + }, + { + name: "ABAC with only repoName - uses repoName for scope", + isAbac: true, + currentRepositories: []string{}, + repoName: "my-repo", + expectedScopePrefix: "repository:my-repo:", + shouldContainRepo: "my-repo", + wantErr: false, + }, + { + name: "ABAC with repoName already in currentRepositories - no duplicate", + isAbac: true, + currentRepositories: []string{"repo1", "repo2"}, + repoName: "repo1", + expectedScopePrefix: "repository:", + shouldContainRepo: "repo1", + wantErr: false, + }, + { + name: "ABAC with no repos and no repoName - returns error", + isAbac: true, + currentRepositories: []string{}, + repoName: "", + wantErr: true, + }, + { + name: "Non-ABAC registry - uses wildcard scope", + isAbac: false, + currentRepositories: []string{}, + repoName: "any-repo", + expectedScopePrefix: "repository:*:*", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedScope string + + // Create a test server that captures the scope parameter + as := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusNotFound) + return + } + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + capturedScope = r.PostForm.Get("scope") + // Return a valid access token + fmt.Fprintf(w, `{"access_token":%q}`, testAccessToken) + })) + defer as.Close() + + // Create client with test configuration + client := newAcrCLIClient(as.URL) + client.isAbac = tt.isAbac + client.currentRepositories = tt.currentRepositories + client.token = &adal.Token{ + AccessToken: testAccessToken, + RefreshToken: testRefreshToken, + } + // Replace transport to trust test server + client.AutorestClient.Sender = as.Client() + + // Call refreshAcrCLIClientToken + err := refreshAcrCLIClientToken(context.Background(), &client, tt.repoName) + + // Check error expectation + if (err != nil) != tt.wantErr { + t.Errorf("refreshAcrCLIClientToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + // Verify the scope was correct + if tt.expectedScopePrefix != "" && !strings.Contains(capturedScope, tt.expectedScopePrefix) { + t.Errorf("Expected scope to contain %q, got %q", tt.expectedScopePrefix, capturedScope) + } + + if tt.shouldContainRepo != "" && !strings.Contains(capturedScope, tt.shouldContainRepo) { + t.Errorf("Expected scope to contain repo %q, got %q", tt.shouldContainRepo, capturedScope) + } + + // For ABAC, verify we're NOT using wildcard + if tt.isAbac && strings.Contains(capturedScope, "repository:*:*") { + t.Errorf("ABAC refresh should NOT use wildcard scope, got %q", capturedScope) + } + }) + } +}