diff --git a/archon-ui-main/src/App.tsx b/archon-ui-main/src/App.tsx index c09fc539..42af02ac 100644 --- a/archon-ui-main/src/App.tsx +++ b/archon-ui-main/src/App.tsx @@ -11,7 +11,9 @@ import { SettingsProvider, useSettings } from './contexts/SettingsContext'; import { ProjectPage } from './pages/ProjectPage'; import { DisconnectScreenOverlay } from './components/DisconnectScreenOverlay'; import { ErrorBoundaryWithBugReport } from './components/bug-report/ErrorBoundaryWithBugReport'; +import { MigrationBanner } from './components/ui/MigrationBanner'; import { serverHealthService } from './services/serverHealthService'; +import { useMigrationStatus } from './hooks/useMigrationStatus'; const AppRoutes = () => { const { projectsEnabled } = useSettings(); @@ -38,6 +40,8 @@ const AppContent = () => { enabled: true, delay: 10000 }); + const [migrationBannerDismissed, setMigrationBannerDismissed] = useState(false); + const migrationStatus = useMigrationStatus(); useEffect(() => { // Load initial settings @@ -77,6 +81,13 @@ const AppContent = () => { + {/* Migration Banner - shows when backend is up but DB schema needs work */} + {migrationStatus.migrationRequired && !migrationBannerDismissed && ( + setMigrationBannerDismissed(true)} + /> + )} diff --git a/archon-ui-main/src/components/BackendStartupError.tsx b/archon-ui-main/src/components/BackendStartupError.tsx index 8959bfc1..9a3a03d9 100644 --- a/archon-ui-main/src/components/BackendStartupError.tsx +++ b/archon-ui-main/src/components/BackendStartupError.tsx @@ -40,8 +40,12 @@ export const BackendStartupError: React.FC = () => {

- Common issue: Using an ANON key instead of SERVICE key in your .env file + Common issues:

+
diff --git a/archon-ui-main/src/components/ui/MigrationBanner.tsx b/archon-ui-main/src/components/ui/MigrationBanner.tsx new file mode 100644 index 00000000..6618e2f7 --- /dev/null +++ b/archon-ui-main/src/components/ui/MigrationBanner.tsx @@ -0,0 +1,60 @@ +import React from 'react'; +import { AlertTriangle, ExternalLink } from 'lucide-react'; +import { Card } from './Card'; + +interface MigrationBannerProps { + message: string; + onDismiss?: () => void; +} + +export const MigrationBanner: React.FC = ({ + message, + onDismiss +}) => { + return ( + +
+ +
+

+ Database Migration Required +

+

+ {message} +

+
+

+ Follow these steps: +

+
    +
  1. Open your Supabase project dashboard
  2. +
  3. Navigate to the SQL Editor
  4. +
  5. Copy and run the migration script from: migration/add_source_url_display_name.sql
  6. +
  7. Restart Docker containers: docker compose down && docker compose up --build -d
  8. +
  9. If you used a profile, add it: --profile full
  10. +
+
+
+ + + Open Supabase Dashboard + + {onDismiss && ( + + )} +
+
+
+
+ ); +}; \ No newline at end of file diff --git a/archon-ui-main/src/hooks/useMigrationStatus.ts b/archon-ui-main/src/hooks/useMigrationStatus.ts new file mode 100644 index 00000000..3e63c016 --- /dev/null +++ b/archon-ui-main/src/hooks/useMigrationStatus.ts @@ -0,0 +1,51 @@ +import { useState, useEffect } from 'react'; + +interface MigrationStatus { + migrationRequired: boolean; + message?: string; + loading: boolean; +} + +export const useMigrationStatus = (): MigrationStatus => { + const [status, setStatus] = useState({ + migrationRequired: false, + loading: true, + }); + + useEffect(() => { + const checkMigrationStatus = async () => { + try { + const response = await fetch('/api/health'); + const healthData = await response.json(); + + if (healthData.status === 'migration_required') { + setStatus({ + migrationRequired: true, + message: healthData.message, + loading: false, + }); + } else { + setStatus({ + migrationRequired: false, + loading: false, + }); + } + } catch (error) { + console.error('Failed to check migration status:', error); + setStatus({ + migrationRequired: false, + loading: false, + }); + } + }; + + checkMigrationStatus(); + + // Check periodically (every 30 seconds) to detect when migration is complete + const interval = setInterval(checkMigrationStatus, 30000); + + return () => clearInterval(interval); + }, []); + + return status; +}; \ No newline at end of file diff --git a/migration/add_source_url_display_name.sql b/migration/add_source_url_display_name.sql new file mode 100644 index 00000000..bf40b417 --- /dev/null +++ b/migration/add_source_url_display_name.sql @@ -0,0 +1,36 @@ +-- ===================================================== +-- Add source_url and source_display_name columns +-- ===================================================== +-- This migration adds two new columns to better identify sources: +-- - source_url: The original URL that was crawled +-- - source_display_name: Human-readable name for UI display +-- +-- This solves the race condition issue where multiple crawls +-- to the same domain would conflict by using domain as source_id +-- ===================================================== + +-- Add new columns to archon_sources table +ALTER TABLE archon_sources +ADD COLUMN IF NOT EXISTS source_url TEXT, +ADD COLUMN IF NOT EXISTS source_display_name TEXT; + +-- Add indexes for the new columns for better query performance +CREATE INDEX IF NOT EXISTS idx_archon_sources_url ON archon_sources(source_url); +CREATE INDEX IF NOT EXISTS idx_archon_sources_display_name ON archon_sources(source_display_name); + +-- Add comments to document the new columns +COMMENT ON COLUMN archon_sources.source_url IS 'The original URL that was crawled to create this source'; +COMMENT ON COLUMN archon_sources.source_display_name IS 'Human-readable name for UI display (e.g., "GitHub - microsoft/typescript")'; + +-- Backfill existing data +-- For existing sources, copy source_id to both new fields as a fallback +UPDATE archon_sources +SET + source_url = COALESCE(source_url, source_id), + source_display_name = COALESCE(source_display_name, source_id) +WHERE + source_url IS NULL + OR source_display_name IS NULL; + +-- Note: source_id will now contain a unique hash instead of domain +-- This ensures no conflicts when multiple sources from same domain are crawled \ No newline at end of file diff --git a/migration/complete_setup.sql b/migration/complete_setup.sql index 94a1778f..4b3550bd 100644 --- a/migration/complete_setup.sql +++ b/migration/complete_setup.sql @@ -170,6 +170,8 @@ COMMENT ON TABLE archon_settings IS 'Stores application configuration including -- Create the sources table CREATE TABLE IF NOT EXISTS archon_sources ( source_id TEXT PRIMARY KEY, + source_url TEXT, + source_display_name TEXT, summary TEXT, total_word_count INTEGER DEFAULT 0, title TEXT, @@ -180,10 +182,15 @@ CREATE TABLE IF NOT EXISTS archon_sources ( -- Create indexes for better query performance CREATE INDEX IF NOT EXISTS idx_archon_sources_title ON archon_sources(title); +CREATE INDEX IF NOT EXISTS idx_archon_sources_url ON archon_sources(source_url); +CREATE INDEX IF NOT EXISTS idx_archon_sources_display_name ON archon_sources(source_display_name); CREATE INDEX IF NOT EXISTS idx_archon_sources_metadata ON archon_sources USING GIN(metadata); CREATE INDEX IF NOT EXISTS idx_archon_sources_knowledge_type ON archon_sources((metadata->>'knowledge_type')); --- Add comments to document the new columns +-- Add comments to document the columns +COMMENT ON COLUMN archon_sources.source_id IS 'Unique hash identifier for the source (16-char SHA256 hash of URL)'; +COMMENT ON COLUMN archon_sources.source_url IS 'The original URL that was crawled to create this source'; +COMMENT ON COLUMN archon_sources.source_display_name IS 'Human-readable name for UI display (e.g., "GitHub - microsoft/typescript")'; COMMENT ON COLUMN archon_sources.title IS 'Descriptive title for the source (e.g., "Pydantic AI API Reference")'; COMMENT ON COLUMN archon_sources.metadata IS 'JSONB field storing knowledge_type, tags, and other metadata'; diff --git a/python/src/server/api_routes/knowledge_api.py b/python/src/server/api_routes/knowledge_api.py index 37eeffc4..11e1f13f 100644 --- a/python/src/server/api_routes/knowledge_api.py +++ b/python/src/server/api_routes/knowledge_api.py @@ -27,10 +27,6 @@ from ..services.crawler_manager import get_crawler # Import unified logging from ..config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info -from ..services.crawler_manager import get_crawler -from ..services.search.rag_service import RAGService -from ..services.storage import DocumentStorageService -from ..utils import get_supabase_client from ..utils.document_processing import extract_text_from_document # Get logger for this module @@ -513,11 +509,6 @@ async def upload_document( ): """Upload and process a document with progress tracking.""" try: - # DETAILED LOGGING: Track knowledge_type parameter flow - safe_logfire_info( - f"📋 UPLOAD: Starting document upload | filename={file.filename} | content_type={file.content_type} | knowledge_type={knowledge_type}" - ) - safe_logfire_info( f"Starting document upload | filename={file.filename} | content_type={file.content_type} | knowledge_type={knowledge_type}" ) @@ -871,7 +862,22 @@ async def get_database_metrics(): @router.get("/health") async def knowledge_health(): - """Knowledge API health check.""" + """Knowledge API health check with migration detection.""" + # Check for database migration needs + from ..main import _check_database_schema + + schema_status = await _check_database_schema() + if not schema_status["valid"]: + return { + "status": "migration_required", + "service": "knowledge-api", + "timestamp": datetime.now().isoformat(), + "ready": False, + "migration_required": True, + "message": schema_status["message"], + "migration_instructions": "Open Supabase Dashboard → SQL Editor → Run: migration/add_source_url_display_name.sql" + } + # Removed health check logging to reduce console noise result = { "status": "healthy", diff --git a/python/src/server/main.py b/python/src/server/main.py index a278e3cc..cfb06722 100644 --- a/python/src/server/main.py +++ b/python/src/server/main.py @@ -246,12 +246,27 @@ async def health_check(): "ready": False, } + # Check for required database schema + schema_status = await _check_database_schema() + if not schema_status["valid"]: + return { + "status": "migration_required", + "service": "archon-backend", + "timestamp": datetime.now().isoformat(), + "ready": False, + "migration_required": True, + "message": schema_status["message"], + "migration_instructions": "Open Supabase Dashboard → SQL Editor → Run: migration/add_source_url_display_name.sql", + "schema_valid": False + } + return { "status": "healthy", "service": "archon-backend", "timestamp": datetime.now().isoformat(), "ready": True, "credentials_loaded": True, + "schema_valid": True, } @@ -262,6 +277,78 @@ async def api_health_check(): return await health_check() +# Cache schema check result to avoid repeated database queries +_schema_check_cache = {"valid": None, "checked_at": 0} + +async def _check_database_schema(): + """Check if required database schema exists - only for existing users who need migration.""" + import time + + # If we've already confirmed schema is valid, don't check again + if _schema_check_cache["valid"] is True: + return {"valid": True, "message": "Schema is up to date (cached)"} + + # If we recently failed, don't spam the database (wait at least 30 seconds) + current_time = time.time() + if (_schema_check_cache["valid"] is False and + current_time - _schema_check_cache["checked_at"] < 30): + return _schema_check_cache["result"] + + try: + from .services.client_manager import get_supabase_client + + client = get_supabase_client() + + # Try to query the new columns directly - if they exist, schema is up to date + test_query = client.table('archon_sources').select('source_url, source_display_name').limit(1).execute() + + # Cache successful result permanently + _schema_check_cache["valid"] = True + _schema_check_cache["checked_at"] = current_time + + return {"valid": True, "message": "Schema is up to date"} + + except Exception as e: + error_msg = str(e).lower() + + # Log schema check error for debugging + api_logger.debug(f"Schema check error: {type(e).__name__}: {str(e)}") + + # Check for specific error types based on PostgreSQL error codes and messages + + # Check for missing columns first (more specific than table check) + missing_source_url = 'source_url' in error_msg and ('column' in error_msg or 'does not exist' in error_msg) + missing_source_display = 'source_display_name' in error_msg and ('column' in error_msg or 'does not exist' in error_msg) + + # Also check for PostgreSQL error code 42703 (undefined column) + is_column_error = '42703' in error_msg or 'column' in error_msg + + if (missing_source_url or missing_source_display) and is_column_error: + result = { + "valid": False, + "message": "Database schema outdated - missing required columns from recent updates" + } + # Cache failed result with timestamp + _schema_check_cache["valid"] = False + _schema_check_cache["checked_at"] = current_time + _schema_check_cache["result"] = result + return result + + # Check for table doesn't exist (less specific, only if column check didn't match) + # Look for relation/table errors specifically + if ('relation' in error_msg and 'does not exist' in error_msg) or ('table' in error_msg and 'does not exist' in error_msg): + # Table doesn't exist - not a migration issue, it's a setup issue + return {"valid": True, "message": "Table doesn't exist - handled by startup error"} + + # Other errors don't necessarily mean migration needed + result = {"valid": True, "message": f"Schema check inconclusive: {str(e)}"} + # Don't cache inconclusive results - allow retry + return result + + +# Export for Socket.IO + + # Create Socket.IO app wrapper # This wraps the FastAPI app with Socket.IO functionality socket_app = create_socketio_app(app) diff --git a/python/src/server/services/crawling/code_extraction_service.py b/python/src/server/services/crawling/code_extraction_service.py index 71e12ebe..e88cb7b4 100644 --- a/python/src/server/services/crawling/code_extraction_service.py +++ b/python/src/server/services/crawling/code_extraction_service.py @@ -7,7 +7,6 @@ Handles extraction, processing, and storage of code examples from documents. import re from collections.abc import Callable from typing import Any -from urllib.parse import urlparse from ...config.logfire_config import safe_logfire_error, safe_logfire_info from ...services.credential_service import credential_service @@ -136,6 +135,7 @@ class CodeExtractionService: self, crawl_results: list[dict[str, Any]], url_to_full_document: dict[str, str], + source_id: str, progress_callback: Callable | None = None, start_progress: int = 0, end_progress: int = 100, @@ -146,6 +146,7 @@ class CodeExtractionService: Args: crawl_results: List of crawled documents with url and markdown content url_to_full_document: Mapping of URLs to full document content + source_id: The unique source_id for all documents progress_callback: Optional async callback for progress updates start_progress: Starting progress percentage (default: 0) end_progress: Ending progress percentage (default: 100) @@ -163,7 +164,7 @@ class CodeExtractionService: # Extract code blocks from all documents all_code_blocks = await self._extract_code_blocks_from_documents( - crawl_results, progress_callback, start_progress, extract_end + crawl_results, source_id, progress_callback, start_progress, extract_end ) if not all_code_blocks: @@ -201,6 +202,7 @@ class CodeExtractionService: async def _extract_code_blocks_from_documents( self, crawl_results: list[dict[str, Any]], + source_id: str, progress_callback: Callable | None = None, start_progress: int = 0, end_progress: int = 100, @@ -208,6 +210,10 @@ class CodeExtractionService: """ Extract code blocks from all documents. + Args: + crawl_results: List of crawled documents + source_id: The unique source_id for all documents + Returns: List of code blocks with metadata """ @@ -306,10 +312,7 @@ class CodeExtractionService: ) if code_blocks: - # Always extract source_id from URL - parsed_url = urlparse(source_url) - source_id = parsed_url.netloc or parsed_url.path - + # Use the provided source_id for all code blocks for block in code_blocks: all_code_blocks.append({ "block": block, diff --git a/python/src/server/services/crawling/crawling_service.py b/python/src/server/services/crawling/crawling_service.py index 5b5d4304..e1b5159b 100644 --- a/python/src/server/services/crawling/crawling_service.py +++ b/python/src/server/services/crawling/crawling_service.py @@ -304,10 +304,12 @@ class CrawlingService: url = str(request.get("url", "")) safe_logfire_info(f"Starting async crawl orchestration | url={url} | task_id={task_id}") - # Extract source_id from the original URL - parsed_original_url = urlparse(url) - original_source_id = parsed_original_url.netloc or parsed_original_url.path - safe_logfire_info(f"Using source_id '{original_source_id}' from original URL '{url}'") + # Generate unique source_id and display name from the original URL + original_source_id = self.url_handler.generate_unique_source_id(url) + source_display_name = self.url_handler.extract_display_name(url) + safe_logfire_info( + f"Generated unique source_id '{original_source_id}' and display name '{source_display_name}' from URL '{url}'" + ) # Helper to update progress with mapper async def update_mapped_progress( @@ -386,6 +388,8 @@ class CrawlingService: original_source_id, doc_storage_callback, self._check_cancellation, + source_url=url, + source_display_name=source_display_name, ) # Check for cancellation after document storage @@ -410,6 +414,7 @@ class CrawlingService: code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples( crawl_results, storage_results["url_to_full_document"], + storage_results["source_id"], code_progress_callback, 85, 95, @@ -558,7 +563,7 @@ class CrawlingService: max_depth = request.get("max_depth", 1) # Let the strategy handle concurrency from settings # This will use CRAWL_MAX_CONCURRENT from database (default: 10) - + crawl_results = await self.crawl_recursive_with_progress( [url], max_depth=max_depth, diff --git a/python/src/server/services/crawling/document_storage_operations.py b/python/src/server/services/crawling/document_storage_operations.py index 90624a20..c6d42bc5 100644 --- a/python/src/server/services/crawling/document_storage_operations.py +++ b/python/src/server/services/crawling/document_storage_operations.py @@ -4,17 +4,13 @@ Document Storage Operations Handles the storage and processing of crawled documents. Extracted from crawl_orchestration_service.py for better modularity. """ + import asyncio from typing import Dict, Any, List, Optional, Callable -from urllib.parse import urlparse from ...config.logfire_config import safe_logfire_info, safe_logfire_error from ..storage.storage_services import DocumentStorageService from ..storage.document_storage_service import add_documents_to_supabase -from ..storage.code_storage_service import ( - generate_code_summaries_batch, - add_code_examples_to_supabase -) from ..source_management_service import update_source_info, extract_source_summary from .code_extraction_service import CodeExtractionService @@ -23,18 +19,18 @@ class DocumentStorageOperations: """ Handles document storage operations for crawled content. """ - + def __init__(self, supabase_client): """ Initialize document storage operations. - + Args: supabase_client: The Supabase client for database operations """ self.supabase_client = supabase_client self.doc_storage_service = DocumentStorageService(supabase_client) self.code_extraction_service = CodeExtractionService(supabase_client) - + async def process_and_store_documents( self, crawl_results: List[Dict], @@ -42,11 +38,13 @@ class DocumentStorageOperations: crawl_type: str, original_source_id: str, progress_callback: Optional[Callable] = None, - cancellation_check: Optional[Callable] = None + cancellation_check: Optional[Callable] = None, + source_url: Optional[str] = None, + source_display_name: Optional[str] = None, ) -> Dict[str, Any]: """ Process crawled documents and store them in the database. - + Args: crawl_results: List of crawled documents request: The original crawl request @@ -54,13 +52,15 @@ class DocumentStorageOperations: original_source_id: The source ID for all documents progress_callback: Optional callback for progress updates cancellation_check: Optional function to check for cancellation - + source_url: Optional original URL that was crawled + source_display_name: Optional human-readable name for the source + Returns: Dict containing storage statistics and document mappings """ - # Initialize storage service for chunking - storage_service = DocumentStorageService(self.supabase_client) - + # Reuse initialized storage service for chunking + storage_service = self.doc_storage_service + # Prepare data for chunked storage all_urls = [] all_chunk_numbers = [] @@ -68,77 +68,85 @@ class DocumentStorageOperations: all_metadatas = [] source_word_counts = {} url_to_full_document = {} - + processed_docs = 0 + # Process and chunk each document for doc_index, doc in enumerate(crawl_results): # Check for cancellation during document processing if cancellation_check: cancellation_check() - - source_url = doc.get('url', '') - markdown_content = doc.get('markdown', '') - + + doc_url = doc.get("url", "") + markdown_content = doc.get("markdown", "") + if not markdown_content: continue - + + # Increment processed document count + processed_docs += 1 + # Store full document for code extraction context - url_to_full_document[source_url] = markdown_content - + url_to_full_document[doc_url] = markdown_content + # CHUNK THE CONTENT chunks = storage_service.smart_chunk_text(markdown_content, chunk_size=5000) - + # Use the original source_id for all documents source_id = original_source_id - safe_logfire_info(f"Using original source_id '{source_id}' for URL '{source_url}'") - + safe_logfire_info(f"Using original source_id '{source_id}' for URL '{doc_url}'") + # Process each chunk for i, chunk in enumerate(chunks): # Check for cancellation during chunk processing if cancellation_check and i % 10 == 0: # Check every 10 chunks cancellation_check() - - all_urls.append(source_url) + + all_urls.append(doc_url) all_chunk_numbers.append(i) all_contents.append(chunk) - + # Create metadata for each chunk word_count = len(chunk.split()) metadata = { - 'url': source_url, - 'title': doc.get('title', ''), - 'description': doc.get('description', ''), - 'source_id': source_id, - 'knowledge_type': request.get('knowledge_type', 'documentation'), - 'crawl_type': crawl_type, - 'word_count': word_count, - 'char_count': len(chunk), - 'chunk_index': i, - 'tags': request.get('tags', []) + "url": doc_url, + "title": doc.get("title", ""), + "description": doc.get("description", ""), + "source_id": source_id, + "knowledge_type": request.get("knowledge_type", "documentation"), + "crawl_type": crawl_type, + "word_count": word_count, + "char_count": len(chunk), + "chunk_index": i, + "tags": request.get("tags", []), } all_metadatas.append(metadata) - + # Accumulate word count source_word_counts[source_id] = source_word_counts.get(source_id, 0) + word_count - + # Yield control every 10 chunks to prevent event loop blocking if i > 0 and i % 10 == 0: await asyncio.sleep(0) - + # Yield control after processing each document if doc_index > 0 and doc_index % 5 == 0: await asyncio.sleep(0) - + # Create/update source record FIRST before storing documents if all_contents and all_metadatas: await self._create_source_records( - all_metadatas, all_contents, source_word_counts, request + all_metadatas, all_contents, source_word_counts, request, + source_url, source_display_name ) - + safe_logfire_info(f"url_to_full_document keys: {list(url_to_full_document.keys())[:5]}") - + # Log chunking results - safe_logfire_info(f"Document storage | documents={len(crawl_results)} | chunks={len(all_contents)} | avg_chunks_per_doc={len(all_contents)/len(crawl_results):.1f}") - + avg_chunks = (len(all_contents) / processed_docs) if processed_docs > 0 else 0.0 + safe_logfire_info( + f"Document storage | processed={processed_docs}/{len(crawl_results)} | chunks={len(all_contents)} | avg_chunks_per_doc={avg_chunks:.1f}" + ) + # Call add_documents_to_supabase with the correct parameters await add_documents_to_supabase( client=self.supabase_client, @@ -151,29 +159,31 @@ class DocumentStorageOperations: progress_callback=progress_callback, # Pass the callback for progress updates enable_parallel_batches=True, # Enable parallel processing provider=None, # Use configured provider - cancellation_check=cancellation_check # Pass cancellation check + cancellation_check=cancellation_check, # Pass cancellation check ) - + # Calculate actual chunk count chunk_count = len(all_contents) - + return { - 'chunk_count': chunk_count, - 'total_word_count': sum(source_word_counts.values()), - 'url_to_full_document': url_to_full_document, - 'source_id': original_source_id + "chunk_count": chunk_count, + "total_word_count": sum(source_word_counts.values()), + "url_to_full_document": url_to_full_document, + "source_id": original_source_id, } - + async def _create_source_records( self, all_metadatas: List[Dict], all_contents: List[str], source_word_counts: Dict[str, int], - request: Dict[str, Any] + request: Dict[str, Any], + source_url: Optional[str] = None, + source_display_name: Optional[str] = None, ): """ Create or update source records in the database. - + Args: all_metadatas: List of metadata for all chunks all_contents: List of all chunk contents @@ -184,121 +194,155 @@ class DocumentStorageOperations: unique_source_ids = set() source_id_contents = {} source_id_word_counts = {} - + for i, metadata in enumerate(all_metadatas): - source_id = metadata['source_id'] + source_id = metadata["source_id"] unique_source_ids.add(source_id) - + # Group content by source_id for better summaries if source_id not in source_id_contents: source_id_contents[source_id] = [] source_id_contents[source_id].append(all_contents[i]) - + # Track word counts per source_id if source_id not in source_id_word_counts: source_id_word_counts[source_id] = 0 - source_id_word_counts[source_id] += metadata.get('word_count', 0) - - safe_logfire_info(f"Found {len(unique_source_ids)} unique source_ids: {list(unique_source_ids)}") - + source_id_word_counts[source_id] += metadata.get("word_count", 0) + + safe_logfire_info( + f"Found {len(unique_source_ids)} unique source_ids: {list(unique_source_ids)}" + ) + # Create source records for ALL unique source_ids for source_id in unique_source_ids: # Get combined content for this specific source_id source_contents = source_id_contents[source_id] - combined_content = '' + combined_content = "" for chunk in source_contents[:3]: # First 3 chunks for this source if len(combined_content) + len(chunk) < 15000: - combined_content += ' ' + chunk + combined_content += " " + chunk else: break - - # Generate summary with fallback + + # Generate summary with fallback (run in thread to avoid blocking async loop) try: - summary = extract_source_summary(source_id, combined_content) + # Run synchronous extract_source_summary in a thread pool + summary = await asyncio.to_thread( + extract_source_summary, source_id, combined_content + ) except Exception as e: - safe_logfire_error(f"Failed to generate AI summary for '{source_id}': {str(e)}, using fallback") + safe_logfire_error( + f"Failed to generate AI summary for '{source_id}': {str(e)}, using fallback" + ) # Fallback to simple summary summary = f"Documentation from {source_id} - {len(source_contents)} pages crawled" - + # Update source info in database BEFORE storing documents - safe_logfire_info(f"About to create/update source record for '{source_id}' (word count: {source_id_word_counts[source_id]})") + safe_logfire_info( + f"About to create/update source record for '{source_id}' (word count: {source_id_word_counts[source_id]})" + ) try: - update_source_info( + # Run synchronous update_source_info in a thread pool + await asyncio.to_thread( + update_source_info, client=self.supabase_client, source_id=source_id, summary=summary, word_count=source_id_word_counts[source_id], content=combined_content, - knowledge_type=request.get('knowledge_type', 'technical'), - tags=request.get('tags', []), + knowledge_type=request.get("knowledge_type", "technical"), + tags=request.get("tags", []), update_frequency=0, # Set to 0 since we're using manual refresh - original_url=request.get('url') # Store the original crawl URL + original_url=request.get("url"), # Store the original crawl URL + source_url=source_url, + source_display_name=source_display_name, ) safe_logfire_info(f"Successfully created/updated source record for '{source_id}'") except Exception as e: - safe_logfire_error(f"Failed to create/update source record for '{source_id}': {str(e)}") + safe_logfire_error( + f"Failed to create/update source record for '{source_id}': {str(e)}" + ) # Try a simpler approach with minimal data try: safe_logfire_info(f"Attempting fallback source creation for '{source_id}'") - self.supabase_client.table('archon_sources').upsert({ - 'source_id': source_id, - 'title': source_id, # Use source_id as title fallback - 'summary': summary, - 'total_word_count': source_id_word_counts[source_id], - 'metadata': { - 'knowledge_type': request.get('knowledge_type', 'technical'), - 'tags': request.get('tags', []), - 'auto_generated': True, - 'fallback_creation': True, - 'original_url': request.get('url') - } - }).execute() + fallback_data = { + "source_id": source_id, + "title": source_id, # Use source_id as title fallback + "summary": summary, + "total_word_count": source_id_word_counts[source_id], + "metadata": { + "knowledge_type": request.get("knowledge_type", "technical"), + "tags": request.get("tags", []), + "auto_generated": True, + "fallback_creation": True, + "original_url": request.get("url"), + }, + } + + # Add new fields if provided + if source_url: + fallback_data["source_url"] = source_url + if source_display_name: + fallback_data["source_display_name"] = source_display_name + + self.supabase_client.table("archon_sources").upsert(fallback_data).execute() safe_logfire_info(f"Fallback source creation succeeded for '{source_id}'") except Exception as fallback_error: - safe_logfire_error(f"Both source creation attempts failed for '{source_id}': {str(fallback_error)}") - raise Exception(f"Unable to create source record for '{source_id}'. This will cause foreign key violations. Error: {str(fallback_error)}") - + safe_logfire_error( + f"Both source creation attempts failed for '{source_id}': {str(fallback_error)}" + ) + raise Exception( + f"Unable to create source record for '{source_id}'. This will cause foreign key violations. Error: {str(fallback_error)}" + ) + # Verify ALL source records exist before proceeding with document storage if unique_source_ids: for source_id in unique_source_ids: try: - source_check = self.supabase_client.table('archon_sources').select('source_id').eq('source_id', source_id).execute() + source_check = ( + self.supabase_client.table("archon_sources") + .select("source_id") + .eq("source_id", source_id) + .execute() + ) if not source_check.data: - raise Exception(f"Source record verification failed - '{source_id}' does not exist in sources table") + raise Exception( + f"Source record verification failed - '{source_id}' does not exist in sources table" + ) safe_logfire_info(f"Source record verified for '{source_id}'") except Exception as e: safe_logfire_error(f"Source verification failed for '{source_id}': {str(e)}") raise - - safe_logfire_info(f"All {len(unique_source_ids)} source records verified - proceeding with document storage") - + + safe_logfire_info( + f"All {len(unique_source_ids)} source records verified - proceeding with document storage" + ) + async def extract_and_store_code_examples( self, crawl_results: List[Dict], url_to_full_document: Dict[str, str], + source_id: str, progress_callback: Optional[Callable] = None, start_progress: int = 85, - end_progress: int = 95 + end_progress: int = 95, ) -> int: """ Extract code examples from crawled documents and store them. - + Args: crawl_results: List of crawled documents url_to_full_document: Mapping of URLs to full document content + source_id: The unique source_id for all documents progress_callback: Optional callback for progress updates start_progress: Starting progress percentage end_progress: Ending progress percentage - + Returns: Number of code examples stored """ result = await self.code_extraction_service.extract_and_store_code_examples( - crawl_results, - url_to_full_document, - progress_callback, - start_progress, - end_progress + crawl_results, url_to_full_document, source_id, progress_callback, start_progress, end_progress ) - - return result \ No newline at end of file + + return result diff --git a/python/src/server/services/crawling/helpers/url_handler.py b/python/src/server/services/crawling/helpers/url_handler.py index d66a2a82..b116233f 100644 --- a/python/src/server/services/crawling/helpers/url_handler.py +++ b/python/src/server/services/crawling/helpers/url_handler.py @@ -3,6 +3,8 @@ URL Handler Helper Handles URL transformations and validations. """ + +import hashlib import re from urllib.parse import urlparse @@ -13,49 +15,49 @@ logger = get_logger(__name__) class URLHandler: """Helper class for URL operations.""" - + @staticmethod def is_sitemap(url: str) -> bool: """ Check if a URL is a sitemap with error handling. - + Args: url: URL to check - + Returns: True if URL is a sitemap, False otherwise """ try: - return url.endswith('sitemap.xml') or 'sitemap' in urlparse(url).path + return url.endswith("sitemap.xml") or "sitemap" in urlparse(url).path except Exception as e: logger.warning(f"Error checking if URL is sitemap: {e}") return False - + @staticmethod def is_txt(url: str) -> bool: """ Check if a URL is a text file with error handling. - + Args: url: URL to check - + Returns: True if URL is a text file, False otherwise """ try: - return url.endswith('.txt') + return url.endswith(".txt") except Exception as e: logger.warning(f"Error checking if URL is text file: {e}") return False - + @staticmethod def is_binary_file(url: str) -> bool: """ Check if a URL points to a binary file that shouldn't be crawled. - + Args: url: URL to check - + Returns: True if URL is a binary file, False otherwise """ @@ -63,65 +65,338 @@ class URLHandler: # Remove query parameters and fragments for cleaner extension checking parsed = urlparse(url) path = parsed.path.lower() - + # Comprehensive list of binary and non-HTML file extensions binary_extensions = { # Archives - '.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz', '.tgz', + ".zip", + ".tar", + ".gz", + ".rar", + ".7z", + ".bz2", + ".xz", + ".tgz", # Executables and installers - '.exe', '.dmg', '.pkg', '.deb', '.rpm', '.msi', '.app', '.appimage', + ".exe", + ".dmg", + ".pkg", + ".deb", + ".rpm", + ".msi", + ".app", + ".appimage", # Documents (non-HTML) - '.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.odt', '.ods', + ".pdf", + ".doc", + ".docx", + ".xls", + ".xlsx", + ".ppt", + ".pptx", + ".odt", + ".ods", # Images - '.jpg', '.jpeg', '.png', '.gif', '.svg', '.webp', '.ico', '.bmp', '.tiff', + ".jpg", + ".jpeg", + ".png", + ".gif", + ".svg", + ".webp", + ".ico", + ".bmp", + ".tiff", # Audio/Video - '.mp3', '.mp4', '.avi', '.mov', '.wmv', '.flv', '.webm', '.mkv', '.wav', '.flac', + ".mp3", + ".mp4", + ".avi", + ".mov", + ".wmv", + ".flv", + ".webm", + ".mkv", + ".wav", + ".flac", # Data files - '.csv', '.sql', '.db', '.sqlite', + ".csv", + ".sql", + ".db", + ".sqlite", # Binary data - '.iso', '.img', '.bin', '.dat', + ".iso", + ".img", + ".bin", + ".dat", # Development files (usually not meant to be crawled as pages) - '.wasm', '.pyc', '.jar', '.war', '.class', '.dll', '.so', '.dylib' + ".wasm", + ".pyc", + ".jar", + ".war", + ".class", + ".dll", + ".so", + ".dylib", } - + # Check if the path ends with any binary extension for ext in binary_extensions: if path.endswith(ext): logger.debug(f"Skipping binary file: {url} (matched extension: {ext})") return True - + return False except Exception as e: logger.warning(f"Error checking if URL is binary file: {e}") # In case of error, don't skip the URL (safer to attempt crawl than miss content) return False - + @staticmethod def transform_github_url(url: str) -> str: """ Transform GitHub URLs to raw content URLs for better content extraction. - + Args: url: URL to transform - + Returns: Transformed URL (or original if not a GitHub file URL) """ # Pattern for GitHub file URLs - github_file_pattern = r'https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.+)' + github_file_pattern = r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.+)" match = re.match(github_file_pattern, url) if match: owner, repo, branch, path = match.groups() - raw_url = f'https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}' + raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}" logger.info(f"Transformed GitHub file URL to raw: {url} -> {raw_url}") return raw_url - + # Pattern for GitHub directory URLs - github_dir_pattern = r'https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.+)' + github_dir_pattern = r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.+)" match = re.match(github_dir_pattern, url) if match: # For directories, we can't directly get raw content # Return original URL but log a warning - logger.warning(f"GitHub directory URL detected: {url} - consider using specific file URLs or GitHub API") + logger.warning( + f"GitHub directory URL detected: {url} - consider using specific file URLs or GitHub API" + ) + + return url + + @staticmethod + def generate_unique_source_id(url: str) -> str: + """ + Generate a unique source ID from URL using hash. + + This creates a 16-character hash that is extremely unlikely to collide + for distinct canonical URLs, solving race condition issues when multiple crawls + target the same domain. - return url \ No newline at end of file + Uses 16-char SHA256 prefix (64 bits) which provides + ~18 quintillion unique values. Collision probability + is negligible for realistic usage (<1M sources). + + Args: + url: The URL to generate an ID for + + Returns: + A 16-character hexadecimal hash string + """ + try: + from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode + + # Canonicalize URL for consistent hashing + parsed = urlparse(url.strip()) + + # Normalize scheme and netloc to lowercase + scheme = (parsed.scheme or "").lower() + netloc = (parsed.netloc or "").lower() + + # Remove default ports + if netloc.endswith(":80") and scheme == "http": + netloc = netloc[:-3] + if netloc.endswith(":443") and scheme == "https": + netloc = netloc[:-4] + + # Normalize path (remove trailing slash except for root) + path = parsed.path or "/" + if path.endswith("/") and len(path) > 1: + path = path.rstrip("/") + + # Remove common tracking parameters and sort remaining + tracking_params = { + "utm_source", "utm_medium", "utm_campaign", "utm_term", "utm_content", + "gclid", "fbclid", "ref", "source" + } + query_items = [ + (k, v) for k, v in parse_qsl(parsed.query, keep_blank_values=True) + if k not in tracking_params + ] + query = urlencode(sorted(query_items)) + + # Reconstruct canonical URL (fragment is dropped) + canonical = urlunparse((scheme, netloc, path, "", query, "")) + + # Generate SHA256 hash and take first 16 characters + return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16] + + except Exception as e: + # Redact sensitive query params from error logs + try: + redacted = url.split("?", 1)[0] if "?" in url else url + except Exception: + redacted = "" + + logger.error(f"Error generating unique source ID for {redacted}: {e}", exc_info=True) + + # Fallback: use a hash of the error message + url to still get something unique + fallback = f"error_{redacted}_{str(e)}" + return hashlib.sha256(fallback.encode("utf-8")).hexdigest()[:16] + + @staticmethod + def extract_display_name(url: str) -> str: + """ + Extract a human-readable display name from URL. + + This creates user-friendly names for common source patterns + while falling back to the domain for unknown patterns. + + Args: + url: The URL to extract a display name from + + Returns: + A human-readable string suitable for UI display + """ + try: + parsed = urlparse(url) + domain = parsed.netloc.lower() + + # Remove www prefix for cleaner display + if domain.startswith("www."): + domain = domain[4:] + + # Handle empty domain (might be a file path or malformed URL) + if not domain: + if url.startswith("/"): + return f"Local: {url.split('/')[-1] if '/' in url else url}" + return url[:50] + "..." if len(url) > 50 else url + + path = parsed.path.strip("/") + + # Special handling for GitHub repositories and API + if "github.com" in domain: + # Check if it's an API endpoint + if domain.startswith("api."): + return "GitHub API" + + parts = path.split("/") + if len(parts) >= 2: + owner = parts[0] + repo = parts[1].replace(".git", "") # Remove .git extension if present + return f"GitHub - {owner}/{repo}" + elif len(parts) == 1 and parts[0]: + return f"GitHub - {parts[0]}" + return "GitHub" + + # Special handling for documentation sites + if domain.startswith("docs."): + # Extract the service name from docs.X.com/org + service_name = domain.replace("docs.", "").split(".")[0] + base_name = f"{service_name.title()}" if service_name else "Documentation" + + # Special handling for special files - preserve the filename + if path: + # Check for llms.txt files + if "llms" in path.lower() and path.endswith(".txt"): + return f"{base_name} - Llms.Txt" + # Check for sitemap files + elif "sitemap" in path.lower() and path.endswith(".xml"): + return f"{base_name} - Sitemap.Xml" + # Check for any other special .txt files + elif path.endswith(".txt"): + filename = path.split("/")[-1] if "/" in path else path + return f"{base_name} - {filename.title()}" + + return f"{base_name} Documentation" if service_name else "Documentation" + + # Handle readthedocs.io subdomains + if domain.endswith(".readthedocs.io"): + project = domain.replace(".readthedocs.io", "") + return f"{project.title()} Docs" + + # Handle common documentation patterns + doc_patterns = [ + ("fastapi.tiangolo.com", "FastAPI Documentation"), + ("pydantic.dev", "Pydantic Documentation"), + ("python.org", "Python Documentation"), + ("djangoproject.com", "Django Documentation"), + ("flask.palletsprojects.com", "Flask Documentation"), + ("numpy.org", "NumPy Documentation"), + ("pandas.pydata.org", "Pandas Documentation"), + ] + + for pattern, name in doc_patterns: + if pattern in domain: + # Add path context if available + if path and len(path) > 1: + # Get first meaningful path segment + path_segment = path.split("/")[0] if "/" in path else path + if path_segment and path_segment not in [ + "docs", + "doc", # Added "doc" to filter list + "documentation", + "api", + "en", + ]: + return f"{name} - {path_segment.title()}" + return name + + # For API endpoints + if "api." in domain or "/api" in path: + service = domain.replace("api.", "").split(".")[0] + return f"{service.title()} API" + + # Special handling for sitemap.xml and llms.txt on any site + if path: + if "sitemap" in path.lower() and path.endswith(".xml"): + # Get base domain name + display = domain + for tld in [".com", ".org", ".io", ".dev", ".net", ".ai", ".app"]: + if display.endswith(tld): + display = display[:-len(tld)] + break + display_parts = display.replace("-", " ").replace("_", " ").split(".") + formatted = " ".join(part.title() for part in display_parts) + return f"{formatted} - Sitemap.Xml" + elif "llms" in path.lower() and path.endswith(".txt"): + # Get base domain name + display = domain + for tld in [".com", ".org", ".io", ".dev", ".net", ".ai", ".app"]: + if display.endswith(tld): + display = display[:-len(tld)] + break + display_parts = display.replace("-", " ").replace("_", " ").split(".") + formatted = " ".join(part.title() for part in display_parts) + return f"{formatted} - Llms.Txt" + + # Default: Use domain with nice formatting + # Remove common TLDs for cleaner display + display = domain + for tld in [".com", ".org", ".io", ".dev", ".net", ".ai", ".app"]: + if display.endswith(tld): + display = display[: -len(tld)] + break + + # Capitalize first letter of each word + display_parts = display.replace("-", " ").replace("_", " ").split(".") + formatted = " ".join(part.title() for part in display_parts) + + # Add path context if it's meaningful + if path and len(path) > 1 and "/" not in path: + formatted += f" - {path.title()}" + + return formatted + + except Exception as e: + logger.warning(f"Error extracting display name for {url}: {e}, using URL") + # Fallback: return truncated URL + return url[:50] + "..." if len(url) > 50 else url diff --git a/python/src/server/services/source_management_service.py b/python/src/server/services/source_management_service.py index bd1a65d3..3e082a33 100644 --- a/python/src/server/services/source_management_service.py +++ b/python/src/server/services/source_management_service.py @@ -5,6 +5,7 @@ Handles source metadata, summaries, and management. Consolidates both utility functions and class-based service. """ +import os from typing import Any from supabase import Client @@ -145,6 +146,7 @@ def generate_source_title_and_metadata( knowledge_type: str = "technical", tags: list[str] | None = None, provider: str = None, + source_display_name: str | None = None, ) -> tuple[str, dict[str, Any]]: """ Generate a user-friendly title and metadata for a source based on its content. @@ -203,8 +205,11 @@ def generate_source_title_and_metadata( # Limit content for prompt sample_content = content[:3000] if len(content) > 3000 else content + + # Use display name if available for better context + source_context = source_display_name if source_display_name else source_id - prompt = f"""Based on this content from {source_id}, generate a concise, descriptive title (3-6 words) that captures what this source is about: + prompt = f"""Based on this content from {source_context}, generate a concise, descriptive title (3-6 words) that captures what this source is about: {sample_content} @@ -230,12 +235,12 @@ Provide only the title, nothing else.""" except Exception as e: search_logger.error(f"Error generating title for {source_id}: {e}") - # Build metadata - determine source_type from source_id pattern - source_type = "file" if source_id.startswith("file_") else "url" + # Build metadata - source_type will be determined by caller based on actual URL + # Default to "url" but this should be overridden by the caller metadata = { "knowledge_type": knowledge_type, "tags": tags or [], - "source_type": source_type, + "source_type": "url", # Default, should be overridden by caller based on actual URL "auto_generated": True } @@ -252,6 +257,8 @@ def update_source_info( tags: list[str] | None = None, update_frequency: int = 7, original_url: str | None = None, + source_url: str | None = None, + source_display_name: str | None = None, ): """ Update or insert source information in the sources table. @@ -279,7 +286,14 @@ def update_source_info( search_logger.info(f"Preserving existing title for {source_id}: {existing_title}") # Update metadata while preserving title - source_type = "file" if source_id.startswith("file_") else "url" + # Determine source_type based on source_url or original_url + if source_url and source_url.startswith("file://"): + source_type = "file" + elif original_url and original_url.startswith("file://"): + source_type = "file" + else: + source_type = "url" + metadata = { "knowledge_type": knowledge_type, "tags": tags or [], @@ -292,14 +306,22 @@ def update_source_info( metadata["original_url"] = original_url # Update existing source (preserving title) + update_data = { + "summary": summary, + "total_word_count": word_count, + "metadata": metadata, + "updated_at": "now()", + } + + # Add new fields if provided + if source_url: + update_data["source_url"] = source_url + if source_display_name: + update_data["source_display_name"] = source_display_name + result = ( client.table("archon_sources") - .update({ - "summary": summary, - "total_word_count": word_count, - "metadata": metadata, - "updated_at": "now()", - }) + .update(update_data) .eq("source_id", source_id) .execute() ) @@ -308,10 +330,38 @@ def update_source_info( f"Updated source {source_id} while preserving title: {existing_title}" ) else: - # New source - generate title and metadata - title, metadata = generate_source_title_and_metadata( - source_id, content, knowledge_type, tags - ) + # New source - use display name as title if available, otherwise generate + if source_display_name: + # Use the display name directly as the title (truncated to prevent DB issues) + title = source_display_name[:100].strip() + + # Determine source_type based on source_url or original_url + if source_url and source_url.startswith("file://"): + source_type = "file" + elif original_url and original_url.startswith("file://"): + source_type = "file" + else: + source_type = "url" + + metadata = { + "knowledge_type": knowledge_type, + "tags": tags or [], + "source_type": source_type, + "auto_generated": False, + } + else: + # Fallback to AI generation only if no display name + title, metadata = generate_source_title_and_metadata( + source_id, content, knowledge_type, tags, None, source_display_name + ) + + # Override the source_type from AI with actual URL-based determination + if source_url and source_url.startswith("file://"): + metadata["source_type"] = "file" + elif original_url and original_url.startswith("file://"): + metadata["source_type"] = "file" + else: + metadata["source_type"] = "url" # Add update_frequency and original_url to metadata metadata["update_frequency"] = update_frequency @@ -319,15 +369,23 @@ def update_source_info( metadata["original_url"] = original_url search_logger.info(f"Creating new source {source_id} with knowledge_type={knowledge_type}") - # Insert new source - client.table("archon_sources").insert({ + # Use upsert to avoid race conditions with concurrent crawls + upsert_data = { "source_id": source_id, "title": title, "summary": summary, "total_word_count": word_count, "metadata": metadata, - }).execute() - search_logger.info(f"Created new source {source_id} with title: {title}") + } + + # Add new fields if provided + if source_url: + upsert_data["source_url"] = source_url + if source_display_name: + upsert_data["source_display_name"] = source_display_name + + client.table("archon_sources").upsert(upsert_data).execute() + search_logger.info(f"Created/updated source {source_id} with title: {title}") except Exception as e: search_logger.error(f"Error updating source {source_id}: {e}") diff --git a/python/tests/test_async_source_summary.py b/python/tests/test_async_source_summary.py new file mode 100644 index 00000000..1744a95d --- /dev/null +++ b/python/tests/test_async_source_summary.py @@ -0,0 +1,413 @@ +""" +Test async execution of extract_source_summary and update_source_info. + +This test ensures that synchronous functions extract_source_summary and +update_source_info are properly executed in thread pools to avoid blocking +the async event loop. +""" + +import asyncio +import time +from unittest.mock import Mock, AsyncMock, patch +import pytest + +from src.server.services.crawling.document_storage_operations import DocumentStorageOperations + + +class TestAsyncSourceSummary: + """Test that extract_source_summary and update_source_info don't block the async event loop.""" + + @pytest.mark.asyncio + async def test_extract_summary_runs_in_thread(self): + """Test that extract_source_summary is executed in a thread pool.""" + # Create mock supabase client + mock_supabase = Mock() + mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock() + + doc_storage = DocumentStorageOperations(mock_supabase) + + # Track when extract_source_summary is called + summary_call_times = [] + original_summary_result = "Test summary from AI" + + def slow_extract_summary(source_id, content): + """Simulate a slow synchronous function that would block the event loop.""" + summary_call_times.append(time.time()) + # Simulate a blocking operation (like an API call) + time.sleep(0.1) # This would block the event loop if not run in thread + return original_summary_result + + # Mock the storage service + doc_storage.doc_storage_service.smart_chunk_text = Mock( + return_value=["chunk1", "chunk2"] + ) + + with patch('src.server.services.crawling.document_storage_operations.extract_source_summary', + side_effect=slow_extract_summary): + with patch('src.server.services.crawling.document_storage_operations.update_source_info'): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error'): + # Create test metadata + all_metadatas = [ + {"source_id": "test123", "word_count": 100}, + {"source_id": "test123", "word_count": 150}, + ] + all_contents = ["chunk1", "chunk2"] + source_word_counts = {"test123": 250} + request = {"knowledge_type": "documentation"} + + # Track async execution + start_time = time.time() + + # This should not block despite the sleep in extract_summary + await doc_storage._create_source_records( + all_metadatas, + all_contents, + source_word_counts, + request, + "https://example.com", + "Example Site" + ) + + end_time = time.time() + + # Verify that extract_source_summary was called + assert len(summary_call_times) == 1, "extract_source_summary should be called once" + + # The async function should complete without blocking + # Even though extract_summary sleeps for 0.1s, the async function + # should not be blocked since it runs in a thread + total_time = end_time - start_time + + # We can't guarantee exact timing, but it should complete + # without throwing a timeout error + assert total_time < 1.0, "Should complete in reasonable time" + + @pytest.mark.asyncio + async def test_extract_summary_error_handling(self): + """Test that errors in extract_source_summary are handled correctly.""" + mock_supabase = Mock() + mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock() + + doc_storage = DocumentStorageOperations(mock_supabase) + + # Mock to raise an exception + def failing_extract_summary(source_id, content): + raise RuntimeError("AI service unavailable") + + doc_storage.doc_storage_service.smart_chunk_text = Mock( + return_value=["chunk1"] + ) + + error_messages = [] + + with patch('src.server.services.crawling.document_storage_operations.extract_source_summary', + side_effect=failing_extract_summary): + with patch('src.server.services.crawling.document_storage_operations.update_source_info') as mock_update: + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error') as mock_error: + mock_error.side_effect = lambda msg: error_messages.append(msg) + + all_metadatas = [{"source_id": "test456", "word_count": 100}] + all_contents = ["chunk1"] + source_word_counts = {"test456": 100} + request = {} + + await doc_storage._create_source_records( + all_metadatas, + all_contents, + source_word_counts, + request, + None, + None + ) + + # Verify error was logged + assert len(error_messages) == 1 + assert "Failed to generate AI summary" in error_messages[0] + assert "AI service unavailable" in error_messages[0] + + # Verify fallback summary was used + mock_update.assert_called_once() + call_args = mock_update.call_args + assert call_args.kwargs["summary"] == "Documentation from test456 - 1 pages crawled" + + @pytest.mark.asyncio + async def test_multiple_sources_concurrent_summaries(self): + """Test that multiple source summaries are generated concurrently.""" + mock_supabase = Mock() + mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock() + + doc_storage = DocumentStorageOperations(mock_supabase) + + # Track concurrent executions + execution_order = [] + + def track_extract_summary(source_id, content): + execution_order.append(f"start_{source_id}") + time.sleep(0.05) # Simulate work + execution_order.append(f"end_{source_id}") + return f"Summary for {source_id}" + + doc_storage.doc_storage_service.smart_chunk_text = Mock( + return_value=["chunk"] + ) + + with patch('src.server.services.crawling.document_storage_operations.extract_source_summary', + side_effect=track_extract_summary): + with patch('src.server.services.crawling.document_storage_operations.update_source_info'): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'): + # Create metadata for multiple sources + all_metadatas = [ + {"source_id": "source1", "word_count": 100}, + {"source_id": "source2", "word_count": 150}, + {"source_id": "source3", "word_count": 200}, + ] + all_contents = ["chunk1", "chunk2", "chunk3"] + source_word_counts = { + "source1": 100, + "source2": 150, + "source3": 200, + } + request = {} + + await doc_storage._create_source_records( + all_metadatas, + all_contents, + source_word_counts, + request, + None, + None + ) + + # With threading, sources are processed sequentially in the loop + # but the extract_summary calls happen in threads + assert len(execution_order) == 6 # 3 sources * 2 events each + + # Verify all sources were processed + processed_sources = set() + for event in execution_order: + if event.startswith("start_"): + processed_sources.add(event.replace("start_", "")) + + assert processed_sources == {"source1", "source2", "source3"} + + @pytest.mark.asyncio + async def test_thread_safety_with_variables(self): + """Test that variables are properly passed to thread execution.""" + mock_supabase = Mock() + mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock() + + doc_storage = DocumentStorageOperations(mock_supabase) + + # Track what gets passed to extract_summary + captured_calls = [] + + def capture_extract_summary(source_id, content): + captured_calls.append({ + "source_id": source_id, + "content_len": len(content), + "content_preview": content[:50] if content else "" + }) + return f"Summary for {source_id}" + + doc_storage.doc_storage_service.smart_chunk_text = Mock( + return_value=["This is chunk one with some content", + "This is chunk two with more content"] + ) + + with patch('src.server.services.crawling.document_storage_operations.extract_source_summary', + side_effect=capture_extract_summary): + with patch('src.server.services.crawling.document_storage_operations.update_source_info'): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'): + all_metadatas = [ + {"source_id": "test789", "word_count": 100}, + {"source_id": "test789", "word_count": 150}, + ] + all_contents = [ + "This is chunk one with some content", + "This is chunk two with more content" + ] + source_word_counts = {"test789": 250} + request = {} + + await doc_storage._create_source_records( + all_metadatas, + all_contents, + source_word_counts, + request, + None, + None + ) + + # Verify the correct values were passed to the thread + assert len(captured_calls) == 1 + call = captured_calls[0] + assert call["source_id"] == "test789" + assert call["content_len"] > 0 + # Combined content should start with space + first chunk + assert "This is chunk one" in call["content_preview"] + + @pytest.mark.asyncio + async def test_update_source_info_runs_in_thread(self): + """Test that update_source_info is executed in a thread pool.""" + mock_supabase = Mock() + mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock() + + doc_storage = DocumentStorageOperations(mock_supabase) + + # Track when update_source_info is called + update_call_times = [] + + def slow_update_source_info(**kwargs): + """Simulate a slow synchronous database operation.""" + update_call_times.append(time.time()) + # Simulate a blocking database operation + time.sleep(0.1) # This would block the event loop if not run in thread + return None # update_source_info doesn't return anything + + doc_storage.doc_storage_service.smart_chunk_text = Mock( + return_value=["chunk1"] + ) + + with patch('src.server.services.crawling.document_storage_operations.extract_source_summary', + return_value="Test summary"): + with patch('src.server.services.crawling.document_storage_operations.update_source_info', + side_effect=slow_update_source_info): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error'): + all_metadatas = [{"source_id": "test_update", "word_count": 100}] + all_contents = ["chunk1"] + source_word_counts = {"test_update": 100} + request = {"knowledge_type": "documentation", "tags": ["test"]} + + start_time = time.time() + + # This should not block despite the sleep in update_source_info + await doc_storage._create_source_records( + all_metadatas, + all_contents, + source_word_counts, + request, + "https://example.com", + "Example Site" + ) + + end_time = time.time() + + # Verify that update_source_info was called + assert len(update_call_times) == 1, "update_source_info should be called once" + + # The async function should complete without blocking + total_time = end_time - start_time + assert total_time < 1.0, "Should complete in reasonable time" + + @pytest.mark.asyncio + async def test_update_source_info_error_handling(self): + """Test that errors in update_source_info trigger fallback correctly.""" + mock_supabase = Mock() + mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock() + + doc_storage = DocumentStorageOperations(mock_supabase) + + # Mock to raise an exception + def failing_update_source_info(**kwargs): + raise RuntimeError("Database connection failed") + + doc_storage.doc_storage_service.smart_chunk_text = Mock( + return_value=["chunk1"] + ) + + error_messages = [] + fallback_called = False + + def track_fallback_upsert(data): + nonlocal fallback_called + fallback_called = True + return Mock(execute=Mock()) + + mock_supabase.table.return_value.upsert.side_effect = track_fallback_upsert + + with patch('src.server.services.crawling.document_storage_operations.extract_source_summary', + return_value="Test summary"): + with patch('src.server.services.crawling.document_storage_operations.update_source_info', + side_effect=failing_update_source_info): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error') as mock_error: + mock_error.side_effect = lambda msg: error_messages.append(msg) + + all_metadatas = [{"source_id": "test_fail", "word_count": 100}] + all_contents = ["chunk1"] + source_word_counts = {"test_fail": 100} + request = {"knowledge_type": "technical", "tags": ["test"]} + + await doc_storage._create_source_records( + all_metadatas, + all_contents, + source_word_counts, + request, + "https://example.com", + "Example Site" + ) + + # Verify error was logged + assert any("Failed to create/update source record" in msg for msg in error_messages) + assert any("Database connection failed" in msg for msg in error_messages) + + # Verify fallback was attempted + assert fallback_called, "Fallback upsert should be called" + + @pytest.mark.asyncio + async def test_update_source_info_preserves_kwargs(self): + """Test that all kwargs are properly passed to update_source_info in thread.""" + mock_supabase = Mock() + mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock() + + doc_storage = DocumentStorageOperations(mock_supabase) + + # Track what gets passed to update_source_info + captured_kwargs = {} + + def capture_update_source_info(**kwargs): + captured_kwargs.update(kwargs) + return None + + doc_storage.doc_storage_service.smart_chunk_text = Mock( + return_value=["chunk content"] + ) + + with patch('src.server.services.crawling.document_storage_operations.extract_source_summary', + return_value="Generated summary"): + with patch('src.server.services.crawling.document_storage_operations.update_source_info', + side_effect=capture_update_source_info): + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'): + all_metadatas = [{"source_id": "test_kwargs", "word_count": 250}] + all_contents = ["chunk content"] + source_word_counts = {"test_kwargs": 250} + request = { + "knowledge_type": "api_reference", + "tags": ["api", "docs"], + "url": "https://original.url/crawl" + } + + await doc_storage._create_source_records( + all_metadatas, + all_contents, + source_word_counts, + request, + "https://source.url", + "Source Display Name" + ) + + # Verify all kwargs were passed correctly + assert captured_kwargs["client"] == mock_supabase + assert captured_kwargs["source_id"] == "test_kwargs" + assert captured_kwargs["summary"] == "Generated summary" + assert captured_kwargs["word_count"] == 250 + assert "chunk content" in captured_kwargs["content"] + assert captured_kwargs["knowledge_type"] == "api_reference" + assert captured_kwargs["tags"] == ["api", "docs"] + assert captured_kwargs["update_frequency"] == 0 + assert captured_kwargs["original_url"] == "https://original.url/crawl" + assert captured_kwargs["source_url"] == "https://source.url" + assert captured_kwargs["source_display_name"] == "Source Display Name" \ No newline at end of file diff --git a/python/tests/test_code_extraction_source_id.py b/python/tests/test_code_extraction_source_id.py new file mode 100644 index 00000000..5ae87b9f --- /dev/null +++ b/python/tests/test_code_extraction_source_id.py @@ -0,0 +1,184 @@ +""" +Test that code extraction uses the correct source_id. + +This test ensures that the fix for using hash-based source_ids +instead of domain-based source_ids works correctly. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch, MagicMock +from src.server.services.crawling.code_extraction_service import CodeExtractionService +from src.server.services.crawling.document_storage_operations import DocumentStorageOperations + + +class TestCodeExtractionSourceId: + """Test that code extraction properly uses the provided source_id.""" + + @pytest.mark.asyncio + async def test_code_extraction_uses_provided_source_id(self): + """Test that code extraction uses the hash-based source_id, not domain.""" + # Create mock supabase client + mock_supabase = Mock() + mock_supabase.table.return_value.select.return_value.eq.return_value.execute.return_value.data = [] + + # Create service instance + code_service = CodeExtractionService(mock_supabase) + + # Track what gets passed to the internal extraction method + extracted_blocks = [] + + async def mock_extract_blocks(crawl_results, source_id, progress_callback=None, start=0, end=100): + # Simulate finding code blocks and verify source_id is passed correctly + for doc in crawl_results: + extracted_blocks.append({ + "block": {"code": "print('hello')", "language": "python"}, + "source_url": doc["url"], + "source_id": source_id # This should be the provided source_id + }) + return extracted_blocks + + code_service._extract_code_blocks_from_documents = mock_extract_blocks + code_service._generate_code_summaries = AsyncMock(return_value=[{"summary": "Test code"}]) + code_service._prepare_code_examples_for_storage = Mock(return_value=[ + {"source_id": extracted_blocks[0]["source_id"] if extracted_blocks else None} + ]) + code_service._store_code_examples = AsyncMock(return_value=1) + + # Test data + crawl_results = [ + { + "url": "https://docs.mem0.ai/example", + "markdown": "```python\nprint('hello')\n```" + } + ] + + url_to_full_document = { + "https://docs.mem0.ai/example": "Full content with code" + } + + # The correct hash-based source_id + correct_source_id = "393224e227ba92eb" + + # Call the method with the correct source_id + result = await code_service.extract_and_store_code_examples( + crawl_results, + url_to_full_document, + correct_source_id, + None, + 0, + 100 + ) + + # Verify that extracted blocks use the correct source_id + assert len(extracted_blocks) > 0, "Should have extracted at least one code block" + + for block in extracted_blocks: + # Check that it's using the hash-based source_id, not the domain + assert block["source_id"] == correct_source_id, \ + f"Should use hash-based source_id '{correct_source_id}', not domain" + assert block["source_id"] != "docs.mem0.ai", \ + "Should NOT use domain-based source_id" + + @pytest.mark.asyncio + async def test_document_storage_passes_source_id(self): + """Test that DocumentStorageOperations passes source_id to code extraction.""" + # Create mock supabase client + mock_supabase = Mock() + + # Create DocumentStorageOperations instance + doc_storage = DocumentStorageOperations(mock_supabase) + + # Mock the code extraction service + mock_extract = AsyncMock(return_value=5) + doc_storage.code_extraction_service.extract_and_store_code_examples = mock_extract + + # Test data + crawl_results = [{"url": "https://example.com", "markdown": "test"}] + url_to_full_document = {"https://example.com": "test content"} + source_id = "abc123def456" + + # Call the wrapper method + result = await doc_storage.extract_and_store_code_examples( + crawl_results, + url_to_full_document, + source_id, + None, + 0, + 100 + ) + + # Verify the correct source_id was passed + mock_extract.assert_called_once_with( + crawl_results, + url_to_full_document, + source_id, # This should be the third argument + None, + 0, + 100 + ) + assert result == 5 + + @pytest.mark.asyncio + async def test_no_domain_extraction_from_url(self): + """Test that we're NOT extracting domain from URL anymore.""" + mock_supabase = Mock() + mock_supabase.table.return_value.select.return_value.eq.return_value.execute.return_value.data = [] + + code_service = CodeExtractionService(mock_supabase) + + # Patch internal methods + code_service._get_setting = AsyncMock(return_value=True) + + # Create a mock that will track what source_id is used + source_ids_seen = [] + + original_extract = code_service._extract_code_blocks_from_documents + async def track_source_id(crawl_results, source_id, progress_callback=None, start=0, end=100): + source_ids_seen.append(source_id) + return [] # Return empty list to skip further processing + + code_service._extract_code_blocks_from_documents = track_source_id + + # Test with various URLs that would produce different domains + test_cases = [ + ("https://github.com/example/repo", "github123abc"), + ("https://docs.python.org/guide", "python456def"), + ("https://api.openai.com/v1", "openai789ghi"), + ] + + for url, expected_source_id in test_cases: + source_ids_seen.clear() + + crawl_results = [{"url": url, "markdown": "# Test"}] + url_to_full_document = {url: "Full content"} + + await code_service.extract_and_store_code_examples( + crawl_results, + url_to_full_document, + expected_source_id, + None, + 0, + 100 + ) + + # Verify the provided source_id was used + assert len(source_ids_seen) == 1 + assert source_ids_seen[0] == expected_source_id + # Verify it's NOT the domain + assert "github.com" not in source_ids_seen[0] + assert "python.org" not in source_ids_seen[0] + assert "openai.com" not in source_ids_seen[0] + + def test_urlparse_not_imported(self): + """Test that urlparse is not imported in code_extraction_service.""" + import src.server.services.crawling.code_extraction_service as module + + # Check that urlparse is not in the module's namespace + assert not hasattr(module, 'urlparse'), \ + "urlparse should not be imported in code_extraction_service" + + # Check the module's actual imports + import inspect + source = inspect.getsource(module) + assert "from urllib.parse import urlparse" not in source, \ + "Should not import urlparse since we don't extract domain from URL anymore" \ No newline at end of file diff --git a/python/tests/test_document_storage_metrics.py b/python/tests/test_document_storage_metrics.py new file mode 100644 index 00000000..5ab8c9fc --- /dev/null +++ b/python/tests/test_document_storage_metrics.py @@ -0,0 +1,205 @@ +""" +Test document storage metrics calculation. + +This test ensures that avg_chunks_per_doc is calculated correctly +and handles edge cases like empty documents. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from src.server.services.crawling.document_storage_operations import DocumentStorageOperations + + +class TestDocumentStorageMetrics: + """Test metrics calculation in document storage operations.""" + + @pytest.mark.asyncio + async def test_avg_chunks_calculation_with_empty_docs(self): + """Test that avg_chunks_per_doc handles empty documents correctly.""" + # Create mock supabase client + mock_supabase = Mock() + doc_storage = DocumentStorageOperations(mock_supabase) + + # Mock the storage service + doc_storage.doc_storage_service.smart_chunk_text = Mock( + side_effect=lambda text, chunk_size: ["chunk1", "chunk2"] if text else [] + ) + + # Mock internal methods + doc_storage._create_source_records = AsyncMock() + + # Track what gets logged + logged_messages = [] + + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log: + mock_log.side_effect = lambda msg: logged_messages.append(msg) + + with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'): + # Test data with mix of empty and non-empty documents + crawl_results = [ + {"url": "https://example.com/page1", "markdown": "Content 1"}, + {"url": "https://example.com/page2", "markdown": ""}, # Empty + {"url": "https://example.com/page3", "markdown": "Content 3"}, + {"url": "https://example.com/page4", "markdown": ""}, # Empty + {"url": "https://example.com/page5", "markdown": "Content 5"}, + ] + + result = await doc_storage.process_and_store_documents( + crawl_results=crawl_results, + request={}, + crawl_type="test", + original_source_id="test123", + source_url="https://example.com", + source_display_name="Example" + ) + + # Find the metrics log message + metrics_log = None + for msg in logged_messages: + if "Document storage | processed=" in msg: + metrics_log = msg + break + + assert metrics_log is not None, "Should log metrics" + + # Verify metrics are correct + # 3 documents processed (non-empty), 5 total, 6 chunks (2 per doc), avg = 2.0 + assert "processed=3/5" in metrics_log, "Should show 3 processed out of 5 total" + assert "chunks=6" in metrics_log, "Should have 6 chunks total" + assert "avg_chunks_per_doc=2.0" in metrics_log, "Average should be 2.0 (6/3)" + + @pytest.mark.asyncio + async def test_avg_chunks_all_empty_docs(self): + """Test that avg_chunks_per_doc handles all empty documents without division by zero.""" + mock_supabase = Mock() + doc_storage = DocumentStorageOperations(mock_supabase) + + # Mock the storage service + doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=[]) + doc_storage._create_source_records = AsyncMock() + + logged_messages = [] + + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log: + mock_log.side_effect = lambda msg: logged_messages.append(msg) + + with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'): + # All documents are empty + crawl_results = [ + {"url": "https://example.com/page1", "markdown": ""}, + {"url": "https://example.com/page2", "markdown": ""}, + {"url": "https://example.com/page3", "markdown": ""}, + ] + + result = await doc_storage.process_and_store_documents( + crawl_results=crawl_results, + request={}, + crawl_type="test", + original_source_id="test456", + source_url="https://example.com", + source_display_name="Example" + ) + + # Find the metrics log + metrics_log = None + for msg in logged_messages: + if "Document storage | processed=" in msg: + metrics_log = msg + break + + assert metrics_log is not None, "Should log metrics even with no processed docs" + + # Should show 0 processed, 0 chunks, 0.0 average (no division by zero) + assert "processed=0/3" in metrics_log, "Should show 0 processed out of 3 total" + assert "chunks=0" in metrics_log, "Should have 0 chunks" + assert "avg_chunks_per_doc=0.0" in metrics_log, "Average should be 0.0 (no division by zero)" + + @pytest.mark.asyncio + async def test_avg_chunks_single_doc(self): + """Test avg_chunks_per_doc with a single document.""" + mock_supabase = Mock() + doc_storage = DocumentStorageOperations(mock_supabase) + + # Mock to return 5 chunks for content + doc_storage.doc_storage_service.smart_chunk_text = Mock( + return_value=["chunk1", "chunk2", "chunk3", "chunk4", "chunk5"] + ) + doc_storage._create_source_records = AsyncMock() + + logged_messages = [] + + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log: + mock_log.side_effect = lambda msg: logged_messages.append(msg) + + with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'): + crawl_results = [ + {"url": "https://example.com/page", "markdown": "Long content here..."}, + ] + + result = await doc_storage.process_and_store_documents( + crawl_results=crawl_results, + request={}, + crawl_type="test", + original_source_id="test789", + source_url="https://example.com", + source_display_name="Example" + ) + + # Find metrics log + metrics_log = None + for msg in logged_messages: + if "Document storage | processed=" in msg: + metrics_log = msg + break + + assert metrics_log is not None + assert "processed=1/1" in metrics_log, "Should show 1 processed out of 1 total" + assert "chunks=5" in metrics_log, "Should have 5 chunks" + assert "avg_chunks_per_doc=5.0" in metrics_log, "Average should be 5.0" + + @pytest.mark.asyncio + async def test_processed_count_accuracy(self): + """Test that processed_docs count is accurate.""" + mock_supabase = Mock() + doc_storage = DocumentStorageOperations(mock_supabase) + + # Track which documents are chunked + chunked_urls = [] + + def mock_chunk(text, chunk_size): + if text: + return ["chunk"] + return [] + + doc_storage.doc_storage_service.smart_chunk_text = Mock(side_effect=mock_chunk) + doc_storage._create_source_records = AsyncMock() + + with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'): + with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'): + # Mix of documents with various content states + crawl_results = [ + {"url": "https://example.com/1", "markdown": "Content"}, + {"url": "https://example.com/2", "markdown": ""}, # Empty markdown + {"url": "https://example.com/3", "markdown": None}, # None markdown + {"url": "https://example.com/4", "markdown": "More content"}, + {"url": "https://example.com/5"}, # Missing markdown key + {"url": "https://example.com/6", "markdown": " "}, # Whitespace (counts as content) + ] + + result = await doc_storage.process_and_store_documents( + crawl_results=crawl_results, + request={}, + crawl_type="test", + original_source_id="test999", + source_url="https://example.com", + source_display_name="Example" + ) + + # Should process documents 1, 4, and 6 (has content including whitespace) + assert result["chunk_count"] == 3, "Should have 3 chunks (one per processed doc)" + + # Check url_to_full_document only has processed docs + assert len(result["url_to_full_document"]) == 3 + assert "https://example.com/1" in result["url_to_full_document"] + assert "https://example.com/4" in result["url_to_full_document"] + assert "https://example.com/6" in result["url_to_full_document"] \ No newline at end of file diff --git a/python/tests/test_source_id_refactor.py b/python/tests/test_source_id_refactor.py new file mode 100644 index 00000000..3ff796f9 --- /dev/null +++ b/python/tests/test_source_id_refactor.py @@ -0,0 +1,357 @@ +""" +Test Suite for Source ID Architecture Refactor + +Tests the new unique source ID generation and display name extraction +to ensure the race condition fix works correctly. +""" + +import time +from concurrent.futures import ThreadPoolExecutor + +# Import the URLHandler class +from src.server.services.crawling.helpers.url_handler import URLHandler + + +class TestSourceIDGeneration: + """Test the unique source ID generation.""" + + def test_unique_id_generation_basic(self): + """Test basic unique ID generation.""" + handler = URLHandler() + + # Test various URLs + test_urls = [ + "https://github.com/microsoft/typescript", + "https://github.com/facebook/react", + "https://docs.python.org/3/", + "https://fastapi.tiangolo.com/", + "https://pydantic.dev/", + ] + + source_ids = [] + for url in test_urls: + source_id = handler.generate_unique_source_id(url) + source_ids.append(source_id) + + # Check that ID is a 16-character hex string + assert len(source_id) == 16, f"ID should be 16 chars, got {len(source_id)}" + assert all(c in '0123456789abcdef' for c in source_id), f"ID should be hex: {source_id}" + + # All IDs should be unique + assert len(set(source_ids)) == len(source_ids), "All source IDs should be unique" + + def test_same_domain_different_ids(self): + """Test that same domain with different paths generates different IDs.""" + handler = URLHandler() + + # Multiple GitHub repos (same domain, different paths) + github_urls = [ + "https://github.com/owner1/repo1", + "https://github.com/owner1/repo2", + "https://github.com/owner2/repo1", + ] + + ids = [handler.generate_unique_source_id(url) for url in github_urls] + + # All should be unique despite same domain + assert len(set(ids)) == len(ids), "Same domain should generate different IDs for different URLs" + + def test_id_consistency(self): + """Test that the same URL always generates the same ID.""" + handler = URLHandler() + url = "https://github.com/microsoft/typescript" + + # Generate ID multiple times + ids = [handler.generate_unique_source_id(url) for _ in range(5)] + + # All should be identical + assert len(set(ids)) == 1, f"Same URL should always generate same ID, got: {set(ids)}" + assert ids[0] == ids[4], "First and last ID should match" + + def test_url_normalization(self): + """Test that URL normalization works correctly.""" + handler = URLHandler() + + # These should all generate the same ID (after normalization) + url_variations = [ + "https://github.com/Microsoft/TypeScript", + "HTTPS://GITHUB.COM/MICROSOFT/TYPESCRIPT", + "https://GitHub.com/Microsoft/TypeScript", + ] + + ids = [handler.generate_unique_source_id(url) for url in url_variations] + + # All normalized versions should generate the same ID + assert len(set(ids)) == 1, f"Normalized URLs should generate same ID, got: {set(ids)}" + + def test_concurrent_crawl_simulation(self): + """Simulate concurrent crawls to verify no race conditions.""" + handler = URLHandler() + + # URLs that would previously conflict + concurrent_urls = [ + "https://github.com/coleam00/archon", + "https://github.com/microsoft/typescript", + "https://github.com/facebook/react", + "https://github.com/vercel/next.js", + "https://github.com/vuejs/vue", + ] + + def generate_id(url): + """Simulate a crawl generating an ID.""" + time.sleep(0.001) # Simulate some processing time + return handler.generate_unique_source_id(url) + + # Run concurrent ID generation + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(generate_id, url) for url in concurrent_urls] + source_ids = [future.result() for future in futures] + + # All IDs should be unique + assert len(set(source_ids)) == len(source_ids), "Concurrent crawls should generate unique IDs" + + def test_error_handling(self): + """Test error handling for edge cases.""" + handler = URLHandler() + + # Test various edge cases + edge_cases = [ + "", # Empty string + "not-a-url", # Invalid URL + "https://", # Incomplete URL + None, # None should be handled gracefully in real code + ] + + for url in edge_cases: + if url is None: + continue # Skip None for this test + + # Should not raise exception + source_id = handler.generate_unique_source_id(url) + assert source_id is not None, f"Should generate ID even for edge case: {url}" + assert len(source_id) == 16, f"Edge case should still generate 16-char ID: {url}" + + +class TestDisplayNameExtraction: + """Test the human-readable display name extraction.""" + + def test_github_display_names(self): + """Test GitHub repository display name extraction.""" + handler = URLHandler() + + test_cases = [ + ("https://github.com/microsoft/typescript", "GitHub - microsoft/typescript"), + ("https://github.com/facebook/react", "GitHub - facebook/react"), + ("https://github.com/vercel/next.js", "GitHub - vercel/next.js"), + ("https://github.com/owner", "GitHub - owner"), + ("https://github.com/", "GitHub"), + ] + + for url, expected in test_cases: + display_name = handler.extract_display_name(url) + assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'" + + def test_documentation_display_names(self): + """Test documentation site display name extraction.""" + handler = URLHandler() + + test_cases = [ + ("https://docs.python.org/3/", "Python Documentation"), + ("https://docs.djangoproject.com/", "Djangoproject Documentation"), + ("https://fastapi.tiangolo.com/", "FastAPI Documentation"), + ("https://pydantic.dev/", "Pydantic Documentation"), + ("https://numpy.org/doc/", "NumPy Documentation"), + ("https://pandas.pydata.org/", "Pandas Documentation"), + ("https://project.readthedocs.io/", "Project Docs"), + ] + + for url, expected in test_cases: + display_name = handler.extract_display_name(url) + assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'" + + def test_api_display_names(self): + """Test API endpoint display name extraction.""" + handler = URLHandler() + + test_cases = [ + ("https://api.github.com/", "GitHub API"), + ("https://api.openai.com/v1/", "Openai API"), + ("https://example.com/api/v2/", "Example"), + ] + + for url, expected in test_cases: + display_name = handler.extract_display_name(url) + assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'" + + def test_generic_display_names(self): + """Test generic website display name extraction.""" + handler = URLHandler() + + test_cases = [ + ("https://example.com/", "Example"), + ("https://my-site.org/", "My Site"), + ("https://test_project.io/", "Test Project"), + ("https://some.subdomain.example.com/", "Some Subdomain Example"), + ] + + for url, expected in test_cases: + display_name = handler.extract_display_name(url) + assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'" + + def test_edge_case_display_names(self): + """Test edge cases for display name extraction.""" + handler = URLHandler() + + # Edge cases + test_cases = [ + ("", ""), # Empty URL + ("not-a-url", "not-a-url"), # Invalid URL + ("/local/file/path", "Local: path"), # Local file path + ("https://", "https://"), # Incomplete URL + ] + + for url, expected_contains in test_cases: + display_name = handler.extract_display_name(url) + assert expected_contains in display_name or display_name == expected_contains, \ + f"Edge case {url} handling failed: {display_name}" + + def test_special_file_display_names(self): + """Test that special files like llms.txt and sitemap.xml are properly displayed.""" + handler = URLHandler() + + test_cases = [ + # llms.txt files + ("https://docs.mem0.ai/llms-full.txt", "Mem0 - Llms.Txt"), + ("https://example.com/llms.txt", "Example - Llms.Txt"), + ("https://api.example.com/llms.txt", "Example API"), # API takes precedence + + # sitemap.xml files + ("https://mem0.ai/sitemap.xml", "Mem0 - Sitemap.Xml"), + ("https://docs.example.com/sitemap.xml", "Example - Sitemap.Xml"), + ("https://example.org/sitemap.xml", "Example - Sitemap.Xml"), + + # Regular .txt files on docs sites + ("https://docs.example.com/readme.txt", "Example - Readme.Txt"), + + # Non-special files should not get special treatment + ("https://docs.example.com/guide", "Example Documentation"), + ("https://example.com/page.html", "Example - Page.Html"), # Path gets added for single file + ] + + for url, expected in test_cases: + display_name = handler.extract_display_name(url) + assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'" + + def test_git_extension_removal(self): + """Test that .git extension is removed from GitHub repos.""" + handler = URLHandler() + + test_cases = [ + ("https://github.com/owner/repo.git", "GitHub - owner/repo"), + ("https://github.com/owner/repo", "GitHub - owner/repo"), + ] + + for url, expected in test_cases: + display_name = handler.extract_display_name(url) + assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'" + + +class TestRaceConditionFix: + """Test that the race condition is actually fixed.""" + + def test_no_domain_conflicts(self): + """Test that multiple sources from same domain don't conflict.""" + handler = URLHandler() + + # These would all have source_id = "github.com" in the old system + github_urls = [ + "https://github.com/microsoft/typescript", + "https://github.com/microsoft/vscode", + "https://github.com/facebook/react", + "https://github.com/vercel/next.js", + "https://github.com/vuejs/vue", + ] + + source_ids = [handler.generate_unique_source_id(url) for url in github_urls] + + # All should be unique + assert len(set(source_ids)) == len(source_ids), \ + "Race condition not fixed: duplicate source IDs for same domain" + + # None should be just "github.com" + for source_id in source_ids: + assert source_id != "github.com", \ + "Source ID should not be just the domain" + + def test_hash_properties(self): + """Test that the hash has good properties.""" + handler = URLHandler() + + # Similar URLs should still generate very different hashes + url1 = "https://github.com/owner/repo1" + url2 = "https://github.com/owner/repo2" # Only differs by one character + + id1 = handler.generate_unique_source_id(url1) + id2 = handler.generate_unique_source_id(url2) + + # IDs should be completely different (good hash distribution) + matching_chars = sum(1 for a, b in zip(id1, id2) if a == b) + assert matching_chars < 8, \ + f"Similar URLs should generate very different hashes, {matching_chars}/16 chars match" + + +class TestIntegration: + """Integration tests for the complete source ID system.""" + + def test_full_source_creation_flow(self): + """Test the complete flow of creating a source with all fields.""" + handler = URLHandler() + url = "https://github.com/microsoft/typescript" + + # Generate all source fields + source_id = handler.generate_unique_source_id(url) + source_display_name = handler.extract_display_name(url) + source_url = url + + # Verify all fields are populated correctly + assert len(source_id) == 16, "Source ID should be 16 characters" + assert source_display_name == "GitHub - microsoft/typescript", \ + f"Display name incorrect: {source_display_name}" + assert source_url == url, "Source URL should match original" + + # Simulate database record + source_record = { + 'source_id': source_id, + 'source_url': source_url, + 'source_display_name': source_display_name, + 'title': None, # Generated later + 'summary': None, # Generated later + 'metadata': {} + } + + # Verify record structure + assert 'source_id' in source_record + assert 'source_url' in source_record + assert 'source_display_name' in source_record + + def test_backward_compatibility(self): + """Test that the system handles existing sources gracefully.""" + handler = URLHandler() + + # Simulate an existing source with old-style source_id + existing_source = { + 'source_id': 'github.com', # Old style - just domain + 'source_url': None, # Not populated in old system + 'source_display_name': None, # Not populated in old system + } + + # The migration should handle this by backfilling + # source_url and source_display_name with source_id value + migrated_source = { + 'source_id': 'github.com', + 'source_url': 'github.com', # Backfilled + 'source_display_name': 'github.com', # Backfilled + } + + assert migrated_source['source_url'] is not None + assert migrated_source['source_display_name'] is not None \ No newline at end of file diff --git a/python/tests/test_source_race_condition.py b/python/tests/test_source_race_condition.py new file mode 100644 index 00000000..0905e9fb --- /dev/null +++ b/python/tests/test_source_race_condition.py @@ -0,0 +1,269 @@ +""" +Test race condition handling in source creation. + +This test ensures that concurrent source creation attempts +don't fail with PRIMARY KEY violations. +""" + +import asyncio +import threading +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import Mock, patch +import pytest + +from src.server.services.source_management_service import update_source_info + + +class TestSourceRaceCondition: + """Test that concurrent source creation handles race conditions properly.""" + + def test_concurrent_source_creation_no_race(self): + """Test that concurrent attempts to create the same source don't fail.""" + # Track successful operations + successful_creates = [] + failed_creates = [] + + def mock_execute(): + """Mock execute that simulates database operation.""" + return Mock(data=[]) + + def track_upsert(data): + """Track upsert calls.""" + successful_creates.append(data["source_id"]) + return Mock(execute=mock_execute) + + # Mock Supabase client + mock_client = Mock() + + # Mock the SELECT (existing source check) - always returns empty + mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = [] + + # Mock the UPSERT operation + mock_client.table.return_value.upsert = track_upsert + + def create_source(thread_id): + """Simulate creating a source from a thread.""" + try: + update_source_info( + client=mock_client, + source_id="test_source_123", + summary=f"Summary from thread {thread_id}", + word_count=100, + content=f"Content from thread {thread_id}", + knowledge_type="documentation", + tags=["test"], + update_frequency=0, + source_url="https://example.com", + source_display_name=f"Example Site {thread_id}" # Will be used as title + ) + except Exception as e: + failed_creates.append((thread_id, str(e))) + + # Run 5 threads concurrently trying to create the same source + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [] + for i in range(5): + futures.append(executor.submit(create_source, i)) + + # Wait for all to complete + for future in futures: + future.result() + + # All should succeed (no failures due to PRIMARY KEY violation) + assert len(failed_creates) == 0, f"Some creates failed: {failed_creates}" + assert len(successful_creates) == 5, "All 5 attempts should succeed" + assert all(sid == "test_source_123" for sid in successful_creates) + + def test_upsert_vs_insert_behavior(self): + """Test that upsert is used instead of insert for new sources.""" + mock_client = Mock() + + # Track which method is called + methods_called = [] + + def track_insert(data): + methods_called.append("insert") + # Simulate PRIMARY KEY violation + raise Exception("duplicate key value violates unique constraint") + + def track_upsert(data): + methods_called.append("upsert") + return Mock(execute=Mock(return_value=Mock(data=[]))) + + # Source doesn't exist + mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = [] + + # Set up mocks + mock_client.table.return_value.insert = track_insert + mock_client.table.return_value.upsert = track_upsert + + update_source_info( + client=mock_client, + source_id="new_source", + summary="Test summary", + word_count=100, + content="Test content", + knowledge_type="documentation", + source_display_name="Test Display Name" # Will be used as title + ) + + # Should use upsert, not insert + assert "upsert" in methods_called, "Should use upsert for new sources" + assert "insert" not in methods_called, "Should not use insert to avoid race conditions" + + def test_existing_source_uses_update(self): + """Test that existing sources still use UPDATE (not affected by change).""" + mock_client = Mock() + + methods_called = [] + + def track_update(data): + methods_called.append("update") + return Mock(eq=Mock(return_value=Mock(execute=Mock(return_value=Mock(data=[]))))) + + def track_upsert(data): + methods_called.append("upsert") + return Mock(execute=Mock(return_value=Mock(data=[]))) + + # Source exists + existing_source = { + "source_id": "existing_source", + "title": "Existing Title", + "metadata": {"knowledge_type": "api"} + } + mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = [existing_source] + + # Set up mocks + mock_client.table.return_value.update = track_update + mock_client.table.return_value.upsert = track_upsert + + update_source_info( + client=mock_client, + source_id="existing_source", + summary="Updated summary", + word_count=200, + content="Updated content", + knowledge_type="documentation" + ) + + # Should use update for existing sources + assert "update" in methods_called, "Should use update for existing sources" + assert "upsert" not in methods_called, "Should not use upsert for existing sources" + + @pytest.mark.asyncio + async def test_async_concurrent_creation(self): + """Test concurrent source creation in async context.""" + mock_client = Mock() + + # Track operations + operations = [] + + def track_upsert(data): + operations.append(("upsert", data["source_id"])) + return Mock(execute=Mock(return_value=Mock(data=[]))) + + # No existing sources + mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = [] + mock_client.table.return_value.upsert = track_upsert + + async def create_source_async(task_id): + """Async wrapper for source creation.""" + await asyncio.to_thread( + update_source_info, + client=mock_client, + source_id=f"async_source_{task_id % 2}", # Only 2 unique sources + summary=f"Summary {task_id}", + word_count=100, + content=f"Content {task_id}", + knowledge_type="documentation" + ) + + # Create 10 tasks, but only 2 unique source_ids + tasks = [create_source_async(i) for i in range(10)] + await asyncio.gather(*tasks) + + # All operations should succeed + assert len(operations) == 10, "All 10 operations should complete" + + # Check that we tried to upsert the two sources multiple times + source_0_count = sum(1 for op, sid in operations if sid == "async_source_0") + source_1_count = sum(1 for op, sid in operations if sid == "async_source_1") + + assert source_0_count == 5, "async_source_0 should be upserted 5 times" + assert source_1_count == 5, "async_source_1 should be upserted 5 times" + + def test_race_condition_with_delay(self): + """Test race condition with simulated delay between check and create.""" + import time + + mock_client = Mock() + + # Track timing of operations + check_times = [] + create_times = [] + source_created = threading.Event() + + def delayed_select(*args): + """Return a mock that simulates SELECT with delay.""" + mock_select = Mock() + + def eq_mock(*args): + mock_eq = Mock() + mock_eq.execute = lambda: delayed_check() + return mock_eq + + mock_select.eq = eq_mock + return mock_select + + def delayed_check(): + """Simulate SELECT execution with delay.""" + check_times.append(time.time()) + result = Mock() + # First thread doesn't see the source + if not source_created.is_set(): + time.sleep(0.01) # Small delay to let both threads check + result.data = [] + else: + # Subsequent checks would see it (but we use upsert so this doesn't matter) + result.data = [{"source_id": "race_source", "title": "Existing", "metadata": {}}] + return result + + def track_upsert(data): + """Track upsert and set event.""" + create_times.append(time.time()) + source_created.set() + return Mock(execute=Mock(return_value=Mock(data=[]))) + + # Set up table mock to return our custom select mock + mock_client.table.return_value.select = delayed_select + mock_client.table.return_value.upsert = track_upsert + + errors = [] + + def create_with_error_tracking(thread_id): + try: + update_source_info( + client=mock_client, + source_id="race_source", + summary="Race summary", + word_count=100, + content="Race content", + knowledge_type="documentation", + source_display_name="Race Display Name" # Will be used as title + ) + except Exception as e: + errors.append((thread_id, str(e))) + + # Run 2 threads that will both check before either creates + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [ + executor.submit(create_with_error_tracking, 1), + executor.submit(create_with_error_tracking, 2) + ] + for future in futures: + future.result() + + # Both should succeed with upsert (no errors) + assert len(errors) == 0, f"No errors should occur with upsert: {errors}" + assert len(check_times) == 2, "Both threads should check" + assert len(create_times) == 2, "Both threads should attempt create/upsert" \ No newline at end of file diff --git a/python/tests/test_source_url_shadowing.py b/python/tests/test_source_url_shadowing.py new file mode 100644 index 00000000..014c357d --- /dev/null +++ b/python/tests/test_source_url_shadowing.py @@ -0,0 +1,124 @@ +""" +Test that source_url parameter is not shadowed by document URLs. + +This test ensures that the original crawl URL (e.g., sitemap URL) +is correctly passed to _create_source_records and not overwritten +by individual document URLs during processing. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, MagicMock, patch +from src.server.services.crawling.document_storage_operations import DocumentStorageOperations + + +class TestSourceUrlShadowing: + """Test that source_url parameter is preserved correctly.""" + + @pytest.mark.asyncio + async def test_source_url_not_shadowed(self): + """Test that the original source_url is passed to _create_source_records.""" + # Create mock supabase client + mock_supabase = Mock() + + # Create DocumentStorageOperations instance + doc_storage = DocumentStorageOperations(mock_supabase) + + # Mock the storage service + doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=["chunk1", "chunk2"]) + + # Track what gets passed to _create_source_records + captured_source_url = None + async def mock_create_source_records(all_metadatas, all_contents, source_word_counts, + request, source_url, source_display_name): + nonlocal captured_source_url + captured_source_url = source_url + + doc_storage._create_source_records = mock_create_source_records + + # Mock add_documents_to_supabase + with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase') as mock_add: + mock_add.return_value = None + + # Test data - simulating a sitemap crawl + original_source_url = "https://mem0.ai/sitemap.xml" + crawl_results = [ + { + "url": "https://mem0.ai/page1", + "markdown": "Content of page 1", + "title": "Page 1" + }, + { + "url": "https://mem0.ai/page2", + "markdown": "Content of page 2", + "title": "Page 2" + }, + { + "url": "https://mem0.ai/models/openai-o3", # Last document URL + "markdown": "Content of models page", + "title": "Models" + } + ] + + request = {"knowledge_type": "documentation", "tags": []} + + # Call the method + result = await doc_storage.process_and_store_documents( + crawl_results=crawl_results, + request=request, + crawl_type="sitemap", + original_source_id="test123", + progress_callback=None, + cancellation_check=None, + source_url=original_source_url, # This should NOT be overwritten + source_display_name="Test Sitemap" + ) + + # Verify the original source_url was preserved + assert captured_source_url == original_source_url, \ + f"source_url should be '{original_source_url}', not '{captured_source_url}'" + + # Verify it's NOT the last document's URL + assert captured_source_url != "https://mem0.ai/models/openai-o3", \ + "source_url should NOT be overwritten with the last document's URL" + + # Verify url_to_full_document has correct URLs + assert "https://mem0.ai/page1" in result["url_to_full_document"] + assert "https://mem0.ai/page2" in result["url_to_full_document"] + assert "https://mem0.ai/models/openai-o3" in result["url_to_full_document"] + + @pytest.mark.asyncio + async def test_metadata_uses_document_urls(self): + """Test that metadata correctly uses individual document URLs.""" + mock_supabase = Mock() + doc_storage = DocumentStorageOperations(mock_supabase) + + # Mock the storage service + doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=["chunk1"]) + + # Capture metadata + captured_metadatas = None + async def mock_create_source_records(all_metadatas, all_contents, source_word_counts, + request, source_url, source_display_name): + nonlocal captured_metadatas + captured_metadatas = all_metadatas + + doc_storage._create_source_records = mock_create_source_records + + with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'): + crawl_results = [ + {"url": "https://example.com/doc1", "markdown": "Doc 1"}, + {"url": "https://example.com/doc2", "markdown": "Doc 2"} + ] + + await doc_storage.process_and_store_documents( + crawl_results=crawl_results, + request={}, + crawl_type="normal", + original_source_id="test456", + source_url="https://example.com", + source_display_name="Example" + ) + + # Each metadata should have the correct document URL + assert captured_metadatas[0]["url"] == "https://example.com/doc1" + assert captured_metadatas[1]["url"] == "https://example.com/doc2" \ No newline at end of file diff --git a/python/tests/test_url_canonicalization.py b/python/tests/test_url_canonicalization.py new file mode 100644 index 00000000..5ab6311f --- /dev/null +++ b/python/tests/test_url_canonicalization.py @@ -0,0 +1,222 @@ +""" +Test URL canonicalization in source ID generation. + +This test ensures that URLs are properly normalized before hashing +to prevent duplicate sources from URL variations. +""" + +import pytest +from src.server.services.crawling.helpers.url_handler import URLHandler + + +class TestURLCanonicalization: + """Test that URL canonicalization works correctly for source ID generation.""" + + def test_trailing_slash_normalization(self): + """Test that trailing slashes are handled consistently.""" + handler = URLHandler() + + # These should generate the same ID + url1 = "https://example.com/path" + url2 = "https://example.com/path/" + + id1 = handler.generate_unique_source_id(url1) + id2 = handler.generate_unique_source_id(url2) + + assert id1 == id2, "URLs with/without trailing slash should generate same ID" + + # Root path should keep its slash + root1 = "https://example.com" + root2 = "https://example.com/" + + root_id1 = handler.generate_unique_source_id(root1) + root_id2 = handler.generate_unique_source_id(root2) + + # These should be the same (both normalize to https://example.com/) + assert root_id1 == root_id2, "Root URLs should normalize consistently" + + def test_fragment_removal(self): + """Test that URL fragments are removed.""" + handler = URLHandler() + + urls = [ + "https://example.com/page", + "https://example.com/page#section1", + "https://example.com/page#section2", + "https://example.com/page#", + ] + + ids = [handler.generate_unique_source_id(url) for url in urls] + + # All should generate the same ID + assert len(set(ids)) == 1, "URLs with different fragments should generate same ID" + + def test_tracking_param_removal(self): + """Test that tracking parameters are removed.""" + handler = URLHandler() + + # URL without tracking params + clean_url = "https://example.com/page?important=value" + + # URLs with various tracking params + tracked_urls = [ + "https://example.com/page?important=value&utm_source=google", + "https://example.com/page?utm_medium=email&important=value", + "https://example.com/page?important=value&fbclid=abc123", + "https://example.com/page?gclid=xyz&important=value&utm_campaign=test", + "https://example.com/page?important=value&ref=homepage", + "https://example.com/page?source=newsletter&important=value", + ] + + clean_id = handler.generate_unique_source_id(clean_url) + tracked_ids = [handler.generate_unique_source_id(url) for url in tracked_urls] + + # All tracked URLs should generate the same ID as the clean URL + for tracked_id in tracked_ids: + assert tracked_id == clean_id, "URLs with tracking params should match clean URL" + + def test_query_param_sorting(self): + """Test that query parameters are sorted for consistency.""" + handler = URLHandler() + + urls = [ + "https://example.com/page?a=1&b=2&c=3", + "https://example.com/page?c=3&a=1&b=2", + "https://example.com/page?b=2&c=3&a=1", + ] + + ids = [handler.generate_unique_source_id(url) for url in urls] + + # All should generate the same ID + assert len(set(ids)) == 1, "URLs with reordered query params should generate same ID" + + def test_default_port_removal(self): + """Test that default ports are removed.""" + handler = URLHandler() + + # HTTP default port (80) + http_urls = [ + "http://example.com/page", + "http://example.com:80/page", + ] + + http_ids = [handler.generate_unique_source_id(url) for url in http_urls] + assert len(set(http_ids)) == 1, "HTTP URLs with/without :80 should generate same ID" + + # HTTPS default port (443) + https_urls = [ + "https://example.com/page", + "https://example.com:443/page", + ] + + https_ids = [handler.generate_unique_source_id(url) for url in https_urls] + assert len(set(https_ids)) == 1, "HTTPS URLs with/without :443 should generate same ID" + + # Non-default ports should be preserved + url1 = "https://example.com:8080/page" + url2 = "https://example.com:9090/page" + + id1 = handler.generate_unique_source_id(url1) + id2 = handler.generate_unique_source_id(url2) + + assert id1 != id2, "URLs with different non-default ports should generate different IDs" + + def test_case_normalization(self): + """Test that scheme and domain are lowercased.""" + handler = URLHandler() + + urls = [ + "https://example.com/Path/To/Page", + "HTTPS://EXAMPLE.COM/Path/To/Page", + "https://Example.Com/Path/To/Page", + "HTTPs://example.COM/Path/To/Page", + ] + + ids = [handler.generate_unique_source_id(url) for url in urls] + + # All should generate the same ID (path case is preserved) + assert len(set(ids)) == 1, "URLs with different case in scheme/domain should generate same ID" + + # But different paths should generate different IDs + path_urls = [ + "https://example.com/path", + "https://example.com/Path", + "https://example.com/PATH", + ] + + path_ids = [handler.generate_unique_source_id(url) for url in path_urls] + + # These should be different (path case matters) + assert len(set(path_ids)) == 3, "URLs with different path case should generate different IDs" + + def test_complex_canonicalization(self): + """Test complex URL with multiple normalizations needed.""" + handler = URLHandler() + + urls = [ + "https://example.com/page", + "HTTPS://EXAMPLE.COM:443/page/", + "https://Example.com/page#section", + "https://example.com/page/?utm_source=test", + "https://example.com:443/page?utm_campaign=abc#footer", + ] + + ids = [handler.generate_unique_source_id(url) for url in urls] + + # All should generate the same ID + assert len(set(ids)) == 1, "Complex URLs should normalize to same ID" + + def test_edge_cases(self): + """Test edge cases and error handling.""" + handler = URLHandler() + + # Empty URL + empty_id = handler.generate_unique_source_id("") + assert len(empty_id) == 16, "Empty URL should still generate valid ID" + + # Invalid URL + invalid_id = handler.generate_unique_source_id("not-a-url") + assert len(invalid_id) == 16, "Invalid URL should still generate valid ID" + + # URL with special characters + special_url = "https://example.com/page?key=value%20with%20spaces" + special_id = handler.generate_unique_source_id(special_url) + assert len(special_id) == 16, "URL with encoded chars should generate valid ID" + + # Very long URL + long_url = "https://example.com/" + "a" * 1000 + long_id = handler.generate_unique_source_id(long_url) + assert len(long_id) == 16, "Long URL should generate valid ID" + + def test_preserves_important_params(self): + """Test that non-tracking params are preserved.""" + handler = URLHandler() + + # These have different important params, should be different + url1 = "https://api.example.com/v1/users?page=1" + url2 = "https://api.example.com/v1/users?page=2" + + id1 = handler.generate_unique_source_id(url1) + id2 = handler.generate_unique_source_id(url2) + + assert id1 != id2, "URLs with different important params should generate different IDs" + + # But tracking params should still be removed + url3 = "https://api.example.com/v1/users?page=1&utm_source=docs" + id3 = handler.generate_unique_source_id(url3) + + assert id3 == id1, "Adding tracking params shouldn't change ID" + + def test_local_file_paths(self): + """Test handling of local file paths.""" + handler = URLHandler() + + # File URLs + file_url = "file:///Users/test/document.pdf" + file_id = handler.generate_unique_source_id(file_url) + assert len(file_id) == 16, "File URL should generate valid ID" + + # Relative paths + relative_path = "../documents/file.txt" + relative_id = handler.generate_unique_source_id(relative_path) + assert len(relative_id) == 16, "Relative path should generate valid ID" \ No newline at end of file