File size: 2,647 Bytes
41d470a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""test_rag.py module."""

import pytest
from src.rag.retriever import FinancialDataRetriever
from src.rag.generator import RAGGenerator
import yaml

@pytest.fixture
def rag_config():
    with open('config/server_config.yaml', 'r') as f:
        config = yaml.safe_load(f)
        config['rag'] = {
            'retriever': 'faiss',
            'max_documents': 5,
            'similarity_threshold': 0.7
        }
        return config

@pytest.fixture
def retriever(rag_config):
    return FinancialDataRetriever(rag_config)

@pytest.fixture
def generator(rag_config):
    return RAGGenerator(rag_config)

def test_retriever_initialization(retriever, rag_config):
    assert retriever.retriever_type == rag_config['rag']['retriever']
    assert retriever.max_documents == rag_config['rag']['max_documents']

def test_document_indexing(retriever):
    test_documents = [
        {'text': 'Financial report 2023', 'id': 1},
        {'text': 'Market analysis Q4', 'id': 2},
        {'text': 'Investment strategy', 'id': 3}
    ]
    
    retriever.index_documents(test_documents)
    assert retriever.index.ntotal == len(test_documents)

def test_document_retrieval(retriever):
    # Index test documents
    test_documents = [
        {'text': 'Financial report 2023', 'id': 1},
        {'text': 'Market analysis Q4', 'id': 2},
        {'text': 'Investment strategy', 'id': 3}
    ]
    retriever.index_documents(test_documents)
    
    # Test retrieval
    query = "financial report"
    results = retriever.retrieve(query)
    assert len(results) > 0
    assert all('document' in result for result in results)
    assert all('score' in result for result in results)

def test_generator_initialization(generator):
    assert hasattr(generator, 'model')
    assert hasattr(generator, 'tokenizer')

def test_text_generation(generator):
    retrieved_docs = [
        {
            'document': {'text': 'Financial market analysis shows positive trends'},
            'score': 0.9
        }
    ]
    
    generated_text = generator.generate(
        query="Summarize market trends",
        retrieved_docs=retrieved_docs
    )
    
    assert isinstance(generated_text, str)
    assert len(generated_text) > 0

def test_context_preparation(generator):
    retrieved_docs = [
        {
            'document': {'text': 'Doc 1 content'},
            'score': 0.9
        },
        {
            'document': {'text': 'Doc 2 content'},
            'score': 0.8
        }
    ]
    
    context = generator.prepare_context(retrieved_docs)
    assert isinstance(context, str)
    assert 'Doc 1 content' in context
    assert 'Doc 2 content' in context