Spaces:
Running
Running
File size: 3,661 Bytes
b110593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package classification
import (
"context"
"sync"
"sync/atomic"
"github.com/pkg/errors"
"github.com/weaviate/weaviate/entities/errorcompounder"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/search"
)
type runWorker struct {
jobs []search.Result
successCount *int64
errorCount *int64
ec *errorcompounder.SafeErrorCompounder
classify ClassifyItemFn
batchWriter Writer
params models.Classification
filters Filters
id int
workerCount int
}
func (w *runWorker) addJob(job search.Result) {
w.jobs = append(w.jobs, job)
}
func (w *runWorker) work(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
for i, item := range w.jobs {
// check if the whole classification operation has been cancelled
// if yes, then abort the classifier worker
if err := ctx.Err(); err != nil {
w.ec.Add(err)
atomic.AddInt64(w.errorCount, 1)
break
}
originalIndex := (i * w.workerCount) + w.id
err := w.classify(item, originalIndex, w.params, w.filters, w.batchWriter)
if err != nil {
w.ec.Add(err)
atomic.AddInt64(w.errorCount, 1)
} else {
atomic.AddInt64(w.successCount, 1)
}
}
}
func newRunWorker(id int, workerCount int, rw *runWorkers) *runWorker {
return &runWorker{
successCount: rw.successCount,
errorCount: rw.errorCount,
ec: rw.ec,
params: rw.params,
filters: rw.filters,
classify: rw.classify,
batchWriter: rw.batchWriter,
id: id,
workerCount: workerCount,
}
}
type runWorkers struct {
workers []*runWorker
successCount *int64
errorCount *int64
ec *errorcompounder.SafeErrorCompounder
classify ClassifyItemFn
params models.Classification
filters Filters
batchWriter Writer
}
func newRunWorkers(amount int, classifyFn ClassifyItemFn,
params models.Classification, filters Filters, vectorRepo vectorRepo,
) *runWorkers {
var successCount int64
var errorCount int64
rw := &runWorkers{
workers: make([]*runWorker, amount),
successCount: &successCount,
errorCount: &errorCount,
ec: &errorcompounder.SafeErrorCompounder{},
classify: classifyFn,
params: params,
filters: filters,
batchWriter: newBatchWriter(vectorRepo),
}
for i := 0; i < amount; i++ {
rw.workers[i] = newRunWorker(i, amount, rw)
}
return rw
}
func (ws *runWorkers) addJobs(jobs []search.Result) {
for i, job := range jobs {
ws.workers[i%len(ws.workers)].addJob(job)
}
}
func (ws *runWorkers) work(ctx context.Context) runWorkerResults {
ws.batchWriter.Start()
wg := &sync.WaitGroup{}
for _, worker := range ws.workers {
wg.Add(1)
go worker.work(ctx, wg)
}
wg.Wait()
res := ws.batchWriter.Stop()
if res.SuccessCount() != *ws.successCount || res.ErrorCount() != *ws.errorCount {
ws.ec.Add(errors.New("data save error"))
}
if res.Err() != nil {
ws.ec.Add(res.Err())
}
return runWorkerResults{
successCount: *ws.successCount,
errorCount: *ws.errorCount,
err: ws.ec.ToError(),
}
}
type runWorkerResults struct {
successCount int64
errorCount int64
err error
}
|