KevinStephenson
Adding in weaviate code
b110593
raw
history blame
18.7 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package answer
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/search"
qnamodels "github.com/weaviate/weaviate/modules/qna-transformers/additional/models"
"github.com/weaviate/weaviate/modules/qna-transformers/ent"
)
func TestAdditionalAnswerProvider(t *testing.T) {
t.Run("should fail with empty content", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "some-uuid",
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.NotNil(t, err)
require.NotEmpty(t, out)
assert.Error(t, err, "empty content")
})
t.Run("should fail with empty question", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "some-uuid",
Schema: map[string]interface{}{
"content": "content",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.NotNil(t, err)
require.NotEmpty(t, out)
assert.Error(t, err, "empty content")
})
t.Run("should answer", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "some-uuid",
Schema: map[string]interface{}{
"content": "content",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{
"ask": map[string]interface{}{
"question": "question",
},
}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.Nil(t, err)
require.NotEmpty(t, out)
assert.Equal(t, 1, len(in))
answer, answerOK := in[0].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "answer", *answerAdditional.Result)
})
t.Run("should answer with property", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "some-uuid",
Schema: map[string]interface{}{
"content": "content with answer",
"content2": "this one is just a title",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{
"ask": map[string]interface{}{
"question": "question",
"properties": []string{"content", "content2"},
},
}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.Nil(t, err)
require.NotEmpty(t, out)
assert.Equal(t, 1, len(in))
answer, answerOK := in[0].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "answer", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.8, *answerAdditional.Certainty)
assert.InDelta(t, 0.4, *answerAdditional.Distance, 1e-9)
assert.Equal(t, 13, answerAdditional.StartPosition)
assert.Equal(t, 19, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
})
t.Run("should answer with similarity set above ask distance", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "some-uuid",
Schema: map[string]interface{}{
"content": "content with answer",
"content2": "this one is just a title",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{
"ask": map[string]interface{}{
"question": "question",
"properties": []string{"content", "content2"},
"distance": float64(0.4),
},
}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.Nil(t, err)
require.NotEmpty(t, out)
assert.Equal(t, 1, len(out))
answer, answerOK := out[0].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "answer", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.8, *answerAdditional.Certainty)
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.8)), *answerAdditional.Distance)
assert.Equal(t, 13, answerAdditional.StartPosition)
assert.Equal(t, 19, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
})
t.Run("should answer with similarity set above ask certainty", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "some-uuid",
Schema: map[string]interface{}{
"content": "content with answer",
"content2": "this one is just a title",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{
"ask": map[string]interface{}{
"question": "question",
"properties": []string{"content", "content2"},
"certainty": float64(0.8),
},
}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.Nil(t, err)
require.NotEmpty(t, out)
assert.Equal(t, 1, len(out))
answer, answerOK := out[0].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "answer", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.8, *answerAdditional.Certainty)
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.8)), *answerAdditional.Distance)
assert.Equal(t, 13, answerAdditional.StartPosition)
assert.Equal(t, 19, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
})
t.Run("should not answer with distance set below ask distance", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "some-uuid",
Schema: map[string]interface{}{
"content": "content with answer",
"content2": "this one is just a title",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{
"ask": map[string]interface{}{
"question": "question",
"properties": []string{"content", "content2"},
"distance": float64(0.19),
},
}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.Nil(t, err)
require.NotEmpty(t, out)
assert.Equal(t, 1, len(in))
answer, answerOK := in[0].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.True(t, answerAdditional.Result == nil)
assert.True(t, answerAdditional.Property == nil)
assert.True(t, answerAdditional.Certainty == nil)
assert.True(t, answerAdditional.Distance == nil)
assert.Equal(t, 0, answerAdditional.StartPosition)
assert.Equal(t, 0, answerAdditional.EndPosition)
assert.Equal(t, false, answerAdditional.HasAnswer)
})
t.Run("should not answer with certainty set below ask certainty", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "some-uuid",
Schema: map[string]interface{}{
"content": "content with answer",
"content2": "this one is just a title",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{
"ask": map[string]interface{}{
"question": "question",
"properties": []string{"content", "content2"},
"certainty": float64(0.81),
},
}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.Nil(t, err)
require.NotEmpty(t, out)
assert.Equal(t, 1, len(in))
answer, answerOK := in[0].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.True(t, answerAdditional.Result == nil)
assert.True(t, answerAdditional.Property == nil)
assert.True(t, answerAdditional.Certainty == nil)
assert.True(t, answerAdditional.Distance == nil)
assert.Equal(t, 0, answerAdditional.StartPosition)
assert.Equal(t, 0, answerAdditional.EndPosition)
assert.Equal(t, false, answerAdditional.HasAnswer)
})
t.Run("should answer with certainty set above ask certainty and the results should be reranked", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "uuid1",
Schema: map[string]interface{}{
"content": "rerank 0.5",
},
},
{
ID: "uuid2",
Schema: map[string]interface{}{
"content": "rerank 0.2",
},
},
{
ID: "uuid3",
Schema: map[string]interface{}{
"content": "rerank 0.9",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{
"ask": map[string]interface{}{
"question": "question",
"properties": []string{"content"},
"rerank": true,
},
}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.Nil(t, err)
require.NotEmpty(t, out)
assert.Equal(t, 3, len(in))
answer, answerOK := in[0].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "rerank 0.9", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.9, *answerAdditional.Certainty)
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.9)), *answerAdditional.Distance)
assert.Equal(t, 0, answerAdditional.StartPosition)
assert.Equal(t, 10, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
answer, answerOK = in[1].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "rerank 0.5", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.5, *answerAdditional.Certainty)
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.5)), *answerAdditional.Distance)
assert.Equal(t, 0, answerAdditional.StartPosition)
assert.Equal(t, 10, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
answer, answerOK = in[2].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "rerank 0.2", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.2, *answerAdditional.Certainty)
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.2)), *answerAdditional.Distance)
assert.Equal(t, 0, answerAdditional.StartPosition)
assert.Equal(t, 10, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
})
t.Run("should answer with certainty set above ask certainty and the results should not be reranked", func(t *testing.T) {
// given
qnaClient := &fakeQnAClient{}
fakeHelper := &fakeParamsHelper{}
answerProvider := New(qnaClient, fakeHelper)
in := []search.Result{
{
ID: "uuid1",
Schema: map[string]interface{}{
"content": "rerank 0.5",
},
},
{
ID: "uuid2",
Schema: map[string]interface{}{
"content": "rerank 0.2",
},
},
{
ID: "uuid3",
Schema: map[string]interface{}{
"content": "rerank 0.9",
},
},
}
fakeParams := &Params{}
limit := 1
argumentModuleParams := map[string]interface{}{
"ask": map[string]interface{}{
"question": "question",
"properties": []string{"content"},
"rerank": false,
},
}
// when
out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
// then
require.Nil(t, err)
require.NotEmpty(t, out)
assert.Equal(t, 3, len(in))
answer, answerOK := in[0].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "rerank 0.5", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.5, *answerAdditional.Certainty)
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.5)), *answerAdditional.Distance)
assert.Equal(t, 0, answerAdditional.StartPosition)
assert.Equal(t, 10, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
answer, answerOK = in[1].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "rerank 0.2", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.2, *answerAdditional.Certainty)
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.2)), *answerAdditional.Distance)
assert.Equal(t, 0, answerAdditional.StartPosition)
assert.Equal(t, 10, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
answer, answerOK = in[2].AdditionalProperties["answer"]
assert.True(t, answerOK)
assert.NotNil(t, answer)
answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer)
assert.True(t, answerAdditionalOK)
assert.Equal(t, "rerank 0.9", *answerAdditional.Result)
assert.Equal(t, "content", *answerAdditional.Property)
assert.Equal(t, 0.9, *answerAdditional.Certainty)
assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.9)), *answerAdditional.Distance)
assert.Equal(t, 0, answerAdditional.StartPosition)
assert.Equal(t, 10, answerAdditional.EndPosition)
assert.Equal(t, true, answerAdditional.HasAnswer)
})
}
type fakeQnAClient struct{}
func (c *fakeQnAClient) Answer(ctx context.Context,
text, question string,
) (*ent.AnswerResult, error) {
if text == "rerank 0.9" {
return c.getAnswer(question, "rerank 0.9", 0.9), nil
}
if text == "rerank 0.5" {
return c.getAnswer(question, "rerank 0.5", 0.5), nil
}
if text == "rerank 0.2" {
return c.getAnswer(question, "rerank 0.2", 0.2), nil
}
return c.getAnswer(question, "answer", 0.8), nil
}
func (c *fakeQnAClient) getAnswer(question, answer string, certainty float64) *ent.AnswerResult {
return &ent.AnswerResult{
Text: question,
Question: question,
Answer: &answer,
Certainty: &certainty,
Distance: additional.CertaintyToDistPtr(&certainty),
}
}
type fakeParamsHelper struct{}
func (h *fakeParamsHelper) GetQuestion(params interface{}) string {
if fakeParamsMap, ok := params.(map[string]interface{}); ok {
if question, ok := fakeParamsMap["question"].(string); ok {
return question
}
}
return ""
}
func (h *fakeParamsHelper) GetProperties(params interface{}) []string {
if fakeParamsMap, ok := params.(map[string]interface{}); ok {
if properties, ok := fakeParamsMap["properties"].([]string); ok {
return properties
}
}
return nil
}
func (h *fakeParamsHelper) GetCertainty(params interface{}) float64 {
if fakeParamsMap, ok := params.(map[string]interface{}); ok {
if certainty, ok := fakeParamsMap["certainty"].(float64); ok {
return certainty
}
}
return 0
}
func (h *fakeParamsHelper) GetDistance(params interface{}) float64 {
if fakeParamsMap, ok := params.(map[string]interface{}); ok {
if distance, ok := fakeParamsMap["distance"].(float64); ok {
return distance
}
}
return 0
}
func (h *fakeParamsHelper) GetRerank(params interface{}) bool {
if fakeParamsMap, ok := params.(map[string]interface{}); ok {
if rerank, ok := fakeParamsMap["rerank"].(bool); ok {
return rerank
}
}
return false
}
func ptFloat(f float64) *float64 {
return &f
}