KevinStephenson
Adding in weaviate code
b110593
raw
history blame
4.8 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package classification
import (
"fmt"
"time"
libfilters "github.com/weaviate/weaviate/entities/filters"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/modulecapabilities"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/search"
libclassification "github.com/weaviate/weaviate/usecases/classification"
)
type tfidfScorer interface {
GetAllTerms(docIndex int) []TermWithTfIdf
}
type contextualPreparationContext struct {
tfidf map[string]tfidfScorer // map[basedOnProp]scorer
targets map[string]search.Results // map[classifyProp]targets
}
func (c *Classifier) prepareContextualClassification(schema schema.Schema,
vectorRepo modulecapabilities.VectorClassSearchRepo, params models.Classification,
filters libclassification.Filters, items search.Results,
) (contextualPreparationContext, error) {
p := &contextualPreparer{
inputItems: items,
params: params,
repo: vectorRepo,
filters: filters,
schema: schema,
}
return p.do()
}
type contextualPreparer struct {
inputItems []search.Result
params models.Classification
repo modulecapabilities.VectorClassSearchRepo
filters libclassification.Filters
schema schema.Schema
}
func (p *contextualPreparer) do() (contextualPreparationContext, error) {
pctx := contextualPreparationContext{}
targets, err := p.findTargetsForProps()
if err != nil {
return pctx, err
}
pctx.targets = targets
tfidf, err := p.calculateTfidfForProps()
if err != nil {
return pctx, err
}
pctx.tfidf = tfidf
return pctx, nil
}
func (p *contextualPreparer) calculateTfidfForProps() (map[string]tfidfScorer, error) {
props := map[string]tfidfScorer{}
for _, basedOnName := range p.params.BasedOnProperties {
calc := NewTfIdfCalculator(len(p.inputItems))
for _, obj := range p.inputItems {
schemaMap, ok := obj.Schema.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("no or incorrect schema map present on source object '%s': %T", obj.ID, obj.Schema)
}
var docCorpus string
if basedOn, ok := schemaMap[basedOnName]; ok {
basedOnString, ok := basedOn.(string)
if !ok {
return nil, fmt.Errorf("property '%s' present on %s, but of unexpected type: want string, got %T",
basedOnName, obj.ID, basedOn)
}
docCorpus = basedOnString
}
calc.AddDoc(docCorpus)
}
calc.Calculate()
props[basedOnName] = calc
}
return props, nil
}
func (p *contextualPreparer) findTargetsForProps() (map[string]search.Results, error) {
targetsMap := map[string]search.Results{}
for _, targetProp := range p.params.ClassifyProperties {
class, err := p.classAndKindOfTarget(targetProp)
if err != nil {
return nil, fmt.Errorf("target prop '%s': find target class: %v", targetProp, err)
}
targets, err := p.findTargets(class)
if err != nil {
return nil, fmt.Errorf("target prop '%s': find targets: %v", targetProp, err)
}
targetsMap[targetProp] = targets
}
return targetsMap, nil
}
func (p *contextualPreparer) findTargets(class schema.ClassName) (search.Results, error) {
ctx, cancel := contextWithTimeout(30 * time.Second)
defer cancel()
res, err := p.repo.VectorClassSearch(ctx, modulecapabilities.VectorClassSearchParams{
Filters: p.filters.Target(),
Pagination: &libfilters.Pagination{
Limit: 10000,
},
ClassName: string(class),
Properties: []string{"id"},
})
if err != nil {
return nil, fmt.Errorf("search closest target: %v", err)
}
if len(res) == 0 {
return nil, fmt.Errorf("no potential targets found of class '%s'", class)
}
return res, nil
}
func (p *contextualPreparer) classAndKindOfTarget(propName string) (schema.ClassName, error) {
prop, err := p.schema.GetProperty(schema.ClassName(p.params.Class), schema.PropertyName(propName))
if err != nil {
return "", fmt.Errorf("get target prop '%s': %v", propName, err)
}
dataType, err := p.schema.FindPropertyDataType(prop.DataType)
if err != nil {
return "", fmt.Errorf("extract dataType of prop '%s': %v", propName, err)
}
// we have passed validation, so it is safe to assume that this is a ref prop
targetClasses := dataType.Classes()
// len=1 is guaranteed from validation
targetClass := targetClasses[0]
return targetClass, nil
}