diff --git a/br/cmd/br/restore.go b/br/cmd/br/restore.go index 916ed3b703933..f991163813af2 100644 --- a/br/cmd/br/restore.go +++ b/br/cmd/br/restore.go @@ -74,14 +74,14 @@ func runRestoreCommand(command *cobra.Command, cmdName string) error { if err := task.RunRestore(GetDefaultContext(), tidbGlue, cmdName, &cfg); err != nil { log.Error("failed to restore", zap.Error(err)) - printWorkaroundOnFullRestoreError(command, err) + printWorkaroundOnFullRestoreError(err) return errors.Trace(err) } return nil } // print workaround when we met not fresh or incompatible cluster error on full cluster restore -func printWorkaroundOnFullRestoreError(command *cobra.Command, err error) { +func printWorkaroundOnFullRestoreError(err error) { if !errors.ErrorEqual(err, berrors.ErrRestoreNotFreshCluster) && !errors.ErrorEqual(err, berrors.ErrRestoreIncompatibleSys) { return diff --git a/br/pkg/checkpoint/checkpoint.go b/br/pkg/checkpoint/checkpoint.go index 4b397a60e5eeb..80f597eb7987a 100644 --- a/br/pkg/checkpoint/checkpoint.go +++ b/br/pkg/checkpoint/checkpoint.go @@ -724,7 +724,7 @@ func walkCheckpointFile[K KeyType, V ValueType]( pastDureTime = checkpointData.DureTime } for _, meta := range checkpointData.RangeGroupMetas { - decryptContent, err := metautil.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv) + decryptContent, err := utils.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv) if err != nil { return errors.Trace(err) } diff --git a/br/pkg/conn/conn.go b/br/pkg/conn/conn.go index cdb81a011c8a5..84122c596280e 100644 --- a/br/pkg/conn/conn.go +++ b/br/pkg/conn/conn.go @@ -296,8 +296,8 @@ func (mgr *Mgr) Close() { mgr.PdController.Close() } -// GetTS gets current ts from pd. -func (mgr *Mgr) GetTS(ctx context.Context) (uint64, error) { +// GetCurrentTsFromPd gets current ts from pd. +func (mgr *Mgr) GetCurrentTsFromPd(ctx context.Context) (uint64, error) { p, l, err := mgr.GetPDClient().GetTS(ctx) if err != nil { return 0, errors.Trace(err) diff --git a/br/pkg/encryption/BUILD.bazel b/br/pkg/encryption/BUILD.bazel new file mode 100644 index 0000000000000..1f45ff810430e --- /dev/null +++ b/br/pkg/encryption/BUILD.bazel @@ -0,0 +1,15 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "encryption", + srcs = ["manager.go"], + importpath = "github.com/pingcap/tidb/br/pkg/encryption", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/encryption/master_key", + "//br/pkg/utils", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", + ], +) diff --git a/br/pkg/encryption/manager.go b/br/pkg/encryption/manager.go new file mode 100644 index 0000000000000..cf20e3cddb381 --- /dev/null +++ b/br/pkg/encryption/manager.go @@ -0,0 +1,82 @@ +package encryption + +import ( + "context" + + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" + encryption "github.com/pingcap/tidb/br/pkg/encryption/master_key" + "github.com/pingcap/tidb/br/pkg/utils" +) + +type Manager struct { + cipherInfo *backuppb.CipherInfo + masterKeyBackends *encryption.MultiMasterKeyBackend + encryptionMethod *encryptionpb.EncryptionMethod +} + +func NewManager(cipherInfo *backuppb.CipherInfo, masterKeyConfigs *backuppb.MasterKeyConfig) (*Manager, error) { + // should never happen since config has default + if cipherInfo == nil || masterKeyConfigs == nil { + return nil, errors.New("cipherInfo or masterKeyConfigs is nil") + } + + if cipherInfo.CipherType != encryptionpb.EncryptionMethod_PLAINTEXT { + return &Manager{ + cipherInfo: cipherInfo, + masterKeyBackends: nil, + encryptionMethod: nil, + }, nil + } + + if masterKeyConfigs.EncryptionType != encryptionpb.EncryptionMethod_PLAINTEXT { + masterKeyBackends, err := encryption.NewMultiMasterKeyBackend(masterKeyConfigs.GetMasterKeys()) + if err != nil { + return nil, errors.Trace(err) + } + return &Manager{ + cipherInfo: nil, + masterKeyBackends: masterKeyBackends, + encryptionMethod: &masterKeyConfigs.EncryptionType, + }, nil + } + return nil, nil +} + +func (m *Manager) Decrypt(ctx context.Context, content []byte, fileEncryptionInfo *encryptionpb.FileEncryptionInfo) ([]byte, error) { + switch mode := fileEncryptionInfo.Mode.(type) { + case *encryptionpb.FileEncryptionInfo_PlainTextDataKey: + if m.cipherInfo == nil { + return nil, errors.New("plaintext data key info is required but not set") + } + decryptedContent, err := utils.Decrypt(content, m.cipherInfo, fileEncryptionInfo.FileIv) + if err != nil { + return nil, errors.Annotate(err, "failed to decrypt content using plaintext data key") + } + return decryptedContent, nil + case *encryptionpb.FileEncryptionInfo_MasterKeyBased: + encryptedContents := fileEncryptionInfo.GetMasterKeyBased().DataKeyEncryptedContent + if encryptedContents == nil || len(encryptedContents) == 0 { + return nil, errors.New("should contain at least one encrypted data key") + } + // pick first one, the list is for future expansion of multiple encrypted data keys by different master key backend + encryptedContent := encryptedContents[0] + decryptedDataKey, err := m.masterKeyBackends.Decrypt(ctx, encryptedContent) + if err != nil { + return nil, errors.Annotate(err, "failed to decrypt data key using master key") + } + + cipherInfo := backuppb.CipherInfo{ + CipherType: fileEncryptionInfo.EncryptionMethod, + CipherKey: decryptedDataKey, + } + decryptedContent, err := utils.Decrypt(content, &cipherInfo, fileEncryptionInfo.FileIv) + if err != nil { + return nil, errors.Annotate(err, "failed to decrypt content using decrypted data key") + } + return decryptedContent, nil + default: + return nil, errors.Errorf("internal error: unsupported encryption mode type %T", mode) + } +} diff --git a/br/pkg/encryption/master_key/BUILD.bazel b/br/pkg/encryption/master_key/BUILD.bazel new file mode 100644 index 0000000000000..cd0ded05373c3 --- /dev/null +++ b/br/pkg/encryption/master_key/BUILD.bazel @@ -0,0 +1,41 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "master_key", + srcs = [ + "common.go", + "file_backend.go", + "kms_backend.go", + "master_key.go", + "mem_backend.go", + "multi_master_key_backend.go", + ], + importpath = "github.com/pingcap/tidb/br/pkg/encryption/master_key", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/kms:aws", + "//br/pkg/utils", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/encryptionpb", + "@com_github_pingcap_log//:log", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "master_key_test", + srcs = [ + "file_backend_test.go", + "kms_backend_test.go", + "mem_backend_test.go", + "multi_master_key_backend_test.go", + ], + embed = [":master_key"], + deps = [ + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/encryptionpb", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//mock", + "@com_github_stretchr_testify//require", + ], +) diff --git a/br/pkg/encryption/master_key/common.go b/br/pkg/encryption/master_key/common.go new file mode 100644 index 0000000000000..a2293876990ee --- /dev/null +++ b/br/pkg/encryption/master_key/common.go @@ -0,0 +1,29 @@ +package encryption + +import ( + "crypto/rand" + "encoding/binary" + "time" +) + +// must keep it same with the constants in TiKV implementation +const ( + MetadataKeyMethod string = "method" + MetadataKeyIv string = "iv" + MetadataKeyAesGcmTag string = "aes_gcm_tag" + MetadataKeyKmsVendor string = "kms_vendor" + MetadataKeyKmsCiphertextKey string = "kms_ciphertext_key" + MetadataMethodAes256Gcm string = "aes256-gcm" +) + +type IV [12]byte + +func NewIV() IV { + var iv IV + binary.BigEndian.PutUint64(iv[:8], uint64(time.Now().UnixNano())) + // Fill the remaining 4 bytes with random data + if _, err := rand.Read(iv[8:]); err != nil { + panic(err) // Handle this error appropriately in production code + } + return iv +} diff --git a/br/pkg/encryption/master_key/file_backend.go b/br/pkg/encryption/master_key/file_backend.go new file mode 100644 index 0000000000000..b57104311a853 --- /dev/null +++ b/br/pkg/encryption/master_key/file_backend.go @@ -0,0 +1,58 @@ +package encryption + +import ( + "context" + "encoding/hex" + "os" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" +) + +const AesGcmKeyLen = 32 // AES-256 key length + +type FileBackend struct { + memCache *MemAesGcmBackend +} + +func createFileBackend(keyPath string) (*FileBackend, error) { + // FileBackend uses AES-256-GCM + keyLen := AesGcmKeyLen + + content, err := os.ReadFile(keyPath) + if err != nil { + return nil, errors.Annotate(err, "failed to read master key file from disk") + } + + fileLen := len(content) + expectedLen := keyLen*2 + 1 // hex-encoded key + newline + + if fileLen != expectedLen { + return nil, errors.Errorf("mismatch master key file size, expected %d, actual %d", expectedLen, fileLen) + } + + if content[fileLen-1] != '\n' { + return nil, errors.Errorf("master key file should end with newline") + } + + key, err := hex.DecodeString(string(content[:fileLen-1])) + if err != nil { + return nil, errors.Annotate(err, "failed to decode master key from file") + } + + backend, err := NewMemAesGcmBackend(key) + if err != nil { + return nil, errors.Annotate(err, "failed to create MemAesGcmBackend") + } + + return &FileBackend{memCache: backend}, nil +} + +func (f *FileBackend) Encrypt(ctx context.Context, plaintext []byte) (*encryptionpb.EncryptedContent, error) { + iv := NewIV() + return f.memCache.EncryptContent(ctx, plaintext, iv) +} + +func (f *FileBackend) Decrypt(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) { + return f.memCache.DecryptContent(ctx, content) +} diff --git a/br/pkg/encryption/master_key/file_backend_test.go b/br/pkg/encryption/master_key/file_backend_test.go new file mode 100644 index 0000000000000..09e466f07e977 --- /dev/null +++ b/br/pkg/encryption/master_key/file_backend_test.go @@ -0,0 +1,103 @@ +package encryption + +import ( + "context" + "encoding/hex" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TempKeyFile represents a temporary key file for testing +type TempKeyFile struct { + Path string + file *os.File +} + +// Cleanup closes and removes the temporary file +func (tkf *TempKeyFile) Cleanup() { + if tkf.file != nil { + tkf.file.Close() + } + os.Remove(tkf.Path) +} + +// createMasterKeyFile creates a temporary master key file for testing +func createMasterKeyFile() (*TempKeyFile, error) { + tempFile, err := os.CreateTemp("", "test_key_*.txt") + if err != nil { + return nil, err + } + + _, err = tempFile.WriteString("c3d99825f2181f4808acd2068eac7441a65bd428f14d2aab43fefc0129091139\n") + if err != nil { + tempFile.Close() + os.Remove(tempFile.Name()) + return nil, err + } + + return &TempKeyFile{ + Path: tempFile.Name(), + file: tempFile, + }, nil +} + +func TestFileBackendAes256Gcm(t *testing.T) { + pt, err := hex.DecodeString("25431587e9ecffc7c37f8d6d52a9bc3310651d46fb0e3bad2726c8f2db653749") + require.NoError(t, err) + ct, err := hex.DecodeString("84e5f23f95648fa247cb28eef53abec947dbf05ac953734618111583840bd980") + require.NoError(t, err) + iv, err := hex.DecodeString("cafabd9672ca6c79a2fbdc22") + require.NoError(t, err) + + tempKeyFile, err := createMasterKeyFile() + require.NoError(t, err) + defer tempKeyFile.Cleanup() + + backend, err := createFileBackend(tempKeyFile.Path) + require.NoError(t, err) + + ctx := context.Background() + encryptedContent, err := backend.memCache.EncryptContent(ctx, pt, IV(iv)) + require.NoError(t, err) + assert.Equal(t, ct, encryptedContent.Content) + + plaintext, err := backend.Decrypt(ctx, encryptedContent) + require.NoError(t, err) + assert.Equal(t, pt, plaintext) +} + +func TestFileBackendAuthenticate(t *testing.T) { + pt := []byte{1, 2, 3} + + tempKeyFile, err := createMasterKeyFile() + require.NoError(t, err) + defer tempKeyFile.Cleanup() + + backend, err := createFileBackend(tempKeyFile.Path) + require.NoError(t, err) + + ctx := context.Background() + encryptedContent, err := backend.Encrypt(ctx, pt) + require.NoError(t, err) + + plaintext, err := backend.Decrypt(ctx, encryptedContent) + require.NoError(t, err) + assert.Equal(t, pt, plaintext) + + // Test checksum mismatch + encryptedContent1 := *encryptedContent + encryptedContent1.Metadata[MetadataKeyAesGcmTag][0] ^= 0xFF + _, err = backend.Decrypt(ctx, &encryptedContent1) + assert.Error(t, err) + assert.Contains(t, err.Error(), wrongMasterKey) + + // Test checksum not found + encryptedContent2 := *encryptedContent + delete(encryptedContent2.Metadata, MetadataKeyAesGcmTag) + _, err = backend.Decrypt(ctx, &encryptedContent2) + assert.Error(t, err) + assert.Contains(t, err.Error(), gcmTagNotFound) +} diff --git a/br/pkg/encryption/master_key/kms_backend.go b/br/pkg/encryption/master_key/kms_backend.go new file mode 100644 index 0000000000000..241a4965cf35f --- /dev/null +++ b/br/pkg/encryption/master_key/kms_backend.go @@ -0,0 +1,89 @@ +package encryption + +import ( + "context" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/tidb/br/pkg/kms" + "github.com/pingcap/tidb/br/pkg/utils" +) + +type CachedKeys struct { + encryptionBackend *MemAesGcmBackend + cachedCiphertextKey *kms.EncryptedKey +} + +type KmsBackend struct { + state struct { + sync.Mutex + cached *CachedKeys + } + kmsProvider kms.Provider +} + +func NewKmsBackend(kmsProvider kms.Provider) (*KmsBackend, error) { + return &KmsBackend{ + kmsProvider: kmsProvider, + }, nil +} + +func (k *KmsBackend) decryptContent(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) { + vendorName := k.kmsProvider.Name() + if val, ok := content.Metadata[MetadataKeyKmsVendor]; !ok { + return nil, errors.New("wrong master key: missing KMS vendor") + } else if string(val) != vendorName { + return nil, errors.Errorf("KMS vendor mismatch expect %s got %s", vendorName, string(val)) + } + + ciphertextKeyBytes, ok := content.Metadata[MetadataKeyKmsCiphertextKey] + if !ok { + return nil, errors.New("KMS ciphertext key not found") + } + ciphertextKey, err := kms.NewEncryptedKey(ciphertextKeyBytes) + if err != nil { + return nil, errors.Annotate(err, "failed to create encrypted key") + } + + k.state.Lock() + defer k.state.Unlock() + + if k.state.cached != nil && k.state.cached.cachedCiphertextKey.Equal(&ciphertextKey) { + return k.state.cached.encryptionBackend.DecryptContent(ctx, content) + } + + // piggyback on NewDownloadSSTBackoffer, a refactor is ongoing to remove all the backoffers + // so user don't need to write a backoffer for every type + decryptedKey, err := utils.WithRetryV2(ctx, utils.NewDownloadSSTBackoffer(), func(ctx context.Context) ([]byte, error) { + return k.kmsProvider.DecryptDataKey(ctx, ciphertextKey) + }) + if err != nil { + return nil, errors.Annotate(err, "decrypt encrypted key failed") + + } + + plaintextKey, err := kms.NewPlainKey(decryptedKey, kms.CryptographyTypeAesGcm256) + if err != nil { + return nil, errors.Annotate(err, "decrypt encrypted key failed") + } + dataKey := kms.DataKeyPair{ + Encrypted: &ciphertextKey, + Plaintext: plaintextKey, + } + backend, err := NewMemAesGcmBackend(dataKey.Plaintext.Key()) + if err != nil { + return nil, errors.Annotate(err, "failed to create MemAesGcmBackend") + } + + k.state.cached = &CachedKeys{ + encryptionBackend: backend, + cachedCiphertextKey: &ciphertextKey, + } + + return k.state.cached.encryptionBackend.DecryptContent(ctx, content) +} + +func (k *KmsBackend) Decrypt(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) { + return k.decryptContent(ctx, content) +} diff --git a/br/pkg/encryption/master_key/kms_backend_test.go b/br/pkg/encryption/master_key/kms_backend_test.go new file mode 100644 index 0000000000000..8bd0b95200ad1 --- /dev/null +++ b/br/pkg/encryption/master_key/kms_backend_test.go @@ -0,0 +1,105 @@ +package encryption + +import ( + "context" + "testing" + + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/require" +) + +type mockKmsProvider struct { + name string + decryptCounter int +} + +func (m *mockKmsProvider) Name() string { + return m.name +} + +func (m *mockKmsProvider) DecryptDataKey(_ctx context.Context, _encryptedKey []byte) ([]byte, error) { + m.decryptCounter++ + return []byte("decrypted_key"), nil +} + +func TestKmsBackendDecrypt(t *testing.T) { + ctx := context.Background() + mockProvider := &mockKmsProvider{name: "mock_kms"} + backend, err := NewKmsBackend(mockProvider) + require.NoError(t, err) + + ciphertextKey := []byte("ciphertext_key") + content := &encryptionpb.EncryptedContent{ + Metadata: map[string][]byte{ + MetadataKeyKmsVendor: []byte("mock_kms"), + MetadataKeyKmsCiphertextKey: ciphertextKey, + }, + Content: []byte("encrypted_content"), + } + + // First decryption + _, err = backend.Decrypt(ctx, content) + require.NoError(t, err) + require.Equal(t, 1, mockProvider.decryptCounter, "KMS provider should be called once") + + // Second decryption with the same ciphertext key (should use cache) + _, err = backend.Decrypt(ctx, content) + require.NoError(t, err) + require.Equal(t, 1, mockProvider.decryptCounter, "KMS provider should not be called again") + + // Third decryption with a different ciphertext key + content.Metadata[MetadataKeyKmsCiphertextKey] = []byte("new_ciphertext_key") + _, err = backend.Decrypt(ctx, content) + require.NoError(t, err) + require.Equal(t, 2, mockProvider.decryptCounter, "KMS provider should be called again for a new key") +} + +func TestKmsBackendDecryptErrors(t *testing.T) { + ctx := context.Background() + mockProvider := &mockKmsProvider{name: "mock_kms"} + backend, err := NewKmsBackend(mockProvider) + require.NoError(t, err) + + testCases := []struct { + name string + content *encryptionpb.EncryptedContent + errMsg string + }{ + { + name: "missing KMS vendor", + content: &encryptionpb.EncryptedContent{ + Metadata: map[string][]byte{ + MetadataKeyKmsCiphertextKey: []byte("ciphertext_key"), + }, + }, + errMsg: "wrong master key: missing KMS vendor", + }, + { + name: "KMS vendor mismatch", + content: &encryptionpb.EncryptedContent{ + Metadata: map[string][]byte{ + MetadataKeyKmsVendor: []byte("wrong_kms"), + MetadataKeyKmsCiphertextKey: []byte("ciphertext_key"), + }, + }, + errMsg: "KMS vendor mismatch expect mock_kms got wrong_kms", + }, + { + name: "missing ciphertext key", + content: &encryptionpb.EncryptedContent{ + Metadata: map[string][]byte{ + MetadataKeyKmsVendor: []byte("mock_kms"), + }, + }, + errMsg: "KMS ciphertext key not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := backend.Decrypt(ctx, tc.content) + require.Error(t, err) + require.Contains(t, err.Error(), tc.errMsg) + }) + } +} diff --git a/br/pkg/encryption/master_key/master_key.go b/br/pkg/encryption/master_key/master_key.go new file mode 100644 index 0000000000000..aa22992fc2e43 --- /dev/null +++ b/br/pkg/encryption/master_key/master_key.go @@ -0,0 +1,74 @@ +package encryption + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/kms" + "go.uber.org/zap" +) + +const ( + StorageVendorNameAWS = "aws" + StorageVendorNameAzure = "azure" + StorageVendorNameGCP = "gcp" +) + +// Backend is an interface that defines the methods required for an encryption backend. +type Backend interface { + // Decrypt takes an EncryptedContent and returns the decrypted plaintext as a byte slice or an error. + Decrypt(ctx context.Context, ciphertext *encryptionpb.EncryptedContent) ([]byte, error) +} + +func CreateBackend(config *encryptionpb.MasterKey) (Backend, error) { + if config == nil { + return nil, errors.Errorf("master key config is nil") + } + + switch backend := config.Backend.(type) { + case *encryptionpb.MasterKey_Plaintext: + // no need to create backend for plaintext + return nil, nil + case *encryptionpb.MasterKey_File: + fileBackend, err := createFileBackend(backend.File.Path) + if err != nil { + return nil, errors.Annotate(err, "master key config is nil") + } + return fileBackend, nil + case *encryptionpb.MasterKey_Kms: + return createCloudBackend(backend.Kms) + default: + return nil, errors.New("unknown master key backend type") + } +} + +func createCloudBackend(config *encryptionpb.MasterKeyKms) (Backend, error) { + log.Info("creating cloud KMS backend", + zap.String("region", config.GetRegion()), + zap.String("endpoint", config.GetEndpoint()), + zap.String("key_id", config.GetKeyId()), + zap.String("Vendor", config.GetVendor())) + + switch config.Vendor { + case StorageVendorNameAWS: + kmsProvider, err := kms.NewAwsKms(config) + if err != nil { + return nil, errors.Annotate(err, "new AWS KMS") + } + return NewKmsBackend(kmsProvider) + + case StorageVendorNameAzure: + return nil, errors.Errorf("not implemented Azure KMS") + case StorageVendorNameGCP: + kmsProvider, err := kms.NewGcpKms(config) + if err != nil { + return nil, errors.Annotate(err, "new GCP KMS") + } + return NewKmsBackend(kmsProvider) + + default: + return nil, errors.Errorf("vendor not found: %s", config.Vendor) + } +} diff --git a/br/pkg/encryption/master_key/mem_backend.go b/br/pkg/encryption/master_key/mem_backend.go new file mode 100644 index 0000000000000..ac8236cea2d49 --- /dev/null +++ b/br/pkg/encryption/master_key/mem_backend.go @@ -0,0 +1,94 @@ +package encryption + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "fmt" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/tidb/br/pkg/kms" +) + +const ( + gcmTagNotFound = "aes gcm tag not found" + wrongMasterKey = "wrong master key" +) + +type MemAesGcmBackend struct { + key *kms.PlainKey +} + +func NewMemAesGcmBackend(key []byte) (*MemAesGcmBackend, error) { + plainKey, err := kms.NewPlainKey(key, kms.CryptographyTypeAesGcm256) + if err != nil { + return nil, errors.Annotate(err, "failed to create new mem aes gcm backend") + } + return &MemAesGcmBackend{ + key: plainKey, + }, nil +} + +func (m *MemAesGcmBackend) EncryptContent(ctx context.Context, plaintext []byte, iv IV) (*encryptionpb.EncryptedContent, error) { + content := encryptionpb.EncryptedContent{ + Metadata: make(map[string][]byte), + } + content.Metadata[MetadataKeyMethod] = []byte(MetadataMethodAes256Gcm) + content.Metadata[MetadataKeyIv] = iv[:] + + block, err := aes.NewCipher(m.key.Key()) + if err != nil { + return nil, err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + ciphertext := aesgcm.Seal(nil, iv[:], plaintext, nil) + content.Content = ciphertext[:len(ciphertext)-aesgcm.Overhead()] + content.Metadata[MetadataKeyAesGcmTag] = ciphertext[len(ciphertext)-aesgcm.Overhead():] + + return &content, nil +} + +func (m *MemAesGcmBackend) DecryptContent(ctx context.Context, content *encryptionpb.EncryptedContent) ([]byte, error) { + method, ok := content.Metadata[MetadataKeyMethod] + if !ok { + return nil, fmt.Errorf("metadata %s not found", MetadataKeyMethod) + } + if string(method) != MetadataMethodAes256Gcm { + return nil, errors.Errorf("encryption method mismatch, expected %s vs actual %s", + MetadataMethodAes256Gcm, method) + } + + ivValue, ok := content.Metadata[MetadataKeyIv] + if !ok { + return nil, errors.Errorf("metadata %s not found", MetadataKeyIv) + } + var iv IV + copy(iv[:], ivValue) + + tag, ok := content.Metadata[MetadataKeyAesGcmTag] + if !ok { + return nil, errors.New("aes gcm tag not found") + } + + block, err := aes.NewCipher(m.key.Key()) + if err != nil { + return nil, err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + ciphertext := append(content.Content, tag...) + plaintext, err := aesgcm.Open(nil, iv[:], ciphertext, nil) + if err != nil { + return nil, errors.Annotate(err, wrongMasterKey+" :decrypt in GCM mode failed") + } + + return plaintext, nil +} diff --git a/br/pkg/encryption/master_key/mem_backend_test.go b/br/pkg/encryption/master_key/mem_backend_test.go new file mode 100644 index 0000000000000..6124c10510e11 --- /dev/null +++ b/br/pkg/encryption/master_key/mem_backend_test.go @@ -0,0 +1,125 @@ +package encryption + +import ( + "bytes" + "context" + "testing" +) + +func TestNewMemAesGcmBackend(t *testing.T) { + key := make([]byte, 32) // 256-bit key + _, err := NewMemAesGcmBackend(key) + if err != nil { + t.Fatalf("Failed to create MemAesGcmBackend: %v", err) + } + + shortKey := make([]byte, 16) + _, err = NewMemAesGcmBackend(shortKey) + if err == nil { + t.Fatal("Expected error for short key, got nil") + } +} + +func TestEncryptDecrypt(t *testing.T) { + key := make([]byte, 32) + backend, err := NewMemAesGcmBackend(key) + if err != nil { + t.Fatalf("Failed to create MemAesGcmBackend: %v", err) + } + + plaintext := []byte("Hello, World!") + iv := NewIV() + + ctx := context.Background() + encrypted, err := backend.EncryptContent(ctx, plaintext, iv) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + decrypted, err := backend.DecryptContent(ctx, encrypted) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatalf("Decrypted text doesn't match original. Got %v, want %v", decrypted, plaintext) + } +} + +func TestDecryptWithWrongKey(t *testing.T) { + key1 := make([]byte, 32) + key2 := make([]byte, 32) + for i := range key2 { + key2[i] = 1 // Different from key1 + } + + backend1, _ := NewMemAesGcmBackend(key1) + backend2, _ := NewMemAesGcmBackend(key2) + + plaintext := []byte("Hello, World!") + iv := NewIV() + + ctx := context.Background() + encrypted, _ := backend1.EncryptContent(ctx, plaintext, iv) + _, err := backend2.DecryptContent(ctx, encrypted) + if err == nil { + t.Fatal("Expected decryption with wrong key to fail, but it succeeded") + } +} + +func TestDecryptWithTamperedCiphertext(t *testing.T) { + key := make([]byte, 32) + backend, _ := NewMemAesGcmBackend(key) + + plaintext := []byte("Hello, World!") + iv := NewIV() + + ctx := context.Background() + encrypted, _ := backend.EncryptContent(ctx, plaintext, iv) + encrypted.Content[0] ^= 1 // Tamper with the ciphertext + + _, err := backend.DecryptContent(ctx, encrypted) + if err == nil { + t.Fatal("Expected decryption of tampered ciphertext to fail, but it succeeded") + } +} + +func TestDecryptWithMissingMetadata(t *testing.T) { + key := make([]byte, 32) + backend, _ := NewMemAesGcmBackend(key) + + plaintext := []byte("Hello, World!") + iv := NewIV() + + ctx := context.Background() + encrypted, _ := backend.EncryptContent(ctx, plaintext, iv) + delete(encrypted.Metadata, MetadataKeyMethod) + + _, err := backend.DecryptContent(ctx, encrypted) + if err == nil { + t.Fatal("Expected decryption with missing metadata to fail, but it succeeded") + } +} + +func TestEncryptDecryptLargeData(t *testing.T) { + key := make([]byte, 32) + backend, _ := NewMemAesGcmBackend(key) + + plaintext := make([]byte, 1000000) // 1 MB of data + iv := NewIV() + + ctx := context.Background() + encrypted, err := backend.EncryptContent(ctx, plaintext, iv) + if err != nil { + t.Fatalf("Encryption of large data failed: %v", err) + } + + decrypted, err := backend.DecryptContent(ctx, encrypted) + if err != nil { + t.Fatalf("Decryption of large data failed: %v", err) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("Decrypted large data doesn't match original") + } +} diff --git a/br/pkg/encryption/master_key/multi_master_key_backend.go b/br/pkg/encryption/master_key/multi_master_key_backend.go new file mode 100644 index 0000000000000..71a26d7a4e352 --- /dev/null +++ b/br/pkg/encryption/master_key/multi_master_key_backend.go @@ -0,0 +1,47 @@ +package encryption + +import ( + "context" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" +) + +type MultiMasterKeyBackend struct { + backends []Backend +} + +func NewMultiMasterKeyBackend(masterKeysProto []*encryptionpb.MasterKey) (*MultiMasterKeyBackend, error) { + if masterKeysProto == nil && len(masterKeysProto) == 0 { + return nil, errors.New("must provide at least one master key") + } + var backends []Backend + for _, masterKeyProto := range masterKeysProto { + backend, err := CreateBackend(masterKeyProto) + if err != nil { + return nil, errors.Trace(err) + } + backends = append(backends, backend) + } + return &MultiMasterKeyBackend{ + backends: backends, + }, nil +} + +func (m *MultiMasterKeyBackend) Decrypt(ctx context.Context, encryptedContent *encryptionpb.EncryptedContent) ([]byte, error) { + if len(m.backends) == 0 { + return nil, errors.New("internal error: should always contain at least one backend") + } + + var errMsgs []string + for _, masterKeyBackend := range m.backends { + res, err := masterKeyBackend.Decrypt(ctx, encryptedContent) + if err == nil { + return res, nil + } + errMsgs = append(errMsgs, errors.ErrorStack(err)) + } + + return nil, errors.Errorf("failed to decrypt in multi master key backend: %s", strings.Join(errMsgs, ",")) +} diff --git a/br/pkg/encryption/master_key/multi_master_key_backend_test.go b/br/pkg/encryption/master_key/multi_master_key_backend_test.go new file mode 100644 index 0000000000000..f32f63c6ebf6e --- /dev/null +++ b/br/pkg/encryption/master_key/multi_master_key_backend_test.go @@ -0,0 +1,99 @@ +package encryption + +import ( + "context" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockBackend is a mock implementation of the Backend interface +type MockBackend struct { + mock.Mock +} + +func (m *MockBackend) Decrypt(ctx context.Context, encryptedContent *encryptionpb.EncryptedContent) ([]byte, error) { + args := m.Called(ctx, encryptedContent) + // The first return value should be []byte or nil + if ret := args.Get(0); ret != nil { + return ret.([]byte), args.Error(1) + } + return nil, args.Error(1) +} + +func TestMultiMasterKeyBackendDecrypt(t *testing.T) { + ctx := context.Background() + encryptedContent := &encryptionpb.EncryptedContent{Content: []byte("encrypted")} + + t.Run("success first backend", func(t *testing.T) { + mock1 := new(MockBackend) + mock1.On("Decrypt", ctx, encryptedContent).Return([]byte("decrypted"), nil) + + mock2 := new(MockBackend) + + backend := &MultiMasterKeyBackend{ + backends: []Backend{mock1, mock2}, + } + + result, err := backend.Decrypt(ctx, encryptedContent) + assert.NoError(t, err) + assert.Equal(t, []byte("decrypted"), result) + + mock1.AssertExpectations(t) + mock2.AssertNotCalled(t, "Decrypt") + }) + + t.Run("success second backend", func(t *testing.T) { + mock1 := new(MockBackend) + mock1.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed")) + + mock2 := new(MockBackend) + mock2.On("Decrypt", ctx, encryptedContent).Return([]byte("decrypted"), nil) + + backend := &MultiMasterKeyBackend{ + backends: []Backend{mock1, mock2}, + } + + result, err := backend.Decrypt(ctx, encryptedContent) + assert.NoError(t, err) + assert.Equal(t, []byte("decrypted"), result) + + mock1.AssertExpectations(t) + mock2.AssertExpectations(t) + }) + + t.Run("all backends fail", func(t *testing.T) { + mock1 := new(MockBackend) + mock1.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed1")) + + mock2 := new(MockBackend) + mock2.On("Decrypt", ctx, encryptedContent).Return(nil, errors.New("failed2")) + + backend := &MultiMasterKeyBackend{ + backends: []Backend{mock1, mock2}, + } + + result, err := backend.Decrypt(ctx, encryptedContent) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed1") + assert.Contains(t, err.Error(), "failed2") + + mock1.AssertExpectations(t) + mock2.AssertExpectations(t) + }) + + t.Run("no backends", func(t *testing.T) { + backend := &MultiMasterKeyBackend{ + backends: []Backend{}, + } + + result, err := backend.Decrypt(ctx, encryptedContent) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "internal error") + }) +} diff --git a/br/pkg/kms/BUILD.bazel b/br/pkg/kms/BUILD.bazel new file mode 100644 index 0000000000000..8e8d6e56b624e --- /dev/null +++ b/br/pkg/kms/BUILD.bazel @@ -0,0 +1,21 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "aws", + srcs = [ + "aws.go", + "common.go", + "gcp.go", + "kms.go", + ], + importpath = "github.com/pingcap/tidb/br/pkg/kms", + visibility = ["//visibility:public"], + deps = [ + "@com_github_aws_aws_sdk_go//aws", + "@com_github_aws_aws_sdk_go//aws/session", + "@com_github_aws_aws_sdk_go//service/kms", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/encryptionpb", + "@org_golang_x_oauth2//google", + ], +) diff --git a/br/pkg/kms/aws.go b/br/pkg/kms/aws.go new file mode 100644 index 0000000000000..1edf1b46e66e6 --- /dev/null +++ b/br/pkg/kms/aws.go @@ -0,0 +1,71 @@ +package kms + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" + pErrors "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" +) + +const ( + EncryptionVendorNameAwsKms = "AWS" +) + +type AwsKms struct { + client *kms.KMS + currentKeyID string + region string + endpoint string +} + +func NewAwsKms(masterKeyConfig *encryptionpb.MasterKeyKms) (*AwsKms, error) { + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(masterKeyConfig.Region), + Endpoint: aws.String(masterKeyConfig.Endpoint), + }) + if err != nil { + return nil, pErrors.Annotate(err, "failed to create AWS session") + } + + return &AwsKms{ + client: kms.New(sess), + currentKeyID: masterKeyConfig.KeyId, + region: masterKeyConfig.Region, + endpoint: masterKeyConfig.Endpoint, + }, nil +} + +func (a *AwsKms) Name() string { + return EncryptionVendorNameAwsKms +} + +func (a *AwsKms) DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) { + input := &kms.DecryptInput{ + CiphertextBlob: dataKey, + KeyId: aws.String(a.currentKeyID), + } + + result, err := a.client.DecryptWithContext(ctx, input) + if err != nil { + return nil, classifyDecryptError(err) + } + + return result.Plaintext, nil +} + +// Update classifyDecryptError to use v1 SDK error types +func classifyDecryptError(err error) error { + switch err := err.(type) { + case *kms.NotFoundException, *kms.InvalidKeyUsageException: + return pErrors.Annotate(err, "wrong master key") + case *kms.DependencyTimeoutException: + return pErrors.Annotate(err, "API timeout") + case *kms.InternalException: + return pErrors.Annotate(err, "API internal error") + default: + return pErrors.Annotate(err, "KMS error") + } +} diff --git a/br/pkg/kms/common.go b/br/pkg/kms/common.go new file mode 100644 index 0000000000000..ab603cda18daf --- /dev/null +++ b/br/pkg/kms/common.go @@ -0,0 +1,69 @@ +package kms + +import ( + "bytes" + + "github.com/pingcap/errors" +) + +// EncryptedKey is used to mark data as an encrypted key +type EncryptedKey []byte + +func NewEncryptedKey(key []byte) (EncryptedKey, error) { + if len(key) == 0 { + return nil, errors.New("encrypted key cannot be empty") + } + return key, nil +} + +// Equal method for EncryptedKey +func (e EncryptedKey) Equal(other *EncryptedKey) bool { + return bytes.Equal(e, *other) +} + +// CryptographyType represents different cryptography methods +type CryptographyType int + +const ( + CryptographyTypePlain CryptographyType = iota + CryptographyTypeAesGcm256 +) + +func (c CryptographyType) TargetKeySize() int { + switch c { + case CryptographyTypePlain: + return 0 // Plain text has no limitation + case CryptographyTypeAesGcm256: + return 32 + default: + return 0 + } +} + +// PlainKey is used to mark a byte slice as a plaintext key +type PlainKey struct { + tag CryptographyType + key []byte +} + +func NewPlainKey(key []byte, t CryptographyType) (*PlainKey, error) { + limitation := t.TargetKeySize() + if limitation > 0 && len(key) != limitation { + return nil, errors.Errorf("encryption method and key length mismatch, expect %d got %d", limitation, len(key)) + } + return &PlainKey{key: key, tag: t}, nil +} + +func (p *PlainKey) KeyTag() CryptographyType { + return p.tag +} + +func (p *PlainKey) Key() []byte { + return p.key +} + +// DataKeyPair represents a pair of encrypted and plaintext keys +type DataKeyPair struct { + Encrypted *EncryptedKey + Plaintext *PlainKey +} diff --git a/br/pkg/kms/gcp.go b/br/pkg/kms/gcp.go new file mode 100644 index 0000000000000..c1fcf5ce0bb26 --- /dev/null +++ b/br/pkg/kms/gcp.go @@ -0,0 +1,166 @@ +package kms + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "hash/crc32" + "io" + "net/http" + "regexp" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "golang.org/x/oauth2/google" +) + +const ( + GcpKmsEndpoint = "https://cloudkms.googleapis.com/v1/" + MethodDecrypt = "decrypt" + KeyIdPattern = `^projects/([^/]+)/locations/([^/]+)/keyRings/([^/]+)/cryptoKeys/([^/]+)/?$` + StorageVendorNameGcp = "gcp" +) + +var KeyIdRegex = regexp.MustCompile(KeyIdPattern) + +type GcpKms struct { + config *encryptionpb.MasterKeyKms + location string + client *http.Client +} + +func NewGcpKms(config *encryptionpb.MasterKeyKms) (*GcpKms, error) { + if config.GcpKms == nil { + return nil, errors.New("GCP config is missing") + } + if !KeyIdRegex.MatchString(config.KeyId) { + return nil, errors.Errorf("invalid key: '%s'", config.KeyId) + } + if strings.HasSuffix(config.KeyId, "/") { + config.KeyId = strings.TrimSuffix(config.KeyId, "/") + } + location := strings.Join(strings.Split(config.KeyId, "/")[:4], "/") + + client, err := google.DefaultClient(context.Background(), "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, errors.Errorf("failed to create GCP client: %v", err) + } + + return &GcpKms{ + config: config, + location: location, + client: client, + }, nil +} + +func (g *GcpKms) doJSONRequest(ctx context.Context, keyName, method string, data interface{}) ([]byte, error) { + url := g.formatCallURL(keyName, method) + jsonData, err := json.Marshal(data) + if err != nil { + return nil, errors.Annotate(err, "failed to marshal request") + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(jsonData))) + if err != nil { + return nil, errors.Annotate(err, "failed to create request") + } + req.Header.Set("Content-Type", "application/json") + + start := time.Now() + resp, err := g.client.Do(req) + if err != nil { + return nil, errors.Annotate(err, "request failed") + } + defer resp.Body.Close() + + // TODO: Implement metrics + _ = time.Since(start) + + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("request failed with status: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Annotate(err, "failed to read response body") + } + + return body, nil +} + +func (g *GcpKms) formatCallURL(key, method string) string { + return fmt.Sprintf("%s%s/:%s?alt=json", GcpKmsEndpoint, key, method) +} + +func (g *GcpKms) Name() string { + return StorageVendorNameGcp +} + +func (g *GcpKms) DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) { + decryptReq := DecryptRequest{ + Ciphertext: base64.StdEncoding.EncodeToString(dataKey), + CiphertextCrc32c: int64(crc32.Checksum(dataKey, crc32.MakeTable(crc32.Castagnoli))), + } + + respBody, err := g.doJSONRequest(ctx, g.config.KeyId, MethodDecrypt, decryptReq) + if err != nil { + return nil, errors.Annotate(err, "decrypt request failed") + } + + var resp DecryptResp + if err := json.Unmarshal(respBody, &resp); err != nil { + return nil, errors.Annotate(err, "failed to unmarshal decrypt response") + } + + plaintext, err := base64.StdEncoding.DecodeString(resp.Plaintext) + if err != nil { + return nil, errors.Annotate(err, "failed to decode plaintext") + } + + if err := checkCRC32(plaintext, resp.PlaintextCrc32c); err != nil { + return nil, err + } + + return plaintext, nil +} + +type EncryptRequest struct { + Plaintext string `json:"plaintext"` + PlaintextCrc32c int64 `json:"plaintextCrc32c,string"` +} + +type EncryptResp struct { + Ciphertext string `json:"ciphertext"` + CiphertextCrc32c int64 `json:"ciphertextCrc32c,string"` +} + +type DecryptRequest struct { + Ciphertext string `json:"ciphertext"` + CiphertextCrc32c int64 `json:"ciphertextCrc32c,string"` +} + +type DecryptResp struct { + Plaintext string `json:"plaintext"` + PlaintextCrc32c int64 `json:"plaintextCrc32c,string"` +} + +type GenRandomBytesReq struct { + LengthBytes int `json:"lengthBytes"` + ProtectionLevel string `json:"protectionLevel"` +} + +type GenRandomBytesResp struct { + Data string `json:"data"` + DataCrc32c int64 `json:"dataCrc32c,string"` +} + +func checkCRC32(data []byte, expected int64) error { + crc := int64(crc32.Checksum(data, crc32.MakeTable(crc32.Castagnoli))) + if crc != expected { + return errors.Errorf("crc32c mismatch, expected: %d, got: %d", expected, crc) + } + return nil +} diff --git a/br/pkg/kms/kms.go b/br/pkg/kms/kms.go new file mode 100644 index 0000000000000..17090037da377 --- /dev/null +++ b/br/pkg/kms/kms.go @@ -0,0 +1,10 @@ +package kms + +import "context" + +// Provider is an interface for key management service providers +// implement encrypt data key in future if needed +type Provider interface { + DecryptDataKey(ctx context.Context, dataKey []byte) ([]byte, error) + Name() string +} diff --git a/br/pkg/metautil/BUILD.bazel b/br/pkg/metautil/BUILD.bazel index 44a35bc2b209c..a7008e6283859 100644 --- a/br/pkg/metautil/BUILD.bazel +++ b/br/pkg/metautil/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//br/pkg/logutil", "//br/pkg/storage", "//br/pkg/summary", + "//br/pkg/utils", "//pkg/meta/model", "//pkg/statistics/handle", "//pkg/statistics/handle/types", @@ -47,6 +48,7 @@ go_test( shard_count = 9, deps = [ "//br/pkg/storage", + "//br/pkg/utils", "//pkg/meta/model", "//pkg/parser/model", "//pkg/statistics/handle/types", diff --git a/br/pkg/metautil/metafile.go b/br/pkg/metautil/metafile.go index 83542d2880e14..d1c7196f20c22 100644 --- a/br/pkg/metautil/metafile.go +++ b/br/pkg/metautil/metafile.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/statistics/handle/util" "github.com/pingcap/tidb/pkg/tablecodec" @@ -87,22 +88,19 @@ func Encrypt(content []byte, cipher *backuppb.CipherInfo) (encryptedContent, iv } } -// Decrypt decrypts the content according to CipherInfo and IV. -func Decrypt(content []byte, cipher *backuppb.CipherInfo, iv []byte) ([]byte, error) { - if len(content) == 0 || cipher == nil { - return content, nil +func DecryptFullBackupMetaIfNeeded(metaData []byte, cipherInfo *backuppb.CipherInfo) ([]byte, error) { + // the prefix of backup meta file is iv(16 bytes) if encryption method is valid + var iv []byte + if cipherInfo != nil && cipherInfo.CipherType != encryptionpb.EncryptionMethod_PLAINTEXT { + iv = metaData[:CrypterIvLen] + } else { + return metaData, nil } - - switch cipher.CipherType { - case encryptionpb.EncryptionMethod_PLAINTEXT: - return content, nil - case encryptionpb.EncryptionMethod_AES128_CTR, - encryptionpb.EncryptionMethod_AES192_CTR, - encryptionpb.EncryptionMethod_AES256_CTR: - return encrypt.AESDecryptWithCTR(content, cipher.CipherKey, iv) - default: - return content, errors.Annotate(berrors.ErrInvalidArgument, "cipher type invalid") + decryptBackupMeta, err := utils.Decrypt(metaData[len(iv):], cipherInfo, iv) + if err != nil { + return nil, errors.Annotate(err, "decrypt failed with wrong key") } + return decryptBackupMeta, nil } // walkLeafMetaFile walks the leaves of the given metafile, and deal with it by calling the function `output`. @@ -130,7 +128,7 @@ func walkLeafMetaFile( return errors.Trace(err) } - decryptContent, err := Decrypt(content, cipher, node.CipherIv) + decryptContent, err := utils.Decrypt(content, cipher, node.CipherIv) if err != nil { return errors.Trace(err) } diff --git a/br/pkg/metautil/metafile_test.go b/br/pkg/metautil/metafile_test.go index 4a010b49654a4..1cea0f86d8d0b 100644 --- a/br/pkg/metautil/metafile_test.go +++ b/br/pkg/metautil/metafile_test.go @@ -11,6 +11,7 @@ import ( backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/stretchr/testify/require" ) @@ -203,14 +204,14 @@ func TestEncryptAndDecrypt(t *testing.T) { require.NoError(t, err) require.Equal(t, originalData, encryptData) - decryptData, err := Decrypt(encryptData, &cipher, iv) + decryptData, err := utils.Decrypt(encryptData, &cipher, iv) require.NoError(t, err) require.Equal(t, decryptData, originalData) } else { require.NoError(t, err) require.NotEqual(t, originalData, encryptData) - decryptData, err := Decrypt(encryptData, &cipher, iv) + decryptData, err := utils.Decrypt(encryptData, &cipher, iv) require.NoError(t, err) require.Equal(t, decryptData, originalData) @@ -218,7 +219,7 @@ func TestEncryptAndDecrypt(t *testing.T) { CipherType: v.method, CipherKey: []byte(v.wrongKey), } - decryptData, err = Decrypt(encryptData, &wrongCipher, iv) + decryptData, err = utils.Decrypt(encryptData, &wrongCipher, iv) if len(v.rightKey) != len(v.wrongKey) { require.Error(t, err) } else { diff --git a/br/pkg/metautil/statsfile.go b/br/pkg/metautil/statsfile.go index 6621a87cb6e14..ddccdfe839caf 100644 --- a/br/pkg/metautil/statsfile.go +++ b/br/pkg/metautil/statsfile.go @@ -26,6 +26,7 @@ import ( backuppb "github.com/pingcap/kvproto/pkg/brpb" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/statistics/handle" statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" @@ -201,7 +202,7 @@ func downloadStats( return errors.Trace(err) } - decryptContent, err := Decrypt(content, cipher, statsFile.CipherIv) + decryptContent, err := utils.Decrypt(content, cipher, statsFile.CipherIv) if err != nil { return errors.Trace(err) } diff --git a/br/pkg/mock/mocklocal/local.go b/br/pkg/mock/mocklocal/local.go index b6c70a530b912..81dad11391067 100644 --- a/br/pkg/mock/mocklocal/local.go +++ b/br/pkg/mock/mocklocal/local.go @@ -138,7 +138,7 @@ func (m *MockStoreHelper) EXPECT() *MockStoreHelperMockRecorder { // GetTS mocks base method. func (m *MockStoreHelper) GetTS(arg0 context.Context) (int64, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTS", arg0) + ret := m.ctrl.Call(m, "GetCurrentTsFromPd", arg0) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(error) @@ -148,7 +148,7 @@ func (m *MockStoreHelper) GetTS(arg0 context.Context) (int64, int64, error) { // GetTS indicates an expected call of GetTS. func (mr *MockStoreHelperMockRecorder) GetTS(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTS", reflect.TypeOf((*MockStoreHelper)(nil).GetTS), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentTsFromPd", reflect.TypeOf((*MockStoreHelper)(nil).GetTS), arg0) } // GetTiKVCodec mocks base method. diff --git a/br/pkg/restore/log_client/BUILD.bazel b/br/pkg/restore/log_client/BUILD.bazel index 76a2f7de49a97..3f000ac4b7db0 100644 --- a/br/pkg/restore/log_client/BUILD.bazel +++ b/br/pkg/restore/log_client/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "//br/pkg/checksum", "//br/pkg/conn", "//br/pkg/conn/util", + "//br/pkg/encryption", "//br/pkg/errors", "//br/pkg/glue", "//br/pkg/logutil", @@ -49,6 +50,7 @@ go_library( "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", "@com_github_pingcap_kvproto//pkg/errorpb", "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_kvproto//pkg/kvrpcpb", diff --git a/br/pkg/restore/log_client/client.go b/br/pkg/restore/log_client/client.go index 9a7ee8ee9cf8d..db6a25dc6a9c7 100644 --- a/br/pkg/restore/log_client/client.go +++ b/br/pkg/restore/log_client/client.go @@ -33,11 +33,13 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/checkpoint" "github.com/pingcap/tidb/br/pkg/checksum" "github.com/pingcap/tidb/br/pkg/conn" "github.com/pingcap/tidb/br/pkg/conn/util" + "github.com/pingcap/tidb/br/pkg/encryption" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/metautil" @@ -282,13 +284,15 @@ func (rc *LogClient) InitCheckpointMetadataForLogRestore(ctx context.Context, ta return gcRatio, nil } -func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint) error { +func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint, + encryptionManager *encryption.Manager) error { init := LogFileManagerInit{ StartTS: startTS, RestoreTS: restoreTS, Storage: rc.storage, MetadataDownloadBatchSize: metadataDownloadBatchSize, + EncryptionManager: encryptionManager, } var err error rc.LogFileManager, err = CreateLogFileManager(ctx, init) @@ -425,7 +429,7 @@ func ApplyKVFilesWithBatchMethod( return nil } -func ApplyKVFilesWithSingelMethod( +func ApplyKVFilesWithSingleMethod( ctx context.Context, files LogIter, applyFunc func(file []*LogDataFileInfo, kvCount int64, size uint64), @@ -466,6 +470,8 @@ func (rc *LogClient) RestoreKVFiles( pitrBatchSize uint32, updateStats func(kvCount uint64, size uint64), onProgress func(cnt int64), + cipherInfo *backuppb.CipherInfo, + masterKeys []*encryptionpb.MasterKey, ) error { var ( err error @@ -477,7 +483,7 @@ func (rc *LogClient) RestoreKVFiles( defer func() { if err == nil { elapsed := time.Since(start) - log.Info("Restore KV files", zap.Duration("take", elapsed)) + log.Info("Restored KV files", zap.Duration("take", elapsed)) summary.CollectSuccessUnit("files", fileCount, elapsed) } }() @@ -537,7 +543,8 @@ func (rc *LogClient) RestoreKVFiles( } }() - return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, supportBatch) + return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, + supportBatch, cipherInfo, masterKeys) }) } } @@ -546,7 +553,7 @@ func (rc *LogClient) RestoreKVFiles( if supportBatch { err = ApplyKVFilesWithBatchMethod(ectx, logIter, int(pitrBatchCount), uint64(pitrBatchSize), applyFunc, &applyWg) } else { - err = ApplyKVFilesWithSingelMethod(ectx, logIter, applyFunc, &applyWg) + err = ApplyKVFilesWithSingleMethod(ectx, logIter, applyFunc, &applyWg) } return errors.Trace(err) }) @@ -608,18 +615,25 @@ func initFullBackupTables( ctx context.Context, s storage.ExternalStorage, tableFilter filter.Filter, + cipherInfo *backuppb.CipherInfo, ) (map[int64]*metautil.Table, error) { metaData, err := s.ReadFile(ctx, metautil.MetaFile) if err != nil { return nil, errors.Trace(err) } + + backupMetaBytes, err := metautil.DecryptFullBackupMetaIfNeeded(metaData, cipherInfo) + if err != nil { + return nil, errors.Annotate(err, "decrypt failed with wrong key") + } + backupMeta := &backuppb.BackupMeta{} - if err = backupMeta.Unmarshal(metaData); err != nil { + if err = backupMeta.Unmarshal(backupMetaBytes); err != nil { return nil, errors.Trace(err) } // read full backup databases to get map[table]table.Info - reader := metautil.NewMetaReader(backupMeta, s, nil) + reader := metautil.NewMetaReader(backupMeta, s, cipherInfo) databases, err := metautil.LoadBackupTables(ctx, reader, false) if err != nil { @@ -673,6 +687,7 @@ const UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL = "UNSAFE_PITR_LOG_RESTO func (rc *LogClient) generateDBReplacesFromFullBackupStorage( ctx context.Context, cfg *InitSchemaConfig, + cipherInfo *backuppb.CipherInfo, ) (map[stream.UpstreamID]*stream.DBReplace, error) { dbReplaces := make(map[stream.UpstreamID]*stream.DBReplace) if cfg.FullBackupStorage == nil { @@ -687,7 +702,7 @@ func (rc *LogClient) generateDBReplacesFromFullBackupStorage( if err != nil { return nil, errors.Trace(err) } - fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter) + fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter, cipherInfo) if err != nil { return nil, errors.Trace(err) } @@ -730,6 +745,7 @@ func (rc *LogClient) generateDBReplacesFromFullBackupStorage( func (rc *LogClient) InitSchemasReplaceForDDL( ctx context.Context, cfg *InitSchemaConfig, + cipherInfo *backuppb.CipherInfo, ) (*stream.SchemasReplace, error) { var ( err error @@ -774,7 +790,7 @@ func (rc *LogClient) InitSchemasReplaceForDDL( if len(dbMaps) <= 0 { log.Info("no id maps, build the table replaces from cluster and full backup schemas") needConstructIdMap = true - dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg) + dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg, cipherInfo) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/restore/log_client/client_test.go b/br/pkg/restore/log_client/client_test.go index 7b00e30e6eaa7..b8cbbd47163f1 100644 --- a/br/pkg/restore/log_client/client_test.go +++ b/br/pkg/restore/log_client/client_test.go @@ -866,7 +866,7 @@ func TestApplyKVFilesWithSingelMethod(t *testing.T) { } } - logclient.ApplyKVFilesWithSingelMethod( + logclient.ApplyKVFilesWithSingleMethod( context.TODO(), toLogDataFileInfoIter(iter.FromSlice(ds)), applyFunc, @@ -1293,7 +1293,7 @@ func TestApplyKVFilesWithBatchMethod5(t *testing.T) { require.Equal(t, backuppb.FileType_Delete, types[len(types)-1]) types = make([]backuppb.FileType, 0) - logclient.ApplyKVFilesWithSingelMethod( + logclient.ApplyKVFilesWithSingleMethod( context.TODO(), toLogDataFileInfoIter(iter.FromSlice(ds)), applyFunc, @@ -1377,7 +1377,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { { client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), fakeSession{}) cfg := &logclient.InitSchemaConfig{IsNewTask: false} - _, err := client.InitSchemasReplaceForDDL(ctx, cfg) + _, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil) require.Error(t, err) require.Regexp(t, "failed to get pitr id map from mysql.tidb_pitr_id_map.* [2, 1]", err.Error()) } @@ -1385,7 +1385,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { { client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), fakeSession{}) cfg := &logclient.InitSchemaConfig{IsNewTask: true} - _, err := client.InitSchemasReplaceForDDL(ctx, cfg) + _, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil) require.Error(t, err) require.Regexp(t, "failed to get pitr id map from mysql.tidb_pitr_id_map.* [1, 1]", err.Error()) } @@ -1399,7 +1399,7 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { require.NoError(t, err) client := logclient.TEST_NewLogClient(123, 1, 2, 1, domain.NewMockDomain(), se) cfg := &logclient.InitSchemaConfig{IsNewTask: true} - _, err = client.InitSchemasReplaceForDDL(ctx, cfg) + _, err := client.InitSchemasReplaceForDDL(ctx, cfg, nil) require.Error(t, err) require.Contains(t, err.Error(), "miss upstream table information at `start-ts`(1) but the full backup path is not specified") } diff --git a/br/pkg/restore/log_client/import.go b/br/pkg/restore/log_client/import.go index 41abe7fe5c426..1b67b78f85ab9 100644 --- a/br/pkg/restore/log_client/import.go +++ b/br/pkg/restore/log_client/import.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -101,6 +102,8 @@ func (importer *LogFileImporter) ImportKVFiles( startTS uint64, restoreTS uint64, supportBatch bool, + cipherInfo *backuppb.CipherInfo, + masterKeys []*encryptionpb.MasterKey, ) error { var ( startKey []byte @@ -111,7 +114,7 @@ func (importer *LogFileImporter) ImportKVFiles( if !supportBatch && len(files) > 1 { return errors.Annotatef(berrors.ErrInvalidArgument, - "do not support batch apply but files count:%v > 1", len(files)) + "do not support batch apply, file count: %v > 1", len(files)) } log.Debug("import kv files", zap.Int("batch file count", len(files))) @@ -143,7 +146,8 @@ func (importer *LogFileImporter) ImportKVFiles( if len(subfiles) == 0 { return RPCResultOK() } - return importer.importKVFileForRegion(ctx, subfiles, rule, shiftStartTS, startTS, restoreTS, r, supportBatch) + return importer.importKVFileForRegion(ctx, subfiles, rule, shiftStartTS, startTS, restoreTS, r, supportBatch, + cipherInfo, masterKeys) }) return errors.Trace(err) } @@ -184,9 +188,11 @@ func (importer *LogFileImporter) importKVFileForRegion( restoreTS uint64, info *split.RegionInfo, supportBatch bool, + cipherInfo *backuppb.CipherInfo, + masterKeys []*encryptionpb.MasterKey, ) RPCResult { // Try to download file. - result := importer.downloadAndApplyKVFile(ctx, files, rule, info, shiftStartTS, startTS, restoreTS, supportBatch) + result := importer.downloadAndApplyKVFile(ctx, files, rule, info, shiftStartTS, startTS, restoreTS, supportBatch, cipherInfo, masterKeys) if !result.OK() { errDownload := result.Err for _, e := range multierr.Errors(errDownload) { @@ -216,7 +222,9 @@ func (importer *LogFileImporter) downloadAndApplyKVFile( startTS uint64, restoreTS uint64, supportBatch bool, -) RPCResult { + cipherInfo *backuppb.CipherInfo, + masterKeys []*encryptionpb.MasterKey) RPCResult { + leader := regionInfo.Leader if leader == nil { return RPCResultFromError(errors.Annotatef(berrors.ErrPDLeaderNotFound, @@ -251,11 +259,12 @@ func (importer *LogFileImporter) downloadAndApplyKVFile( } return startTS }(), - RestoreTs: restoreTS, - StartKey: regionInfo.Region.GetStartKey(), - EndKey: regionInfo.Region.GetEndKey(), - Sha256: file.GetSha256(), - CompressionType: file.CompressionType, + RestoreTs: restoreTS, + StartKey: regionInfo.Region.GetStartKey(), + EndKey: regionInfo.Region.GetEndKey(), + Sha256: file.GetSha256(), + CompressionType: file.CompressionType, + FileEncryptionInfo: file.FileEncryptionInfo, } metas = append(metas, meta) @@ -276,6 +285,8 @@ func (importer *LogFileImporter) downloadAndApplyKVFile( RewriteRules: rewriteRules, Context: reqCtx, StorageCacheId: importer.cacheKey, + CipherInfo: cipherInfo, + MasterKeys: masterKeys, } } else { req = &import_sstpb.ApplyRequest{ @@ -284,16 +295,18 @@ func (importer *LogFileImporter) downloadAndApplyKVFile( RewriteRule: *rewriteRules[0], Context: reqCtx, StorageCacheId: importer.cacheKey, + CipherInfo: cipherInfo, + MasterKeys: masterKeys, } } - log.Debug("apply kv file", logutil.Leader(leader)) + log.Debug("applying kv file", logutil.Leader(leader)) resp, err := importer.importClient.ApplyKVFile(ctx, leader.GetStoreId(), req) if err != nil { return RPCResultFromError(errors.Trace(err)) } if resp.GetError() != nil { - logutil.CL(ctx).Warn("import meet error", zap.Stringer("error", resp.GetError())) + logutil.CL(ctx).Warn("import has error", zap.Stringer("error", resp.GetError())) return RPCResultFromPBError(resp.GetError()) } return RPCResultOK() diff --git a/br/pkg/restore/log_client/log_file_manager.go b/br/pkg/restore/log_client/log_file_manager.go index d29b4f1f7ab8e..612d10ed64e65 100644 --- a/br/pkg/restore/log_client/log_file_manager.go +++ b/br/pkg/restore/log_client/log_file_manager.go @@ -12,7 +12,9 @@ import ( "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/encryption" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/stream" @@ -54,6 +56,7 @@ type streamMetadataHelper interface { length uint64, compressionType backuppb.CompressionType, storage storage.ExternalStorage, + encryptionInfo *encryptionpb.FileEncryptionInfo, ) ([]byte, error) ParseToMetadata(rawMetaData []byte) (*backuppb.Metadata, error) } @@ -85,6 +88,7 @@ type LogFileManagerInit struct { Storage storage.ExternalStorage MetadataDownloadBatchSize uint + EncryptionManager *encryption.Manager } type DDLMetaGroup struct { @@ -99,7 +103,7 @@ func CreateLogFileManager(ctx context.Context, init LogFileManagerInit) (*LogFil startTS: init.StartTS, restoreTS: init.RestoreTS, storage: init.Storage, - helper: stream.NewMetadataHelper(), + helper: stream.NewMetadataHelper(init.EncryptionManager), metadataDownloadBatchSize: init.MetadataDownloadBatchSize, } @@ -329,7 +333,8 @@ func (rc *LogFileManager) ReadAllEntries( kvEntries := make([]*KvEntryWithTS, 0) nextKvEntries := make([]*KvEntryWithTS, 0) - buff, err := rc.helper.ReadFile(ctx, file.Path, file.RangeOffset, file.RangeLength, file.CompressionType, rc.storage) + buff, err := rc.helper.ReadFile(ctx, file.Path, file.RangeOffset, file.RangeLength, file.CompressionType, + rc.storage, file.FileEncryptionInfo) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/br/pkg/restore/log_client/log_file_manager_test.go b/br/pkg/restore/log_client/log_file_manager_test.go index 82fcf628d0139..79a31cbc8a2da 100644 --- a/br/pkg/restore/log_client/log_file_manager_test.go +++ b/br/pkg/restore/log_client/log_file_manager_test.go @@ -306,7 +306,7 @@ func testReadFromMetadataWithVersion(t *testing.T, m metaMaker) { }() meta := new(stream.StreamMetadataSet) - meta.Helper = stream.NewMetadataHelper() + meta.Helper = stream.NewMetadataHelper(nil) meta.MetadataDownloadBatchSize = 128 _, err := meta.LoadUntilAndCalculateShiftTS(ctx, loc, c.untilTS) require.NoError(t, err) diff --git a/br/pkg/restore/snap_client/client.go b/br/pkg/restore/snap_client/client.go index be54507c23ed0..2bbc488009a7b 100644 --- a/br/pkg/restore/snap_client/client.go +++ b/br/pkg/restore/snap_client/client.go @@ -633,7 +633,7 @@ func (rc *SnapClient) CreateDatabases(ctx context.Context, dbs []*metautil.Datab return nil } - log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool))) + log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool)), zap.Int("number of db", len(dbs))) eg, ectx := errgroup.WithContext(ctx) workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DB DDL workers") for _, db_ := range dbs { diff --git a/br/pkg/stream/BUILD.bazel b/br/pkg/stream/BUILD.bazel index aff6b7ee2a8e2..f2a38a769c45d 100644 --- a/br/pkg/stream/BUILD.bazel +++ b/br/pkg/stream/BUILD.bazel @@ -15,6 +15,7 @@ go_library( importpath = "github.com/pingcap/tidb/br/pkg/stream", visibility = ["//visibility:public"], deps = [ + "//br/pkg/encryption", "//br/pkg/errors", "//br/pkg/glue", "//br/pkg/httputil", @@ -36,6 +37,7 @@ go_library( "@com_github_klauspost_compress//zstd", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_pingcap_log//:log", "@com_github_tikv_client_go_v2//oracle", diff --git a/br/pkg/stream/stream_metas_test.go b/br/pkg/stream/stream_metas_test.go index 2545f2fd90097..6efd7ff6adae0 100644 --- a/br/pkg/stream/stream_metas_test.go +++ b/br/pkg/stream/stream_metas_test.go @@ -148,7 +148,7 @@ func TestTruncateLog(t *testing.T) { require.NoError(t, fakeStreamBackup(l)) s := StreamMetadataSet{ - Helper: NewMetadataHelper(), + Helper: NewMetadataHelper(nil), MetadataDownloadBatchSize: 128, } require.NoError(t, s.LoadFrom(ctx, l)) @@ -221,7 +221,7 @@ func TestTruncateLogV2(t *testing.T) { require.NoError(t, fakeStreamBackupV2(l)) s := StreamMetadataSet{ - Helper: NewMetadataHelper(), + Helper: NewMetadataHelper(nil), MetadataDownloadBatchSize: 128, } require.NoError(t, s.LoadFrom(ctx, l)) @@ -1190,7 +1190,7 @@ func TestTruncate1(t *testing.T) { for _, until := range ts.until { t.Logf("case %d, param %d, until %d", i, j, until) metas := StreamMetadataSet{ - Helper: NewMetadataHelper(), + Helper: NewMetadataHelper(nil), MetadataDownloadBatchSize: 128, } err := generateFiles(ctx, s, cs.metas, tmpDir) @@ -1706,7 +1706,7 @@ func TestTruncate2(t *testing.T) { for _, until := range ts.until { t.Logf("case %d, param %d, until %d", i, j, until) metas := StreamMetadataSet{ - Helper: NewMetadataHelper(), + Helper: NewMetadataHelper(nil), MetadataDownloadBatchSize: 128, } err := generateFiles(ctx, s, cs.metas, tmpDir) @@ -2090,7 +2090,7 @@ func TestTruncate3(t *testing.T) { for _, until := range ts.until { t.Logf("case %d, param %d, until %d", i, j, until) metas := StreamMetadataSet{ - Helper: NewMetadataHelper(), + Helper: NewMetadataHelper(nil), MetadataDownloadBatchSize: 128, } err := generateFiles(ctx, s, cs.metas, tmpDir) @@ -2303,7 +2303,7 @@ func TestCalculateShiftTS(t *testing.T) { for _, until := range ts.until { t.Logf("case %d, param %d, until %d", i, j, until) metas := StreamMetadataSet{ - Helper: NewMetadataHelper(), + Helper: NewMetadataHelper(nil), MetadataDownloadBatchSize: 128, } err := generateFiles(ctx, s, cs.metas, tmpDir) diff --git a/br/pkg/stream/stream_mgr.go b/br/pkg/stream/stream_mgr.go index d53a66b7f1416..4b977b3243450 100644 --- a/br/pkg/stream/stream_mgr.go +++ b/br/pkg/stream/stream_mgr.go @@ -21,7 +21,9 @@ import ( "github.com/klauspost/compress/zstd" "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/encryption" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/meta" @@ -157,15 +159,17 @@ type ContentRef struct { // MetadataHelper make restore/truncate compatible with metadataV1 and metadataV2. type MetadataHelper struct { - cache map[string]*ContentRef - decoder *zstd.Decoder + cache map[string]*ContentRef + decoder *zstd.Decoder + encryptionManager *encryption.Manager } -func NewMetadataHelper() *MetadataHelper { +func NewMetadataHelper(encryptionManager *encryption.Manager) *MetadataHelper { decoder, _ := zstd.NewReader(nil) return &MetadataHelper{ - cache: make(map[string]*ContentRef), - decoder: decoder, + cache: make(map[string]*ContentRef), + decoder: decoder, + encryptionManager: encryptionManager, } } @@ -191,6 +195,23 @@ func (m *MetadataHelper) decodeCompressedData(data []byte, compressionType backu "failed to decode compressed data: compression type is unimplemented. type id is %d", compressionType) } +func (m *MetadataHelper) decryptIfNeeded(ctx context.Context, data []byte, encryptionInfo *encryptionpb.FileEncryptionInfo) ([]byte, error) { + // no need to decrypt + if encryptionInfo == nil { + return data, nil + } + + if m.encryptionManager == nil { + return data, errors.New("need to decrypt data but encryption manager not set") + } + + decryptedContent, err := m.encryptionManager.Decrypt(ctx, data, encryptionInfo) + if err != nil { + return nil, errors.Trace(err) + } + return decryptedContent, nil +} + func (m *MetadataHelper) ReadFile( ctx context.Context, path string, @@ -198,6 +219,7 @@ func (m *MetadataHelper) ReadFile( length uint64, compressionType backuppb.CompressionType, storage storage.ExternalStorage, + encryptionInfo *encryptionpb.FileEncryptionInfo, ) ([]byte, error) { var err error cref, exist := m.cache[path] @@ -212,7 +234,12 @@ func (m *MetadataHelper) ReadFile( if err != nil { return nil, errors.Trace(err) } - return m.decodeCompressedData(data, compressionType) + // decrypt if needed + decryptedData, err := m.decryptIfNeeded(ctx, data, encryptionInfo) + if err != nil { + return nil, errors.Trace(err) + } + return m.decodeCompressedData(decryptedData, compressionType) } cref.ref -= 1 @@ -223,8 +250,12 @@ func (m *MetadataHelper) ReadFile( return nil, errors.Trace(err) } } - - buf, err := m.decodeCompressedData(cref.data[offset:offset+length], compressionType) + // decrypt if needed + decryptedData, err := m.decryptIfNeeded(ctx, cref.data[offset:offset+length], encryptionInfo) + if err != nil { + return nil, errors.Trace(err) + } + buf, err := m.decodeCompressedData(decryptedData, compressionType) if cref.ref <= 0 { // need reset reference information. diff --git a/br/pkg/stream/stream_misc_test.go b/br/pkg/stream/stream_misc_test.go index 2de4784a5f137..4946682b288c4 100644 --- a/br/pkg/stream/stream_misc_test.go +++ b/br/pkg/stream/stream_misc_test.go @@ -48,7 +48,7 @@ func TestMetadataHelperReadFile(t *testing.T) { tmpdir := t.TempDir() s, err := storage.NewLocalStorage(tmpdir) require.Nil(t, err) - helper := stream.NewMetadataHelper() + helper := stream.NewMetadataHelper(nil) filename1 := "full_data" filename2 := "misc_data" data1 := []byte("Test MetadataHelper. The data contains bare data (or maybe compressed data).") diff --git a/br/pkg/task/BUILD.bazel b/br/pkg/task/BUILD.bazel index e62938e7fab66..353d070cbb86e 100644 --- a/br/pkg/task/BUILD.bazel +++ b/br/pkg/task/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "backup_raw.go", "backup_txn.go", "common.go", + "encryption.go", "restore.go", "restore_data.go", "restore_ebs_meta.go", @@ -27,6 +28,7 @@ go_library( "//br/pkg/config", "//br/pkg/conn", "//br/pkg/conn/util", + "//br/pkg/encryption", "//br/pkg/errors", "//br/pkg/glue", "//br/pkg/httputil", @@ -106,6 +108,7 @@ go_test( "backup_test.go", "common_test.go", "config_test.go", + "encryption_test.go", "restore_test.go", "stream_test.go", ], diff --git a/br/pkg/task/common.go b/br/pkg/task/common.go index 94ef67364c376..1de6102abe92e 100644 --- a/br/pkg/task/common.go +++ b/br/pkg/task/common.go @@ -91,9 +91,13 @@ const ( defaultGRPCKeepaliveTimeout = 3 * time.Second defaultCloudAPIConcurrency = 8 - flagCipherType = "crypter.method" - flagCipherKey = "crypter.key" - flagCipherKeyFile = "crypter.key-file" + flagFullBackupCipherType = "crypter.method" + flagFullBackupCipherKey = "crypter.key" + flagFullBackupCipherKeyFile = "crypter.key-file" + + flagLogBackupCipherType = "log.crypter.method" + flagLogBackupCipherKey = "log.crypter.key" + flagLogBackupCipherKeyFile = "log.crypter.key-file" flagMetadataDownloadBatchSize = "metadata-download-batch-size" defaultMetadataDownloadBatchSize = 128 @@ -104,6 +108,10 @@ const ( crypterAES256KeyLen = 32 flagFullBackupType = "type" + + masterKeysDelimiter = "," + flagMasterKeyConfig = "master-key" + flagMasterKeyCipherType = "master-key-crypter-method" ) const ( @@ -260,8 +268,17 @@ type Config struct { // GrpcKeepaliveTimeout is the max time a grpc conn can keep idel before killed. GRPCKeepaliveTimeout time.Duration `json:"grpc-keepalive-timeout" toml:"grpc-keepalive-timeout"` + // Plaintext data key mainly used for full/snapshot backup and restore. CipherInfo backuppb.CipherInfo `json:"-" toml:"-"` + // Could be used in log backup and restore but not recommended in a serious environment since data key is stored + // in PD in plaintext. + LogBackupCipherInfo backuppb.CipherInfo `json:"-" toml:"-"` + + // Master key based encryption for log restore. + // More than one can be specified for log restore if master key rotated during log backup. + MasterKeyConfig backuppb.MasterKeyConfig `json:"master-key-config" toml:"master-key-config"` + // whether there's explicit filter ExplicitFilter bool `json:"-" toml:"-"` @@ -310,17 +327,36 @@ func DefineCommonFlags(flags *pflag.FlagSet) { flags.BoolP(flagSkipCheckPath, "", false, "Skip path verification") _ = flags.MarkHidden(flagSkipCheckPath) - flags.String(flagCipherType, "plaintext", "Encrypt/decrypt method, "+ + flags.String(flagFullBackupCipherType, "plaintext", "Encrypt/decrypt method, "+ "be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+ "\"plaintext\" represents no encrypt/decrypt") - flags.String(flagCipherKey, "", + flags.String(flagFullBackupCipherKey, "", "aes-crypter key, used to encrypt/decrypt the data "+ "by the hexadecimal string, eg: \"0123456789abcdef0123456789abcdef\"") - flags.String(flagCipherKeyFile, "", "FilePath, its content is used as the cipher-key") + flags.String(flagFullBackupCipherKeyFile, "", "FilePath, its content is used as the cipher-key") flags.Uint(flagMetadataDownloadBatchSize, defaultMetadataDownloadBatchSize, "the batch size of downloading metadata, such as log restore metadata for truncate or restore") + // log backup plaintext key flags + flags.String(flagLogBackupCipherType, "plaintext", "Encrypt/decrypt method, "+ + "be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+ + "\"plaintext\" represents no encrypt/decrypt") + flags.String(flagLogBackupCipherKey, "", + "aes-crypter key, used to encrypt/decrypt the data "+ + "by the hexadecimal string, eg: \"0123456789abcdef0123456789abcdef\"") + flags.String(flagLogBackupCipherKeyFile, "", "FilePath, its content is used as the cipher-key") + + // master key config + flags.String(flagMasterKeyCipherType, "plaintext", "Encrypt/decrypt method, "+ + "be one of plaintext|aes128-ctr|aes192-ctr|aes256-ctr case-insensitively, "+ + "\"plaintext\" represents no encrypt/decrypt") + flags.String(flagMasterKeyConfig, "", "Master key configs for point in time restore "+ + "can use comma separated string to specify multiple master key backends if log backup had master key rotation."+ + "example: \"local:///path/to/master/key/file,"+ + "aws-kms:///{key-id}?AWS_ACCESS_KEY_ID={access-key}&AWS_SECRET_ACCESS_KEY={secret-key}®ION={region},"+ + "azure-kms:///{key-name}/{key-version}?AZURE_TENANT_ID={tenant-id}&AZURE_CLIENT_ID={client-id}&AZURE_CLIENT_SECRET={client-secret}&AZURE_VAULT_NAME={vault-name},"+ + "gcp-kms:///projects/{project-id}/locations/{location}/keyRings/{keyring}/cryptoKeys/{key-name}?AUTH=specified&CREDENTIALS={credentials}\"") _ = flags.MarkHidden(flagMetadataDownloadBatchSize) storage.DefineFlags(flags) @@ -334,10 +370,15 @@ func HiddenFlagsForStream(flags *pflag.FlagSet) { _ = flags.MarkHidden(flagRateLimit) _ = flags.MarkHidden(flagRateLimitUnit) _ = flags.MarkHidden(flagRemoveTiFlash) - _ = flags.MarkHidden(flagCipherType) - _ = flags.MarkHidden(flagCipherKey) - _ = flags.MarkHidden(flagCipherKeyFile) + _ = flags.MarkHidden(flagFullBackupCipherType) + _ = flags.MarkHidden(flagFullBackupCipherKey) + _ = flags.MarkHidden(flagFullBackupCipherKeyFile) + _ = flags.MarkHidden(flagLogBackupCipherType) + _ = flags.MarkHidden(flagLogBackupCipherKey) + _ = flags.MarkHidden(flagLogBackupCipherKeyFile) _ = flags.MarkHidden(flagSwitchModeInterval) + _ = flags.MarkHidden(flagMasterKeyConfig) + _ = flags.MarkHidden(flagMasterKeyCipherType) storage.HiddenFlagsForStream(flags) } @@ -456,7 +497,7 @@ func checkCipherKeyMatch(cipher *backuppb.CipherInfo) bool { } func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error { - crypterStr, err := flags.GetString(flagCipherType) + crypterStr, err := flags.GetString(flagFullBackupCipherType) if err != nil { return errors.Trace(err) } @@ -470,12 +511,12 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error { return nil } - key, err := flags.GetString(flagCipherKey) + key, err := flags.GetString(flagFullBackupCipherKey) if err != nil { return errors.Trace(err) } - keyFilePath, err := flags.GetString(flagCipherKeyFile) + keyFilePath, err := flags.GetString(flagFullBackupCipherKeyFile) if err != nil { return errors.Trace(err) } @@ -492,6 +533,43 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error { return nil } +func (cfg *Config) parseLogBackupCipherInfo(flags *pflag.FlagSet) (bool, error) { + crypterStr, err := flags.GetString(flagLogBackupCipherType) + if err != nil { + return false, errors.Trace(err) + } + + cfg.LogBackupCipherInfo.CipherType, err = parseCipherType(crypterStr) + if err != nil { + return false, errors.Trace(err) + } + + if !isEffectiveEncryptionMethod(cfg.LogBackupCipherInfo.CipherType) { + return false, nil + } + + key, err := flags.GetString(flagLogBackupCipherKey) + if err != nil { + return false, errors.Trace(err) + } + + keyFilePath, err := flags.GetString(flagLogBackupCipherKeyFile) + if err != nil { + return false, errors.Trace(err) + } + + cfg.LogBackupCipherInfo.CipherKey, err = GetCipherKeyContent(key, keyFilePath) + if err != nil { + return false, errors.Trace(err) + } + + if !checkCipherKeyMatch(&cfg.CipherInfo) { + return false, errors.Annotate(berrors.ErrInvalidArgument, "log backup encryption method and key length not match") + } + + return true, nil +} + func (cfg *Config) normalizePDURLs() error { for i := range cfg.PD { var err error @@ -618,7 +696,17 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error { log.L().Info("--skip-check-path is deprecated, need explicitly set it anymore") } - if err = cfg.parseCipherInfo(flags); err != nil { + err = cfg.parseCipherInfo(flags) + if err != nil { + return errors.Trace(err) + } + + hasLogBackupPlaintextKey, err := cfg.parseLogBackupCipherInfo(flags) + if err != nil { + return errors.Trace(err) + } + + if err = cfg.parseAndValidateMasterKeyInfo(hasLogBackupPlaintextKey, flags); err != nil { return errors.Trace(err) } @@ -629,6 +717,56 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error { return cfg.normalizePDURLs() } +func (cfg *Config) parseAndValidateMasterKeyInfo(hasPlaintextKey bool, flags *pflag.FlagSet) error { + masterKeyString, err := flags.GetString(flagMasterKeyConfig) + if err != nil { + return errors.Errorf("master key flag '%s' is not defined: %v", flagMasterKeyConfig, err) + } + + if masterKeyString == "" { + return nil + } + + if hasPlaintextKey { + return errors.Errorf("invalid argument: both plaintext data key encryption and master key based encryption are set at the same time") + } + + encryptionMethodString, err := flags.GetString(flagMasterKeyCipherType) + if err != nil { + return errors.Errorf("encryption method flag '%s' is not defined: %v", flagMasterKeyCipherType, err) + } + + encryptionMethod, err := parseCipherType(encryptionMethodString) + if err != nil { + return errors.Errorf("failed to parse encryption method: %v", err) + } + + if !isEffectiveEncryptionMethod(encryptionMethod) { + return errors.Errorf("invalid encryption method: %s", encryptionMethodString) + } + + masterKeyStrings := strings.Split(masterKeyString, masterKeysDelimiter) + cfg.MasterKeyConfig = backuppb.MasterKeyConfig{ + EncryptionType: encryptionMethod, + MasterKeys: make([]*encryptionpb.MasterKey, 0, len(masterKeyStrings)), + } + + for _, keyString := range masterKeyStrings { + masterKey, err := validateAndParseMasterKeyString(strings.TrimSpace(keyString)) + if err != nil { + return errors.Wrapf(err, "invalid master key configuration: %s", keyString) + } + cfg.MasterKeyConfig.MasterKeys = append(cfg.MasterKeyConfig.MasterKeys, &masterKey) + } + + return nil +} + +func isEffectiveEncryptionMethod(method encryptionpb.EncryptionMethod) bool { + return method != encryptionpb.EncryptionMethod_UNKNOWN && + method != encryptionpb.EncryptionMethod_PLAINTEXT +} + // NewMgr creates a new mgr at the given PD address. func NewMgr(ctx context.Context, g glue.Glue, pds []string, @@ -726,7 +864,7 @@ func ReadBackupMeta( if cfg.CipherInfo.CipherType != encryptionpb.EncryptionMethod_PLAINTEXT { iv = metaData[:metautil.CrypterIvLen] } - decryptBackupMeta, err := metautil.Decrypt(metaData[len(iv):], &cfg.CipherInfo, iv) + decryptBackupMeta, err := utils.Decrypt(metaData[len(iv):], &cfg.CipherInfo, iv) if err != nil { return nil, nil, nil, errors.Annotate(err, "decrypt failed with wrong key") } diff --git a/br/pkg/task/common_test.go b/br/pkg/task/common_test.go index c942b96bc531e..c4433da574109 100644 --- a/br/pkg/task/common_test.go +++ b/br/pkg/task/common_test.go @@ -185,6 +185,7 @@ func expectedDefaultConfig() Config { GRPCKeepaliveTime: 10000000000, GRPCKeepaliveTimeout: 3000000000, CipherInfo: backup.CipherInfo{CipherType: 1}, + LogBackupCipherInfo: backup.CipherInfo{CipherType: 1}, MetadataDownloadBatchSize: 0x80, } } @@ -241,3 +242,132 @@ func TestDefaultRestore(t *testing.T) { defaultConfig := expectedDefaultRestoreConfig() require.Equal(t, defaultConfig, def) } + +func TestParseAndValidateMasterKeyInfo(t *testing.T) { + tests := []struct { + name string + input string + expectedKeys []*encryptionpb.MasterKey + expectError bool + }{ + { + name: "Empty input", + input: "", + expectedKeys: nil, + expectError: false, + }, + { + name: "Single local config", + input: "local:///path/to/key", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"}, + }, + }, + }, + expectError: false, + }, + { + name: "Single AWS config", + input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "aws", + KeyId: "key-id", + Region: "us-west-2", + }, + }, + }, + }, + expectError: false, + }, + { + name: "Single Azure config", + input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_CLIENT_SECRET=client-secret&AZURE_VAULT_NAME=vault-name", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "azure", + KeyId: "key-name/key-version", + AzureKms: &encryptionpb.AzureKms{ + TenantId: "tenant-id", + ClientId: "client-id", + ClientSecret: "client-secret", + KeyVaultUrl: "vault-name", + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "Single GCP config", + input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name?CREDENTIALS=credentials", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "gcp", + KeyId: "projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name", + GcpKms: &encryptionpb.GcpKms{ + Credential: "credentials", + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "Multiple configs", + input: "local:///path/to/key," + + "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2", + expectedKeys: []*encryptionpb.MasterKey{ + { + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"}, + }, + }, + { + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "aws", + KeyId: "key-id", + Region: "us-west-2", + }, + }, + }, + }, + expectError: false, + }, + { + name: "Invalid config", + input: "invalid:///config", + expectedKeys: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{} + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.String(flagMasterKeyConfig, tt.input, "") + flags.String(flagMasterKeyCipherType, "aes256-ctr", "") + + err := cfg.parseAndValidateMasterKeyInfo(false, flags) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedKeys, cfg.MasterKeyConfig.MasterKeys) + } + }) + } +} diff --git a/br/pkg/task/encryption.go b/br/pkg/task/encryption.go new file mode 100644 index 0000000000000..69f90c327faaf --- /dev/null +++ b/br/pkg/task/encryption.go @@ -0,0 +1,160 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package task + +import ( + "fmt" + "net/url" + "regexp" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/encryptionpb" +) + +const ( + SchemeLocal = "local" + SchemeAWS = "aws-kms" + SchemeAzure = "azure-kms" + SchemeGCP = "gcp-kms" + + AWSVendor = "aws" + AWSAccessKeyID = "AWS_ACCESS_KEY_ID" + AWSSecretAccessKey = "AWS_SECRET_ACCESS_KEY" + AWSRegion = "REGION" + AWSEndpoint = "ENDPOINT" + + AzureVendor = "azure" + AzureTenantID = "AZURE_TENANT_ID" + AzureClientID = "AZURE_CLIENT_ID" + AzureClientSecret = "AZURE_CLIENT_SECRET" + AzureVaultName = "AZURE_VAULT_NAME" + + GCPVendor = "gcp" + GCPCredentials = "CREDENTIALS" +) + +var ( + localRegex = regexp.MustCompile(`^/.*$`) + awsRegex = regexp.MustCompile(`^/([^/]+)$`) + azureRegex = regexp.MustCompile(`^/(.+)$`) + gcpRegex = regexp.MustCompile(`^/projects/([^/]+)/locations/([^/]+)/keyRings/([^/]+)/cryptoKeys/([^/]+)/?$`) +) + +func validateAndParseMasterKeyString(keyString string) (encryptionpb.MasterKey, error) { + u, err := url.Parse(keyString) + if err != nil { + return encryptionpb.MasterKey{}, errors.Trace(err) + } + + switch u.Scheme { + case SchemeLocal: + return parseLocalDiskConfig(u) + case SchemeAWS: + return parseAwsKmsConfig(u) + case SchemeAzure: + return parseAzureKmsConfig(u) + case SchemeGCP: + return parseGcpKmsConfig(u) + default: + return encryptionpb.MasterKey{}, errors.Errorf("unsupported master key type: %s", u.Scheme) + } +} + +func parseLocalDiskConfig(u *url.URL) (encryptionpb.MasterKey, error) { + if !localRegex.MatchString(u.Path) { + return encryptionpb.MasterKey{}, errors.New("local master key path must be absolute") + } + return encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_File{ + File: &encryptionpb.MasterKeyFile{ + Path: u.Path, + }, + }, + }, nil +} + +func parseAwsKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) { + matches := awsRegex.FindStringSubmatch(u.Path) + if matches == nil { + return encryptionpb.MasterKey{}, errors.New("invalid AWS KMS key ID format") + } + keyID := matches[1] + + q := u.Query() + accessKey := q.Get(AWSAccessKeyID) + secretKey := q.Get(AWSSecretAccessKey) + region := q.Get(AWSRegion) + + if accessKey == "" || secretKey == "" || region == "" { + return encryptionpb.MasterKey{}, errors.New("missing required AWS KMS parameters") + } + + return encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: AWSVendor, + KeyId: keyID, + Region: region, + Endpoint: q.Get(AWSEndpoint), // Optional + }, + }, + }, nil +} + +func parseAzureKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) { + matches := azureRegex.FindStringSubmatch(u.Path) + if matches == nil { + return encryptionpb.MasterKey{}, errors.New("invalid Azure KMS path format") + } + + keyID := matches[1] // This now captures the entire path as the key ID + q := u.Query() + + azureKms := &encryptionpb.AzureKms{ + TenantId: q.Get(AzureTenantID), + ClientId: q.Get(AzureClientID), + ClientSecret: q.Get(AzureClientSecret), + KeyVaultUrl: q.Get(AzureVaultName), + } + + if azureKms.TenantId == "" || azureKms.ClientId == "" || azureKms.ClientSecret == "" || azureKms.KeyVaultUrl == "" { + return encryptionpb.MasterKey{}, errors.New("missing required Azure KMS parameters") + } + + return encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: AzureVendor, + KeyId: keyID, + AzureKms: azureKms, + }, + }, + }, nil +} + +func parseGcpKmsConfig(u *url.URL) (encryptionpb.MasterKey, error) { + matches := gcpRegex.FindStringSubmatch(u.Path) + if matches == nil { + return encryptionpb.MasterKey{}, errors.New("invalid GCP KMS path format") + } + + projectID, location, keyRing, keyName := matches[1], matches[2], matches[3], matches[4] + q := u.Query() + credentials := q.Get(GCPCredentials) + + if credentials == "" { + return encryptionpb.MasterKey{}, errors.Errorf("missing required GCP KMS parameter: %s", GCPCredentials) + } + + return encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: GCPVendor, + KeyId: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", projectID, location, keyRing, keyName), + GcpKms: &encryptionpb.GcpKms{ + Credential: credentials, + }, + }, + }, + }, nil +} diff --git a/br/pkg/task/encryption_test.go b/br/pkg/task/encryption_test.go new file mode 100644 index 0000000000000..033062511fd31 --- /dev/null +++ b/br/pkg/task/encryption_test.go @@ -0,0 +1,192 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package task + +import ( + "net/url" + "testing" + + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/stretchr/testify/assert" +) + +func TestParseLocalDiskConfig(t *testing.T) { + tests := []struct { + name string + input string + expected encryptionpb.MasterKey + expectError bool + }{ + { + name: "Valid local path", + input: "local:///path/to/key", + expected: encryptionpb.MasterKey{Backend: &encryptionpb.MasterKey_File{File: &encryptionpb.MasterKeyFile{Path: "/path/to/key"}}}, + expectError: false, + }, + { + name: "Invalid local path", + input: "local://relative/path", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.input) + result, err := parseLocalDiskConfig(u) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestParseAwsKmsConfig(t *testing.T) { + tests := []struct { + name string + input string + expected encryptionpb.MasterKey + expectError bool + }{ + { + name: "Valid AWS config", + input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2", + expected: encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "aws", + KeyId: "key-id", + Region: "us-west-2", + }, + }, + }, + expectError: false, + }, + { + name: "Missing key ID", + input: "aws-kms:///?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE&AWS_SECRET_ACCESS_KEY=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY®ION=us-west-2", + expectError: true, + }, + { + name: "Missing required parameter", + input: "aws-kms:///key-id?AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE®ION=us-west-2", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.input) + result, err := parseAwsKmsConfig(u) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestParseAzureKmsConfig(t *testing.T) { + tests := []struct { + name string + input string + expected encryptionpb.MasterKey + expectError bool + }{ + { + name: "Valid Azure config", + input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_CLIENT_SECRET=client-secret&AZURE_VAULT_NAME=vault-name", + expected: encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "azure", + KeyId: "key-name/key-version", + AzureKms: &encryptionpb.AzureKms{ + TenantId: "tenant-id", + ClientId: "client-id", + ClientSecret: "client-secret", + KeyVaultUrl: "vault-name", + }, + }, + }, + }, + expectError: false, + }, + { + name: "Missing required parameter", + input: "azure-kms:///key-name/key-version?AZURE_TENANT_ID=tenant-id&AZURE_CLIENT_ID=client-id&AZURE_VAULT_NAME=vault-name", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.input) + result, err := parseAzureKmsConfig(u) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestParseGcpKmsConfig(t *testing.T) { + tests := []struct { + name string + input string + expected encryptionpb.MasterKey + expectError bool + }{ + { + name: "Valid GCP config", + input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name?CREDENTIALS=credentials", + expected: encryptionpb.MasterKey{ + Backend: &encryptionpb.MasterKey_Kms{ + Kms: &encryptionpb.MasterKeyKms{ + Vendor: "gcp", + KeyId: "projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name", + GcpKms: &encryptionpb.GcpKms{ + Credential: "credentials", + }, + }, + }, + }, + expectError: false, + }, + { + name: "Invalid path format", + input: "gcp-kms:///invalid/path?CREDENTIALS=credentials", + expectError: true, + }, + { + name: "Missing credentials", + input: "gcp-kms:///projects/project-id/locations/global/keyRings/ring-name/cryptoKeys/key-name", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.input) + result, err := parseGcpKmsConfig(u) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 3e45bf6934f1c..7a4d4db78cd27 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -247,7 +247,8 @@ type RestoreConfig struct { AllowPITRFromIncremental bool `json:"allow-pitr-from-incremental" toml:"allow-pitr-from-incremental"` // [startTs, RestoreTS] is used to `restore log` from StartTS to RestoreTS. - StartTS uint64 `json:"start-ts" toml:"start-ts"` + StartTS uint64 `json:"start-ts" toml:"start-ts"` + // if not specified system will restore to the max TS available RestoreTS uint64 `json:"restore-ts" toml:"restore-ts"` tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"` PitrBatchCount uint32 `json:"pitr-batch-count" toml:"pitr-batch-count"` @@ -717,9 +718,9 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf var restoreError error if IsStreamRestore(cmdName) { - restoreError = RunStreamRestore(c, g, cmdName, cfg) + restoreError = RunStreamRestore(c, g, cfg) } else { - restoreError = runRestore(c, g, cmdName, cfg, nil) + restoreError = runSnapshotRestore(c, g, cmdName, cfg, nil) } if restoreError != nil { return errors.Trace(restoreError) @@ -751,22 +752,22 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf return nil } -func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error { +func runSnapshotRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error { cfg.Adjust() defer summary.Summary(cmdName) ctx, cancel := context.WithCancel(c) defer cancel() + log.Info("starting snapshot restore") if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("task.RunRestore", opentracing.ChildOf(span.Context())) defer span1.Finish() ctx = opentracing.ContextWithSpan(ctx, span1) } - // Restore needs domain to do DDL. - needDomain := true keepaliveCfg := GetKeepalive(&cfg.Config) - mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, keepaliveCfg, cfg.CheckRequirements, needDomain, conn.NormalVersionChecker) + // Restore needs domain to do DDL. + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, keepaliveCfg, cfg.CheckRequirements, true, conn.NormalVersionChecker) if err != nil { return errors.Trace(err) } @@ -861,7 +862,7 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf if client.IsIncremental() { // don't support checkpoint for the ddl restore - log.Info("the incremental snapshot restore doesn't support checkpoint mode, so unuse checkpoint.") + log.Info("the incremental snapshot restore doesn't support checkpoint mode, disable checkpoint.") cfg.UseCheckpoint = false } @@ -886,7 +887,7 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf }() var checkpointTaskName string - var checkpointFirstRun bool = true + var checkpointFirstRun = true if cfg.UseCheckpoint { checkpointTaskName = cfg.generateSnapshotRestoreTaskName(client.GetClusterID(ctx)) // if the checkpoint metadata exists in the external storage, the restore is not diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index dc0a5f863637d..c94577b868001 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -36,6 +36,7 @@ import ( "github.com/pingcap/tidb/br/pkg/backup" "github.com/pingcap/tidb/br/pkg/checkpoint" "github.com/pingcap/tidb/br/pkg/conn" + "github.com/pingcap/tidb/br/pkg/encryption" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/httputil" @@ -302,8 +303,8 @@ func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamS } }() - // just stream start need Storage - s := &streamMgr{ + // only stream start command needs Storage + streamManager := &streamMgr{ cfg: cfg, mgr: mgr, } @@ -323,12 +324,12 @@ func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamS if err = client.SetStorage(ctx, backend, &opts); err != nil { return nil, errors.Trace(err) } - s.bc = client + streamManager.bc = client // create http client to do some requirements check. - s.httpCli = httputil.NewClient(mgr.GetTLSConfig()) + streamManager.httpCli = httputil.NewClient(mgr.GetTLSConfig()) } - return s, nil + return streamManager, nil } func (s *streamMgr) close() { @@ -346,7 +347,7 @@ func (s *streamMgr) setLock(ctx context.Context) error { // adjustAndCheckStartTS checks that startTS should be smaller than currentTS, // and endTS is larger than currentTS. func (s *streamMgr) adjustAndCheckStartTS(ctx context.Context) error { - currentTS, err := s.mgr.GetTS(ctx) + currentTS, err := s.mgr.GetCurrentTsFromPd(ctx) if err != nil { return errors.Trace(err) } @@ -506,7 +507,7 @@ func RunStreamCommand( } if err := commandFn(ctx, g, cmdName, cfg); err != nil { - log.Error("failed to stream", zap.String("command", cmdName), zap.Error(err)) + log.Error("failed to run stream command", zap.String("command", cmdName), zap.Error(err)) summary.SetSuccessStatus(false) summary.CollectFailureUnit(cmdName, err) return err @@ -554,22 +555,28 @@ func RunStreamStart( log.Warn("failed to close etcd client", zap.Error(closeErr)) } }() + + // check if any import/restore task is running, it's not allowed to start log backup + // while restore is ongoing. if err = streamMgr.checkImportTaskRunning(ctx, cli.Client); err != nil { return errors.Trace(err) } + // It supports single stream log task currently. if count, err := cli.GetTaskCount(ctx); err != nil { return errors.Trace(err) } else if count > 0 { - return errors.Annotate(berrors.ErrStreamLogTaskExist, "It supports single stream log task currently") + return errors.Annotate(berrors.ErrStreamLogTaskExist, "failed to start the log backup, allow only one running task") } - exist, err := streamMgr.checkLock(ctx) + // make sure external file lock is available + locked, err := streamMgr.checkLock(ctx) if err != nil { return errors.Trace(err) } - // exist is true, which represents restart a stream task. Or create a new stream task. - if exist { + + // locked means this is a stream task restart. Or create a new stream task. + if locked { logInfo, err := getLogRange(ctx, &cfg.Config) if err != nil { return errors.Trace(err) @@ -621,6 +628,7 @@ func RunStreamStart( return errors.Annotate(berrors.ErrInvalidArgument, "nothing need to observe") } + securityConfig := generateSecurityConfig(cfg) ti := streamhelper.TaskInfo{ PBInfo: backuppb.StreamBackupTaskInfo{ Storage: streamMgr.bc.GetStorageBackend(), @@ -629,6 +637,7 @@ func RunStreamStart( Name: cfg.TaskName, TableFilter: cfg.FilterStr, CompressionType: backuppb.CompressionType_ZSTD, + SecurityConfig: &securityConfig, }, Ranges: ranges, Pausing: false, @@ -640,6 +649,30 @@ func RunStreamStart( return nil } +func generateSecurityConfig(cfg *StreamConfig) backuppb.StreamBackupTaskSecurityConfig { + if len(cfg.LogBackupCipherInfo.CipherKey) > 0 && isEffectiveEncryptionMethod(cfg.LogBackupCipherInfo.CipherType) { + return backuppb.StreamBackupTaskSecurityConfig{ + Encryption: &backuppb.StreamBackupTaskSecurityConfig_PlaintextDataKey{ + PlaintextDataKey: &backuppb.CipherInfo{ + CipherType: cfg.LogBackupCipherInfo.CipherType, + CipherKey: cfg.LogBackupCipherInfo.CipherKey, + }, + }, + } + } + if len(cfg.MasterKeyConfig.MasterKeys) > 0 && isEffectiveEncryptionMethod(cfg.MasterKeyConfig.EncryptionType) { + return backuppb.StreamBackupTaskSecurityConfig{ + Encryption: &backuppb.StreamBackupTaskSecurityConfig_MasterKeyConfig{ + MasterKeyConfig: &backuppb.MasterKeyConfig{ + EncryptionType: cfg.MasterKeyConfig.EncryptionType, + MasterKeys: cfg.MasterKeyConfig.MasterKeys, + }, + }, + } + } + return backuppb.StreamBackupTaskSecurityConfig{} +} + func RunStreamMetadata( c context.Context, g glue.Glue, @@ -1002,16 +1035,16 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre } if cfg.Until < sp { - console.Println("According to the log, you have truncated backup data before", em(formatTS(sp))) + console.Println("According to the log, you have truncated log backup data before", em(formatTS(sp))) if !cfg.SkipPrompt && !console.PromptBool("Continue? ") { return nil } } - readMetaDone := console.ShowTask("Reading Metadata... ", glue.WithTimeCost()) + readMetaDone := console.ShowTask("Reading log backup metadata... ", glue.WithTimeCost()) metas := stream.StreamMetadataSet{ MetadataDownloadBatchSize: cfg.MetadataDownloadBatchSize, - Helper: stream.NewMetadataHelper(), + Helper: stream.NewMetadataHelper(nil), DryRun: cfg.DryRun, } shiftUntilTS, err := metas.LoadUntilAndCalculateShiftTS(ctx, extStorage, cfg.Until) @@ -1119,7 +1152,6 @@ func checkIncompatibleChangefeed(ctx context.Context, backupTS uint64, etcdCLI * func RunStreamRestore( c context.Context, g glue.Glue, - cmdName string, cfg *RestoreConfig, ) (err error) { ctx, cancelFn := context.WithCancel(c) @@ -1138,6 +1170,8 @@ func RunStreamRestore( if err != nil { return errors.Trace(err) } + + // if not set by user, restore to the max TS available if cfg.RestoreTS == 0 { cfg.RestoreTS = logInfo.logMaxTS } @@ -1163,7 +1197,7 @@ func RunStreamRestore( } } - log.Info("start restore on point", + log.Info("start point in time restore", zap.Uint64("restore-from", cfg.StartTS), zap.Uint64("restore-to", cfg.RestoreTS), zap.Uint64("log-min-ts", logInfo.logMinTS), zap.Uint64("log-max-ts", logInfo.logMaxTS)) if err := checkLogRange(cfg.StartTS, cfg.RestoreTS, logInfo.logMinTS, logInfo.logMaxTS); err != nil { @@ -1186,7 +1220,7 @@ func RunStreamRestore( logStorage := cfg.Config.Storage cfg.Config.Storage = cfg.FullBackupStorage // TiFlash replica is restored to down-stream on 'pitr' currently. - if err = runRestore(ctx, g, FullRestoreCmd, cfg, checkInfo); err != nil { + if err = runSnapshotRestore(ctx, g, FullRestoreCmd, cfg, checkInfo); err != nil { return errors.Trace(err) } cfg.Config.Storage = logStorage @@ -1197,9 +1231,6 @@ func RunStreamRestore( } if checkInfo.CheckpointInfo != nil && checkInfo.CheckpointInfo.TiFlashItems != nil { log.Info("load tiflash records of snapshot restore from checkpoint") - if err != nil { - return errors.Trace(err) - } cfg.tiflashRecorder.Load(checkInfo.CheckpointInfo.TiFlashItems) } } @@ -1335,7 +1366,11 @@ func restoreStream( }() } - err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize) + encryptionManager, err := encryption.NewManager(&cfg.LogBackupCipherInfo, &cfg.MasterKeyConfig) + if err != nil { + return errors.Annotate(err, "failed to create encryption manager for log restore") + } + err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize, encryptionManager) if err != nil { return err } @@ -1356,7 +1391,7 @@ func restoreStream( TableFilter: cfg.TableFilter, TiFlashRecorder: cfg.tiflashRecorder, FullBackupStorage: fullBackupStorage, - }) + }, &cfg.Config.CipherInfo) if err != nil { return errors.Trace(err) } @@ -1432,7 +1467,8 @@ func restoreStream( return errors.Trace(err) } - return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy) + return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, + cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy, &cfg.LogBackupCipherInfo, cfg.MasterKeyConfig.MasterKeys) }) if err != nil { return errors.Annotate(err, "failed to restore kv files") @@ -1548,15 +1584,15 @@ func getExternalStorageOptions(cfg *Config, u *backuppb.StorageBackend) storage. } } -func checkLogRange(restoreFrom, restoreTo, logMinTS, logMaxTS uint64) error { - // serveral ts constraint: - // logMinTS <= restoreFrom <= restoreTo <= logMaxTS - if logMinTS > restoreFrom || restoreFrom > restoreTo || restoreTo > logMaxTS { +func checkLogRange(restoreFromTS, restoreToTS, logMinTS, logMaxTS uint64) error { + // several ts constraint: + // logMinTS <= restoreFromTS <= restoreToTS <= logMaxTS + if logMinTS > restoreFromTS || restoreFromTS > restoreToTS || restoreToTS > logMaxTS { return errors.Annotatef(berrors.ErrInvalidArgument, "restore log from %d(%s) to %d(%s), "+ " but the current existed log from %d(%s) to %d(%s)", - restoreFrom, oracle.GetTimeFromTS(restoreFrom), - restoreTo, oracle.GetTimeFromTS(restoreTo), + restoreFromTS, oracle.GetTimeFromTS(restoreFromTS), + restoreToTS, oracle.GetTimeFromTS(restoreToTS), logMinTS, oracle.GetTimeFromTS(logMinTS), logMaxTS, oracle.GetTimeFromTS(logMaxTS), ) @@ -1649,7 +1685,7 @@ func getGlobalCheckpointFromStorage(ctx context.Context, s storage.ExternalStora return globalCheckPointTS, errors.Trace(err) } -// getFullBackupTS gets the snapshot-ts of full bakcup +// getFullBackupTS gets the snapshot-ts of full backup func getFullBackupTS( ctx context.Context, cfg *RestoreConfig, @@ -1664,11 +1700,17 @@ func getFullBackupTS( return 0, 0, errors.Trace(err) } + decryptedMetaData, err := metautil.DecryptFullBackupMetaIfNeeded(metaData, &cfg.CipherInfo) + if err != nil { + return 0, 0, errors.Trace(err) + } + backupmeta := &backuppb.BackupMeta{} - if err = backupmeta.Unmarshal(metaData); err != nil { + if err = backupmeta.Unmarshal(decryptedMetaData); err != nil { return 0, 0, errors.Trace(err) } + // start and end are identical in full backup, pick random one return backupmeta.GetEndVersion(), backupmeta.GetClusterId(), nil } @@ -1758,7 +1800,7 @@ func checkPiTRTaskInfo( cfg *RestoreConfig, ) (*PiTRTaskInfo, error) { var ( - doFullRestore = (len(cfg.FullBackupStorage) > 0) + doFullRestore = len(cfg.FullBackupStorage) > 0 curTaskInfo *checkpoint.CheckpointTaskInfoForLogRestore errTaskMsg string ) diff --git a/br/pkg/task/stream_test.go b/br/pkg/task/stream_test.go index 627924e8239cc..50dbef03e1264 100644 --- a/br/pkg/task/stream_test.go +++ b/br/pkg/task/stream_test.go @@ -24,6 +24,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/metautil" "github.com/pingcap/tidb/br/pkg/storage" @@ -110,35 +111,6 @@ func TestCheckLogRange(t *testing.T) { } } -type fakeResolvedInfo struct { - storeID int64 - resolvedTS uint64 -} - -func fakeMetaFiles(ctx context.Context, tempDir string, infos []fakeResolvedInfo) error { - backupMetaDir := filepath.Join(tempDir, stream.GetStreamBackupMetaPrefix()) - s, err := storage.NewLocalStorage(backupMetaDir) - if err != nil { - return errors.Trace(err) - } - - for _, info := range infos { - meta := &backuppb.Metadata{ - StoreId: info.storeID, - ResolvedTs: info.resolvedTS, - } - buff, err := meta.Marshal() - if err != nil { - return errors.Trace(err) - } - filename := fmt.Sprintf("%d_%d.meta", info.storeID, info.resolvedTS) - if err = s.WriteFile(ctx, filename, buff); err != nil { - return errors.Trace(err) - } - } - return nil -} - func fakeCheckpointFiles( ctx context.Context, tmpDir string, @@ -154,7 +126,7 @@ func fakeCheckpointFiles( for _, info := range infos { filename := fmt.Sprintf("%v.ts", info.storeID) buff := make([]byte, 8) - binary.LittleEndian.PutUint64(buff, info.global_checkpoint) + binary.LittleEndian.PutUint64(buff, info.globalCheckpoint) if _, err := s.Create(ctx, filename, nil); err != nil { return errors.Trace(err) } @@ -170,8 +142,8 @@ func fakeCheckpointFiles( } type fakeGlobalCheckPoint struct { - storeID int64 - global_checkpoint uint64 + storeID int64 + globalCheckpoint uint64 } func TestGetGlobalCheckpointFromStorage(t *testing.T) { @@ -182,16 +154,16 @@ func TestGetGlobalCheckpointFromStorage(t *testing.T) { infos := []fakeGlobalCheckPoint{ { - storeID: 1, - global_checkpoint: 98, + storeID: 1, + globalCheckpoint: 98, }, { - storeID: 2, - global_checkpoint: 90, + storeID: 2, + globalCheckpoint: 90, }, { - storeID: 2, - global_checkpoint: 99, + storeID: 2, + globalCheckpoint: 99, }, } @@ -261,3 +233,38 @@ func TestGetExternalStorageOptions(t *testing.T) { options = getExternalStorageOptions(&cfg, u) require.Nil(t, options.HTTPClient) } + +func TestGenerateSecurityConfig(t *testing.T) { + cfg := &StreamConfig{ + Config: Config{ + CipherInfo: backuppb.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_UNKNOWN, + CipherKey: []byte("12345678901234567890123456789012"), + }, + }, + } + securityConfig := generateSecurityConfig(cfg) + require.Equal(t, securityConfig, backuppb.StreamBackupTaskSecurityConfig{}) + + cfg = &StreamConfig{ + Config: Config{ + CipherInfo: backuppb.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_PLAINTEXT, + CipherKey: []byte("12345678901234567890123456789012"), + }, + }, + } + securityConfig = generateSecurityConfig(cfg) + require.Equal(t, securityConfig, backuppb.StreamBackupTaskSecurityConfig{}) + + cfg = &StreamConfig{ + Config: Config{ + MasterKeyConfig: backuppb.MasterKeyConfig{ + EncryptionType: encryptionpb.EncryptionMethod_AES256_CTR, + MasterKeys: []*encryptionpb.MasterKey{}, + }, + }, + } + securityConfig = generateSecurityConfig(cfg) + require.Equal(t, securityConfig, backuppb.StreamBackupTaskSecurityConfig{}) +} diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index d185e669169c4..fa18a8317b234 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "db.go", "dyn_pprof_other.go", "dyn_pprof_unix.go", + "encryption.go", "error_handling.go", "json.go", "key.go", @@ -35,6 +36,7 @@ go_library( "//pkg/parser/types", "//pkg/sessionctx", "//pkg/util", + "//pkg/util/encrypt", "//pkg/util/logutil", "//pkg/util/sqlexec", "@com_github_cheggaaa_pb_v3//:pb", @@ -43,6 +45,7 @@ go_library( "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/brpb", + "@com_github_pingcap_kvproto//pkg/encryptionpb", "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_pingcap_log//:log", "@com_github_tikv_client_go_v2//oracle", diff --git a/br/pkg/utils/backoff.go b/br/pkg/utils/backoff.go index fda272606e5c7..a76297c362f02 100644 --- a/br/pkg/utils/backoff.go +++ b/br/pkg/utils/backoff.go @@ -260,7 +260,7 @@ func (bo *pdReqBackoffer) NextBackoff(err error) time.Duration { // If the connection timeout, pd client would cancel the context, and return grpc context cancel error. // So make the codes.Canceled retryable too. // It's OK to retry the grpc context cancel error, because the parent context cancel returns context.Canceled. - // For example, cancel the `ectx` and then pdClient.GetTS(ectx) returns context.Canceled instead of grpc context canceled. + // For example, cancel the `ectx` and then pdClient.GetCurrentTsFromPd(ectx) returns context.Canceled instead of grpc context canceled. switch status.Code(e) { case codes.DeadlineExceeded, codes.Canceled, codes.NotFound, codes.AlreadyExists, codes.PermissionDenied, codes.ResourceExhausted, codes.Aborted, codes.OutOfRange, codes.Unavailable, codes.DataLoss, codes.Unknown: bo.delayTime = 2 * bo.delayTime diff --git a/br/pkg/utils/encryption.go b/br/pkg/utils/encryption.go new file mode 100644 index 0000000000000..471a57497bd0f --- /dev/null +++ b/br/pkg/utils/encryption.go @@ -0,0 +1,26 @@ +package utils + +import ( + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/pkg/util/encrypt" +) + +func Decrypt(content []byte, cipher *backuppb.CipherInfo, iv []byte) ([]byte, error) { + if len(content) == 0 || cipher == nil { + return content, nil + } + + switch cipher.CipherType { + case encryptionpb.EncryptionMethod_PLAINTEXT: + return content, nil + case encryptionpb.EncryptionMethod_AES128_CTR, + encryptionpb.EncryptionMethod_AES192_CTR, + encryptionpb.EncryptionMethod_AES256_CTR: + return encrypt.AESDecryptWithCTR(content, cipher.CipherKey, iv) + default: + return content, errors.Annotate(berrors.ErrInvalidArgument, "cipher type invalid") + } +} diff --git a/br/tests/README.md b/br/tests/README.md index 6338d3ebcdf94..009a0230e4c13 100644 --- a/br/tests/README.md +++ b/br/tests/README.md @@ -33,7 +33,7 @@ This folder contains all tests which relies on external processes such as TiDB. ## Preparations -1. The following 9 executables must be copied or linked into these locations: +1. The following 9 executables must be copied or linked into the `bin` folder under the TiDB root dir: * `bin/tidb-server` * `bin/tikv-server` @@ -80,7 +80,7 @@ If you have docker installed, you can skip step 1 and step 2 by running 1. Build `br.test` using `make build_for_br_integration_test` 2. Check that all 9 required executables and `br` executable exist 3. Select the tests to run using `export TEST_NAME=" ..."` -3. Execute `tests/run.sh` +4. Execute `br/tests/run.sh` If the first two steps are done before, you could also run `tests/run.sh` directly. diff --git a/br/tests/br_encryption/run.sh b/br/tests/br_encryption/run.sh new file mode 100755 index 0000000000000..4a43acf602487 --- /dev/null +++ b/br/tests/br_encryption/run.sh @@ -0,0 +1,252 @@ +#!/bin/sh +# +# Copyright 2024 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eu +. run_services +CUR=$(cd "$(dirname "$0")" && pwd) + +# const value +PREFIX="encryption_backup" +res_file="$TEST_DIR/sql_res.$TEST_NAME.txt" +DB="$TEST_NAME" +TABLE="usertable" +DB_COUNT=3 + +create_db_with_table() { + for i in $(seq $DB_COUNT); do + run_sql "CREATE DATABASE $DB${i};" + go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB${i} -p recordcount=1000 + done +} + +start_log_backup() { + _storage=$1 + _encryption_args=$2 + echo "start log backup task" + run_br --pd "$PD_ADDR" log start --task-name encryption_test -s "$_storage" $_encryption_args +} + +drop_db() { + for i in $(seq $DB_COUNT); do + run_sql "DROP DATABASE IF EXISTS $DB${i};" + done +} + +insert_additional_data() { + local prefix=$1 + for i in $(seq $DB_COUNT); do + go-ycsb load mysql -P $CUR/workload -p mysql.host=$TIDB_IP -p mysql.port=$TIDB_PORT -p mysql.user=root -p mysql.db=$DB${i} -p insertcount=1000 -p insertstart=1000000 -p recordcount=1001000 -p workload=core + done +} + +wait_log_checkpoint_advance() { + echo "wait for log checkpoint to advance" + sleep 10 + local current_ts=$(python3 -c "import time; print(int(time.time() * 1000) << 18)") + echo "current ts: $current_ts" + i=0 + while true; do + # extract the checkpoint ts of the log backup task. If there is some error, the checkpoint ts should be empty + log_backup_status=$(unset BR_LOG_TO_TERM && run_br --skip-goleak --pd $PD_ADDR log status --task-name encryption_test --json 2>br.log) + echo "log backup status: $log_backup_status" + local checkpoint_ts=$(echo "$log_backup_status" | head -n 1 | jq 'if .[0].last_errors | length == 0 then .[0].checkpoint else empty end') + echo "checkpoint ts: $checkpoint_ts" + + # check whether the checkpoint ts is a number + if [ $checkpoint_ts -gt 0 ] 2>/dev/null; then + if [ $checkpoint_ts -gt $current_ts ]; then + echo "the checkpoint has advanced" + break + fi + echo "the checkpoint hasn't advanced" + i=$((i+1)) + if [ "$i" -gt 50 ]; then + echo 'the checkpoint lag is too large' + exit 1 + fi + sleep 10 + else + echo "TEST: [$TEST_NAME] failed to wait checkpoint advance!" + exit 1 + fi + done +} + +calculate_checksum() { + local db=$1 + local checksum=$(run_sql "USE $db; ADMIN CHECKSUM TABLE $TABLE;" | awk '/CHECKSUM/{print $2}') + echo $checksum +} + +check_db_consistency() { + fail=false + for i in $(seq $DB_COUNT); do + local original_checksum=${checksum_ori[i]} + local new_checksum=$(calculate_checksum "$DB${i}") + + if [ "$original_checksum" != "$new_checksum" ]; then + fail=true + echo "TEST: [$TEST_NAME] checksum mismatch on database $DB${i}" + echo "Original checksum: $original_checksum, New checksum: $new_checksum" + else + echo "Database $DB${i} checksum match: $new_checksum" + fi + done + + if $fail; then + echo "TEST: [$TEST_NAME] data consistency check failed!" + return 1 + fi + echo "TEST: [$TEST_NAME] data consistency check passed." + return 0 +} + +run_backup_restore_test() { + local encryption_mode=$1 + local encryption_args=$2 + + restart_services || { echo "Failed to restart services"; exit 1; } + + # Drop existing databases before starting the test + drop_db || { echo "Failed to drop databases"; exit 1; } + + # Start log backup + start_log_backup "local://$TEST_DIR/$PREFIX/log" "$encryption_args" || { echo "Failed to start log backup"; exit 1; } + + # Create test databases and insert initial data + create_db_with_table || { echo "Failed to create databases and tables"; exit 1; } + + # Calculate and store original checksums + for i in $(seq $DB_COUNT); do + checksum_ori[${i}]=$(calculate_checksum "$DB${i}") || { echo "Failed to calculate initial checksum"; exit 1; } + done + + # Full backup + echo "run full backup with $encryption_mode" + run_br --pd "$PD_ADDR" backup full -s "local://$TEST_DIR/$PREFIX/full" $encryption_args || { echo "Full backup failed"; exit 1; } + + # Insert additional test data + insert_additional_data "${encryption_mode}_after_full_backup" || { echo "Failed to insert additional data"; exit 1; } + + # Update checksums after inserting additional data + for i in $(seq $DB_COUNT); do + checksum_ori[${i}]=$(calculate_checksum "$DB${i}") || { echo "Failed to calculate checksum after insertion"; exit 1; } + done + + wait_log_checkpoint_advance || { echo "Failed to wait for log checkpoint"; exit 1; } + + restart_services || { echo "Failed to restart services"; exit 1; } + + # Drop databases before restoring + drop_db || { echo "Failed to drop databases before restore"; exit 1; } + + # Run pitr restore + echo "restore log backup with $encryption_mode" + timeout 300 run_br --pd "$PD_ADDR" restore point -s "local://$TEST_DIR/$PREFIX/log" --full-backup-storage "local://$TEST_DIR/$PREFIX/full" $encryption_args || { + echo "Log backup restore failed or timed out after 5 minutes" + exit 1 + } + + # Check data consistency after restore + echo "check data consistency after restore" + check_db_consistency || { echo "TEST: [$TEST_NAME] $encryption_mode backup and restore (including log) failed"; exit 1; } + + # Clean up after the test + drop_db || { echo "Failed to drop databases after test"; exit 1; } + + echo "TEST: [$TEST_NAME] $encryption_mode backup and restore (including log) passed" +} + +drop_test_db() { + echo "Dropping test database" + run_sql "DROP DATABASE IF EXISTS test_db;" +} + +test_plaintext() { + run_backup_restore_test "plaintext" "" +} + +test_plaintext_data_key() { + run_backup_restore_test "plaintext" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef --log.crypter.method AES256-CTR --log.crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" +} + +test_local_master_key() { + _MASTER_KEY_DIR="$TEST_DIR/$PREFIX/master_key" + mkdir -p "$_MASTER_KEY_DIR" + openssl rand -hex 32 > "$_MASTER_KEY_DIR/master.key" + + _MASTER_KEY_PATH="local:///$_MASTER_KEY_DIR/master.key" + + run_backup_restore_test "local_master_key" "--master-key-crypter-method AES256-CTR --master-key $_MASTER_KEY_PATH" + + rm -rf "$_MASTER_KEY_DIR" +} + +test_aws_kms() { + # Start LocalStack in the background + localstack start & + LOCALSTACK_PID=$! + + # Wait for LocalStack to be ready + while ! nc -z localhost 4566; do + sleep 0.1 + done + + # Replace with your actual AWS KMS key ID if using real AWS KMS + AWS_KMS_KEY_ID=$(awslocal kms create-key --query 'KeyMetadata.KeyId' --output text) + AWS_ACCESS_KEY_ID="TEST" + AWS_SECRET_ACCESS_KEY="TEST" + # default to us-east-1 in up.sh + REGION="us-east-1" + # localstack listening port + ENDPOINT="http://localhost:4566" + + AWS_KMS_URI="aws-kms:///${AWS_KMS_KEY_ID}?AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}&AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}®ION=${REGION}&ENDPOINT=${ENDPOINT}" + + run_backup_restore_test "aws_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef --master-key-crypter-method AES256-CTR --master-key $AWS_KMS_URI" + + # Stop LocalStack + kill $LOCALSTACK_PID +} + +test_gcp_kms() { + # Ensure GCP credentials are set + if [ -z "$GOOGLE_APPLICATION_CREDENTIALS" ]; then + echo "GCP credentials not set. Skipping GCP KMS test." + return + fi + + # Replace these with your actual GCP KMS details + GCP_PROJECT_ID="carbide-network-435219-q3" + GCP_LOCATION="us-west1" + GCP_KEY_RING="local-kms-testing" + GCP_KEY_NAME="kms-testing-key" + GCP_CREDENTIALS="$GOOGLE_APPLICATION_CREDENTIALS" + + GCP_KMS_URI="gcp-kms:///projects/$GCP_PROJECT_ID/locations/$GCP_LOCATION/keyRings/$GCP_KEY_RING/cryptoKeys/$GCP_KEY_NAME?AUTH=specified&CREDENTIALS=$GCP_CREDENTIALS" + + run_backup_restore_test "gcp_kms" "--crypter.method AES256-CTR --crypter.key 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef --master-key-crypter-method AES256-CTR --master-key $GCP_KMS_URI" +} + +# Run tests +# test_plaintext +# test_plaintext_data_key +# test_local_master_key +test_aws_kms +#test_gcp_kms + +echo "All encryption tests passed successfully" + diff --git a/br/tests/br_encryption/workload b/br/tests/br_encryption/workload new file mode 100644 index 0000000000000..448ca3c1a477f --- /dev/null +++ b/br/tests/br_encryption/workload @@ -0,0 +1,12 @@ +recordcount=1000 +operationcount=0 +workload=core + +readallfields=true + +readproportion=0 +updateproportion=0 +scanproportion=0 +insertproportion=0 + +requestdistribution=uniform \ No newline at end of file diff --git a/br/tests/config/tikv.toml b/br/tests/config/tikv.toml index a469b389989e7..22126549ab848 100644 --- a/br/tests/config/tikv.toml +++ b/br/tests/config/tikv.toml @@ -36,3 +36,6 @@ path = "/tmp/backup_restore_test/master-key-file" [log-backup] max-flush-interval = "50s" +[gc] +ratio-threshold = 1.1 + diff --git a/br/tests/download_integration_test_binaries.sh b/br/tests/download_integration_test_binaries.sh index ef04bf04a5e1c..48cbcaa1f3c14 100755 --- a/br/tests/download_integration_test_binaries.sh +++ b/br/tests/download_integration_test_binaries.sh @@ -103,6 +103,12 @@ function main() { download "$fake_gcs_server_url" "fake-gcs-server" "third_bin/fake-gcs-server" download "$brv_url" "brv4.0.8" "third_bin/brv4.0.8" + # Download and set up LocalStack + download "$localstack_url" "localstack.tar.gz" "tmp/localstack.tar.gz" + mkdir -p third_bin/localstack + tar -xzf tmp/localstack.tar.gz -C third_bin/localstack --strip-components=1 + ln -s third_bin/localstack/bin/localstack third_bin/localstack + chmod +x third_bin/* rm -rf tmp rm -rf third_bin/bin diff --git a/br/tests/up.sh b/br/tests/up.sh index 0dc6cebd69fe1..0750523553f84 100755 --- a/br/tests/up.sh +++ b/br/tests/up.sh @@ -137,7 +137,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ vim \ less \ jq \ - default-mysql-client + default-mysql-client \ + python3 \ + python3-pip \ + netcat + +# Install LocalStack and awscli-local +RUN pip3 install localstack awscli-local RUN mkdir -p /br/bin COPY --from=tidb-builder /tidb-server /br/bin/tidb-server @@ -160,6 +166,13 @@ WORKDIR /br # Required by tiflash ENV LD_LIBRARY_PATH=/br/bin +# LocalStack configuration +ENV SERVICES=kms +ENV DEFAULT_REGION=us-east-1 +ENV EDGE_PORT=4566 + +EXPOSE 4566 + ENTRYPOINT ["/bin/bash"] EOF diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 2069bf10e2a5a..e2e4619533f0f 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -1,9 +1,9 @@ -package(default_visibility = ["//visibility:public"]) - -load("@io_bazel_rules_go//go:def.bzl", "go_library", "nogo") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "nogo") load("//build/linter/staticcheck:def.bzl", "staticcheck_analyzers") +package(default_visibility = ["//visibility:public"]) + bool_flag( name = "with_nogo_flag", build_setting_default = False, diff --git a/pkg/lightning/backend/external/engine.go b/pkg/lightning/backend/external/engine.go index b562f3f5bc81b..5b9bdae15bbfc 100644 --- a/pkg/lightning/backend/external/engine.go +++ b/pkg/lightning/backend/external/engine.go @@ -717,7 +717,7 @@ func (m *MemoryIngestData) NewIter( } } -// GetTS implements IngestData.GetTS. +// GetTS implements IngestData.GetCurrentTsFromPd. func (m *MemoryIngestData) GetTS() uint64 { return m.ts } diff --git a/tests/_utils/run_services b/tests/_utils/run_services index 5963ef125d116..8f8a31caf7f96 100644 --- a/tests/_utils/run_services +++ b/tests/_utils/run_services @@ -47,8 +47,10 @@ stop() { } restart_services() { + echo "Restarting services" stop_services start_services + echo "Services restarted" } stop_services() {