SemanticSearchPOC / usecases /objects /batch_references_add.go
KevinStephenson
Adding in weaviate code
b110593
raw
history blame
9.3 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package objects
import (
"context"
"fmt"
"strings"
"sync"
"github.com/go-openapi/strfmt"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/schema/crossref"
)
// AddReferences Class Instances in batch to the connected DB
func (b *BatchManager) AddReferences(ctx context.Context, principal *models.Principal,
refs []*models.BatchReference, repl *additional.ReplicationProperties,
) (BatchReferences, error) {
err := b.authorizer.Authorize(principal, "update", "batch/*")
if err != nil {
return nil, err
}
unlock, err := b.locks.LockSchema()
if err != nil {
return nil, NewErrInternal("could not acquire lock: %v", err)
}
defer unlock()
b.metrics.BatchRefInc()
defer b.metrics.BatchRefDec()
return b.addReferences(ctx, principal, refs, repl)
}
func (b *BatchManager) addReferences(ctx context.Context, principal *models.Principal,
refs []*models.BatchReference, repl *additional.ReplicationProperties,
) (BatchReferences, error) {
if err := b.validateReferenceForm(refs); err != nil {
return nil, NewErrInvalidUserInput("invalid params: %v", err)
}
batchReferences := b.validateReferencesConcurrently(ctx, principal, refs)
if err := b.autodetectToClass(ctx, principal, batchReferences); err != nil {
return nil, err
}
if res, err := b.vectorRepo.AddBatchReferences(ctx, batchReferences, repl); err != nil {
return nil, NewErrInternal("could not add batch request to connector: %v", err)
} else {
return res, nil
}
}
func (b *BatchManager) validateReferenceForm(refs []*models.BatchReference) error {
if len(refs) == 0 {
return fmt.Errorf("length cannot be 0, need at least one reference for batching")
}
return nil
}
func (b *BatchManager) validateReferencesConcurrently(ctx context.Context,
principal *models.Principal, refs []*models.BatchReference,
) BatchReferences {
c := make(chan BatchReference, len(refs))
wg := new(sync.WaitGroup)
// Generate a goroutine for each separate request
for i, ref := range refs {
wg.Add(1)
go b.validateReference(ctx, principal, wg, ref, i, &c)
}
wg.Wait()
close(c)
return referencesChanToSlice(c)
}
// autodetectToClass gets the class name of the referenced class through the schema definition
func (b *BatchManager) autodetectToClass(ctx context.Context,
principal *models.Principal, batchReferences BatchReferences,
) error {
classPropTarget := make(map[string]string)
scheme, err := b.schemaManager.GetSchema(principal)
if err != nil {
return NewErrInvalidUserInput("get schema: %v", err)
}
for i, ref := range batchReferences {
// get to class from property datatype
if ref.To.Class != "" || ref.Err != nil {
continue
}
className := string(ref.From.Class)
propName := schema.LowercaseFirstLetter(string(ref.From.Property))
target, ok := classPropTarget[className+propName]
if !ok {
class := scheme.FindClassByName(ref.From.Class)
if class == nil {
batchReferences[i].Err = fmt.Errorf("class %s does not exist", className)
continue
}
prop, err := schema.GetPropertyByName(class, propName)
if err != nil {
batchReferences[i].Err = fmt.Errorf("property %s does not exist for class %s", propName, className)
continue
}
if len(prop.DataType) > 1 {
continue // can't auto-detect for multi-target
}
target = prop.DataType[0] // datatype is the name of the class that is referenced
classPropTarget[className+propName] = target
}
batchReferences[i].To.Class = target
}
return nil
}
func (b *BatchManager) validateReference(ctx context.Context, principal *models.Principal,
wg *sync.WaitGroup, ref *models.BatchReference, i int, resultsC *chan BatchReference,
) {
defer wg.Done()
var validateErrors []error
source, err := crossref.ParseSource(string(ref.From))
if err != nil {
validateErrors = append(validateErrors, err)
} else if !source.Local {
validateErrors = append(validateErrors, fmt.Errorf("source class must always point to the local peer, but got %s",
source.PeerName))
}
target, err := crossref.Parse(string(ref.To))
if err != nil {
validateErrors = append(validateErrors, err)
} else if !target.Local {
validateErrors = append(validateErrors, fmt.Errorf("importing network references in batch is not possible. "+
"Please perform a regular non-batch import for network references, got peer %s",
target.PeerName))
}
// target id must be lowercase
target.TargetID = strfmt.UUID(strings.ToLower(target.TargetID.String()))
if len(validateErrors) == 0 {
err = nil
} else {
err = joinErrors(validateErrors)
}
if err == nil && shouldValidateMultiTenantRef(ref.Tenant, source, target) {
// can only validate multi-tenancy when everything above succeeds
err = validateReferenceMultiTenancy(ctx, principal,
b.schemaManager, b.vectorRepo, source, target, ref.Tenant)
}
*resultsC <- BatchReference{
From: source,
To: target,
Err: err,
OriginalIndex: i,
Tenant: ref.Tenant,
}
}
func validateReferenceMultiTenancy(ctx context.Context,
principal *models.Principal, schemaManager schemaManager,
repo VectorRepo, source *crossref.RefSource, target *crossref.Ref,
tenant string,
) error {
if source == nil || target == nil {
return fmt.Errorf("can't validate multi-tenancy for nil refs")
}
sourceClass, targetClass, err := getReferenceClasses(
ctx, principal, schemaManager, source.Class.String(), target.Class)
if err != nil {
return err
}
sourceEnabled := schema.MultiTenancyEnabled(sourceClass)
targetEnabled := schema.MultiTenancyEnabled(targetClass)
if !sourceEnabled && targetEnabled {
return fmt.Errorf("invalid reference: cannot reference a multi-tenant " +
"enabled class from a non multi-tenant enabled class")
}
if sourceEnabled && !targetEnabled {
if err := validateTenantRefObject(ctx, repo, sourceClass, source.TargetID, tenant); err != nil {
return fmt.Errorf("source: %w", err)
}
if err := validateTenantRefObject(ctx, repo, targetClass, target.TargetID, ""); err != nil {
return fmt.Errorf("target: %w", err)
}
}
// if both classes have MT enabled but different tenant keys,
// no cross-tenant references can be made
if sourceEnabled && targetEnabled {
if err := validateTenantRefObject(ctx, repo, sourceClass, source.TargetID, tenant); err != nil {
return fmt.Errorf("source: %w", err)
}
if err := validateTenantRefObject(ctx, repo, targetClass, target.TargetID, tenant); err != nil {
return fmt.Errorf("target: %w", err)
}
}
return nil
}
func getReferenceClasses(ctx context.Context,
principal *models.Principal, schemaManager schemaManager,
classFrom, classTo string,
) (sourceClass *models.Class, targetClass *models.Class, err error) {
if classFrom == "" || classTo == "" {
err = fmt.Errorf("references involving a multi-tenancy enabled class " +
"requires class name in the beacon url")
return
}
sourceClass, err = schemaManager.GetClass(ctx, principal, classFrom)
if err != nil {
err = fmt.Errorf("get source class %q: %w", classFrom, err)
return
}
if sourceClass == nil {
err = fmt.Errorf("source class %q not found in schema", classFrom)
return
}
targetClass, err = schemaManager.GetClass(ctx, principal, classTo)
if err != nil {
err = fmt.Errorf("get target class %q: %w", classTo, err)
return
}
if targetClass == nil {
err = fmt.Errorf("target class %q not found in schema", classTo)
return
}
return
}
// validateTenantRefObject ensures that object exist for the given tenant key.
// This asserts that no cross-tenant references can occur,
// as a class+id which belongs to a different
// tenant will not be found in the searched tenant shard
func validateTenantRefObject(ctx context.Context, repo VectorRepo,
class *models.Class, ID strfmt.UUID, tenant string,
) error {
exists, err := repo.Exists(ctx, class.Class, ID, nil, tenant)
if err != nil {
return fmt.Errorf("get object %s/%s: %w", class.Class, ID, err)
}
if !exists {
return fmt.Errorf("object %s/%s not found for tenant %q", class.Class, ID, tenant)
}
return nil
}
func referencesChanToSlice(c chan BatchReference) BatchReferences {
result := make([]BatchReference, len(c))
for reference := range c {
result[reference.OriginalIndex] = reference
}
return result
}
func joinErrors(errors []error) error {
errorStrings := []string{}
for _, err := range errors {
if err != nil {
errorStrings = append(errorStrings, err.Error())
}
}
if len(errorStrings) == 0 {
return nil
}
return fmt.Errorf(strings.Join(errorStrings, ", "))
}