Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package modstgazure | |
import ( | |
"bytes" | |
"context" | |
"fmt" | |
"io" | |
"os" | |
"path" | |
"strings" | |
"time" | |
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" | |
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to" | |
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" | |
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" | |
"github.com/pkg/errors" | |
"github.com/weaviate/weaviate/entities/backup" | |
) | |
type azureClient struct { | |
client *azblob.Client | |
config clientConfig | |
serviceURL string | |
dataPath string | |
} | |
func newClient(ctx context.Context, config *clientConfig, dataPath string) (*azureClient, error) { | |
connectionString := os.Getenv("AZURE_STORAGE_CONNECTION_STRING") | |
if connectionString != "" { | |
client, err := azblob.NewClientFromConnectionString(connectionString, nil) | |
if err != nil { | |
return nil, errors.Wrap(err, "create client using connection string") | |
} | |
serviceURL := "" | |
connectionStrings := strings.Split(connectionString, ";") | |
for _, str := range connectionStrings { | |
if strings.HasPrefix(str, "BlobEndpoint") { | |
blobEndpoint := strings.Split(str, "=") | |
if len(blobEndpoint) > 1 { | |
serviceURL = blobEndpoint[1] | |
if !strings.HasSuffix(serviceURL, "/") { | |
serviceURL = serviceURL + "/" | |
} | |
} | |
} | |
} | |
return &azureClient{client, *config, serviceURL, dataPath}, nil | |
} | |
// Your account name and key can be obtained from the Azure Portal. | |
accountName := os.Getenv("AZURE_STORAGE_ACCOUNT") | |
accountKey := os.Getenv("AZURE_STORAGE_KEY") | |
if accountName == "" { | |
return nil, errors.New("AZURE_STORAGE_ACCOUNT must be set") | |
} | |
// The service URL for blob endpoints is usually in the form: http(s)://<account>.blob.core.windows.net/ | |
serviceURL := fmt.Sprintf("https://%s.blob.core.windows.net/", accountName) | |
if accountKey != "" { | |
cred, err := azblob.NewSharedKeyCredential(accountName, accountKey) | |
if err != nil { | |
return nil, err | |
} | |
client, err := azblob.NewClientWithSharedKeyCredential(serviceURL, cred, nil) | |
if err != nil { | |
return nil, err | |
} | |
return &azureClient{client, *config, serviceURL, dataPath}, nil | |
} | |
options := &azblob.ClientOptions{ | |
ClientOptions: policy.ClientOptions{ | |
Retry: policy.RetryOptions{ | |
MaxRetries: 3, | |
RetryDelay: 4 * time.Second, | |
MaxRetryDelay: 120 * time.Second, | |
}, | |
}, | |
} | |
client, err := azblob.NewClientWithNoCredential(serviceURL, options) | |
if err != nil { | |
return nil, err | |
} | |
return &azureClient{client, *config, serviceURL, dataPath}, nil | |
} | |
func (a *azureClient) HomeDir(backupID string) string { | |
return a.serviceURL + path.Join(a.config.Container, a.makeObjectName(backupID)) | |
} | |
func (a *azureClient) makeObjectName(parts ...string) string { | |
base := path.Join(parts...) | |
return path.Join(a.config.BackupPath, base) | |
} | |
func (a *azureClient) GetObject(ctx context.Context, backupID, key string) ([]byte, error) { | |
objectName := a.makeObjectName(backupID, key) | |
blobDownloadResponse, err := a.client.DownloadStream(ctx, a.config.Container, objectName, nil) | |
if err != nil { | |
if bloberror.HasCode(err, bloberror.BlobNotFound) { | |
return nil, backup.NewErrNotFound(errors.Wrapf(err, "get object '%s'", objectName)) | |
} | |
return nil, backup.NewErrInternal(errors.Wrapf(err, "download stream for object '%s'", objectName)) | |
} | |
reader := blobDownloadResponse.Body | |
downloadData, err := io.ReadAll(reader) | |
errClose := reader.Close() | |
if errClose != nil { | |
return nil, backup.NewErrInternal(errors.Wrapf(errClose, "close stream for object '%s'", objectName)) | |
} | |
if err != nil { | |
return nil, backup.NewErrInternal(errors.Wrapf(err, "read stream for object '%s'", objectName)) | |
} | |
return downloadData, nil | |
} | |
func (a *azureClient) PutFile(ctx context.Context, backupID, key, srcPath string) error { | |
filePath := path.Join(a.dataPath, srcPath) | |
file, err := os.Open(filePath) | |
if err != nil { | |
return backup.NewErrInternal(errors.Wrapf(err, "open file: %q", filePath)) | |
} | |
defer file.Close() | |
objectName := a.makeObjectName(backupID, key) | |
_, err = a.client.UploadFile(ctx, | |
a.config.Container, | |
objectName, | |
file, | |
&azblob.UploadFileOptions{ | |
Metadata: map[string]*string{"backupid": to.Ptr(backupID)}, | |
Tags: map[string]string{"backupid": backupID}, | |
}) | |
if err != nil { | |
return backup.NewErrInternal(errors.Wrapf(err, "upload file for object '%s'", objectName)) | |
} | |
return nil | |
} | |
func (a *azureClient) PutObject(ctx context.Context, backupID, key string, data []byte) error { | |
objectName := a.makeObjectName(backupID, key) | |
reader := bytes.NewReader(data) | |
_, err := a.client.UploadStream(ctx, | |
a.config.Container, | |
objectName, | |
reader, | |
&azblob.UploadStreamOptions{ | |
Metadata: map[string]*string{"backupid": to.Ptr(backupID)}, | |
Tags: map[string]string{"backupid": backupID}, | |
}) | |
if err != nil { | |
return backup.NewErrInternal(errors.Wrapf(err, "upload stream for object '%s'", objectName)) | |
} | |
return nil | |
} | |
func (a *azureClient) Initialize(ctx context.Context, backupID string) error { | |
key := "access-check" | |
if err := a.PutObject(ctx, backupID, key, []byte("")); err != nil { | |
return errors.Wrap(err, "failed to access-check Azure backup module") | |
} | |
objectName := a.makeObjectName(backupID, key) | |
if _, err := a.client.DeleteBlob(ctx, a.config.Container, objectName, nil); err != nil { | |
return errors.Wrap(err, "failed to remove access-check Azure backup module") | |
} | |
return nil | |
} | |
func (a *azureClient) WriteToFile(ctx context.Context, backupID, key, destPath string) error { | |
dir := path.Dir(destPath) | |
if err := os.MkdirAll(dir, os.ModePerm); err != nil { | |
return errors.Wrapf(err, "make dir '%s'", dir) | |
} | |
file, err := os.Create(destPath) | |
if err != nil { | |
return backup.NewErrInternal(errors.Wrapf(err, "create file: %q", destPath)) | |
} | |
defer file.Close() | |
objectName := a.makeObjectName(backupID, key) | |
_, err = a.client.DownloadFile(ctx, a.config.Container, objectName, file, nil) | |
if err != nil { | |
if bloberror.HasCode(err, bloberror.BlobNotFound) { | |
return backup.NewErrNotFound(errors.Wrapf(err, "get object '%s'", objectName)) | |
} | |
return backup.NewErrInternal(errors.Wrapf(err, "download file for object '%s'", objectName)) | |
} | |
return nil | |
} | |
func (a *azureClient) Write(ctx context.Context, backupID, key string, r io.ReadCloser) (written int64, err error) { | |
path := a.makeObjectName(backupID, key) | |
reader := &reader{src: r} | |
defer func() { | |
r.Close() | |
written = int64(reader.count) | |
}() | |
if _, err = a.client.UploadStream(ctx, | |
a.config.Container, | |
path, | |
reader, | |
&azblob.UploadStreamOptions{ | |
Metadata: map[string]*string{"backupid": to.Ptr(backupID)}, | |
Tags: map[string]string{"backupid": backupID}, | |
}); err != nil { | |
err = fmt.Errorf("upload stream %q: %w", path, err) | |
} | |
return | |
} | |
func (a *azureClient) Read(ctx context.Context, backupID, key string, w io.WriteCloser) (int64, error) { | |
defer w.Close() | |
path := a.makeObjectName(backupID, key) | |
resp, err := a.client.DownloadStream(ctx, a.config.Container, path, nil) | |
if err != nil { | |
err = fmt.Errorf("find object %q: %w", path, err) | |
if bloberror.HasCode(err, bloberror.BlobNotFound) { | |
err = backup.NewErrNotFound(err) | |
} | |
return 0, err | |
} | |
defer resp.Body.Close() | |
read, err := io.Copy(w, resp.Body) | |
if err != nil { | |
return read, fmt.Errorf("io.copy %q: %w", path, err) | |
} | |
return read, nil | |
} | |
func (a *azureClient) SourceDataPath() string { | |
return a.dataPath | |
} | |
// reader is a wrapper used to count number of written bytes | |
// Unlike GCS and S3 Azure Interface does not provide this information | |
type reader struct { | |
src io.Reader | |
count int | |
} | |
func (r *reader) Read(p []byte) (n int, err error) { | |
n, err = r.src.Read(p) | |
r.count += n | |
return | |
} | |