Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package classification | |
import ( | |
"context" | |
"fmt" | |
"runtime" | |
"time" | |
"github.com/go-openapi/strfmt" | |
"github.com/pkg/errors" | |
"github.com/sirupsen/logrus" | |
"github.com/weaviate/weaviate/entities/additional" | |
"github.com/weaviate/weaviate/entities/models" | |
"github.com/weaviate/weaviate/entities/modulecapabilities" | |
"github.com/weaviate/weaviate/entities/schema" | |
"github.com/weaviate/weaviate/entities/search" | |
) | |
// the contents of this file deal with anything about a classification run | |
// which is generic, whereas the individual classify_item fns can be found in | |
// the respective files such as classifier_run_knn.go | |
func (c *Classifier) run(params models.Classification, | |
filters Filters, | |
) { | |
ctx, cancel := contextWithTimeout(30 * time.Minute) | |
defer cancel() | |
go c.monitorClassification(ctx, cancel, schema.ClassName(params.Class)) | |
c.logBegin(params, filters) | |
unclassifiedItems, err := c.vectorRepo.GetUnclassified(ctx, | |
params.Class, params.ClassifyProperties, filters.Source()) | |
if err != nil { | |
c.failRunWithError(params, errors.Wrap(err, "retrieve to-be-classifieds")) | |
return | |
} | |
if len(unclassifiedItems) == 0 { | |
c.failRunWithError(params, | |
fmt.Errorf("no classes to be classified - did you run a previous classification already?")) | |
return | |
} | |
c.logItemsFetched(params, unclassifiedItems) | |
classifyItem, err := c.prepareRun(params, filters, unclassifiedItems) | |
if err != nil { | |
c.failRunWithError(params, errors.Wrap(err, "prepare classification")) | |
return | |
} | |
params, err = c.runItems(ctx, classifyItem, params, filters, unclassifiedItems) | |
if err != nil { | |
c.failRunWithError(params, err) | |
return | |
} | |
c.succeedRun(params) | |
} | |
func (c *Classifier) monitorClassification(ctx context.Context, cancelFn context.CancelFunc, | |
className schema.ClassName, | |
) { | |
ticker := time.NewTicker(100 * time.Millisecond) | |
defer ticker.Stop() | |
for { | |
select { | |
case <-ctx.Done(): | |
return | |
case <-ticker.C: | |
schema := c.schemaGetter.GetSchemaSkipAuth() | |
class := schema.FindClassByName(className) | |
if class == nil { | |
cancelFn() | |
return | |
} | |
} | |
} | |
} | |
func (c *Classifier) prepareRun(params models.Classification, filters Filters, | |
unclassifiedItems []search.Result, | |
) (ClassifyItemFn, error) { | |
c.logBeginPreparation(params) | |
defer c.logFinishPreparation(params) | |
if params.Type == "knn" { | |
return c.classifyItemUsingKNN, nil | |
} | |
if params.Type == "zeroshot" { | |
return c.classifyItemUsingZeroShot, nil | |
} | |
if c.modulesProvider != nil { | |
classifyItemFn, err := c.modulesProvider.GetClassificationFn(params.Class, params.Type, | |
c.getClassifyParams(params, filters, unclassifiedItems)) | |
if err != nil { | |
return nil, errors.Wrapf(err, "cannot classify") | |
} | |
if classifyItemFn == nil { | |
return nil, errors.Errorf("cannot classify: empty classifier for %s", params.Type) | |
} | |
classification := &moduleClassification{classifyItemFn} | |
return classification.classifyFn, nil | |
} | |
return nil, errors.Errorf("unsupported type '%s', have no classify item fn for this", params.Type) | |
} | |
func (c *Classifier) getClassifyParams(params models.Classification, | |
filters Filters, unclassifiedItems []search.Result, | |
) modulecapabilities.ClassifyParams { | |
return modulecapabilities.ClassifyParams{ | |
Schema: c.schemaGetter.GetSchemaSkipAuth(), | |
Params: params, | |
Filters: filters, | |
UnclassifiedItems: unclassifiedItems, | |
VectorRepo: c.vectorClassSearchRepo, | |
} | |
} | |
// runItems splits the job list into batches that can be worked on parallelly | |
// depending on the available CPUs | |
func (c *Classifier) runItems(ctx context.Context, classifyItem ClassifyItemFn, params models.Classification, filters Filters, | |
items []search.Result, | |
) (models.Classification, error) { | |
workerCount := runtime.GOMAXPROCS(0) | |
if len(items) < workerCount { | |
workerCount = len(items) | |
} | |
workers := newRunWorkers(workerCount, classifyItem, params, filters, c.vectorRepo) | |
workers.addJobs(items) | |
res := workers.work(ctx) | |
params.Meta.Completed = strfmt.DateTime(time.Now()) | |
params.Meta.CountSucceeded = res.successCount | |
params.Meta.CountFailed = res.errorCount | |
params.Meta.Count = res.successCount + res.errorCount | |
return params, res.err | |
} | |
func (c *Classifier) succeedRun(params models.Classification) { | |
params.Status = models.ClassificationStatusCompleted | |
ctx, cancel := contextWithTimeout(2 * time.Second) | |
defer cancel() | |
err := c.repo.Put(ctx, params) | |
if err != nil { | |
c.logExecutionError("store succeeded run", err, params) | |
} | |
c.logFinish(params) | |
} | |
func (c *Classifier) failRunWithError(params models.Classification, err error) { | |
params.Status = models.ClassificationStatusFailed | |
params.Error = fmt.Sprintf("classification failed: %v", err) | |
err = c.repo.Put(context.Background(), params) | |
if err != nil { | |
c.logExecutionError("store failed run", err, params) | |
} | |
c.logFinish(params) | |
} | |
func (c *Classifier) extendItemWithObjectMeta(item *search.Result, | |
params models.Classification, classified []string, | |
) { | |
// don't overwrite existing non-classification meta info | |
if item.AdditionalProperties == nil { | |
item.AdditionalProperties = models.AdditionalProperties{} | |
} | |
item.AdditionalProperties["classification"] = additional.Classification{ | |
ID: params.ID, | |
Scope: params.ClassifyProperties, | |
ClassifiedFields: classified, | |
Completed: strfmt.DateTime(time.Now()), | |
} | |
} | |
func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) { | |
return context.WithTimeout(context.Background(), d) | |
} | |
// Logging helper methods | |
func (c *Classifier) logBase(params models.Classification, event string) *logrus.Entry { | |
return c.logger.WithField("action", "classification_run"). | |
WithField("event", event). | |
WithField("params", params). | |
WithField("classification_type", params.Type) | |
} | |
func (c *Classifier) logBegin(params models.Classification, filters Filters) { | |
c.logBase(params, "classification_begin"). | |
WithField("filters", filters). | |
Debug("classification started") | |
} | |
func (c *Classifier) logFinish(params models.Classification) { | |
c.logBase(params, "classification_finish"). | |
WithField("status", params.Status). | |
Debug("classification finished") | |
} | |
func (c *Classifier) logItemsFetched(params models.Classification, items search.Results) { | |
c.logBase(params, "classification_items_fetched"). | |
WithField("status", params.Status). | |
WithField("item_count", len(items)). | |
Debug("fetched source items") | |
} | |
func (c *Classifier) logBeginPreparation(params models.Classification) { | |
c.logBase(params, "classification_preparation_begin"). | |
Debug("begin run preparation") | |
} | |
func (c *Classifier) logFinishPreparation(params models.Classification) { | |
c.logBase(params, "classification_preparation_finish"). | |
Debug("finish run preparation") | |
} | |
func (c *Classifier) logExecutionError(event string, err error, params models.Classification) { | |
c.logBase(params, event). | |
WithError(err). | |
Error("classification execution failure") | |
} | |