mirror of
https://github.com/coleam00/Archon.git
synced 2025-12-24 02:39:17 -05:00
Fix race condition in concurrent crawling with unique source IDs (#472)
* Fix race condition in concurrent crawling with unique source IDs - Add unique hash-based source_id generation to prevent conflicts - Separate source identification from display with three fields: - source_id: 16-char SHA256 hash for unique identification - source_url: Original URL for tracking - source_display_name: Human-friendly name for UI - Add comprehensive test suite validating the fix - Migrate existing data with backward compatibility * Fix title generation to use source_display_name for better AI context - Pass source_display_name to title generation function - Use display name in AI prompt instead of hash-based source_id - Results in more specific, meaningful titles for each source * Skip AI title generation when display name is available - Use source_display_name directly as title to avoid unnecessary AI calls - More efficient and predictable than AI-generated titles - Keep AI generation only as fallback for backward compatibility * Fix critical issues from code review - Add missing os import to prevent NameError crash - Remove unused imports (pytest, Mock, patch, hashlib, urlparse, etc.) - Fix GitHub API capitalization consistency - Reuse existing DocumentStorageService instance - Update test expectations to match corrected capitalization Addresses CodeRabbit review feedback on PR #472 * Add safety improvements from code review - Truncate display names to 100 chars when used as titles - Document hash collision probability (negligible for <1M sources) Simple, pragmatic fixes per KISS principle * Fix code extraction to use hash-based source_ids and improve display names - Fixed critical bug where code extraction was using old domain-based source_ids - Updated code extraction service to accept source_id as parameter instead of extracting from URL - Added special handling for llms.txt and sitemap.xml files in display names - Added comprehensive tests for source_id handling in code extraction - Removed unused urlparse import from code_extraction_service.py This fixes the foreign key constraint errors that were preventing code examples from being stored after the source_id architecture refactor. Co-Authored-By: Claude <noreply@anthropic.com> * Fix critical variable shadowing and source_type determination issues - Fixed variable shadowing in document_storage_operations.py where source_url parameter was being overwritten by document URLs, causing incorrect source_url in database - Fixed source_type determination to use actual URLs instead of hash-based source_id - Added comprehensive tests for source URL preservation - Ensure source_type is correctly set to "file" for file uploads, "url" for web crawls The variable shadowing bug was causing sitemap sources to have the wrong source_url (last crawled page instead of sitemap URL). The source_type bug would mark all sources as "url" even for file uploads due to hash-based IDs not starting with "file_". Co-Authored-By: Claude <noreply@anthropic.com> * Fix URL canonicalization and document metrics calculation - Implement proper URL canonicalization to prevent duplicate sources - Remove trailing slashes (except root) - Remove URL fragments - Remove tracking parameters (utm_*, gclid, fbclid, etc.) - Sort query parameters for consistency - Remove default ports (80 for HTTP, 443 for HTTPS) - Normalize scheme and domain to lowercase - Fix avg_chunks_per_doc calculation to avoid division by zero - Track processed_docs count separately from total crawl_results - Handle all-empty document sets gracefully - Show processed/total in logs for better visibility - Add comprehensive tests for both fixes - 10 test cases for URL canonicalization edge cases - 4 test cases for document metrics calculation This prevents database constraint violations when crawling the same content with URL variations and provides accurate metrics in logs. * Fix synchronous extract_source_summary blocking async event loop - Run extract_source_summary in thread pool using asyncio.to_thread - Prevents blocking the async event loop during AI summary generation - Preserves exact error handling and fallback behavior - Variables (source_id, combined_content) properly passed to thread Added comprehensive tests verifying: - Function runs in thread without blocking - Error handling works correctly with fallback - Multiple sources can be processed - Thread safety with variable passing * Fix synchronous update_source_info blocking async event loop - Run update_source_info in thread pool using asyncio.to_thread - Prevents blocking the async event loop during database operations - Preserves exact error handling and fallback behavior - All kwargs properly passed to thread execution Added comprehensive tests verifying: - Function runs in thread without blocking - Error handling triggers fallback correctly - All kwargs are preserved when passed to thread - Existing extract_source_summary tests still pass * Fix race condition in source creation using upsert - Replace INSERT with UPSERT for new sources to prevent PRIMARY KEY violations - Handles concurrent crawls attempting to create the same source - Maintains existing UPDATE behavior for sources that already exist Added comprehensive tests verifying: - Concurrent source creation doesn't fail - Upsert is used for new sources (not insert) - Update is still used for existing sources - Async concurrent operations work correctly - Race conditions with delays are handled This prevents database constraint errors when multiple crawls target the same URL simultaneously. * Add migration detection UI components Add MigrationBanner component with clear user instructions for database schema updates. Add useMigrationStatus hook for periodic health check monitoring with graceful error handling. * Integrate migration banner into main app Add migration status monitoring and banner display to App.tsx. Shows migration banner when database schema updates are required. * Enhance backend startup error instructions Add detailed Docker restart instructions and migration script guidance. Improves user experience when encountering startup failures. * Add database schema caching to health endpoint Implement smart caching for schema validation to prevent repeated database queries. Cache successful validations permanently and throttle failures to 30-second intervals. Replace debug prints with proper logging. * Clean up knowledge API imports and logging Remove duplicate import statements and redundant logging. Improves code clarity and reduces log noise. * Remove unused instructions prop from MigrationBanner Clean up component API by removing instructions prop that was accepted but never rendered. Simplifies the interface and eliminates dead code while keeping the functional hardcoded migration steps. * Add schema_valid flag to migration_required health response Add schema_valid: false flag to health endpoint response when database schema migration is required. Improves API consistency without changing existing behavior. --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
413
python/tests/test_async_source_summary.py
Normal file
413
python/tests/test_async_source_summary.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Test async execution of extract_source_summary and update_source_info.
|
||||
|
||||
This test ensures that synchronous functions extract_source_summary and
|
||||
update_source_info are properly executed in thread pools to avoid blocking
|
||||
the async event loop.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
from src.server.services.crawling.document_storage_operations import DocumentStorageOperations
|
||||
|
||||
|
||||
class TestAsyncSourceSummary:
|
||||
"""Test that extract_source_summary and update_source_info don't block the async event loop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_summary_runs_in_thread(self):
|
||||
"""Test that extract_source_summary is executed in a thread pool."""
|
||||
# Create mock supabase client
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
|
||||
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Track when extract_source_summary is called
|
||||
summary_call_times = []
|
||||
original_summary_result = "Test summary from AI"
|
||||
|
||||
def slow_extract_summary(source_id, content):
|
||||
"""Simulate a slow synchronous function that would block the event loop."""
|
||||
summary_call_times.append(time.time())
|
||||
# Simulate a blocking operation (like an API call)
|
||||
time.sleep(0.1) # This would block the event loop if not run in thread
|
||||
return original_summary_result
|
||||
|
||||
# Mock the storage service
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
return_value=["chunk1", "chunk2"]
|
||||
)
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
|
||||
side_effect=slow_extract_summary):
|
||||
with patch('src.server.services.crawling.document_storage_operations.update_source_info'):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error'):
|
||||
# Create test metadata
|
||||
all_metadatas = [
|
||||
{"source_id": "test123", "word_count": 100},
|
||||
{"source_id": "test123", "word_count": 150},
|
||||
]
|
||||
all_contents = ["chunk1", "chunk2"]
|
||||
source_word_counts = {"test123": 250}
|
||||
request = {"knowledge_type": "documentation"}
|
||||
|
||||
# Track async execution
|
||||
start_time = time.time()
|
||||
|
||||
# This should not block despite the sleep in extract_summary
|
||||
await doc_storage._create_source_records(
|
||||
all_metadatas,
|
||||
all_contents,
|
||||
source_word_counts,
|
||||
request,
|
||||
"https://example.com",
|
||||
"Example Site"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Verify that extract_source_summary was called
|
||||
assert len(summary_call_times) == 1, "extract_source_summary should be called once"
|
||||
|
||||
# The async function should complete without blocking
|
||||
# Even though extract_summary sleeps for 0.1s, the async function
|
||||
# should not be blocked since it runs in a thread
|
||||
total_time = end_time - start_time
|
||||
|
||||
# We can't guarantee exact timing, but it should complete
|
||||
# without throwing a timeout error
|
||||
assert total_time < 1.0, "Should complete in reasonable time"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_summary_error_handling(self):
|
||||
"""Test that errors in extract_source_summary are handled correctly."""
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
|
||||
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Mock to raise an exception
|
||||
def failing_extract_summary(source_id, content):
|
||||
raise RuntimeError("AI service unavailable")
|
||||
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
return_value=["chunk1"]
|
||||
)
|
||||
|
||||
error_messages = []
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
|
||||
side_effect=failing_extract_summary):
|
||||
with patch('src.server.services.crawling.document_storage_operations.update_source_info') as mock_update:
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error') as mock_error:
|
||||
mock_error.side_effect = lambda msg: error_messages.append(msg)
|
||||
|
||||
all_metadatas = [{"source_id": "test456", "word_count": 100}]
|
||||
all_contents = ["chunk1"]
|
||||
source_word_counts = {"test456": 100}
|
||||
request = {}
|
||||
|
||||
await doc_storage._create_source_records(
|
||||
all_metadatas,
|
||||
all_contents,
|
||||
source_word_counts,
|
||||
request,
|
||||
None,
|
||||
None
|
||||
)
|
||||
|
||||
# Verify error was logged
|
||||
assert len(error_messages) == 1
|
||||
assert "Failed to generate AI summary" in error_messages[0]
|
||||
assert "AI service unavailable" in error_messages[0]
|
||||
|
||||
# Verify fallback summary was used
|
||||
mock_update.assert_called_once()
|
||||
call_args = mock_update.call_args
|
||||
assert call_args.kwargs["summary"] == "Documentation from test456 - 1 pages crawled"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_sources_concurrent_summaries(self):
|
||||
"""Test that multiple source summaries are generated concurrently."""
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
|
||||
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Track concurrent executions
|
||||
execution_order = []
|
||||
|
||||
def track_extract_summary(source_id, content):
|
||||
execution_order.append(f"start_{source_id}")
|
||||
time.sleep(0.05) # Simulate work
|
||||
execution_order.append(f"end_{source_id}")
|
||||
return f"Summary for {source_id}"
|
||||
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
return_value=["chunk"]
|
||||
)
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
|
||||
side_effect=track_extract_summary):
|
||||
with patch('src.server.services.crawling.document_storage_operations.update_source_info'):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
|
||||
# Create metadata for multiple sources
|
||||
all_metadatas = [
|
||||
{"source_id": "source1", "word_count": 100},
|
||||
{"source_id": "source2", "word_count": 150},
|
||||
{"source_id": "source3", "word_count": 200},
|
||||
]
|
||||
all_contents = ["chunk1", "chunk2", "chunk3"]
|
||||
source_word_counts = {
|
||||
"source1": 100,
|
||||
"source2": 150,
|
||||
"source3": 200,
|
||||
}
|
||||
request = {}
|
||||
|
||||
await doc_storage._create_source_records(
|
||||
all_metadatas,
|
||||
all_contents,
|
||||
source_word_counts,
|
||||
request,
|
||||
None,
|
||||
None
|
||||
)
|
||||
|
||||
# With threading, sources are processed sequentially in the loop
|
||||
# but the extract_summary calls happen in threads
|
||||
assert len(execution_order) == 6 # 3 sources * 2 events each
|
||||
|
||||
# Verify all sources were processed
|
||||
processed_sources = set()
|
||||
for event in execution_order:
|
||||
if event.startswith("start_"):
|
||||
processed_sources.add(event.replace("start_", ""))
|
||||
|
||||
assert processed_sources == {"source1", "source2", "source3"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_safety_with_variables(self):
|
||||
"""Test that variables are properly passed to thread execution."""
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
|
||||
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Track what gets passed to extract_summary
|
||||
captured_calls = []
|
||||
|
||||
def capture_extract_summary(source_id, content):
|
||||
captured_calls.append({
|
||||
"source_id": source_id,
|
||||
"content_len": len(content),
|
||||
"content_preview": content[:50] if content else ""
|
||||
})
|
||||
return f"Summary for {source_id}"
|
||||
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
return_value=["This is chunk one with some content",
|
||||
"This is chunk two with more content"]
|
||||
)
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
|
||||
side_effect=capture_extract_summary):
|
||||
with patch('src.server.services.crawling.document_storage_operations.update_source_info'):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
|
||||
all_metadatas = [
|
||||
{"source_id": "test789", "word_count": 100},
|
||||
{"source_id": "test789", "word_count": 150},
|
||||
]
|
||||
all_contents = [
|
||||
"This is chunk one with some content",
|
||||
"This is chunk two with more content"
|
||||
]
|
||||
source_word_counts = {"test789": 250}
|
||||
request = {}
|
||||
|
||||
await doc_storage._create_source_records(
|
||||
all_metadatas,
|
||||
all_contents,
|
||||
source_word_counts,
|
||||
request,
|
||||
None,
|
||||
None
|
||||
)
|
||||
|
||||
# Verify the correct values were passed to the thread
|
||||
assert len(captured_calls) == 1
|
||||
call = captured_calls[0]
|
||||
assert call["source_id"] == "test789"
|
||||
assert call["content_len"] > 0
|
||||
# Combined content should start with space + first chunk
|
||||
assert "This is chunk one" in call["content_preview"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_source_info_runs_in_thread(self):
|
||||
"""Test that update_source_info is executed in a thread pool."""
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
|
||||
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Track when update_source_info is called
|
||||
update_call_times = []
|
||||
|
||||
def slow_update_source_info(**kwargs):
|
||||
"""Simulate a slow synchronous database operation."""
|
||||
update_call_times.append(time.time())
|
||||
# Simulate a blocking database operation
|
||||
time.sleep(0.1) # This would block the event loop if not run in thread
|
||||
return None # update_source_info doesn't return anything
|
||||
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
return_value=["chunk1"]
|
||||
)
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
|
||||
return_value="Test summary"):
|
||||
with patch('src.server.services.crawling.document_storage_operations.update_source_info',
|
||||
side_effect=slow_update_source_info):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error'):
|
||||
all_metadatas = [{"source_id": "test_update", "word_count": 100}]
|
||||
all_contents = ["chunk1"]
|
||||
source_word_counts = {"test_update": 100}
|
||||
request = {"knowledge_type": "documentation", "tags": ["test"]}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# This should not block despite the sleep in update_source_info
|
||||
await doc_storage._create_source_records(
|
||||
all_metadatas,
|
||||
all_contents,
|
||||
source_word_counts,
|
||||
request,
|
||||
"https://example.com",
|
||||
"Example Site"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Verify that update_source_info was called
|
||||
assert len(update_call_times) == 1, "update_source_info should be called once"
|
||||
|
||||
# The async function should complete without blocking
|
||||
total_time = end_time - start_time
|
||||
assert total_time < 1.0, "Should complete in reasonable time"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_source_info_error_handling(self):
|
||||
"""Test that errors in update_source_info trigger fallback correctly."""
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
|
||||
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Mock to raise an exception
|
||||
def failing_update_source_info(**kwargs):
|
||||
raise RuntimeError("Database connection failed")
|
||||
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
return_value=["chunk1"]
|
||||
)
|
||||
|
||||
error_messages = []
|
||||
fallback_called = False
|
||||
|
||||
def track_fallback_upsert(data):
|
||||
nonlocal fallback_called
|
||||
fallback_called = True
|
||||
return Mock(execute=Mock())
|
||||
|
||||
mock_supabase.table.return_value.upsert.side_effect = track_fallback_upsert
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
|
||||
return_value="Test summary"):
|
||||
with patch('src.server.services.crawling.document_storage_operations.update_source_info',
|
||||
side_effect=failing_update_source_info):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error') as mock_error:
|
||||
mock_error.side_effect = lambda msg: error_messages.append(msg)
|
||||
|
||||
all_metadatas = [{"source_id": "test_fail", "word_count": 100}]
|
||||
all_contents = ["chunk1"]
|
||||
source_word_counts = {"test_fail": 100}
|
||||
request = {"knowledge_type": "technical", "tags": ["test"]}
|
||||
|
||||
await doc_storage._create_source_records(
|
||||
all_metadatas,
|
||||
all_contents,
|
||||
source_word_counts,
|
||||
request,
|
||||
"https://example.com",
|
||||
"Example Site"
|
||||
)
|
||||
|
||||
# Verify error was logged
|
||||
assert any("Failed to create/update source record" in msg for msg in error_messages)
|
||||
assert any("Database connection failed" in msg for msg in error_messages)
|
||||
|
||||
# Verify fallback was attempted
|
||||
assert fallback_called, "Fallback upsert should be called"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_source_info_preserves_kwargs(self):
|
||||
"""Test that all kwargs are properly passed to update_source_info in thread."""
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
|
||||
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Track what gets passed to update_source_info
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_update_source_info(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return None
|
||||
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
return_value=["chunk content"]
|
||||
)
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
|
||||
return_value="Generated summary"):
|
||||
with patch('src.server.services.crawling.document_storage_operations.update_source_info',
|
||||
side_effect=capture_update_source_info):
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
|
||||
all_metadatas = [{"source_id": "test_kwargs", "word_count": 250}]
|
||||
all_contents = ["chunk content"]
|
||||
source_word_counts = {"test_kwargs": 250}
|
||||
request = {
|
||||
"knowledge_type": "api_reference",
|
||||
"tags": ["api", "docs"],
|
||||
"url": "https://original.url/crawl"
|
||||
}
|
||||
|
||||
await doc_storage._create_source_records(
|
||||
all_metadatas,
|
||||
all_contents,
|
||||
source_word_counts,
|
||||
request,
|
||||
"https://source.url",
|
||||
"Source Display Name"
|
||||
)
|
||||
|
||||
# Verify all kwargs were passed correctly
|
||||
assert captured_kwargs["client"] == mock_supabase
|
||||
assert captured_kwargs["source_id"] == "test_kwargs"
|
||||
assert captured_kwargs["summary"] == "Generated summary"
|
||||
assert captured_kwargs["word_count"] == 250
|
||||
assert "chunk content" in captured_kwargs["content"]
|
||||
assert captured_kwargs["knowledge_type"] == "api_reference"
|
||||
assert captured_kwargs["tags"] == ["api", "docs"]
|
||||
assert captured_kwargs["update_frequency"] == 0
|
||||
assert captured_kwargs["original_url"] == "https://original.url/crawl"
|
||||
assert captured_kwargs["source_url"] == "https://source.url"
|
||||
assert captured_kwargs["source_display_name"] == "Source Display Name"
|
||||
184
python/tests/test_code_extraction_source_id.py
Normal file
184
python/tests/test_code_extraction_source_id.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Test that code extraction uses the correct source_id.
|
||||
|
||||
This test ensures that the fix for using hash-based source_ids
|
||||
instead of domain-based source_ids works correctly.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
from src.server.services.crawling.code_extraction_service import CodeExtractionService
|
||||
from src.server.services.crawling.document_storage_operations import DocumentStorageOperations
|
||||
|
||||
|
||||
class TestCodeExtractionSourceId:
|
||||
"""Test that code extraction properly uses the provided source_id."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_extraction_uses_provided_source_id(self):
|
||||
"""Test that code extraction uses the hash-based source_id, not domain."""
|
||||
# Create mock supabase client
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
|
||||
|
||||
# Create service instance
|
||||
code_service = CodeExtractionService(mock_supabase)
|
||||
|
||||
# Track what gets passed to the internal extraction method
|
||||
extracted_blocks = []
|
||||
|
||||
async def mock_extract_blocks(crawl_results, source_id, progress_callback=None, start=0, end=100):
|
||||
# Simulate finding code blocks and verify source_id is passed correctly
|
||||
for doc in crawl_results:
|
||||
extracted_blocks.append({
|
||||
"block": {"code": "print('hello')", "language": "python"},
|
||||
"source_url": doc["url"],
|
||||
"source_id": source_id # This should be the provided source_id
|
||||
})
|
||||
return extracted_blocks
|
||||
|
||||
code_service._extract_code_blocks_from_documents = mock_extract_blocks
|
||||
code_service._generate_code_summaries = AsyncMock(return_value=[{"summary": "Test code"}])
|
||||
code_service._prepare_code_examples_for_storage = Mock(return_value=[
|
||||
{"source_id": extracted_blocks[0]["source_id"] if extracted_blocks else None}
|
||||
])
|
||||
code_service._store_code_examples = AsyncMock(return_value=1)
|
||||
|
||||
# Test data
|
||||
crawl_results = [
|
||||
{
|
||||
"url": "https://docs.mem0.ai/example",
|
||||
"markdown": "```python\nprint('hello')\n```"
|
||||
}
|
||||
]
|
||||
|
||||
url_to_full_document = {
|
||||
"https://docs.mem0.ai/example": "Full content with code"
|
||||
}
|
||||
|
||||
# The correct hash-based source_id
|
||||
correct_source_id = "393224e227ba92eb"
|
||||
|
||||
# Call the method with the correct source_id
|
||||
result = await code_service.extract_and_store_code_examples(
|
||||
crawl_results,
|
||||
url_to_full_document,
|
||||
correct_source_id,
|
||||
None,
|
||||
0,
|
||||
100
|
||||
)
|
||||
|
||||
# Verify that extracted blocks use the correct source_id
|
||||
assert len(extracted_blocks) > 0, "Should have extracted at least one code block"
|
||||
|
||||
for block in extracted_blocks:
|
||||
# Check that it's using the hash-based source_id, not the domain
|
||||
assert block["source_id"] == correct_source_id, \
|
||||
f"Should use hash-based source_id '{correct_source_id}', not domain"
|
||||
assert block["source_id"] != "docs.mem0.ai", \
|
||||
"Should NOT use domain-based source_id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_storage_passes_source_id(self):
|
||||
"""Test that DocumentStorageOperations passes source_id to code extraction."""
|
||||
# Create mock supabase client
|
||||
mock_supabase = Mock()
|
||||
|
||||
# Create DocumentStorageOperations instance
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Mock the code extraction service
|
||||
mock_extract = AsyncMock(return_value=5)
|
||||
doc_storage.code_extraction_service.extract_and_store_code_examples = mock_extract
|
||||
|
||||
# Test data
|
||||
crawl_results = [{"url": "https://example.com", "markdown": "test"}]
|
||||
url_to_full_document = {"https://example.com": "test content"}
|
||||
source_id = "abc123def456"
|
||||
|
||||
# Call the wrapper method
|
||||
result = await doc_storage.extract_and_store_code_examples(
|
||||
crawl_results,
|
||||
url_to_full_document,
|
||||
source_id,
|
||||
None,
|
||||
0,
|
||||
100
|
||||
)
|
||||
|
||||
# Verify the correct source_id was passed
|
||||
mock_extract.assert_called_once_with(
|
||||
crawl_results,
|
||||
url_to_full_document,
|
||||
source_id, # This should be the third argument
|
||||
None,
|
||||
0,
|
||||
100
|
||||
)
|
||||
assert result == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_domain_extraction_from_url(self):
|
||||
"""Test that we're NOT extracting domain from URL anymore."""
|
||||
mock_supabase = Mock()
|
||||
mock_supabase.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
|
||||
|
||||
code_service = CodeExtractionService(mock_supabase)
|
||||
|
||||
# Patch internal methods
|
||||
code_service._get_setting = AsyncMock(return_value=True)
|
||||
|
||||
# Create a mock that will track what source_id is used
|
||||
source_ids_seen = []
|
||||
|
||||
original_extract = code_service._extract_code_blocks_from_documents
|
||||
async def track_source_id(crawl_results, source_id, progress_callback=None, start=0, end=100):
|
||||
source_ids_seen.append(source_id)
|
||||
return [] # Return empty list to skip further processing
|
||||
|
||||
code_service._extract_code_blocks_from_documents = track_source_id
|
||||
|
||||
# Test with various URLs that would produce different domains
|
||||
test_cases = [
|
||||
("https://github.com/example/repo", "github123abc"),
|
||||
("https://docs.python.org/guide", "python456def"),
|
||||
("https://api.openai.com/v1", "openai789ghi"),
|
||||
]
|
||||
|
||||
for url, expected_source_id in test_cases:
|
||||
source_ids_seen.clear()
|
||||
|
||||
crawl_results = [{"url": url, "markdown": "# Test"}]
|
||||
url_to_full_document = {url: "Full content"}
|
||||
|
||||
await code_service.extract_and_store_code_examples(
|
||||
crawl_results,
|
||||
url_to_full_document,
|
||||
expected_source_id,
|
||||
None,
|
||||
0,
|
||||
100
|
||||
)
|
||||
|
||||
# Verify the provided source_id was used
|
||||
assert len(source_ids_seen) == 1
|
||||
assert source_ids_seen[0] == expected_source_id
|
||||
# Verify it's NOT the domain
|
||||
assert "github.com" not in source_ids_seen[0]
|
||||
assert "python.org" not in source_ids_seen[0]
|
||||
assert "openai.com" not in source_ids_seen[0]
|
||||
|
||||
def test_urlparse_not_imported(self):
|
||||
"""Test that urlparse is not imported in code_extraction_service."""
|
||||
import src.server.services.crawling.code_extraction_service as module
|
||||
|
||||
# Check that urlparse is not in the module's namespace
|
||||
assert not hasattr(module, 'urlparse'), \
|
||||
"urlparse should not be imported in code_extraction_service"
|
||||
|
||||
# Check the module's actual imports
|
||||
import inspect
|
||||
source = inspect.getsource(module)
|
||||
assert "from urllib.parse import urlparse" not in source, \
|
||||
"Should not import urlparse since we don't extract domain from URL anymore"
|
||||
205
python/tests/test_document_storage_metrics.py
Normal file
205
python/tests/test_document_storage_metrics.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Test document storage metrics calculation.
|
||||
|
||||
This test ensures that avg_chunks_per_doc is calculated correctly
|
||||
and handles edge cases like empty documents.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from src.server.services.crawling.document_storage_operations import DocumentStorageOperations
|
||||
|
||||
|
||||
class TestDocumentStorageMetrics:
|
||||
"""Test metrics calculation in document storage operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_avg_chunks_calculation_with_empty_docs(self):
|
||||
"""Test that avg_chunks_per_doc handles empty documents correctly."""
|
||||
# Create mock supabase client
|
||||
mock_supabase = Mock()
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Mock the storage service
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
side_effect=lambda text, chunk_size: ["chunk1", "chunk2"] if text else []
|
||||
)
|
||||
|
||||
# Mock internal methods
|
||||
doc_storage._create_source_records = AsyncMock()
|
||||
|
||||
# Track what gets logged
|
||||
logged_messages = []
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log:
|
||||
mock_log.side_effect = lambda msg: logged_messages.append(msg)
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
|
||||
# Test data with mix of empty and non-empty documents
|
||||
crawl_results = [
|
||||
{"url": "https://example.com/page1", "markdown": "Content 1"},
|
||||
{"url": "https://example.com/page2", "markdown": ""}, # Empty
|
||||
{"url": "https://example.com/page3", "markdown": "Content 3"},
|
||||
{"url": "https://example.com/page4", "markdown": ""}, # Empty
|
||||
{"url": "https://example.com/page5", "markdown": "Content 5"},
|
||||
]
|
||||
|
||||
result = await doc_storage.process_and_store_documents(
|
||||
crawl_results=crawl_results,
|
||||
request={},
|
||||
crawl_type="test",
|
||||
original_source_id="test123",
|
||||
source_url="https://example.com",
|
||||
source_display_name="Example"
|
||||
)
|
||||
|
||||
# Find the metrics log message
|
||||
metrics_log = None
|
||||
for msg in logged_messages:
|
||||
if "Document storage | processed=" in msg:
|
||||
metrics_log = msg
|
||||
break
|
||||
|
||||
assert metrics_log is not None, "Should log metrics"
|
||||
|
||||
# Verify metrics are correct
|
||||
# 3 documents processed (non-empty), 5 total, 6 chunks (2 per doc), avg = 2.0
|
||||
assert "processed=3/5" in metrics_log, "Should show 3 processed out of 5 total"
|
||||
assert "chunks=6" in metrics_log, "Should have 6 chunks total"
|
||||
assert "avg_chunks_per_doc=2.0" in metrics_log, "Average should be 2.0 (6/3)"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_avg_chunks_all_empty_docs(self):
|
||||
"""Test that avg_chunks_per_doc handles all empty documents without division by zero."""
|
||||
mock_supabase = Mock()
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Mock the storage service
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=[])
|
||||
doc_storage._create_source_records = AsyncMock()
|
||||
|
||||
logged_messages = []
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log:
|
||||
mock_log.side_effect = lambda msg: logged_messages.append(msg)
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
|
||||
# All documents are empty
|
||||
crawl_results = [
|
||||
{"url": "https://example.com/page1", "markdown": ""},
|
||||
{"url": "https://example.com/page2", "markdown": ""},
|
||||
{"url": "https://example.com/page3", "markdown": ""},
|
||||
]
|
||||
|
||||
result = await doc_storage.process_and_store_documents(
|
||||
crawl_results=crawl_results,
|
||||
request={},
|
||||
crawl_type="test",
|
||||
original_source_id="test456",
|
||||
source_url="https://example.com",
|
||||
source_display_name="Example"
|
||||
)
|
||||
|
||||
# Find the metrics log
|
||||
metrics_log = None
|
||||
for msg in logged_messages:
|
||||
if "Document storage | processed=" in msg:
|
||||
metrics_log = msg
|
||||
break
|
||||
|
||||
assert metrics_log is not None, "Should log metrics even with no processed docs"
|
||||
|
||||
# Should show 0 processed, 0 chunks, 0.0 average (no division by zero)
|
||||
assert "processed=0/3" in metrics_log, "Should show 0 processed out of 3 total"
|
||||
assert "chunks=0" in metrics_log, "Should have 0 chunks"
|
||||
assert "avg_chunks_per_doc=0.0" in metrics_log, "Average should be 0.0 (no division by zero)"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_avg_chunks_single_doc(self):
|
||||
"""Test avg_chunks_per_doc with a single document."""
|
||||
mock_supabase = Mock()
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Mock to return 5 chunks for content
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(
|
||||
return_value=["chunk1", "chunk2", "chunk3", "chunk4", "chunk5"]
|
||||
)
|
||||
doc_storage._create_source_records = AsyncMock()
|
||||
|
||||
logged_messages = []
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log:
|
||||
mock_log.side_effect = lambda msg: logged_messages.append(msg)
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
|
||||
crawl_results = [
|
||||
{"url": "https://example.com/page", "markdown": "Long content here..."},
|
||||
]
|
||||
|
||||
result = await doc_storage.process_and_store_documents(
|
||||
crawl_results=crawl_results,
|
||||
request={},
|
||||
crawl_type="test",
|
||||
original_source_id="test789",
|
||||
source_url="https://example.com",
|
||||
source_display_name="Example"
|
||||
)
|
||||
|
||||
# Find metrics log
|
||||
metrics_log = None
|
||||
for msg in logged_messages:
|
||||
if "Document storage | processed=" in msg:
|
||||
metrics_log = msg
|
||||
break
|
||||
|
||||
assert metrics_log is not None
|
||||
assert "processed=1/1" in metrics_log, "Should show 1 processed out of 1 total"
|
||||
assert "chunks=5" in metrics_log, "Should have 5 chunks"
|
||||
assert "avg_chunks_per_doc=5.0" in metrics_log, "Average should be 5.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processed_count_accuracy(self):
|
||||
"""Test that processed_docs count is accurate."""
|
||||
mock_supabase = Mock()
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Track which documents are chunked
|
||||
chunked_urls = []
|
||||
|
||||
def mock_chunk(text, chunk_size):
|
||||
if text:
|
||||
return ["chunk"]
|
||||
return []
|
||||
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(side_effect=mock_chunk)
|
||||
doc_storage._create_source_records = AsyncMock()
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
|
||||
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
|
||||
# Mix of documents with various content states
|
||||
crawl_results = [
|
||||
{"url": "https://example.com/1", "markdown": "Content"},
|
||||
{"url": "https://example.com/2", "markdown": ""}, # Empty markdown
|
||||
{"url": "https://example.com/3", "markdown": None}, # None markdown
|
||||
{"url": "https://example.com/4", "markdown": "More content"},
|
||||
{"url": "https://example.com/5"}, # Missing markdown key
|
||||
{"url": "https://example.com/6", "markdown": " "}, # Whitespace (counts as content)
|
||||
]
|
||||
|
||||
result = await doc_storage.process_and_store_documents(
|
||||
crawl_results=crawl_results,
|
||||
request={},
|
||||
crawl_type="test",
|
||||
original_source_id="test999",
|
||||
source_url="https://example.com",
|
||||
source_display_name="Example"
|
||||
)
|
||||
|
||||
# Should process documents 1, 4, and 6 (has content including whitespace)
|
||||
assert result["chunk_count"] == 3, "Should have 3 chunks (one per processed doc)"
|
||||
|
||||
# Check url_to_full_document only has processed docs
|
||||
assert len(result["url_to_full_document"]) == 3
|
||||
assert "https://example.com/1" in result["url_to_full_document"]
|
||||
assert "https://example.com/4" in result["url_to_full_document"]
|
||||
assert "https://example.com/6" in result["url_to_full_document"]
|
||||
357
python/tests/test_source_id_refactor.py
Normal file
357
python/tests/test_source_id_refactor.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Test Suite for Source ID Architecture Refactor
|
||||
|
||||
Tests the new unique source ID generation and display name extraction
|
||||
to ensure the race condition fix works correctly.
|
||||
"""
|
||||
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# Import the URLHandler class
|
||||
from src.server.services.crawling.helpers.url_handler import URLHandler
|
||||
|
||||
|
||||
class TestSourceIDGeneration:
|
||||
"""Test the unique source ID generation."""
|
||||
|
||||
def test_unique_id_generation_basic(self):
|
||||
"""Test basic unique ID generation."""
|
||||
handler = URLHandler()
|
||||
|
||||
# Test various URLs
|
||||
test_urls = [
|
||||
"https://github.com/microsoft/typescript",
|
||||
"https://github.com/facebook/react",
|
||||
"https://docs.python.org/3/",
|
||||
"https://fastapi.tiangolo.com/",
|
||||
"https://pydantic.dev/",
|
||||
]
|
||||
|
||||
source_ids = []
|
||||
for url in test_urls:
|
||||
source_id = handler.generate_unique_source_id(url)
|
||||
source_ids.append(source_id)
|
||||
|
||||
# Check that ID is a 16-character hex string
|
||||
assert len(source_id) == 16, f"ID should be 16 chars, got {len(source_id)}"
|
||||
assert all(c in '0123456789abcdef' for c in source_id), f"ID should be hex: {source_id}"
|
||||
|
||||
# All IDs should be unique
|
||||
assert len(set(source_ids)) == len(source_ids), "All source IDs should be unique"
|
||||
|
||||
def test_same_domain_different_ids(self):
|
||||
"""Test that same domain with different paths generates different IDs."""
|
||||
handler = URLHandler()
|
||||
|
||||
# Multiple GitHub repos (same domain, different paths)
|
||||
github_urls = [
|
||||
"https://github.com/owner1/repo1",
|
||||
"https://github.com/owner1/repo2",
|
||||
"https://github.com/owner2/repo1",
|
||||
]
|
||||
|
||||
ids = [handler.generate_unique_source_id(url) for url in github_urls]
|
||||
|
||||
# All should be unique despite same domain
|
||||
assert len(set(ids)) == len(ids), "Same domain should generate different IDs for different URLs"
|
||||
|
||||
def test_id_consistency(self):
|
||||
"""Test that the same URL always generates the same ID."""
|
||||
handler = URLHandler()
|
||||
url = "https://github.com/microsoft/typescript"
|
||||
|
||||
# Generate ID multiple times
|
||||
ids = [handler.generate_unique_source_id(url) for _ in range(5)]
|
||||
|
||||
# All should be identical
|
||||
assert len(set(ids)) == 1, f"Same URL should always generate same ID, got: {set(ids)}"
|
||||
assert ids[0] == ids[4], "First and last ID should match"
|
||||
|
||||
def test_url_normalization(self):
|
||||
"""Test that URL normalization works correctly."""
|
||||
handler = URLHandler()
|
||||
|
||||
# These should all generate the same ID (after normalization)
|
||||
url_variations = [
|
||||
"https://github.com/Microsoft/TypeScript",
|
||||
"HTTPS://GITHUB.COM/MICROSOFT/TYPESCRIPT",
|
||||
"https://GitHub.com/Microsoft/TypeScript",
|
||||
]
|
||||
|
||||
ids = [handler.generate_unique_source_id(url) for url in url_variations]
|
||||
|
||||
# All normalized versions should generate the same ID
|
||||
assert len(set(ids)) == 1, f"Normalized URLs should generate same ID, got: {set(ids)}"
|
||||
|
||||
def test_concurrent_crawl_simulation(self):
|
||||
"""Simulate concurrent crawls to verify no race conditions."""
|
||||
handler = URLHandler()
|
||||
|
||||
# URLs that would previously conflict
|
||||
concurrent_urls = [
|
||||
"https://github.com/coleam00/archon",
|
||||
"https://github.com/microsoft/typescript",
|
||||
"https://github.com/facebook/react",
|
||||
"https://github.com/vercel/next.js",
|
||||
"https://github.com/vuejs/vue",
|
||||
]
|
||||
|
||||
def generate_id(url):
|
||||
"""Simulate a crawl generating an ID."""
|
||||
time.sleep(0.001) # Simulate some processing time
|
||||
return handler.generate_unique_source_id(url)
|
||||
|
||||
# Run concurrent ID generation
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(generate_id, url) for url in concurrent_urls]
|
||||
source_ids = [future.result() for future in futures]
|
||||
|
||||
# All IDs should be unique
|
||||
assert len(set(source_ids)) == len(source_ids), "Concurrent crawls should generate unique IDs"
|
||||
|
||||
def test_error_handling(self):
|
||||
"""Test error handling for edge cases."""
|
||||
handler = URLHandler()
|
||||
|
||||
# Test various edge cases
|
||||
edge_cases = [
|
||||
"", # Empty string
|
||||
"not-a-url", # Invalid URL
|
||||
"https://", # Incomplete URL
|
||||
None, # None should be handled gracefully in real code
|
||||
]
|
||||
|
||||
for url in edge_cases:
|
||||
if url is None:
|
||||
continue # Skip None for this test
|
||||
|
||||
# Should not raise exception
|
||||
source_id = handler.generate_unique_source_id(url)
|
||||
assert source_id is not None, f"Should generate ID even for edge case: {url}"
|
||||
assert len(source_id) == 16, f"Edge case should still generate 16-char ID: {url}"
|
||||
|
||||
|
||||
class TestDisplayNameExtraction:
|
||||
"""Test the human-readable display name extraction."""
|
||||
|
||||
def test_github_display_names(self):
|
||||
"""Test GitHub repository display name extraction."""
|
||||
handler = URLHandler()
|
||||
|
||||
test_cases = [
|
||||
("https://github.com/microsoft/typescript", "GitHub - microsoft/typescript"),
|
||||
("https://github.com/facebook/react", "GitHub - facebook/react"),
|
||||
("https://github.com/vercel/next.js", "GitHub - vercel/next.js"),
|
||||
("https://github.com/owner", "GitHub - owner"),
|
||||
("https://github.com/", "GitHub"),
|
||||
]
|
||||
|
||||
for url, expected in test_cases:
|
||||
display_name = handler.extract_display_name(url)
|
||||
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
|
||||
|
||||
def test_documentation_display_names(self):
|
||||
"""Test documentation site display name extraction."""
|
||||
handler = URLHandler()
|
||||
|
||||
test_cases = [
|
||||
("https://docs.python.org/3/", "Python Documentation"),
|
||||
("https://docs.djangoproject.com/", "Djangoproject Documentation"),
|
||||
("https://fastapi.tiangolo.com/", "FastAPI Documentation"),
|
||||
("https://pydantic.dev/", "Pydantic Documentation"),
|
||||
("https://numpy.org/doc/", "NumPy Documentation"),
|
||||
("https://pandas.pydata.org/", "Pandas Documentation"),
|
||||
("https://project.readthedocs.io/", "Project Docs"),
|
||||
]
|
||||
|
||||
for url, expected in test_cases:
|
||||
display_name = handler.extract_display_name(url)
|
||||
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
|
||||
|
||||
def test_api_display_names(self):
|
||||
"""Test API endpoint display name extraction."""
|
||||
handler = URLHandler()
|
||||
|
||||
test_cases = [
|
||||
("https://api.github.com/", "GitHub API"),
|
||||
("https://api.openai.com/v1/", "Openai API"),
|
||||
("https://example.com/api/v2/", "Example"),
|
||||
]
|
||||
|
||||
for url, expected in test_cases:
|
||||
display_name = handler.extract_display_name(url)
|
||||
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
|
||||
|
||||
def test_generic_display_names(self):
|
||||
"""Test generic website display name extraction."""
|
||||
handler = URLHandler()
|
||||
|
||||
test_cases = [
|
||||
("https://example.com/", "Example"),
|
||||
("https://my-site.org/", "My Site"),
|
||||
("https://test_project.io/", "Test Project"),
|
||||
("https://some.subdomain.example.com/", "Some Subdomain Example"),
|
||||
]
|
||||
|
||||
for url, expected in test_cases:
|
||||
display_name = handler.extract_display_name(url)
|
||||
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
|
||||
|
||||
def test_edge_case_display_names(self):
|
||||
"""Test edge cases for display name extraction."""
|
||||
handler = URLHandler()
|
||||
|
||||
# Edge cases
|
||||
test_cases = [
|
||||
("", ""), # Empty URL
|
||||
("not-a-url", "not-a-url"), # Invalid URL
|
||||
("/local/file/path", "Local: path"), # Local file path
|
||||
("https://", "https://"), # Incomplete URL
|
||||
]
|
||||
|
||||
for url, expected_contains in test_cases:
|
||||
display_name = handler.extract_display_name(url)
|
||||
assert expected_contains in display_name or display_name == expected_contains, \
|
||||
f"Edge case {url} handling failed: {display_name}"
|
||||
|
||||
def test_special_file_display_names(self):
|
||||
"""Test that special files like llms.txt and sitemap.xml are properly displayed."""
|
||||
handler = URLHandler()
|
||||
|
||||
test_cases = [
|
||||
# llms.txt files
|
||||
("https://docs.mem0.ai/llms-full.txt", "Mem0 - Llms.Txt"),
|
||||
("https://example.com/llms.txt", "Example - Llms.Txt"),
|
||||
("https://api.example.com/llms.txt", "Example API"), # API takes precedence
|
||||
|
||||
# sitemap.xml files
|
||||
("https://mem0.ai/sitemap.xml", "Mem0 - Sitemap.Xml"),
|
||||
("https://docs.example.com/sitemap.xml", "Example - Sitemap.Xml"),
|
||||
("https://example.org/sitemap.xml", "Example - Sitemap.Xml"),
|
||||
|
||||
# Regular .txt files on docs sites
|
||||
("https://docs.example.com/readme.txt", "Example - Readme.Txt"),
|
||||
|
||||
# Non-special files should not get special treatment
|
||||
("https://docs.example.com/guide", "Example Documentation"),
|
||||
("https://example.com/page.html", "Example - Page.Html"), # Path gets added for single file
|
||||
]
|
||||
|
||||
for url, expected in test_cases:
|
||||
display_name = handler.extract_display_name(url)
|
||||
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
|
||||
|
||||
def test_git_extension_removal(self):
|
||||
"""Test that .git extension is removed from GitHub repos."""
|
||||
handler = URLHandler()
|
||||
|
||||
test_cases = [
|
||||
("https://github.com/owner/repo.git", "GitHub - owner/repo"),
|
||||
("https://github.com/owner/repo", "GitHub - owner/repo"),
|
||||
]
|
||||
|
||||
for url, expected in test_cases:
|
||||
display_name = handler.extract_display_name(url)
|
||||
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
|
||||
|
||||
|
||||
class TestRaceConditionFix:
|
||||
"""Test that the race condition is actually fixed."""
|
||||
|
||||
def test_no_domain_conflicts(self):
|
||||
"""Test that multiple sources from same domain don't conflict."""
|
||||
handler = URLHandler()
|
||||
|
||||
# These would all have source_id = "github.com" in the old system
|
||||
github_urls = [
|
||||
"https://github.com/microsoft/typescript",
|
||||
"https://github.com/microsoft/vscode",
|
||||
"https://github.com/facebook/react",
|
||||
"https://github.com/vercel/next.js",
|
||||
"https://github.com/vuejs/vue",
|
||||
]
|
||||
|
||||
source_ids = [handler.generate_unique_source_id(url) for url in github_urls]
|
||||
|
||||
# All should be unique
|
||||
assert len(set(source_ids)) == len(source_ids), \
|
||||
"Race condition not fixed: duplicate source IDs for same domain"
|
||||
|
||||
# None should be just "github.com"
|
||||
for source_id in source_ids:
|
||||
assert source_id != "github.com", \
|
||||
"Source ID should not be just the domain"
|
||||
|
||||
def test_hash_properties(self):
|
||||
"""Test that the hash has good properties."""
|
||||
handler = URLHandler()
|
||||
|
||||
# Similar URLs should still generate very different hashes
|
||||
url1 = "https://github.com/owner/repo1"
|
||||
url2 = "https://github.com/owner/repo2" # Only differs by one character
|
||||
|
||||
id1 = handler.generate_unique_source_id(url1)
|
||||
id2 = handler.generate_unique_source_id(url2)
|
||||
|
||||
# IDs should be completely different (good hash distribution)
|
||||
matching_chars = sum(1 for a, b in zip(id1, id2) if a == b)
|
||||
assert matching_chars < 8, \
|
||||
f"Similar URLs should generate very different hashes, {matching_chars}/16 chars match"
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the complete source ID system."""
|
||||
|
||||
def test_full_source_creation_flow(self):
|
||||
"""Test the complete flow of creating a source with all fields."""
|
||||
handler = URLHandler()
|
||||
url = "https://github.com/microsoft/typescript"
|
||||
|
||||
# Generate all source fields
|
||||
source_id = handler.generate_unique_source_id(url)
|
||||
source_display_name = handler.extract_display_name(url)
|
||||
source_url = url
|
||||
|
||||
# Verify all fields are populated correctly
|
||||
assert len(source_id) == 16, "Source ID should be 16 characters"
|
||||
assert source_display_name == "GitHub - microsoft/typescript", \
|
||||
f"Display name incorrect: {source_display_name}"
|
||||
assert source_url == url, "Source URL should match original"
|
||||
|
||||
# Simulate database record
|
||||
source_record = {
|
||||
'source_id': source_id,
|
||||
'source_url': source_url,
|
||||
'source_display_name': source_display_name,
|
||||
'title': None, # Generated later
|
||||
'summary': None, # Generated later
|
||||
'metadata': {}
|
||||
}
|
||||
|
||||
# Verify record structure
|
||||
assert 'source_id' in source_record
|
||||
assert 'source_url' in source_record
|
||||
assert 'source_display_name' in source_record
|
||||
|
||||
def test_backward_compatibility(self):
|
||||
"""Test that the system handles existing sources gracefully."""
|
||||
handler = URLHandler()
|
||||
|
||||
# Simulate an existing source with old-style source_id
|
||||
existing_source = {
|
||||
'source_id': 'github.com', # Old style - just domain
|
||||
'source_url': None, # Not populated in old system
|
||||
'source_display_name': None, # Not populated in old system
|
||||
}
|
||||
|
||||
# The migration should handle this by backfilling
|
||||
# source_url and source_display_name with source_id value
|
||||
migrated_source = {
|
||||
'source_id': 'github.com',
|
||||
'source_url': 'github.com', # Backfilled
|
||||
'source_display_name': 'github.com', # Backfilled
|
||||
}
|
||||
|
||||
assert migrated_source['source_url'] is not None
|
||||
assert migrated_source['source_display_name'] is not None
|
||||
269
python/tests/test_source_race_condition.py
Normal file
269
python/tests/test_source_race_condition.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Test race condition handling in source creation.
|
||||
|
||||
This test ensures that concurrent source creation attempts
|
||||
don't fail with PRIMARY KEY violations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
from src.server.services.source_management_service import update_source_info
|
||||
|
||||
|
||||
class TestSourceRaceCondition:
|
||||
"""Test that concurrent source creation handles race conditions properly."""
|
||||
|
||||
def test_concurrent_source_creation_no_race(self):
|
||||
"""Test that concurrent attempts to create the same source don't fail."""
|
||||
# Track successful operations
|
||||
successful_creates = []
|
||||
failed_creates = []
|
||||
|
||||
def mock_execute():
|
||||
"""Mock execute that simulates database operation."""
|
||||
return Mock(data=[])
|
||||
|
||||
def track_upsert(data):
|
||||
"""Track upsert calls."""
|
||||
successful_creates.append(data["source_id"])
|
||||
return Mock(execute=mock_execute)
|
||||
|
||||
# Mock Supabase client
|
||||
mock_client = Mock()
|
||||
|
||||
# Mock the SELECT (existing source check) - always returns empty
|
||||
mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
|
||||
|
||||
# Mock the UPSERT operation
|
||||
mock_client.table.return_value.upsert = track_upsert
|
||||
|
||||
def create_source(thread_id):
|
||||
"""Simulate creating a source from a thread."""
|
||||
try:
|
||||
update_source_info(
|
||||
client=mock_client,
|
||||
source_id="test_source_123",
|
||||
summary=f"Summary from thread {thread_id}",
|
||||
word_count=100,
|
||||
content=f"Content from thread {thread_id}",
|
||||
knowledge_type="documentation",
|
||||
tags=["test"],
|
||||
update_frequency=0,
|
||||
source_url="https://example.com",
|
||||
source_display_name=f"Example Site {thread_id}" # Will be used as title
|
||||
)
|
||||
except Exception as e:
|
||||
failed_creates.append((thread_id, str(e)))
|
||||
|
||||
# Run 5 threads concurrently trying to create the same source
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = []
|
||||
for i in range(5):
|
||||
futures.append(executor.submit(create_source, i))
|
||||
|
||||
# Wait for all to complete
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# All should succeed (no failures due to PRIMARY KEY violation)
|
||||
assert len(failed_creates) == 0, f"Some creates failed: {failed_creates}"
|
||||
assert len(successful_creates) == 5, "All 5 attempts should succeed"
|
||||
assert all(sid == "test_source_123" for sid in successful_creates)
|
||||
|
||||
def test_upsert_vs_insert_behavior(self):
|
||||
"""Test that upsert is used instead of insert for new sources."""
|
||||
mock_client = Mock()
|
||||
|
||||
# Track which method is called
|
||||
methods_called = []
|
||||
|
||||
def track_insert(data):
|
||||
methods_called.append("insert")
|
||||
# Simulate PRIMARY KEY violation
|
||||
raise Exception("duplicate key value violates unique constraint")
|
||||
|
||||
def track_upsert(data):
|
||||
methods_called.append("upsert")
|
||||
return Mock(execute=Mock(return_value=Mock(data=[])))
|
||||
|
||||
# Source doesn't exist
|
||||
mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
|
||||
|
||||
# Set up mocks
|
||||
mock_client.table.return_value.insert = track_insert
|
||||
mock_client.table.return_value.upsert = track_upsert
|
||||
|
||||
update_source_info(
|
||||
client=mock_client,
|
||||
source_id="new_source",
|
||||
summary="Test summary",
|
||||
word_count=100,
|
||||
content="Test content",
|
||||
knowledge_type="documentation",
|
||||
source_display_name="Test Display Name" # Will be used as title
|
||||
)
|
||||
|
||||
# Should use upsert, not insert
|
||||
assert "upsert" in methods_called, "Should use upsert for new sources"
|
||||
assert "insert" not in methods_called, "Should not use insert to avoid race conditions"
|
||||
|
||||
def test_existing_source_uses_update(self):
|
||||
"""Test that existing sources still use UPDATE (not affected by change)."""
|
||||
mock_client = Mock()
|
||||
|
||||
methods_called = []
|
||||
|
||||
def track_update(data):
|
||||
methods_called.append("update")
|
||||
return Mock(eq=Mock(return_value=Mock(execute=Mock(return_value=Mock(data=[])))))
|
||||
|
||||
def track_upsert(data):
|
||||
methods_called.append("upsert")
|
||||
return Mock(execute=Mock(return_value=Mock(data=[])))
|
||||
|
||||
# Source exists
|
||||
existing_source = {
|
||||
"source_id": "existing_source",
|
||||
"title": "Existing Title",
|
||||
"metadata": {"knowledge_type": "api"}
|
||||
}
|
||||
mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = [existing_source]
|
||||
|
||||
# Set up mocks
|
||||
mock_client.table.return_value.update = track_update
|
||||
mock_client.table.return_value.upsert = track_upsert
|
||||
|
||||
update_source_info(
|
||||
client=mock_client,
|
||||
source_id="existing_source",
|
||||
summary="Updated summary",
|
||||
word_count=200,
|
||||
content="Updated content",
|
||||
knowledge_type="documentation"
|
||||
)
|
||||
|
||||
# Should use update for existing sources
|
||||
assert "update" in methods_called, "Should use update for existing sources"
|
||||
assert "upsert" not in methods_called, "Should not use upsert for existing sources"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_concurrent_creation(self):
|
||||
"""Test concurrent source creation in async context."""
|
||||
mock_client = Mock()
|
||||
|
||||
# Track operations
|
||||
operations = []
|
||||
|
||||
def track_upsert(data):
|
||||
operations.append(("upsert", data["source_id"]))
|
||||
return Mock(execute=Mock(return_value=Mock(data=[])))
|
||||
|
||||
# No existing sources
|
||||
mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
|
||||
mock_client.table.return_value.upsert = track_upsert
|
||||
|
||||
async def create_source_async(task_id):
|
||||
"""Async wrapper for source creation."""
|
||||
await asyncio.to_thread(
|
||||
update_source_info,
|
||||
client=mock_client,
|
||||
source_id=f"async_source_{task_id % 2}", # Only 2 unique sources
|
||||
summary=f"Summary {task_id}",
|
||||
word_count=100,
|
||||
content=f"Content {task_id}",
|
||||
knowledge_type="documentation"
|
||||
)
|
||||
|
||||
# Create 10 tasks, but only 2 unique source_ids
|
||||
tasks = [create_source_async(i) for i in range(10)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# All operations should succeed
|
||||
assert len(operations) == 10, "All 10 operations should complete"
|
||||
|
||||
# Check that we tried to upsert the two sources multiple times
|
||||
source_0_count = sum(1 for op, sid in operations if sid == "async_source_0")
|
||||
source_1_count = sum(1 for op, sid in operations if sid == "async_source_1")
|
||||
|
||||
assert source_0_count == 5, "async_source_0 should be upserted 5 times"
|
||||
assert source_1_count == 5, "async_source_1 should be upserted 5 times"
|
||||
|
||||
def test_race_condition_with_delay(self):
|
||||
"""Test race condition with simulated delay between check and create."""
|
||||
import time
|
||||
|
||||
mock_client = Mock()
|
||||
|
||||
# Track timing of operations
|
||||
check_times = []
|
||||
create_times = []
|
||||
source_created = threading.Event()
|
||||
|
||||
def delayed_select(*args):
|
||||
"""Return a mock that simulates SELECT with delay."""
|
||||
mock_select = Mock()
|
||||
|
||||
def eq_mock(*args):
|
||||
mock_eq = Mock()
|
||||
mock_eq.execute = lambda: delayed_check()
|
||||
return mock_eq
|
||||
|
||||
mock_select.eq = eq_mock
|
||||
return mock_select
|
||||
|
||||
def delayed_check():
|
||||
"""Simulate SELECT execution with delay."""
|
||||
check_times.append(time.time())
|
||||
result = Mock()
|
||||
# First thread doesn't see the source
|
||||
if not source_created.is_set():
|
||||
time.sleep(0.01) # Small delay to let both threads check
|
||||
result.data = []
|
||||
else:
|
||||
# Subsequent checks would see it (but we use upsert so this doesn't matter)
|
||||
result.data = [{"source_id": "race_source", "title": "Existing", "metadata": {}}]
|
||||
return result
|
||||
|
||||
def track_upsert(data):
|
||||
"""Track upsert and set event."""
|
||||
create_times.append(time.time())
|
||||
source_created.set()
|
||||
return Mock(execute=Mock(return_value=Mock(data=[])))
|
||||
|
||||
# Set up table mock to return our custom select mock
|
||||
mock_client.table.return_value.select = delayed_select
|
||||
mock_client.table.return_value.upsert = track_upsert
|
||||
|
||||
errors = []
|
||||
|
||||
def create_with_error_tracking(thread_id):
|
||||
try:
|
||||
update_source_info(
|
||||
client=mock_client,
|
||||
source_id="race_source",
|
||||
summary="Race summary",
|
||||
word_count=100,
|
||||
content="Race content",
|
||||
knowledge_type="documentation",
|
||||
source_display_name="Race Display Name" # Will be used as title
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append((thread_id, str(e)))
|
||||
|
||||
# Run 2 threads that will both check before either creates
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
futures = [
|
||||
executor.submit(create_with_error_tracking, 1),
|
||||
executor.submit(create_with_error_tracking, 2)
|
||||
]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# Both should succeed with upsert (no errors)
|
||||
assert len(errors) == 0, f"No errors should occur with upsert: {errors}"
|
||||
assert len(check_times) == 2, "Both threads should check"
|
||||
assert len(create_times) == 2, "Both threads should attempt create/upsert"
|
||||
124
python/tests/test_source_url_shadowing.py
Normal file
124
python/tests/test_source_url_shadowing.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Test that source_url parameter is not shadowed by document URLs.
|
||||
|
||||
This test ensures that the original crawl URL (e.g., sitemap URL)
|
||||
is correctly passed to _create_source_records and not overwritten
|
||||
by individual document URLs during processing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
||||
from src.server.services.crawling.document_storage_operations import DocumentStorageOperations
|
||||
|
||||
|
||||
class TestSourceUrlShadowing:
|
||||
"""Test that source_url parameter is preserved correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_source_url_not_shadowed(self):
|
||||
"""Test that the original source_url is passed to _create_source_records."""
|
||||
# Create mock supabase client
|
||||
mock_supabase = Mock()
|
||||
|
||||
# Create DocumentStorageOperations instance
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Mock the storage service
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=["chunk1", "chunk2"])
|
||||
|
||||
# Track what gets passed to _create_source_records
|
||||
captured_source_url = None
|
||||
async def mock_create_source_records(all_metadatas, all_contents, source_word_counts,
|
||||
request, source_url, source_display_name):
|
||||
nonlocal captured_source_url
|
||||
captured_source_url = source_url
|
||||
|
||||
doc_storage._create_source_records = mock_create_source_records
|
||||
|
||||
# Mock add_documents_to_supabase
|
||||
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase') as mock_add:
|
||||
mock_add.return_value = None
|
||||
|
||||
# Test data - simulating a sitemap crawl
|
||||
original_source_url = "https://mem0.ai/sitemap.xml"
|
||||
crawl_results = [
|
||||
{
|
||||
"url": "https://mem0.ai/page1",
|
||||
"markdown": "Content of page 1",
|
||||
"title": "Page 1"
|
||||
},
|
||||
{
|
||||
"url": "https://mem0.ai/page2",
|
||||
"markdown": "Content of page 2",
|
||||
"title": "Page 2"
|
||||
},
|
||||
{
|
||||
"url": "https://mem0.ai/models/openai-o3", # Last document URL
|
||||
"markdown": "Content of models page",
|
||||
"title": "Models"
|
||||
}
|
||||
]
|
||||
|
||||
request = {"knowledge_type": "documentation", "tags": []}
|
||||
|
||||
# Call the method
|
||||
result = await doc_storage.process_and_store_documents(
|
||||
crawl_results=crawl_results,
|
||||
request=request,
|
||||
crawl_type="sitemap",
|
||||
original_source_id="test123",
|
||||
progress_callback=None,
|
||||
cancellation_check=None,
|
||||
source_url=original_source_url, # This should NOT be overwritten
|
||||
source_display_name="Test Sitemap"
|
||||
)
|
||||
|
||||
# Verify the original source_url was preserved
|
||||
assert captured_source_url == original_source_url, \
|
||||
f"source_url should be '{original_source_url}', not '{captured_source_url}'"
|
||||
|
||||
# Verify it's NOT the last document's URL
|
||||
assert captured_source_url != "https://mem0.ai/models/openai-o3", \
|
||||
"source_url should NOT be overwritten with the last document's URL"
|
||||
|
||||
# Verify url_to_full_document has correct URLs
|
||||
assert "https://mem0.ai/page1" in result["url_to_full_document"]
|
||||
assert "https://mem0.ai/page2" in result["url_to_full_document"]
|
||||
assert "https://mem0.ai/models/openai-o3" in result["url_to_full_document"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_uses_document_urls(self):
|
||||
"""Test that metadata correctly uses individual document URLs."""
|
||||
mock_supabase = Mock()
|
||||
doc_storage = DocumentStorageOperations(mock_supabase)
|
||||
|
||||
# Mock the storage service
|
||||
doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=["chunk1"])
|
||||
|
||||
# Capture metadata
|
||||
captured_metadatas = None
|
||||
async def mock_create_source_records(all_metadatas, all_contents, source_word_counts,
|
||||
request, source_url, source_display_name):
|
||||
nonlocal captured_metadatas
|
||||
captured_metadatas = all_metadatas
|
||||
|
||||
doc_storage._create_source_records = mock_create_source_records
|
||||
|
||||
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
|
||||
crawl_results = [
|
||||
{"url": "https://example.com/doc1", "markdown": "Doc 1"},
|
||||
{"url": "https://example.com/doc2", "markdown": "Doc 2"}
|
||||
]
|
||||
|
||||
await doc_storage.process_and_store_documents(
|
||||
crawl_results=crawl_results,
|
||||
request={},
|
||||
crawl_type="normal",
|
||||
original_source_id="test456",
|
||||
source_url="https://example.com",
|
||||
source_display_name="Example"
|
||||
)
|
||||
|
||||
# Each metadata should have the correct document URL
|
||||
assert captured_metadatas[0]["url"] == "https://example.com/doc1"
|
||||
assert captured_metadatas[1]["url"] == "https://example.com/doc2"
|
||||
222
python/tests/test_url_canonicalization.py
Normal file
222
python/tests/test_url_canonicalization.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Test URL canonicalization in source ID generation.
|
||||
|
||||
This test ensures that URLs are properly normalized before hashing
|
||||
to prevent duplicate sources from URL variations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from src.server.services.crawling.helpers.url_handler import URLHandler
|
||||
|
||||
|
||||
class TestURLCanonicalization:
|
||||
"""Test that URL canonicalization works correctly for source ID generation."""
|
||||
|
||||
def test_trailing_slash_normalization(self):
|
||||
"""Test that trailing slashes are handled consistently."""
|
||||
handler = URLHandler()
|
||||
|
||||
# These should generate the same ID
|
||||
url1 = "https://example.com/path"
|
||||
url2 = "https://example.com/path/"
|
||||
|
||||
id1 = handler.generate_unique_source_id(url1)
|
||||
id2 = handler.generate_unique_source_id(url2)
|
||||
|
||||
assert id1 == id2, "URLs with/without trailing slash should generate same ID"
|
||||
|
||||
# Root path should keep its slash
|
||||
root1 = "https://example.com"
|
||||
root2 = "https://example.com/"
|
||||
|
||||
root_id1 = handler.generate_unique_source_id(root1)
|
||||
root_id2 = handler.generate_unique_source_id(root2)
|
||||
|
||||
# These should be the same (both normalize to https://example.com/)
|
||||
assert root_id1 == root_id2, "Root URLs should normalize consistently"
|
||||
|
||||
def test_fragment_removal(self):
|
||||
"""Test that URL fragments are removed."""
|
||||
handler = URLHandler()
|
||||
|
||||
urls = [
|
||||
"https://example.com/page",
|
||||
"https://example.com/page#section1",
|
||||
"https://example.com/page#section2",
|
||||
"https://example.com/page#",
|
||||
]
|
||||
|
||||
ids = [handler.generate_unique_source_id(url) for url in urls]
|
||||
|
||||
# All should generate the same ID
|
||||
assert len(set(ids)) == 1, "URLs with different fragments should generate same ID"
|
||||
|
||||
def test_tracking_param_removal(self):
|
||||
"""Test that tracking parameters are removed."""
|
||||
handler = URLHandler()
|
||||
|
||||
# URL without tracking params
|
||||
clean_url = "https://example.com/page?important=value"
|
||||
|
||||
# URLs with various tracking params
|
||||
tracked_urls = [
|
||||
"https://example.com/page?important=value&utm_source=google",
|
||||
"https://example.com/page?utm_medium=email&important=value",
|
||||
"https://example.com/page?important=value&fbclid=abc123",
|
||||
"https://example.com/page?gclid=xyz&important=value&utm_campaign=test",
|
||||
"https://example.com/page?important=value&ref=homepage",
|
||||
"https://example.com/page?source=newsletter&important=value",
|
||||
]
|
||||
|
||||
clean_id = handler.generate_unique_source_id(clean_url)
|
||||
tracked_ids = [handler.generate_unique_source_id(url) for url in tracked_urls]
|
||||
|
||||
# All tracked URLs should generate the same ID as the clean URL
|
||||
for tracked_id in tracked_ids:
|
||||
assert tracked_id == clean_id, "URLs with tracking params should match clean URL"
|
||||
|
||||
def test_query_param_sorting(self):
|
||||
"""Test that query parameters are sorted for consistency."""
|
||||
handler = URLHandler()
|
||||
|
||||
urls = [
|
||||
"https://example.com/page?a=1&b=2&c=3",
|
||||
"https://example.com/page?c=3&a=1&b=2",
|
||||
"https://example.com/page?b=2&c=3&a=1",
|
||||
]
|
||||
|
||||
ids = [handler.generate_unique_source_id(url) for url in urls]
|
||||
|
||||
# All should generate the same ID
|
||||
assert len(set(ids)) == 1, "URLs with reordered query params should generate same ID"
|
||||
|
||||
def test_default_port_removal(self):
|
||||
"""Test that default ports are removed."""
|
||||
handler = URLHandler()
|
||||
|
||||
# HTTP default port (80)
|
||||
http_urls = [
|
||||
"http://example.com/page",
|
||||
"http://example.com:80/page",
|
||||
]
|
||||
|
||||
http_ids = [handler.generate_unique_source_id(url) for url in http_urls]
|
||||
assert len(set(http_ids)) == 1, "HTTP URLs with/without :80 should generate same ID"
|
||||
|
||||
# HTTPS default port (443)
|
||||
https_urls = [
|
||||
"https://example.com/page",
|
||||
"https://example.com:443/page",
|
||||
]
|
||||
|
||||
https_ids = [handler.generate_unique_source_id(url) for url in https_urls]
|
||||
assert len(set(https_ids)) == 1, "HTTPS URLs with/without :443 should generate same ID"
|
||||
|
||||
# Non-default ports should be preserved
|
||||
url1 = "https://example.com:8080/page"
|
||||
url2 = "https://example.com:9090/page"
|
||||
|
||||
id1 = handler.generate_unique_source_id(url1)
|
||||
id2 = handler.generate_unique_source_id(url2)
|
||||
|
||||
assert id1 != id2, "URLs with different non-default ports should generate different IDs"
|
||||
|
||||
def test_case_normalization(self):
|
||||
"""Test that scheme and domain are lowercased."""
|
||||
handler = URLHandler()
|
||||
|
||||
urls = [
|
||||
"https://example.com/Path/To/Page",
|
||||
"HTTPS://EXAMPLE.COM/Path/To/Page",
|
||||
"https://Example.Com/Path/To/Page",
|
||||
"HTTPs://example.COM/Path/To/Page",
|
||||
]
|
||||
|
||||
ids = [handler.generate_unique_source_id(url) for url in urls]
|
||||
|
||||
# All should generate the same ID (path case is preserved)
|
||||
assert len(set(ids)) == 1, "URLs with different case in scheme/domain should generate same ID"
|
||||
|
||||
# But different paths should generate different IDs
|
||||
path_urls = [
|
||||
"https://example.com/path",
|
||||
"https://example.com/Path",
|
||||
"https://example.com/PATH",
|
||||
]
|
||||
|
||||
path_ids = [handler.generate_unique_source_id(url) for url in path_urls]
|
||||
|
||||
# These should be different (path case matters)
|
||||
assert len(set(path_ids)) == 3, "URLs with different path case should generate different IDs"
|
||||
|
||||
def test_complex_canonicalization(self):
|
||||
"""Test complex URL with multiple normalizations needed."""
|
||||
handler = URLHandler()
|
||||
|
||||
urls = [
|
||||
"https://example.com/page",
|
||||
"HTTPS://EXAMPLE.COM:443/page/",
|
||||
"https://Example.com/page#section",
|
||||
"https://example.com/page/?utm_source=test",
|
||||
"https://example.com:443/page?utm_campaign=abc#footer",
|
||||
]
|
||||
|
||||
ids = [handler.generate_unique_source_id(url) for url in urls]
|
||||
|
||||
# All should generate the same ID
|
||||
assert len(set(ids)) == 1, "Complex URLs should normalize to same ID"
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases and error handling."""
|
||||
handler = URLHandler()
|
||||
|
||||
# Empty URL
|
||||
empty_id = handler.generate_unique_source_id("")
|
||||
assert len(empty_id) == 16, "Empty URL should still generate valid ID"
|
||||
|
||||
# Invalid URL
|
||||
invalid_id = handler.generate_unique_source_id("not-a-url")
|
||||
assert len(invalid_id) == 16, "Invalid URL should still generate valid ID"
|
||||
|
||||
# URL with special characters
|
||||
special_url = "https://example.com/page?key=value%20with%20spaces"
|
||||
special_id = handler.generate_unique_source_id(special_url)
|
||||
assert len(special_id) == 16, "URL with encoded chars should generate valid ID"
|
||||
|
||||
# Very long URL
|
||||
long_url = "https://example.com/" + "a" * 1000
|
||||
long_id = handler.generate_unique_source_id(long_url)
|
||||
assert len(long_id) == 16, "Long URL should generate valid ID"
|
||||
|
||||
def test_preserves_important_params(self):
|
||||
"""Test that non-tracking params are preserved."""
|
||||
handler = URLHandler()
|
||||
|
||||
# These have different important params, should be different
|
||||
url1 = "https://api.example.com/v1/users?page=1"
|
||||
url2 = "https://api.example.com/v1/users?page=2"
|
||||
|
||||
id1 = handler.generate_unique_source_id(url1)
|
||||
id2 = handler.generate_unique_source_id(url2)
|
||||
|
||||
assert id1 != id2, "URLs with different important params should generate different IDs"
|
||||
|
||||
# But tracking params should still be removed
|
||||
url3 = "https://api.example.com/v1/users?page=1&utm_source=docs"
|
||||
id3 = handler.generate_unique_source_id(url3)
|
||||
|
||||
assert id3 == id1, "Adding tracking params shouldn't change ID"
|
||||
|
||||
def test_local_file_paths(self):
|
||||
"""Test handling of local file paths."""
|
||||
handler = URLHandler()
|
||||
|
||||
# File URLs
|
||||
file_url = "file:///Users/test/document.pdf"
|
||||
file_id = handler.generate_unique_source_id(file_url)
|
||||
assert len(file_id) == 16, "File URL should generate valid ID"
|
||||
|
||||
# Relative paths
|
||||
relative_path = "../documents/file.txt"
|
||||
relative_id = handler.generate_unique_source_id(relative_path)
|
||||
assert len(relative_id) == 16, "Relative path should generate valid ID"
|
||||
Reference in New Issue
Block a user