Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package answer | |
import ( | |
"context" | |
"testing" | |
"github.com/stretchr/testify/assert" | |
"github.com/stretchr/testify/require" | |
"github.com/weaviate/weaviate/entities/additional" | |
"github.com/weaviate/weaviate/entities/search" | |
qnamodels "github.com/weaviate/weaviate/modules/qna-transformers/additional/models" | |
"github.com/weaviate/weaviate/modules/qna-transformers/ent" | |
) | |
func TestAdditionalAnswerProvider(t *testing.T) { | |
t.Run("should fail with empty content", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "some-uuid", | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.NotNil(t, err) | |
require.NotEmpty(t, out) | |
assert.Error(t, err, "empty content") | |
}) | |
t.Run("should fail with empty question", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "some-uuid", | |
Schema: map[string]interface{}{ | |
"content": "content", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.NotNil(t, err) | |
require.NotEmpty(t, out) | |
assert.Error(t, err, "empty content") | |
}) | |
t.Run("should answer", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "some-uuid", | |
Schema: map[string]interface{}{ | |
"content": "content", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{ | |
"ask": map[string]interface{}{ | |
"question": "question", | |
}, | |
} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.Nil(t, err) | |
require.NotEmpty(t, out) | |
assert.Equal(t, 1, len(in)) | |
answer, answerOK := in[0].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "answer", *answerAdditional.Result) | |
}) | |
t.Run("should answer with property", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "some-uuid", | |
Schema: map[string]interface{}{ | |
"content": "content with answer", | |
"content2": "this one is just a title", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{ | |
"ask": map[string]interface{}{ | |
"question": "question", | |
"properties": []string{"content", "content2"}, | |
}, | |
} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.Nil(t, err) | |
require.NotEmpty(t, out) | |
assert.Equal(t, 1, len(in)) | |
answer, answerOK := in[0].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "answer", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.8, *answerAdditional.Certainty) | |
assert.InDelta(t, 0.4, *answerAdditional.Distance, 1e-9) | |
assert.Equal(t, 13, answerAdditional.StartPosition) | |
assert.Equal(t, 19, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
}) | |
t.Run("should answer with similarity set above ask distance", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "some-uuid", | |
Schema: map[string]interface{}{ | |
"content": "content with answer", | |
"content2": "this one is just a title", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{ | |
"ask": map[string]interface{}{ | |
"question": "question", | |
"properties": []string{"content", "content2"}, | |
"distance": float64(0.4), | |
}, | |
} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.Nil(t, err) | |
require.NotEmpty(t, out) | |
assert.Equal(t, 1, len(out)) | |
answer, answerOK := out[0].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "answer", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.8, *answerAdditional.Certainty) | |
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.8)), *answerAdditional.Distance) | |
assert.Equal(t, 13, answerAdditional.StartPosition) | |
assert.Equal(t, 19, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
}) | |
t.Run("should answer with similarity set above ask certainty", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "some-uuid", | |
Schema: map[string]interface{}{ | |
"content": "content with answer", | |
"content2": "this one is just a title", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{ | |
"ask": map[string]interface{}{ | |
"question": "question", | |
"properties": []string{"content", "content2"}, | |
"certainty": float64(0.8), | |
}, | |
} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.Nil(t, err) | |
require.NotEmpty(t, out) | |
assert.Equal(t, 1, len(out)) | |
answer, answerOK := out[0].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "answer", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.8, *answerAdditional.Certainty) | |
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.8)), *answerAdditional.Distance) | |
assert.Equal(t, 13, answerAdditional.StartPosition) | |
assert.Equal(t, 19, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
}) | |
t.Run("should not answer with distance set below ask distance", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "some-uuid", | |
Schema: map[string]interface{}{ | |
"content": "content with answer", | |
"content2": "this one is just a title", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{ | |
"ask": map[string]interface{}{ | |
"question": "question", | |
"properties": []string{"content", "content2"}, | |
"distance": float64(0.19), | |
}, | |
} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.Nil(t, err) | |
require.NotEmpty(t, out) | |
assert.Equal(t, 1, len(in)) | |
answer, answerOK := in[0].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.True(t, answerAdditional.Result == nil) | |
assert.True(t, answerAdditional.Property == nil) | |
assert.True(t, answerAdditional.Certainty == nil) | |
assert.True(t, answerAdditional.Distance == nil) | |
assert.Equal(t, 0, answerAdditional.StartPosition) | |
assert.Equal(t, 0, answerAdditional.EndPosition) | |
assert.Equal(t, false, answerAdditional.HasAnswer) | |
}) | |
t.Run("should not answer with certainty set below ask certainty", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "some-uuid", | |
Schema: map[string]interface{}{ | |
"content": "content with answer", | |
"content2": "this one is just a title", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{ | |
"ask": map[string]interface{}{ | |
"question": "question", | |
"properties": []string{"content", "content2"}, | |
"certainty": float64(0.81), | |
}, | |
} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.Nil(t, err) | |
require.NotEmpty(t, out) | |
assert.Equal(t, 1, len(in)) | |
answer, answerOK := in[0].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.True(t, answerAdditional.Result == nil) | |
assert.True(t, answerAdditional.Property == nil) | |
assert.True(t, answerAdditional.Certainty == nil) | |
assert.True(t, answerAdditional.Distance == nil) | |
assert.Equal(t, 0, answerAdditional.StartPosition) | |
assert.Equal(t, 0, answerAdditional.EndPosition) | |
assert.Equal(t, false, answerAdditional.HasAnswer) | |
}) | |
t.Run("should answer with certainty set above ask certainty and the results should be reranked", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "uuid1", | |
Schema: map[string]interface{}{ | |
"content": "rerank 0.5", | |
}, | |
}, | |
{ | |
ID: "uuid2", | |
Schema: map[string]interface{}{ | |
"content": "rerank 0.2", | |
}, | |
}, | |
{ | |
ID: "uuid3", | |
Schema: map[string]interface{}{ | |
"content": "rerank 0.9", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{ | |
"ask": map[string]interface{}{ | |
"question": "question", | |
"properties": []string{"content"}, | |
"rerank": true, | |
}, | |
} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.Nil(t, err) | |
require.NotEmpty(t, out) | |
assert.Equal(t, 3, len(in)) | |
answer, answerOK := in[0].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "rerank 0.9", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.9, *answerAdditional.Certainty) | |
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.9)), *answerAdditional.Distance) | |
assert.Equal(t, 0, answerAdditional.StartPosition) | |
assert.Equal(t, 10, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
answer, answerOK = in[1].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "rerank 0.5", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.5, *answerAdditional.Certainty) | |
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.5)), *answerAdditional.Distance) | |
assert.Equal(t, 0, answerAdditional.StartPosition) | |
assert.Equal(t, 10, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
answer, answerOK = in[2].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "rerank 0.2", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.2, *answerAdditional.Certainty) | |
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.2)), *answerAdditional.Distance) | |
assert.Equal(t, 0, answerAdditional.StartPosition) | |
assert.Equal(t, 10, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
}) | |
t.Run("should answer with certainty set above ask certainty and the results should not be reranked", func(t *testing.T) { | |
// given | |
qnaClient := &fakeQnAClient{} | |
fakeHelper := &fakeParamsHelper{} | |
answerProvider := New(qnaClient, fakeHelper) | |
in := []search.Result{ | |
{ | |
ID: "uuid1", | |
Schema: map[string]interface{}{ | |
"content": "rerank 0.5", | |
}, | |
}, | |
{ | |
ID: "uuid2", | |
Schema: map[string]interface{}{ | |
"content": "rerank 0.2", | |
}, | |
}, | |
{ | |
ID: "uuid3", | |
Schema: map[string]interface{}{ | |
"content": "rerank 0.9", | |
}, | |
}, | |
} | |
fakeParams := &Params{} | |
limit := 1 | |
argumentModuleParams := map[string]interface{}{ | |
"ask": map[string]interface{}{ | |
"question": "question", | |
"properties": []string{"content"}, | |
"rerank": false, | |
}, | |
} | |
// when | |
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) | |
// then | |
require.Nil(t, err) | |
require.NotEmpty(t, out) | |
assert.Equal(t, 3, len(in)) | |
answer, answerOK := in[0].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "rerank 0.5", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.5, *answerAdditional.Certainty) | |
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.5)), *answerAdditional.Distance) | |
assert.Equal(t, 0, answerAdditional.StartPosition) | |
assert.Equal(t, 10, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
answer, answerOK = in[1].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "rerank 0.2", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.2, *answerAdditional.Certainty) | |
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.2)), *answerAdditional.Distance) | |
assert.Equal(t, 0, answerAdditional.StartPosition) | |
assert.Equal(t, 10, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
answer, answerOK = in[2].AdditionalProperties["answer"] | |
assert.True(t, answerOK) | |
assert.NotNil(t, answer) | |
answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer) | |
assert.True(t, answerAdditionalOK) | |
assert.Equal(t, "rerank 0.9", *answerAdditional.Result) | |
assert.Equal(t, "content", *answerAdditional.Property) | |
assert.Equal(t, 0.9, *answerAdditional.Certainty) | |
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.9)), *answerAdditional.Distance) | |
assert.Equal(t, 0, answerAdditional.StartPosition) | |
assert.Equal(t, 10, answerAdditional.EndPosition) | |
assert.Equal(t, true, answerAdditional.HasAnswer) | |
}) | |
} | |
type fakeQnAClient struct{} | |
func (c *fakeQnAClient) Answer(ctx context.Context, | |
text, question string, | |
) (*ent.AnswerResult, error) { | |
if text == "rerank 0.9" { | |
return c.getAnswer(question, "rerank 0.9", 0.9), nil | |
} | |
if text == "rerank 0.5" { | |
return c.getAnswer(question, "rerank 0.5", 0.5), nil | |
} | |
if text == "rerank 0.2" { | |
return c.getAnswer(question, "rerank 0.2", 0.2), nil | |
} | |
return c.getAnswer(question, "answer", 0.8), nil | |
} | |
func (c *fakeQnAClient) getAnswer(question, answer string, certainty float64) *ent.AnswerResult { | |
return &ent.AnswerResult{ | |
Text: question, | |
Question: question, | |
Answer: &answer, | |
Certainty: &certainty, | |
Distance: additional.CertaintyToDistPtr(&certainty), | |
} | |
} | |
type fakeParamsHelper struct{} | |
func (h *fakeParamsHelper) GetQuestion(params interface{}) string { | |
if fakeParamsMap, ok := params.(map[string]interface{}); ok { | |
if question, ok := fakeParamsMap["question"].(string); ok { | |
return question | |
} | |
} | |
return "" | |
} | |
func (h *fakeParamsHelper) GetProperties(params interface{}) []string { | |
if fakeParamsMap, ok := params.(map[string]interface{}); ok { | |
if properties, ok := fakeParamsMap["properties"].([]string); ok { | |
return properties | |
} | |
} | |
return nil | |
} | |
func (h *fakeParamsHelper) GetCertainty(params interface{}) float64 { | |
if fakeParamsMap, ok := params.(map[string]interface{}); ok { | |
if certainty, ok := fakeParamsMap["certainty"].(float64); ok { | |
return certainty | |
} | |
} | |
return 0 | |
} | |
func (h *fakeParamsHelper) GetDistance(params interface{}) float64 { | |
if fakeParamsMap, ok := params.(map[string]interface{}); ok { | |
if distance, ok := fakeParamsMap["distance"].(float64); ok { | |
return distance | |
} | |
} | |
return 0 | |
} | |
func (h *fakeParamsHelper) GetRerank(params interface{}) bool { | |
if fakeParamsMap, ok := params.(map[string]interface{}); ok { | |
if rerank, ok := fakeParamsMap["rerank"].(bool); ok { | |
return rerank | |
} | |
} | |
return false | |
} | |
func ptFloat(f float64) *float64 { | |
return &f | |
} | |