Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Test RAG connection error fix | |
Tests the specific multiprocessing and connection timeout issues | |
""" | |
import os | |
import tempfile | |
import warnings | |
# Set environment variables before any imports | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
os.environ['OMP_NUM_THREADS'] = '1' | |
os.environ['MKL_NUM_THREADS'] = '1' | |
# Suppress warnings for cleaner output | |
warnings.filterwarnings("ignore", category=UserWarning) | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
def test_connection_fix(): | |
"""Test the connection error fix specifically""" | |
print("Testing RAG connection error fix...") | |
try: | |
# Test conditional import | |
try: | |
from rag_tool import RAGTool | |
has_rag = True | |
print("β RAG dependencies available") | |
except ImportError: | |
print("β RAG dependencies not available") | |
return False | |
# Create a test document | |
test_content = """This is a test document for connection error testing. | |
It contains multiple sentences to test the embedding process. | |
The document should be processed without connection errors. | |
This tests multiprocessing fixes and memory management.""" | |
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: | |
f.write(test_content) | |
test_file = f.name | |
try: | |
print("β Test document created") | |
# Initialize RAG tool with environment variables already set | |
print("Initializing RAG tool with connection fixes...") | |
rag_tool = RAGTool() | |
print("β RAG tool initialized successfully") | |
# Process document - this was causing the connection error | |
print("Processing document (this was causing connection errors)...") | |
result = rag_tool.process_uploaded_files([test_file]) | |
if result['success']: | |
print(f"β Document processed successfully: {result['message']}") | |
print(f" - Chunks created: {result.get('index_stats', {}).get('total_chunks', 'unknown')}") | |
# Test search to ensure embeddings work | |
context = rag_tool.get_relevant_context("test document", max_chunks=1) | |
print(f"β Search test successful, context length: {len(context)}") | |
return True | |
else: | |
print(f"β Document processing failed: {result['message']}") | |
return False | |
finally: | |
# Clean up | |
if os.path.exists(test_file): | |
os.unlink(test_file) | |
print("β Test file cleaned up") | |
except Exception as e: | |
print(f"β Test failed with error: {e}") | |
return False | |
def test_gradio_integration(): | |
"""Test integration with Gradio interface""" | |
print("\nTesting Gradio integration...") | |
try: | |
import gradio as gr | |
# Create a minimal Gradio interface similar to the main app | |
def test_process_documents(files): | |
"""Minimal version of process_documents for testing""" | |
if not files: | |
return "No files uploaded" | |
try: | |
from rag_tool import RAGTool | |
rag_tool = RAGTool() | |
# Simulate file processing | |
file_paths = [f.name if hasattr(f, 'name') else str(f) for f in files] | |
result = rag_tool.process_uploaded_files(file_paths) | |
if result['success']: | |
return f"β Success: {result['message']}" | |
else: | |
return f"β Failed: {result['message']}" | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
# Create interface without launching | |
with gr.Blocks() as interface: | |
file_input = gr.File(file_count="multiple", label="Test Documents") | |
output = gr.Textbox(label="Result") | |
process_btn = gr.Button("Process") | |
process_btn.click( | |
test_process_documents, | |
inputs=[file_input], | |
outputs=[output] | |
) | |
print("β Gradio interface created successfully") | |
print(" Interface can be launched without connection errors") | |
return True | |
except Exception as e: | |
print(f"β Gradio integration test failed: {e}") | |
return False | |
if __name__ == "__main__": | |
success = test_connection_fix() | |
if success: | |
success = test_gradio_integration() | |
if success: | |
print("\nπ All connection error fixes are working!") | |
print("The RAG processing should now work without connection timeouts.") | |
else: | |
print("\nβ Some tests failed. Check the error messages above.") |