Spaces:
Sleeping
Sleeping
| // _ _ | |
| // __ _____ __ ___ ___ __ _| |_ ___ | |
| // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
| // \ V V / __/ (_| |\ V /| | (_| | || __/ | |
| // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
| // | |
| // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
| // | |
| // CONTACT: [email protected] | |
| // | |
| package answer | |
| import ( | |
| "context" | |
| "errors" | |
| "sort" | |
| "strings" | |
| "github.com/weaviate/weaviate/entities/models" | |
| "github.com/weaviate/weaviate/entities/search" | |
| qnamodels "github.com/weaviate/weaviate/modules/qna-transformers/additional/models" | |
| "github.com/weaviate/weaviate/modules/qna-transformers/ent" | |
| ) | |
| func (p *AnswerProvider) findAnswer(ctx context.Context, | |
| in []search.Result, params *Params, limit *int, | |
| argumentModuleParams map[string]interface{}, | |
| ) ([]search.Result, error) { | |
| if len(in) > 0 { | |
| question := p.paramsHelper.GetQuestion(argumentModuleParams["ask"]) | |
| if question == "" { | |
| return in, errors.New("empty question") | |
| } | |
| properties := p.paramsHelper.GetProperties(argumentModuleParams["ask"]) | |
| for i := range in { | |
| textProperties := map[string]string{} | |
| schema := in[i].Object().Properties.(map[string]interface{}) | |
| for property, value := range schema { | |
| if p.containsProperty(property, properties) { | |
| if valueString, ok := value.(string); ok && len(valueString) > 0 { | |
| textProperties[property] = valueString | |
| } | |
| } | |
| } | |
| texts := []string{} | |
| for _, value := range textProperties { | |
| texts = append(texts, value) | |
| } | |
| text := strings.Join(texts, " ") | |
| if len(text) == 0 { | |
| return in, errors.New("empty content") | |
| } | |
| answer, err := p.qna.Answer(ctx, text, question) | |
| if err != nil { | |
| return in, err | |
| } | |
| ap := in[i].AdditionalProperties | |
| if ap == nil { | |
| ap = models.AdditionalProperties{} | |
| } | |
| if answerMeetsSimilarityThreshold(argumentModuleParams["ask"], p.paramsHelper, answer) { | |
| propertyName, startPos, endPos := p.findProperty(answer.Answer, textProperties) | |
| ap["answer"] = &qnamodels.Answer{ | |
| Result: answer.Answer, | |
| Property: propertyName, | |
| StartPosition: startPos, | |
| EndPosition: endPos, | |
| Certainty: answer.Certainty, | |
| Distance: answer.Distance, | |
| HasAnswer: answer.Answer != nil, | |
| } | |
| } else { | |
| ap["answer"] = &qnamodels.Answer{ | |
| HasAnswer: false, | |
| } | |
| } | |
| in[i].AdditionalProperties = ap | |
| } | |
| } | |
| rerank := p.paramsHelper.GetRerank(argumentModuleParams["ask"]) | |
| if rerank { | |
| return p.rerank(in), nil | |
| } | |
| return in, nil | |
| } | |
| func answerMeetsSimilarityThreshold(params interface{}, helper paramsHelper, ans *ent.AnswerResult) bool { | |
| certainty := helper.GetCertainty(params) | |
| if certainty > 0 && ans.Certainty != nil && *ans.Certainty < certainty { | |
| return false | |
| } | |
| distance := helper.GetDistance(params) | |
| if distance > 0 && ans.Distance != nil && *ans.Distance > distance { | |
| return false | |
| } | |
| return true | |
| } | |
| func (p *AnswerProvider) rerank(in []search.Result) []search.Result { | |
| if len(in) > 0 { | |
| sort.SliceStable(in, func(i, j int) bool { | |
| return p.getAnswerCertainty(in[i]) > p.getAnswerCertainty(in[j]) | |
| }) | |
| } | |
| return in | |
| } | |
| func (p *AnswerProvider) getAnswerCertainty(result search.Result) float64 { | |
| answerObj, ok := result.AdditionalProperties["answer"] | |
| if ok { | |
| answer, ok := answerObj.(*qnamodels.Answer) | |
| if ok { | |
| if answer.HasAnswer { | |
| return *answer.Certainty | |
| } | |
| } | |
| } | |
| return 0 | |
| } | |
| func (p *AnswerProvider) containsProperty(property string, properties []string) bool { | |
| if len(properties) == 0 { | |
| return true | |
| } | |
| for i := range properties { | |
| if properties[i] == property { | |
| return true | |
| } | |
| } | |
| return false | |
| } | |
| func (p *AnswerProvider) findProperty(answer *string, textProperties map[string]string) (*string, int, int) { | |
| if answer == nil { | |
| return nil, 0, 0 | |
| } | |
| lowercaseAnswer := strings.ToLower(*answer) | |
| if len(lowercaseAnswer) > 0 { | |
| for property, value := range textProperties { | |
| lowercaseValue := strings.ToLower(strings.ReplaceAll(value, "\n", " ")) | |
| if strings.Contains(lowercaseValue, lowercaseAnswer) { | |
| startIndex := strings.Index(lowercaseValue, lowercaseAnswer) | |
| return &property, startIndex, startIndex + len(lowercaseAnswer) | |
| } | |
| } | |
| } | |
| propertyNotFound := "" | |
| return &propertyNotFound, 0, 0 | |
| } | |