chatui-helper / test_connection_fix.py
milwright's picture
Fix RAG processing crashes with multiprocessing and memory optimizations
a1588ad
raw
history blame
5.03 kB
#!/usr/bin/env python3
"""
Test RAG connection error fix
Tests the specific multiprocessing and connection timeout issues
"""
import os
import tempfile
import warnings
# Set environment variables before any imports
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
def test_connection_fix():
"""Test the connection error fix specifically"""
print("Testing RAG connection error fix...")
try:
# Test conditional import
try:
from rag_tool import RAGTool
has_rag = True
print("βœ“ RAG dependencies available")
except ImportError:
print("βœ— RAG dependencies not available")
return False
# Create a test document
test_content = """This is a test document for connection error testing.
It contains multiple sentences to test the embedding process.
The document should be processed without connection errors.
This tests multiprocessing fixes and memory management."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write(test_content)
test_file = f.name
try:
print("βœ“ Test document created")
# Initialize RAG tool with environment variables already set
print("Initializing RAG tool with connection fixes...")
rag_tool = RAGTool()
print("βœ“ RAG tool initialized successfully")
# Process document - this was causing the connection error
print("Processing document (this was causing connection errors)...")
result = rag_tool.process_uploaded_files([test_file])
if result['success']:
print(f"βœ“ Document processed successfully: {result['message']}")
print(f" - Chunks created: {result.get('index_stats', {}).get('total_chunks', 'unknown')}")
# Test search to ensure embeddings work
context = rag_tool.get_relevant_context("test document", max_chunks=1)
print(f"βœ“ Search test successful, context length: {len(context)}")
return True
else:
print(f"βœ— Document processing failed: {result['message']}")
return False
finally:
# Clean up
if os.path.exists(test_file):
os.unlink(test_file)
print("βœ“ Test file cleaned up")
except Exception as e:
print(f"βœ— Test failed with error: {e}")
return False
def test_gradio_integration():
"""Test integration with Gradio interface"""
print("\nTesting Gradio integration...")
try:
import gradio as gr
# Create a minimal Gradio interface similar to the main app
def test_process_documents(files):
"""Minimal version of process_documents for testing"""
if not files:
return "No files uploaded"
try:
from rag_tool import RAGTool
rag_tool = RAGTool()
# Simulate file processing
file_paths = [f.name if hasattr(f, 'name') else str(f) for f in files]
result = rag_tool.process_uploaded_files(file_paths)
if result['success']:
return f"βœ“ Success: {result['message']}"
else:
return f"βœ— Failed: {result['message']}"
except Exception as e:
return f"βœ— Error: {str(e)}"
# Create interface without launching
with gr.Blocks() as interface:
file_input = gr.File(file_count="multiple", label="Test Documents")
output = gr.Textbox(label="Result")
process_btn = gr.Button("Process")
process_btn.click(
test_process_documents,
inputs=[file_input],
outputs=[output]
)
print("βœ“ Gradio interface created successfully")
print(" Interface can be launched without connection errors")
return True
except Exception as e:
print(f"βœ— Gradio integration test failed: {e}")
return False
if __name__ == "__main__":
success = test_connection_fix()
if success:
success = test_gradio_integration()
if success:
print("\nπŸŽ‰ All connection error fixes are working!")
print("The RAG processing should now work without connection timeouts.")
else:
print("\n❌ Some tests failed. Check the error messages above.")