Spaces:
Build error
Build error
import os | |
import json | |
import pytest | |
import torch | |
import numpy as np | |
from src.gpp import parse_markdown_table, GPP, GPPConfig | |
from src.qa import Retriever, RetrieverConfig, Reranker, RerankerConfig, AnswerGenerator | |
from src.utils import LLMClient | |
# --- Tests for parse_markdown_table --- | |
def test_parse_markdown_table_valid(): | |
md = """ | |
|h1|h2| | |
|--|--| | |
|a|b| | |
|c|d| | |
""" | |
res = parse_markdown_table(md) | |
assert res['headers'] == ['h1', 'h2'] | |
assert res['rows'] == [['a', 'b'], ['c', 'd']] | |
def test_parse_markdown_table_invalid(): | |
md = "not a table" | |
assert parse_markdown_table(md) is None | |
# --- Tests for GPP.chunk_blocks --- | |
class DummyGPPConfig(GPPConfig): | |
CHUNK_TOKEN_SIZE = 4 # small threshold for testing | |
def gpp(): | |
return GPP(DummyGPPConfig()) | |
def blocks(): | |
return [ | |
{'type': 'text', 'text': 'one two three four'}, | |
{'type': 'table', 'text': '|h|\n|-|\n|v|'}, | |
{'type': 'text', 'text': 'five six'} | |
] | |
def test_chunk_blocks_table_isolation(gpp, blocks): | |
chunks = gpp.chunk_blocks(blocks) | |
# Expect 3 chunks: one text (4 tokens), one table, one text (2 tokens) | |
assert len(chunks) == 3 | |
assert chunks[1]['type'] == 'table' | |
assert 'table_structure' in chunks[1] | |
# --- Tests for Retriever.retrieve combining sparse & dense --- | |
def test_retriever_combine_unique(monkeypatch): | |
chunks = [{'narration': 'a'}, {'narration': 'b'}, {'narration': 'c'}] | |
config = RetrieverConfig() | |
retr = Retriever(chunks, config) | |
# Monkey-patch methods | |
monkeypatch.setattr(Retriever, 'retrieve_sparse', lambda self, q, top_k: [chunks[0], chunks[1]]) | |
monkeypatch.setattr(Retriever, 'retrieve_dense', lambda self, q, top_k: [chunks[1], chunks[2]]) | |
combined = retr.retrieve('query', top_k=2) | |
assert combined == [chunks[0], chunks[1], chunks[2]] | |
# --- Tests for Reranker.rerank with dummy model and tokenizer --- | |
class DummyTokenizer: | |
def __call__(self, queries, contexts, padding, truncation, return_tensors): | |
batch = len(queries) | |
return { | |
'input_ids': torch.ones((batch, 1), dtype=torch.long), | |
'attention_mask': torch.ones((batch, 1), dtype=torch.long) | |
} | |
class DummyModel: | |
def __init__(self): pass | |
def to(self, device): return self | |
def __call__(self, **kwargs): | |
# Generate logits: second candidate more relevant | |
batch = kwargs['input_ids'].shape[0] | |
logits = torch.tensor([[0.1], [0.9]]) if batch == 2 else torch.rand((batch,1)) | |
return type('Out', (), {'logits': logits}) | |
def dummy_pretrained(monkeypatch): | |
import transformers | |
monkeypatch.setattr(transformers.AutoTokenizer, 'from_pretrained', lambda name: DummyTokenizer()) | |
monkeypatch.setattr(transformers.AutoModelForSequenceClassification, 'from_pretrained', lambda name: DummyModel()) | |
return | |
def test_reranker_order(): | |
config = RerankerConfig() | |
rer = Reranker(config) | |
candidates = [{'narration': 'A'}, {'narration': 'B'}] | |
ranked = rer.rerank('q', candidates, top_k=2) | |
# B should be ranked higher than A | |
assert ranked[0]['narration'] == 'B' | |
assert ranked[1]['narration'] == 'A' | |
# --- Tests for AnswerGenerator end-to-end logic --- | |
def test_answer_generator(monkeypatch): | |
# Dummy chunks | |
chunks = [{'narration': 'hello world'}] | |
# Dummy Retriever and Reranker | |
class DummyRetriever: | |
def __init__(self, chunks, config): pass | |
def retrieve(self, q, top_k=10): return chunks | |
class DummyReranker: | |
def __init__(self, config): pass | |
def rerank(self, q, cands, top_k): return chunks | |
# Patch in dummy classes | |
monkeypatch.setattr('src.qa.Retriever', DummyRetriever) | |
monkeypatch.setattr('src.qa.Reranker', DummyReranker) | |
# Patch LLMClient.generate | |
monkeypatch.setattr(LLMClient, 'generate', staticmethod(lambda prompt: 'TEST_ANSWER')) | |
ag = AnswerGenerator() | |
ans, sup = ag.answer(chunks, 'What?') | |
assert ans == 'TEST_ANSWER' | |
assert sup == chunks | |