File size: 3,661 Bytes
b110593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
//                           _       _
// __      _____  __ ___   ___  __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
//  CONTACT: [email protected]
//

package classification

import (
	"context"
	"sync"
	"sync/atomic"

	"github.com/pkg/errors"
	"github.com/weaviate/weaviate/entities/errorcompounder"
	"github.com/weaviate/weaviate/entities/models"
	"github.com/weaviate/weaviate/entities/search"
)

type runWorker struct {
	jobs         []search.Result
	successCount *int64
	errorCount   *int64
	ec           *errorcompounder.SafeErrorCompounder
	classify     ClassifyItemFn
	batchWriter  Writer
	params       models.Classification
	filters      Filters
	id           int
	workerCount  int
}

func (w *runWorker) addJob(job search.Result) {
	w.jobs = append(w.jobs, job)
}

func (w *runWorker) work(ctx context.Context, wg *sync.WaitGroup) {
	defer wg.Done()

	for i, item := range w.jobs {
		// check if the whole classification operation has been cancelled
		// if yes, then abort the classifier worker
		if err := ctx.Err(); err != nil {
			w.ec.Add(err)
			atomic.AddInt64(w.errorCount, 1)
			break
		}
		originalIndex := (i * w.workerCount) + w.id
		err := w.classify(item, originalIndex, w.params, w.filters, w.batchWriter)
		if err != nil {
			w.ec.Add(err)
			atomic.AddInt64(w.errorCount, 1)
		} else {
			atomic.AddInt64(w.successCount, 1)
		}
	}
}

func newRunWorker(id int, workerCount int, rw *runWorkers) *runWorker {
	return &runWorker{
		successCount: rw.successCount,
		errorCount:   rw.errorCount,
		ec:           rw.ec,
		params:       rw.params,
		filters:      rw.filters,
		classify:     rw.classify,
		batchWriter:  rw.batchWriter,
		id:           id,
		workerCount:  workerCount,
	}
}

type runWorkers struct {
	workers      []*runWorker
	successCount *int64
	errorCount   *int64
	ec           *errorcompounder.SafeErrorCompounder
	classify     ClassifyItemFn
	params       models.Classification
	filters      Filters
	batchWriter  Writer
}

func newRunWorkers(amount int, classifyFn ClassifyItemFn,

	params models.Classification, filters Filters, vectorRepo vectorRepo,

) *runWorkers {
	var successCount int64
	var errorCount int64

	rw := &runWorkers{
		workers:      make([]*runWorker, amount),
		successCount: &successCount,
		errorCount:   &errorCount,
		ec:           &errorcompounder.SafeErrorCompounder{},
		classify:     classifyFn,
		params:       params,
		filters:      filters,
		batchWriter:  newBatchWriter(vectorRepo),
	}

	for i := 0; i < amount; i++ {
		rw.workers[i] = newRunWorker(i, amount, rw)
	}

	return rw
}

func (ws *runWorkers) addJobs(jobs []search.Result) {
	for i, job := range jobs {
		ws.workers[i%len(ws.workers)].addJob(job)
	}
}

func (ws *runWorkers) work(ctx context.Context) runWorkerResults {
	ws.batchWriter.Start()

	wg := &sync.WaitGroup{}
	for _, worker := range ws.workers {
		wg.Add(1)
		go worker.work(ctx, wg)
	}

	wg.Wait()

	res := ws.batchWriter.Stop()

	if res.SuccessCount() != *ws.successCount || res.ErrorCount() != *ws.errorCount {
		ws.ec.Add(errors.New("data save error"))
	}

	if res.Err() != nil {
		ws.ec.Add(res.Err())
	}

	return runWorkerResults{
		successCount: *ws.successCount,
		errorCount:   *ws.errorCount,
		err:          ws.ec.ToError(),
	}
}

type runWorkerResults struct {
	successCount int64
	errorCount   int64
	err          error
}