Spaces:
Running
Running
// _ _ | |
// __ _____ __ ___ ___ __ _| |_ ___ | |
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ | |
// \ V V / __/ (_| |\ V /| | (_| | || __/ | |
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| | |
// | |
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. | |
// | |
// CONTACT: [email protected] | |
// | |
package clients | |
import ( | |
"net/http" | |
"net/http/httptest" | |
"testing" | |
"time" | |
"github.com/stretchr/testify/assert" | |
) | |
func TestGetMeta(t *testing.T) { | |
t.Run("when common server is providing meta", func(t *testing.T) { | |
server := httptest.NewServer(&testMetaHandler{t: t}) | |
defer server.Close() | |
v := New(server.URL, server.URL, 0, nullLogger()) | |
meta, err := v.MetaInfo() | |
assert.Nil(t, err) | |
assert.NotNil(t, meta) | |
model := extractChildMap(t, meta, "model") | |
assert.NotNil(t, model["_name_or_path"]) | |
assert.NotNil(t, model["architectures"]) | |
assert.Contains(t, model["architectures"], "DistilBertModel") | |
ID2Label := extractChildMap(t, model, "id2label") | |
assert.NotNil(t, ID2Label["0"]) | |
assert.NotNil(t, ID2Label["1"]) | |
}) | |
t.Run("when passage and query servers are providing meta", func(t *testing.T) { | |
serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage"}) | |
serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query"}) | |
defer serverPassage.Close() | |
defer serverQuery.Close() | |
v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger()) | |
meta, err := v.MetaInfo() | |
assert.Nil(t, err) | |
assert.NotNil(t, meta) | |
passage := extractChildMap(t, meta, "passage") | |
passageModel := extractChildMap(t, passage, "model") | |
assert.NotNil(t, passageModel["_name_or_path"]) | |
assert.NotNil(t, passageModel["architectures"]) | |
assert.Contains(t, passageModel["architectures"], "DPRContextEncoder") | |
passageID2Label := extractChildMap(t, passageModel, "id2label") | |
assert.NotNil(t, passageID2Label["0"]) | |
assert.NotNil(t, passageID2Label["1"]) | |
query := extractChildMap(t, meta, "query") | |
queryModel := extractChildMap(t, query, "model") | |
assert.NotNil(t, queryModel["_name_or_path"]) | |
assert.NotNil(t, queryModel["architectures"]) | |
assert.Contains(t, queryModel["architectures"], "DPRQuestionEncoder") | |
queryID2Label := extractChildMap(t, queryModel, "id2label") | |
assert.NotNil(t, queryID2Label["0"]) | |
assert.NotNil(t, queryID2Label["1"]) | |
}) | |
t.Run("when passage and query servers are unavailable", func(t *testing.T) { | |
rt := time.Now().Add(time.Hour) | |
serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage", readyTime: rt}) | |
serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query", readyTime: rt}) | |
defer serverPassage.Close() | |
defer serverQuery.Close() | |
v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger()) | |
meta, err := v.MetaInfo() | |
assert.NotNil(t, err) | |
assert.Contains(t, err.Error(), "[passage] unexpected status code '503' of meta request") | |
assert.Contains(t, err.Error(), "[query] unexpected status code '503' of meta request") | |
assert.Nil(t, meta) | |
}) | |
} | |
type testMetaHandler struct { | |
t *testing.T | |
// the test handler will report as not ready before the time has passed | |
readyTime time.Time | |
modelType string | |
} | |
func (h *testMetaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |
assert.Equal(h.t, "/meta", r.URL.String()) | |
assert.Equal(h.t, http.MethodGet, r.Method) | |
if time.Since(h.readyTime) < 0 { | |
w.WriteHeader(http.StatusServiceUnavailable) | |
return | |
} | |
w.Write([]byte(h.metaInfo())) | |
} | |
func (h *testMetaHandler) metaInfo() string { | |
switch h.modelType { | |
case "passage": | |
return `{ | |
"model": { | |
"return_dict": true, | |
"output_hidden_states": false, | |
"output_attentions": false, | |
"torchscript": false, | |
"torch_dtype": "float32", | |
"use_bfloat16": false, | |
"pruned_heads": {}, | |
"tie_word_embeddings": true, | |
"is_encoder_decoder": false, | |
"is_decoder": false, | |
"cross_attention_hidden_size": null, | |
"add_cross_attention": false, | |
"tie_encoder_decoder": false, | |
"max_length": 20, | |
"min_length": 0, | |
"do_sample": false, | |
"early_stopping": false, | |
"num_beams": 1, | |
"num_beam_groups": 1, | |
"diversity_penalty": 0, | |
"temperature": 1, | |
"top_k": 50, | |
"top_p": 1, | |
"repetition_penalty": 1, | |
"length_penalty": 1, | |
"no_repeat_ngram_size": 0, | |
"encoder_no_repeat_ngram_size": 0, | |
"bad_words_ids": null, | |
"num_return_sequences": 1, | |
"chunk_size_feed_forward": 0, | |
"output_scores": false, | |
"return_dict_in_generate": false, | |
"forced_bos_token_id": null, | |
"forced_eos_token_id": null, | |
"remove_invalid_values": false, | |
"architectures": [ | |
"DPRContextEncoder" | |
], | |
"finetuning_task": null, | |
"id2label": { | |
"0": "LABEL_0", | |
"1": "LABEL_1" | |
}, | |
"label2id": { | |
"LABEL_0": 0, | |
"LABEL_1": 1 | |
}, | |
"tokenizer_class": null, | |
"prefix": null, | |
"bos_token_id": null, | |
"pad_token_id": 0, | |
"eos_token_id": null, | |
"sep_token_id": null, | |
"decoder_start_token_id": null, | |
"task_specific_params": null, | |
"problem_type": null, | |
"_name_or_path": "./models/model", | |
"transformers_version": "4.16.2", | |
"gradient_checkpointing": false, | |
"model_type": "dpr", | |
"vocab_size": 30522, | |
"hidden_size": 768, | |
"num_hidden_layers": 12, | |
"num_attention_heads": 12, | |
"hidden_act": "gelu", | |
"intermediate_size": 3072, | |
"hidden_dropout_prob": 0.1, | |
"attention_probs_dropout_prob": 0.1, | |
"max_position_embeddings": 512, | |
"type_vocab_size": 2, | |
"initializer_range": 0.02, | |
"layer_norm_eps": 1e-12, | |
"projection_dim": 0, | |
"position_embedding_type": "absolute" | |
} | |
}` | |
case "query": | |
return `{ | |
"model": { | |
"return_dict": true, | |
"output_hidden_states": false, | |
"output_attentions": false, | |
"torchscript": false, | |
"torch_dtype": "float32", | |
"use_bfloat16": false, | |
"pruned_heads": {}, | |
"tie_word_embeddings": true, | |
"is_encoder_decoder": false, | |
"is_decoder": false, | |
"cross_attention_hidden_size": null, | |
"add_cross_attention": false, | |
"tie_encoder_decoder": false, | |
"max_length": 20, | |
"min_length": 0, | |
"do_sample": false, | |
"early_stopping": false, | |
"num_beams": 1, | |
"num_beam_groups": 1, | |
"diversity_penalty": 0, | |
"temperature": 1, | |
"top_k": 50, | |
"top_p": 1, | |
"repetition_penalty": 1, | |
"length_penalty": 1, | |
"no_repeat_ngram_size": 0, | |
"encoder_no_repeat_ngram_size": 0, | |
"bad_words_ids": null, | |
"num_return_sequences": 1, | |
"chunk_size_feed_forward": 0, | |
"output_scores": false, | |
"return_dict_in_generate": false, | |
"forced_bos_token_id": null, | |
"forced_eos_token_id": null, | |
"remove_invalid_values": false, | |
"architectures": [ | |
"DPRQuestionEncoder" | |
], | |
"finetuning_task": null, | |
"id2label": { | |
"0": "LABEL_0", | |
"1": "LABEL_1" | |
}, | |
"label2id": { | |
"LABEL_0": 0, | |
"LABEL_1": 1 | |
}, | |
"tokenizer_class": null, | |
"prefix": null, | |
"bos_token_id": null, | |
"pad_token_id": 0, | |
"eos_token_id": null, | |
"sep_token_id": null, | |
"decoder_start_token_id": null, | |
"task_specific_params": null, | |
"problem_type": null, | |
"_name_or_path": "./models/model", | |
"transformers_version": "4.16.2", | |
"gradient_checkpointing": false, | |
"model_type": "dpr", | |
"vocab_size": 30522, | |
"hidden_size": 768, | |
"num_hidden_layers": 12, | |
"num_attention_heads": 12, | |
"hidden_act": "gelu", | |
"intermediate_size": 3072, | |
"hidden_dropout_prob": 0.1, | |
"attention_probs_dropout_prob": 0.1, | |
"max_position_embeddings": 512, | |
"type_vocab_size": 2, | |
"initializer_range": 0.02, | |
"layer_norm_eps": 1e-12, | |
"projection_dim": 0, | |
"position_embedding_type": "absolute" | |
} | |
}` | |
default: | |
return `{ | |
"model": { | |
"_name_or_path": "distilbert-base-uncased", | |
"activation": "gelu", | |
"add_cross_attention": false, | |
"architectures": [ | |
"DistilBertModel" | |
], | |
"attention_dropout": 0.1, | |
"bad_words_ids": null, | |
"bos_token_id": null, | |
"chunk_size_feed_forward": 0, | |
"decoder_start_token_id": null, | |
"dim": 768, | |
"diversity_penalty": 0, | |
"do_sample": false, | |
"dropout": 0.1, | |
"early_stopping": false, | |
"encoder_no_repeat_ngram_size": 0, | |
"eos_token_id": null, | |
"finetuning_task": null, | |
"hidden_dim": 3072, | |
"id2label": { | |
"0": "LABEL_0", | |
"1": "LABEL_1" | |
}, | |
"initializer_range": 0.02, | |
"is_decoder": false, | |
"is_encoder_decoder": false, | |
"label2id": { | |
"LABEL_0": 0, | |
"LABEL_1": 1 | |
}, | |
"length_penalty": 1, | |
"max_length": 20, | |
"max_position_embeddings": 512, | |
"min_length": 0, | |
"model_type": "distilbert", | |
"n_heads": 12, | |
"n_layers": 6, | |
"no_repeat_ngram_size": 0, | |
"num_beam_groups": 1, | |
"num_beams": 1, | |
"num_return_sequences": 1, | |
"output_attentions": false, | |
"output_hidden_states": false, | |
"output_scores": false, | |
"pad_token_id": 0, | |
"prefix": null, | |
"pruned_heads": {}, | |
"qa_dropout": 0.1, | |
"repetition_penalty": 1, | |
"return_dict": true, | |
"return_dict_in_generate": false, | |
"sep_token_id": null, | |
"seq_classif_dropout": 0.2, | |
"sinusoidal_pos_embds": false, | |
"task_specific_params": null, | |
"temperature": 1, | |
"tie_encoder_decoder": false, | |
"tie_weights_": true, | |
"tie_word_embeddings": true, | |
"tokenizer_class": null, | |
"top_k": 50, | |
"top_p": 1, | |
"torchscript": false, | |
"transformers_version": "4.3.2", | |
"use_bfloat16": false, | |
"vocab_size": 30522, | |
"xla_device": null | |
} | |
}` | |
} | |
} | |
func extractChildMap(t *testing.T, parent map[string]interface{}, name string) map[string]interface{} { | |
assert.NotNil(t, parent[name]) | |
child, ok := parent[name].(map[string]interface{}) | |
assert.True(t, ok) | |
assert.NotNil(t, child) | |
return child | |
} | |