Fix race condition in concurrent crawling with unique source IDs (#472)

* Fix race condition in concurrent crawling with unique source IDs

- Add unique hash-based source_id generation to prevent conflicts
- Separate source identification from display with three fields:
  - source_id: 16-char SHA256 hash for unique identification
  - source_url: Original URL for tracking
  - source_display_name: Human-friendly name for UI
- Add comprehensive test suite validating the fix
- Migrate existing data with backward compatibility

* Fix title generation to use source_display_name for better AI context

- Pass source_display_name to title generation function
- Use display name in AI prompt instead of hash-based source_id
- Results in more specific, meaningful titles for each source

* Skip AI title generation when display name is available

- Use source_display_name directly as title to avoid unnecessary AI calls
- More efficient and predictable than AI-generated titles
- Keep AI generation only as fallback for backward compatibility

* Fix critical issues from code review

- Add missing os import to prevent NameError crash
- Remove unused imports (pytest, Mock, patch, hashlib, urlparse, etc.)
- Fix GitHub API capitalization consistency
- Reuse existing DocumentStorageService instance
- Update test expectations to match corrected capitalization

Addresses CodeRabbit review feedback on PR #472

* Add safety improvements from code review

- Truncate display names to 100 chars when used as titles
- Document hash collision probability (negligible for <1M sources)

Simple, pragmatic fixes per KISS principle

* Fix code extraction to use hash-based source_ids and improve display names

- Fixed critical bug where code extraction was using old domain-based source_ids
- Updated code extraction service to accept source_id as parameter instead of extracting from URL
- Added special handling for llms.txt and sitemap.xml files in display names
- Added comprehensive tests for source_id handling in code extraction
- Removed unused urlparse import from code_extraction_service.py

This fixes the foreign key constraint errors that were preventing code examples
from being stored after the source_id architecture refactor.

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix critical variable shadowing and source_type determination issues

- Fixed variable shadowing in document_storage_operations.py where source_url parameter
  was being overwritten by document URLs, causing incorrect source_url in database
- Fixed source_type determination to use actual URLs instead of hash-based source_id
- Added comprehensive tests for source URL preservation
- Ensure source_type is correctly set to "file" for file uploads, "url" for web crawls

The variable shadowing bug was causing sitemap sources to have the wrong source_url
(last crawled page instead of sitemap URL). The source_type bug would mark all
sources as "url" even for file uploads due to hash-based IDs not starting with "file_".

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix URL canonicalization and document metrics calculation

- Implement proper URL canonicalization to prevent duplicate sources
  - Remove trailing slashes (except root)
  - Remove URL fragments
  - Remove tracking parameters (utm_*, gclid, fbclid, etc.)
  - Sort query parameters for consistency
  - Remove default ports (80 for HTTP, 443 for HTTPS)
  - Normalize scheme and domain to lowercase

- Fix avg_chunks_per_doc calculation to avoid division by zero
  - Track processed_docs count separately from total crawl_results
  - Handle all-empty document sets gracefully
  - Show processed/total in logs for better visibility

- Add comprehensive tests for both fixes
  - 10 test cases for URL canonicalization edge cases
  - 4 test cases for document metrics calculation

This prevents database constraint violations when crawling the same
content with URL variations and provides accurate metrics in logs.

* Fix synchronous extract_source_summary blocking async event loop

- Run extract_source_summary in thread pool using asyncio.to_thread
- Prevents blocking the async event loop during AI summary generation
- Preserves exact error handling and fallback behavior
- Variables (source_id, combined_content) properly passed to thread

Added comprehensive tests verifying:
- Function runs in thread without blocking
- Error handling works correctly with fallback
- Multiple sources can be processed
- Thread safety with variable passing

* Fix synchronous update_source_info blocking async event loop

- Run update_source_info in thread pool using asyncio.to_thread
- Prevents blocking the async event loop during database operations
- Preserves exact error handling and fallback behavior
- All kwargs properly passed to thread execution

Added comprehensive tests verifying:
- Function runs in thread without blocking
- Error handling triggers fallback correctly
- All kwargs are preserved when passed to thread
- Existing extract_source_summary tests still pass

* Fix race condition in source creation using upsert

- Replace INSERT with UPSERT for new sources to prevent PRIMARY KEY violations
- Handles concurrent crawls attempting to create the same source
- Maintains existing UPDATE behavior for sources that already exist

Added comprehensive tests verifying:
- Concurrent source creation doesn't fail
- Upsert is used for new sources (not insert)
- Update is still used for existing sources
- Async concurrent operations work correctly
- Race conditions with delays are handled

This prevents database constraint errors when multiple crawls target
the same URL simultaneously.

* Add migration detection UI components

Add MigrationBanner component with clear user instructions for database schema updates. Add useMigrationStatus hook for periodic health check monitoring with graceful error handling.

* Integrate migration banner into main app

Add migration status monitoring and banner display to App.tsx. Shows migration banner when database schema updates are required.

* Enhance backend startup error instructions

Add detailed Docker restart instructions and migration script guidance. Improves user experience when encountering startup failures.

* Add database schema caching to health endpoint

Implement smart caching for schema validation to prevent repeated database queries. Cache successful validations permanently and throttle failures to 30-second intervals. Replace debug prints with proper logging.

* Clean up knowledge API imports and logging

Remove duplicate import statements and redundant logging. Improves code clarity and reduces log noise.

* Remove unused instructions prop from MigrationBanner

Clean up component API by removing instructions prop that was accepted but never rendered. Simplifies the interface and eliminates dead code while keeping the functional hardcoded migration steps.

* Add schema_valid flag to migration_required health response

Add schema_valid: false flag to health endpoint response when database schema migration is required. Improves API consistency without changing existing behavior.

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Wirasm
2025-08-29 14:54:16 +03:00
committed by GitHub
parent 02e72d9107
commit 3e204b0be1
20 changed files with 2603 additions and 182 deletions

View File

@@ -11,7 +11,9 @@ import { SettingsProvider, useSettings } from './contexts/SettingsContext';
import { ProjectPage } from './pages/ProjectPage';
import { DisconnectScreenOverlay } from './components/DisconnectScreenOverlay';
import { ErrorBoundaryWithBugReport } from './components/bug-report/ErrorBoundaryWithBugReport';
import { MigrationBanner } from './components/ui/MigrationBanner';
import { serverHealthService } from './services/serverHealthService';
import { useMigrationStatus } from './hooks/useMigrationStatus';
const AppRoutes = () => {
const { projectsEnabled } = useSettings();
@@ -38,6 +40,8 @@ const AppContent = () => {
enabled: true,
delay: 10000
});
const [migrationBannerDismissed, setMigrationBannerDismissed] = useState(false);
const migrationStatus = useMigrationStatus();
useEffect(() => {
// Load initial settings
@@ -77,6 +81,13 @@ const AppContent = () => {
<Router>
<ErrorBoundaryWithBugReport>
<MainLayout>
{/* Migration Banner - shows when backend is up but DB schema needs work */}
{migrationStatus.migrationRequired && !migrationBannerDismissed && (
<MigrationBanner
message={migrationStatus.message || "Database migration required"}
onDismiss={() => setMigrationBannerDismissed(true)}
/>
)}
<AppRoutes />
</MainLayout>
</ErrorBoundaryWithBugReport>

View File

@@ -40,8 +40,12 @@ export const BackendStartupError: React.FC = () => {
<div className="bg-yellow-950/30 border border-yellow-700/30 rounded-lg p-3">
<p className="text-yellow-200 text-sm">
<strong>Common issue:</strong> Using an ANON key instead of SERVICE key in your .env file
<strong>Common issues:</strong>
</p>
<ul className="text-yellow-200 text-sm mt-1 space-y-1 list-disc list-inside">
<li>Using an ANON key instead of SERVICE key in your .env file</li>
<li>Database not set up - run <code className="bg-yellow-800/50 px-1 rounded">migration/complete_setup.sql</code> in Supabase SQL Editor</li>
</ul>
</div>
<div className="pt-4 border-t border-red-900/30">

View File

@@ -0,0 +1,60 @@
import React from 'react';
import { AlertTriangle, ExternalLink } from 'lucide-react';
import { Card } from './Card';
interface MigrationBannerProps {
message: string;
onDismiss?: () => void;
}
export const MigrationBanner: React.FC<MigrationBannerProps> = ({
message,
onDismiss
}) => {
return (
<Card className="bg-red-50 border-red-200 dark:bg-red-900/20 dark:border-red-800 mb-6">
<div className="flex items-start gap-3 p-4">
<AlertTriangle className="w-6 h-6 text-red-500 flex-shrink-0 mt-0.5" />
<div className="flex-1">
<h3 className="text-lg font-semibold text-red-800 dark:text-red-300 mb-2">
Database Migration Required
</h3>
<p className="text-red-700 dark:text-red-400 mb-3">
{message}
</p>
<div className="bg-red-100 dark:bg-red-900/40 border border-red-200 dark:border-red-800 rounded-lg p-3 mb-3">
<p className="text-sm font-medium text-red-800 dark:text-red-300 mb-2">
Follow these steps:
</p>
<ol className="text-sm text-red-700 dark:text-red-400 space-y-1 list-decimal list-inside">
<li>Open your Supabase project dashboard</li>
<li>Navigate to the SQL Editor</li>
<li>Copy and run the migration script from: <code className="bg-red-200 dark:bg-red-800 px-1 rounded">migration/add_source_url_display_name.sql</code></li>
<li>Restart Docker containers: <code className="bg-red-200 dark:bg-red-800 px-1 rounded">docker compose down && docker compose up --build -d</code></li>
<li>If you used a profile, add it: <code className="bg-red-200 dark:bg-red-800 px-1 rounded">--profile full</code></li>
</ol>
</div>
<div className="flex items-center gap-3">
<a
href="https://supabase.com/dashboard"
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center gap-2 bg-red-600 hover:bg-red-700 text-white px-4 py-2 rounded-lg text-sm font-medium transition-colors"
>
<ExternalLink className="w-4 h-4" />
Open Supabase Dashboard
</a>
{onDismiss && (
<button
onClick={onDismiss}
className="text-red-600 dark:text-red-400 hover:text-red-800 dark:hover:text-red-200 text-sm font-medium"
>
Dismiss (temporarily)
</button>
)}
</div>
</div>
</div>
</Card>
);
};

View File

@@ -0,0 +1,51 @@
import { useState, useEffect } from 'react';
interface MigrationStatus {
migrationRequired: boolean;
message?: string;
loading: boolean;
}
export const useMigrationStatus = (): MigrationStatus => {
const [status, setStatus] = useState<MigrationStatus>({
migrationRequired: false,
loading: true,
});
useEffect(() => {
const checkMigrationStatus = async () => {
try {
const response = await fetch('/api/health');
const healthData = await response.json();
if (healthData.status === 'migration_required') {
setStatus({
migrationRequired: true,
message: healthData.message,
loading: false,
});
} else {
setStatus({
migrationRequired: false,
loading: false,
});
}
} catch (error) {
console.error('Failed to check migration status:', error);
setStatus({
migrationRequired: false,
loading: false,
});
}
};
checkMigrationStatus();
// Check periodically (every 30 seconds) to detect when migration is complete
const interval = setInterval(checkMigrationStatus, 30000);
return () => clearInterval(interval);
}, []);
return status;
};

View File

@@ -0,0 +1,36 @@
-- =====================================================
-- Add source_url and source_display_name columns
-- =====================================================
-- This migration adds two new columns to better identify sources:
-- - source_url: The original URL that was crawled
-- - source_display_name: Human-readable name for UI display
--
-- This solves the race condition issue where multiple crawls
-- to the same domain would conflict by using domain as source_id
-- =====================================================
-- Add new columns to archon_sources table
ALTER TABLE archon_sources
ADD COLUMN IF NOT EXISTS source_url TEXT,
ADD COLUMN IF NOT EXISTS source_display_name TEXT;
-- Add indexes for the new columns for better query performance
CREATE INDEX IF NOT EXISTS idx_archon_sources_url ON archon_sources(source_url);
CREATE INDEX IF NOT EXISTS idx_archon_sources_display_name ON archon_sources(source_display_name);
-- Add comments to document the new columns
COMMENT ON COLUMN archon_sources.source_url IS 'The original URL that was crawled to create this source';
COMMENT ON COLUMN archon_sources.source_display_name IS 'Human-readable name for UI display (e.g., "GitHub - microsoft/typescript")';
-- Backfill existing data
-- For existing sources, copy source_id to both new fields as a fallback
UPDATE archon_sources
SET
source_url = COALESCE(source_url, source_id),
source_display_name = COALESCE(source_display_name, source_id)
WHERE
source_url IS NULL
OR source_display_name IS NULL;
-- Note: source_id will now contain a unique hash instead of domain
-- This ensures no conflicts when multiple sources from same domain are crawled

View File

@@ -170,6 +170,8 @@ COMMENT ON TABLE archon_settings IS 'Stores application configuration including
-- Create the sources table
CREATE TABLE IF NOT EXISTS archon_sources (
source_id TEXT PRIMARY KEY,
source_url TEXT,
source_display_name TEXT,
summary TEXT,
total_word_count INTEGER DEFAULT 0,
title TEXT,
@@ -180,10 +182,15 @@ CREATE TABLE IF NOT EXISTS archon_sources (
-- Create indexes for better query performance
CREATE INDEX IF NOT EXISTS idx_archon_sources_title ON archon_sources(title);
CREATE INDEX IF NOT EXISTS idx_archon_sources_url ON archon_sources(source_url);
CREATE INDEX IF NOT EXISTS idx_archon_sources_display_name ON archon_sources(source_display_name);
CREATE INDEX IF NOT EXISTS idx_archon_sources_metadata ON archon_sources USING GIN(metadata);
CREATE INDEX IF NOT EXISTS idx_archon_sources_knowledge_type ON archon_sources((metadata->>'knowledge_type'));
-- Add comments to document the new columns
-- Add comments to document the columns
COMMENT ON COLUMN archon_sources.source_id IS 'Unique hash identifier for the source (16-char SHA256 hash of URL)';
COMMENT ON COLUMN archon_sources.source_url IS 'The original URL that was crawled to create this source';
COMMENT ON COLUMN archon_sources.source_display_name IS 'Human-readable name for UI display (e.g., "GitHub - microsoft/typescript")';
COMMENT ON COLUMN archon_sources.title IS 'Descriptive title for the source (e.g., "Pydantic AI API Reference")';
COMMENT ON COLUMN archon_sources.metadata IS 'JSONB field storing knowledge_type, tags, and other metadata';

View File

@@ -27,10 +27,6 @@ from ..services.crawler_manager import get_crawler
# Import unified logging
from ..config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info
from ..services.crawler_manager import get_crawler
from ..services.search.rag_service import RAGService
from ..services.storage import DocumentStorageService
from ..utils import get_supabase_client
from ..utils.document_processing import extract_text_from_document
# Get logger for this module
@@ -513,11 +509,6 @@ async def upload_document(
):
"""Upload and process a document with progress tracking."""
try:
# DETAILED LOGGING: Track knowledge_type parameter flow
safe_logfire_info(
f"📋 UPLOAD: Starting document upload | filename={file.filename} | content_type={file.content_type} | knowledge_type={knowledge_type}"
)
safe_logfire_info(
f"Starting document upload | filename={file.filename} | content_type={file.content_type} | knowledge_type={knowledge_type}"
)
@@ -871,7 +862,22 @@ async def get_database_metrics():
@router.get("/health")
async def knowledge_health():
"""Knowledge API health check."""
"""Knowledge API health check with migration detection."""
# Check for database migration needs
from ..main import _check_database_schema
schema_status = await _check_database_schema()
if not schema_status["valid"]:
return {
"status": "migration_required",
"service": "knowledge-api",
"timestamp": datetime.now().isoformat(),
"ready": False,
"migration_required": True,
"message": schema_status["message"],
"migration_instructions": "Open Supabase Dashboard → SQL Editor → Run: migration/add_source_url_display_name.sql"
}
# Removed health check logging to reduce console noise
result = {
"status": "healthy",

View File

@@ -246,12 +246,27 @@ async def health_check():
"ready": False,
}
# Check for required database schema
schema_status = await _check_database_schema()
if not schema_status["valid"]:
return {
"status": "migration_required",
"service": "archon-backend",
"timestamp": datetime.now().isoformat(),
"ready": False,
"migration_required": True,
"message": schema_status["message"],
"migration_instructions": "Open Supabase Dashboard → SQL Editor → Run: migration/add_source_url_display_name.sql",
"schema_valid": False
}
return {
"status": "healthy",
"service": "archon-backend",
"timestamp": datetime.now().isoformat(),
"ready": True,
"credentials_loaded": True,
"schema_valid": True,
}
@@ -262,6 +277,78 @@ async def api_health_check():
return await health_check()
# Cache schema check result to avoid repeated database queries
_schema_check_cache = {"valid": None, "checked_at": 0}
async def _check_database_schema():
"""Check if required database schema exists - only for existing users who need migration."""
import time
# If we've already confirmed schema is valid, don't check again
if _schema_check_cache["valid"] is True:
return {"valid": True, "message": "Schema is up to date (cached)"}
# If we recently failed, don't spam the database (wait at least 30 seconds)
current_time = time.time()
if (_schema_check_cache["valid"] is False and
current_time - _schema_check_cache["checked_at"] < 30):
return _schema_check_cache["result"]
try:
from .services.client_manager import get_supabase_client
client = get_supabase_client()
# Try to query the new columns directly - if they exist, schema is up to date
test_query = client.table('archon_sources').select('source_url, source_display_name').limit(1).execute()
# Cache successful result permanently
_schema_check_cache["valid"] = True
_schema_check_cache["checked_at"] = current_time
return {"valid": True, "message": "Schema is up to date"}
except Exception as e:
error_msg = str(e).lower()
# Log schema check error for debugging
api_logger.debug(f"Schema check error: {type(e).__name__}: {str(e)}")
# Check for specific error types based on PostgreSQL error codes and messages
# Check for missing columns first (more specific than table check)
missing_source_url = 'source_url' in error_msg and ('column' in error_msg or 'does not exist' in error_msg)
missing_source_display = 'source_display_name' in error_msg and ('column' in error_msg or 'does not exist' in error_msg)
# Also check for PostgreSQL error code 42703 (undefined column)
is_column_error = '42703' in error_msg or 'column' in error_msg
if (missing_source_url or missing_source_display) and is_column_error:
result = {
"valid": False,
"message": "Database schema outdated - missing required columns from recent updates"
}
# Cache failed result with timestamp
_schema_check_cache["valid"] = False
_schema_check_cache["checked_at"] = current_time
_schema_check_cache["result"] = result
return result
# Check for table doesn't exist (less specific, only if column check didn't match)
# Look for relation/table errors specifically
if ('relation' in error_msg and 'does not exist' in error_msg) or ('table' in error_msg and 'does not exist' in error_msg):
# Table doesn't exist - not a migration issue, it's a setup issue
return {"valid": True, "message": "Table doesn't exist - handled by startup error"}
# Other errors don't necessarily mean migration needed
result = {"valid": True, "message": f"Schema check inconclusive: {str(e)}"}
# Don't cache inconclusive results - allow retry
return result
# Export for Socket.IO
# Create Socket.IO app wrapper
# This wraps the FastAPI app with Socket.IO functionality
socket_app = create_socketio_app(app)

View File

@@ -7,7 +7,6 @@ Handles extraction, processing, and storage of code examples from documents.
import re
from collections.abc import Callable
from typing import Any
from urllib.parse import urlparse
from ...config.logfire_config import safe_logfire_error, safe_logfire_info
from ...services.credential_service import credential_service
@@ -136,6 +135,7 @@ class CodeExtractionService:
self,
crawl_results: list[dict[str, Any]],
url_to_full_document: dict[str, str],
source_id: str,
progress_callback: Callable | None = None,
start_progress: int = 0,
end_progress: int = 100,
@@ -146,6 +146,7 @@ class CodeExtractionService:
Args:
crawl_results: List of crawled documents with url and markdown content
url_to_full_document: Mapping of URLs to full document content
source_id: The unique source_id for all documents
progress_callback: Optional async callback for progress updates
start_progress: Starting progress percentage (default: 0)
end_progress: Ending progress percentage (default: 100)
@@ -163,7 +164,7 @@ class CodeExtractionService:
# Extract code blocks from all documents
all_code_blocks = await self._extract_code_blocks_from_documents(
crawl_results, progress_callback, start_progress, extract_end
crawl_results, source_id, progress_callback, start_progress, extract_end
)
if not all_code_blocks:
@@ -201,6 +202,7 @@ class CodeExtractionService:
async def _extract_code_blocks_from_documents(
self,
crawl_results: list[dict[str, Any]],
source_id: str,
progress_callback: Callable | None = None,
start_progress: int = 0,
end_progress: int = 100,
@@ -208,6 +210,10 @@ class CodeExtractionService:
"""
Extract code blocks from all documents.
Args:
crawl_results: List of crawled documents
source_id: The unique source_id for all documents
Returns:
List of code blocks with metadata
"""
@@ -306,10 +312,7 @@ class CodeExtractionService:
)
if code_blocks:
# Always extract source_id from URL
parsed_url = urlparse(source_url)
source_id = parsed_url.netloc or parsed_url.path
# Use the provided source_id for all code blocks
for block in code_blocks:
all_code_blocks.append({
"block": block,

View File

@@ -304,10 +304,12 @@ class CrawlingService:
url = str(request.get("url", ""))
safe_logfire_info(f"Starting async crawl orchestration | url={url} | task_id={task_id}")
# Extract source_id from the original URL
parsed_original_url = urlparse(url)
original_source_id = parsed_original_url.netloc or parsed_original_url.path
safe_logfire_info(f"Using source_id '{original_source_id}' from original URL '{url}'")
# Generate unique source_id and display name from the original URL
original_source_id = self.url_handler.generate_unique_source_id(url)
source_display_name = self.url_handler.extract_display_name(url)
safe_logfire_info(
f"Generated unique source_id '{original_source_id}' and display name '{source_display_name}' from URL '{url}'"
)
# Helper to update progress with mapper
async def update_mapped_progress(
@@ -386,6 +388,8 @@ class CrawlingService:
original_source_id,
doc_storage_callback,
self._check_cancellation,
source_url=url,
source_display_name=source_display_name,
)
# Check for cancellation after document storage
@@ -410,6 +414,7 @@ class CrawlingService:
code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples(
crawl_results,
storage_results["url_to_full_document"],
storage_results["source_id"],
code_progress_callback,
85,
95,
@@ -558,7 +563,7 @@ class CrawlingService:
max_depth = request.get("max_depth", 1)
# Let the strategy handle concurrency from settings
# This will use CRAWL_MAX_CONCURRENT from database (default: 10)
crawl_results = await self.crawl_recursive_with_progress(
[url],
max_depth=max_depth,

View File

@@ -4,17 +4,13 @@ Document Storage Operations
Handles the storage and processing of crawled documents.
Extracted from crawl_orchestration_service.py for better modularity.
"""
import asyncio
from typing import Dict, Any, List, Optional, Callable
from urllib.parse import urlparse
from ...config.logfire_config import safe_logfire_info, safe_logfire_error
from ..storage.storage_services import DocumentStorageService
from ..storage.document_storage_service import add_documents_to_supabase
from ..storage.code_storage_service import (
generate_code_summaries_batch,
add_code_examples_to_supabase
)
from ..source_management_service import update_source_info, extract_source_summary
from .code_extraction_service import CodeExtractionService
@@ -23,18 +19,18 @@ class DocumentStorageOperations:
"""
Handles document storage operations for crawled content.
"""
def __init__(self, supabase_client):
"""
Initialize document storage operations.
Args:
supabase_client: The Supabase client for database operations
"""
self.supabase_client = supabase_client
self.doc_storage_service = DocumentStorageService(supabase_client)
self.code_extraction_service = CodeExtractionService(supabase_client)
async def process_and_store_documents(
self,
crawl_results: List[Dict],
@@ -42,11 +38,13 @@ class DocumentStorageOperations:
crawl_type: str,
original_source_id: str,
progress_callback: Optional[Callable] = None,
cancellation_check: Optional[Callable] = None
cancellation_check: Optional[Callable] = None,
source_url: Optional[str] = None,
source_display_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
Process crawled documents and store them in the database.
Args:
crawl_results: List of crawled documents
request: The original crawl request
@@ -54,13 +52,15 @@ class DocumentStorageOperations:
original_source_id: The source ID for all documents
progress_callback: Optional callback for progress updates
cancellation_check: Optional function to check for cancellation
source_url: Optional original URL that was crawled
source_display_name: Optional human-readable name for the source
Returns:
Dict containing storage statistics and document mappings
"""
# Initialize storage service for chunking
storage_service = DocumentStorageService(self.supabase_client)
# Reuse initialized storage service for chunking
storage_service = self.doc_storage_service
# Prepare data for chunked storage
all_urls = []
all_chunk_numbers = []
@@ -68,77 +68,85 @@ class DocumentStorageOperations:
all_metadatas = []
source_word_counts = {}
url_to_full_document = {}
processed_docs = 0
# Process and chunk each document
for doc_index, doc in enumerate(crawl_results):
# Check for cancellation during document processing
if cancellation_check:
cancellation_check()
source_url = doc.get('url', '')
markdown_content = doc.get('markdown', '')
doc_url = doc.get("url", "")
markdown_content = doc.get("markdown", "")
if not markdown_content:
continue
# Increment processed document count
processed_docs += 1
# Store full document for code extraction context
url_to_full_document[source_url] = markdown_content
url_to_full_document[doc_url] = markdown_content
# CHUNK THE CONTENT
chunks = storage_service.smart_chunk_text(markdown_content, chunk_size=5000)
# Use the original source_id for all documents
source_id = original_source_id
safe_logfire_info(f"Using original source_id '{source_id}' for URL '{source_url}'")
safe_logfire_info(f"Using original source_id '{source_id}' for URL '{doc_url}'")
# Process each chunk
for i, chunk in enumerate(chunks):
# Check for cancellation during chunk processing
if cancellation_check and i % 10 == 0: # Check every 10 chunks
cancellation_check()
all_urls.append(source_url)
all_urls.append(doc_url)
all_chunk_numbers.append(i)
all_contents.append(chunk)
# Create metadata for each chunk
word_count = len(chunk.split())
metadata = {
'url': source_url,
'title': doc.get('title', ''),
'description': doc.get('description', ''),
'source_id': source_id,
'knowledge_type': request.get('knowledge_type', 'documentation'),
'crawl_type': crawl_type,
'word_count': word_count,
'char_count': len(chunk),
'chunk_index': i,
'tags': request.get('tags', [])
"url": doc_url,
"title": doc.get("title", ""),
"description": doc.get("description", ""),
"source_id": source_id,
"knowledge_type": request.get("knowledge_type", "documentation"),
"crawl_type": crawl_type,
"word_count": word_count,
"char_count": len(chunk),
"chunk_index": i,
"tags": request.get("tags", []),
}
all_metadatas.append(metadata)
# Accumulate word count
source_word_counts[source_id] = source_word_counts.get(source_id, 0) + word_count
# Yield control every 10 chunks to prevent event loop blocking
if i > 0 and i % 10 == 0:
await asyncio.sleep(0)
# Yield control after processing each document
if doc_index > 0 and doc_index % 5 == 0:
await asyncio.sleep(0)
# Create/update source record FIRST before storing documents
if all_contents and all_metadatas:
await self._create_source_records(
all_metadatas, all_contents, source_word_counts, request
all_metadatas, all_contents, source_word_counts, request,
source_url, source_display_name
)
safe_logfire_info(f"url_to_full_document keys: {list(url_to_full_document.keys())[:5]}")
# Log chunking results
safe_logfire_info(f"Document storage | documents={len(crawl_results)} | chunks={len(all_contents)} | avg_chunks_per_doc={len(all_contents)/len(crawl_results):.1f}")
avg_chunks = (len(all_contents) / processed_docs) if processed_docs > 0 else 0.0
safe_logfire_info(
f"Document storage | processed={processed_docs}/{len(crawl_results)} | chunks={len(all_contents)} | avg_chunks_per_doc={avg_chunks:.1f}"
)
# Call add_documents_to_supabase with the correct parameters
await add_documents_to_supabase(
client=self.supabase_client,
@@ -151,29 +159,31 @@ class DocumentStorageOperations:
progress_callback=progress_callback, # Pass the callback for progress updates
enable_parallel_batches=True, # Enable parallel processing
provider=None, # Use configured provider
cancellation_check=cancellation_check # Pass cancellation check
cancellation_check=cancellation_check, # Pass cancellation check
)
# Calculate actual chunk count
chunk_count = len(all_contents)
return {
'chunk_count': chunk_count,
'total_word_count': sum(source_word_counts.values()),
'url_to_full_document': url_to_full_document,
'source_id': original_source_id
"chunk_count": chunk_count,
"total_word_count": sum(source_word_counts.values()),
"url_to_full_document": url_to_full_document,
"source_id": original_source_id,
}
async def _create_source_records(
self,
all_metadatas: List[Dict],
all_contents: List[str],
source_word_counts: Dict[str, int],
request: Dict[str, Any]
request: Dict[str, Any],
source_url: Optional[str] = None,
source_display_name: Optional[str] = None,
):
"""
Create or update source records in the database.
Args:
all_metadatas: List of metadata for all chunks
all_contents: List of all chunk contents
@@ -184,121 +194,155 @@ class DocumentStorageOperations:
unique_source_ids = set()
source_id_contents = {}
source_id_word_counts = {}
for i, metadata in enumerate(all_metadatas):
source_id = metadata['source_id']
source_id = metadata["source_id"]
unique_source_ids.add(source_id)
# Group content by source_id for better summaries
if source_id not in source_id_contents:
source_id_contents[source_id] = []
source_id_contents[source_id].append(all_contents[i])
# Track word counts per source_id
if source_id not in source_id_word_counts:
source_id_word_counts[source_id] = 0
source_id_word_counts[source_id] += metadata.get('word_count', 0)
safe_logfire_info(f"Found {len(unique_source_ids)} unique source_ids: {list(unique_source_ids)}")
source_id_word_counts[source_id] += metadata.get("word_count", 0)
safe_logfire_info(
f"Found {len(unique_source_ids)} unique source_ids: {list(unique_source_ids)}"
)
# Create source records for ALL unique source_ids
for source_id in unique_source_ids:
# Get combined content for this specific source_id
source_contents = source_id_contents[source_id]
combined_content = ''
combined_content = ""
for chunk in source_contents[:3]: # First 3 chunks for this source
if len(combined_content) + len(chunk) < 15000:
combined_content += ' ' + chunk
combined_content += " " + chunk
else:
break
# Generate summary with fallback
# Generate summary with fallback (run in thread to avoid blocking async loop)
try:
summary = extract_source_summary(source_id, combined_content)
# Run synchronous extract_source_summary in a thread pool
summary = await asyncio.to_thread(
extract_source_summary, source_id, combined_content
)
except Exception as e:
safe_logfire_error(f"Failed to generate AI summary for '{source_id}': {str(e)}, using fallback")
safe_logfire_error(
f"Failed to generate AI summary for '{source_id}': {str(e)}, using fallback"
)
# Fallback to simple summary
summary = f"Documentation from {source_id} - {len(source_contents)} pages crawled"
# Update source info in database BEFORE storing documents
safe_logfire_info(f"About to create/update source record for '{source_id}' (word count: {source_id_word_counts[source_id]})")
safe_logfire_info(
f"About to create/update source record for '{source_id}' (word count: {source_id_word_counts[source_id]})"
)
try:
update_source_info(
# Run synchronous update_source_info in a thread pool
await asyncio.to_thread(
update_source_info,
client=self.supabase_client,
source_id=source_id,
summary=summary,
word_count=source_id_word_counts[source_id],
content=combined_content,
knowledge_type=request.get('knowledge_type', 'technical'),
tags=request.get('tags', []),
knowledge_type=request.get("knowledge_type", "technical"),
tags=request.get("tags", []),
update_frequency=0, # Set to 0 since we're using manual refresh
original_url=request.get('url') # Store the original crawl URL
original_url=request.get("url"), # Store the original crawl URL
source_url=source_url,
source_display_name=source_display_name,
)
safe_logfire_info(f"Successfully created/updated source record for '{source_id}'")
except Exception as e:
safe_logfire_error(f"Failed to create/update source record for '{source_id}': {str(e)}")
safe_logfire_error(
f"Failed to create/update source record for '{source_id}': {str(e)}"
)
# Try a simpler approach with minimal data
try:
safe_logfire_info(f"Attempting fallback source creation for '{source_id}'")
self.supabase_client.table('archon_sources').upsert({
'source_id': source_id,
'title': source_id, # Use source_id as title fallback
'summary': summary,
'total_word_count': source_id_word_counts[source_id],
'metadata': {
'knowledge_type': request.get('knowledge_type', 'technical'),
'tags': request.get('tags', []),
'auto_generated': True,
'fallback_creation': True,
'original_url': request.get('url')
}
}).execute()
fallback_data = {
"source_id": source_id,
"title": source_id, # Use source_id as title fallback
"summary": summary,
"total_word_count": source_id_word_counts[source_id],
"metadata": {
"knowledge_type": request.get("knowledge_type", "technical"),
"tags": request.get("tags", []),
"auto_generated": True,
"fallback_creation": True,
"original_url": request.get("url"),
},
}
# Add new fields if provided
if source_url:
fallback_data["source_url"] = source_url
if source_display_name:
fallback_data["source_display_name"] = source_display_name
self.supabase_client.table("archon_sources").upsert(fallback_data).execute()
safe_logfire_info(f"Fallback source creation succeeded for '{source_id}'")
except Exception as fallback_error:
safe_logfire_error(f"Both source creation attempts failed for '{source_id}': {str(fallback_error)}")
raise Exception(f"Unable to create source record for '{source_id}'. This will cause foreign key violations. Error: {str(fallback_error)}")
safe_logfire_error(
f"Both source creation attempts failed for '{source_id}': {str(fallback_error)}"
)
raise Exception(
f"Unable to create source record for '{source_id}'. This will cause foreign key violations. Error: {str(fallback_error)}"
)
# Verify ALL source records exist before proceeding with document storage
if unique_source_ids:
for source_id in unique_source_ids:
try:
source_check = self.supabase_client.table('archon_sources').select('source_id').eq('source_id', source_id).execute()
source_check = (
self.supabase_client.table("archon_sources")
.select("source_id")
.eq("source_id", source_id)
.execute()
)
if not source_check.data:
raise Exception(f"Source record verification failed - '{source_id}' does not exist in sources table")
raise Exception(
f"Source record verification failed - '{source_id}' does not exist in sources table"
)
safe_logfire_info(f"Source record verified for '{source_id}'")
except Exception as e:
safe_logfire_error(f"Source verification failed for '{source_id}': {str(e)}")
raise
safe_logfire_info(f"All {len(unique_source_ids)} source records verified - proceeding with document storage")
safe_logfire_info(
f"All {len(unique_source_ids)} source records verified - proceeding with document storage"
)
async def extract_and_store_code_examples(
self,
crawl_results: List[Dict],
url_to_full_document: Dict[str, str],
source_id: str,
progress_callback: Optional[Callable] = None,
start_progress: int = 85,
end_progress: int = 95
end_progress: int = 95,
) -> int:
"""
Extract code examples from crawled documents and store them.
Args:
crawl_results: List of crawled documents
url_to_full_document: Mapping of URLs to full document content
source_id: The unique source_id for all documents
progress_callback: Optional callback for progress updates
start_progress: Starting progress percentage
end_progress: Ending progress percentage
Returns:
Number of code examples stored
"""
result = await self.code_extraction_service.extract_and_store_code_examples(
crawl_results,
url_to_full_document,
progress_callback,
start_progress,
end_progress
crawl_results, url_to_full_document, source_id, progress_callback, start_progress, end_progress
)
return result
return result

View File

@@ -3,6 +3,8 @@ URL Handler Helper
Handles URL transformations and validations.
"""
import hashlib
import re
from urllib.parse import urlparse
@@ -13,49 +15,49 @@ logger = get_logger(__name__)
class URLHandler:
"""Helper class for URL operations."""
@staticmethod
def is_sitemap(url: str) -> bool:
"""
Check if a URL is a sitemap with error handling.
Args:
url: URL to check
Returns:
True if URL is a sitemap, False otherwise
"""
try:
return url.endswith('sitemap.xml') or 'sitemap' in urlparse(url).path
return url.endswith("sitemap.xml") or "sitemap" in urlparse(url).path
except Exception as e:
logger.warning(f"Error checking if URL is sitemap: {e}")
return False
@staticmethod
def is_txt(url: str) -> bool:
"""
Check if a URL is a text file with error handling.
Args:
url: URL to check
Returns:
True if URL is a text file, False otherwise
"""
try:
return url.endswith('.txt')
return url.endswith(".txt")
except Exception as e:
logger.warning(f"Error checking if URL is text file: {e}")
return False
@staticmethod
def is_binary_file(url: str) -> bool:
"""
Check if a URL points to a binary file that shouldn't be crawled.
Args:
url: URL to check
Returns:
True if URL is a binary file, False otherwise
"""
@@ -63,65 +65,338 @@ class URLHandler:
# Remove query parameters and fragments for cleaner extension checking
parsed = urlparse(url)
path = parsed.path.lower()
# Comprehensive list of binary and non-HTML file extensions
binary_extensions = {
# Archives
'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz', '.tgz',
".zip",
".tar",
".gz",
".rar",
".7z",
".bz2",
".xz",
".tgz",
# Executables and installers
'.exe', '.dmg', '.pkg', '.deb', '.rpm', '.msi', '.app', '.appimage',
".exe",
".dmg",
".pkg",
".deb",
".rpm",
".msi",
".app",
".appimage",
# Documents (non-HTML)
'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.odt', '.ods',
".pdf",
".doc",
".docx",
".xls",
".xlsx",
".ppt",
".pptx",
".odt",
".ods",
# Images
'.jpg', '.jpeg', '.png', '.gif', '.svg', '.webp', '.ico', '.bmp', '.tiff',
".jpg",
".jpeg",
".png",
".gif",
".svg",
".webp",
".ico",
".bmp",
".tiff",
# Audio/Video
'.mp3', '.mp4', '.avi', '.mov', '.wmv', '.flv', '.webm', '.mkv', '.wav', '.flac',
".mp3",
".mp4",
".avi",
".mov",
".wmv",
".flv",
".webm",
".mkv",
".wav",
".flac",
# Data files
'.csv', '.sql', '.db', '.sqlite',
".csv",
".sql",
".db",
".sqlite",
# Binary data
'.iso', '.img', '.bin', '.dat',
".iso",
".img",
".bin",
".dat",
# Development files (usually not meant to be crawled as pages)
'.wasm', '.pyc', '.jar', '.war', '.class', '.dll', '.so', '.dylib'
".wasm",
".pyc",
".jar",
".war",
".class",
".dll",
".so",
".dylib",
}
# Check if the path ends with any binary extension
for ext in binary_extensions:
if path.endswith(ext):
logger.debug(f"Skipping binary file: {url} (matched extension: {ext})")
return True
return False
except Exception as e:
logger.warning(f"Error checking if URL is binary file: {e}")
# In case of error, don't skip the URL (safer to attempt crawl than miss content)
return False
@staticmethod
def transform_github_url(url: str) -> str:
"""
Transform GitHub URLs to raw content URLs for better content extraction.
Args:
url: URL to transform
Returns:
Transformed URL (or original if not a GitHub file URL)
"""
# Pattern for GitHub file URLs
github_file_pattern = r'https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.+)'
github_file_pattern = r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.+)"
match = re.match(github_file_pattern, url)
if match:
owner, repo, branch, path = match.groups()
raw_url = f'https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}'
raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
logger.info(f"Transformed GitHub file URL to raw: {url} -> {raw_url}")
return raw_url
# Pattern for GitHub directory URLs
github_dir_pattern = r'https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.+)'
github_dir_pattern = r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.+)"
match = re.match(github_dir_pattern, url)
if match:
# For directories, we can't directly get raw content
# Return original URL but log a warning
logger.warning(f"GitHub directory URL detected: {url} - consider using specific file URLs or GitHub API")
logger.warning(
f"GitHub directory URL detected: {url} - consider using specific file URLs or GitHub API"
)
return url
@staticmethod
def generate_unique_source_id(url: str) -> str:
"""
Generate a unique source ID from URL using hash.
This creates a 16-character hash that is extremely unlikely to collide
for distinct canonical URLs, solving race condition issues when multiple crawls
target the same domain.
return url
Uses 16-char SHA256 prefix (64 bits) which provides
~18 quintillion unique values. Collision probability
is negligible for realistic usage (<1M sources).
Args:
url: The URL to generate an ID for
Returns:
A 16-character hexadecimal hash string
"""
try:
from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
# Canonicalize URL for consistent hashing
parsed = urlparse(url.strip())
# Normalize scheme and netloc to lowercase
scheme = (parsed.scheme or "").lower()
netloc = (parsed.netloc or "").lower()
# Remove default ports
if netloc.endswith(":80") and scheme == "http":
netloc = netloc[:-3]
if netloc.endswith(":443") and scheme == "https":
netloc = netloc[:-4]
# Normalize path (remove trailing slash except for root)
path = parsed.path or "/"
if path.endswith("/") and len(path) > 1:
path = path.rstrip("/")
# Remove common tracking parameters and sort remaining
tracking_params = {
"utm_source", "utm_medium", "utm_campaign", "utm_term", "utm_content",
"gclid", "fbclid", "ref", "source"
}
query_items = [
(k, v) for k, v in parse_qsl(parsed.query, keep_blank_values=True)
if k not in tracking_params
]
query = urlencode(sorted(query_items))
# Reconstruct canonical URL (fragment is dropped)
canonical = urlunparse((scheme, netloc, path, "", query, ""))
# Generate SHA256 hash and take first 16 characters
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]
except Exception as e:
# Redact sensitive query params from error logs
try:
redacted = url.split("?", 1)[0] if "?" in url else url
except Exception:
redacted = "<unparseable-url>"
logger.error(f"Error generating unique source ID for {redacted}: {e}", exc_info=True)
# Fallback: use a hash of the error message + url to still get something unique
fallback = f"error_{redacted}_{str(e)}"
return hashlib.sha256(fallback.encode("utf-8")).hexdigest()[:16]
@staticmethod
def extract_display_name(url: str) -> str:
"""
Extract a human-readable display name from URL.
This creates user-friendly names for common source patterns
while falling back to the domain for unknown patterns.
Args:
url: The URL to extract a display name from
Returns:
A human-readable string suitable for UI display
"""
try:
parsed = urlparse(url)
domain = parsed.netloc.lower()
# Remove www prefix for cleaner display
if domain.startswith("www."):
domain = domain[4:]
# Handle empty domain (might be a file path or malformed URL)
if not domain:
if url.startswith("/"):
return f"Local: {url.split('/')[-1] if '/' in url else url}"
return url[:50] + "..." if len(url) > 50 else url
path = parsed.path.strip("/")
# Special handling for GitHub repositories and API
if "github.com" in domain:
# Check if it's an API endpoint
if domain.startswith("api."):
return "GitHub API"
parts = path.split("/")
if len(parts) >= 2:
owner = parts[0]
repo = parts[1].replace(".git", "") # Remove .git extension if present
return f"GitHub - {owner}/{repo}"
elif len(parts) == 1 and parts[0]:
return f"GitHub - {parts[0]}"
return "GitHub"
# Special handling for documentation sites
if domain.startswith("docs."):
# Extract the service name from docs.X.com/org
service_name = domain.replace("docs.", "").split(".")[0]
base_name = f"{service_name.title()}" if service_name else "Documentation"
# Special handling for special files - preserve the filename
if path:
# Check for llms.txt files
if "llms" in path.lower() and path.endswith(".txt"):
return f"{base_name} - Llms.Txt"
# Check for sitemap files
elif "sitemap" in path.lower() and path.endswith(".xml"):
return f"{base_name} - Sitemap.Xml"
# Check for any other special .txt files
elif path.endswith(".txt"):
filename = path.split("/")[-1] if "/" in path else path
return f"{base_name} - {filename.title()}"
return f"{base_name} Documentation" if service_name else "Documentation"
# Handle readthedocs.io subdomains
if domain.endswith(".readthedocs.io"):
project = domain.replace(".readthedocs.io", "")
return f"{project.title()} Docs"
# Handle common documentation patterns
doc_patterns = [
("fastapi.tiangolo.com", "FastAPI Documentation"),
("pydantic.dev", "Pydantic Documentation"),
("python.org", "Python Documentation"),
("djangoproject.com", "Django Documentation"),
("flask.palletsprojects.com", "Flask Documentation"),
("numpy.org", "NumPy Documentation"),
("pandas.pydata.org", "Pandas Documentation"),
]
for pattern, name in doc_patterns:
if pattern in domain:
# Add path context if available
if path and len(path) > 1:
# Get first meaningful path segment
path_segment = path.split("/")[0] if "/" in path else path
if path_segment and path_segment not in [
"docs",
"doc", # Added "doc" to filter list
"documentation",
"api",
"en",
]:
return f"{name} - {path_segment.title()}"
return name
# For API endpoints
if "api." in domain or "/api" in path:
service = domain.replace("api.", "").split(".")[0]
return f"{service.title()} API"
# Special handling for sitemap.xml and llms.txt on any site
if path:
if "sitemap" in path.lower() and path.endswith(".xml"):
# Get base domain name
display = domain
for tld in [".com", ".org", ".io", ".dev", ".net", ".ai", ".app"]:
if display.endswith(tld):
display = display[:-len(tld)]
break
display_parts = display.replace("-", " ").replace("_", " ").split(".")
formatted = " ".join(part.title() for part in display_parts)
return f"{formatted} - Sitemap.Xml"
elif "llms" in path.lower() and path.endswith(".txt"):
# Get base domain name
display = domain
for tld in [".com", ".org", ".io", ".dev", ".net", ".ai", ".app"]:
if display.endswith(tld):
display = display[:-len(tld)]
break
display_parts = display.replace("-", " ").replace("_", " ").split(".")
formatted = " ".join(part.title() for part in display_parts)
return f"{formatted} - Llms.Txt"
# Default: Use domain with nice formatting
# Remove common TLDs for cleaner display
display = domain
for tld in [".com", ".org", ".io", ".dev", ".net", ".ai", ".app"]:
if display.endswith(tld):
display = display[: -len(tld)]
break
# Capitalize first letter of each word
display_parts = display.replace("-", " ").replace("_", " ").split(".")
formatted = " ".join(part.title() for part in display_parts)
# Add path context if it's meaningful
if path and len(path) > 1 and "/" not in path:
formatted += f" - {path.title()}"
return formatted
except Exception as e:
logger.warning(f"Error extracting display name for {url}: {e}, using URL")
# Fallback: return truncated URL
return url[:50] + "..." if len(url) > 50 else url

View File

@@ -5,6 +5,7 @@ Handles source metadata, summaries, and management.
Consolidates both utility functions and class-based service.
"""
import os
from typing import Any
from supabase import Client
@@ -145,6 +146,7 @@ def generate_source_title_and_metadata(
knowledge_type: str = "technical",
tags: list[str] | None = None,
provider: str = None,
source_display_name: str | None = None,
) -> tuple[str, dict[str, Any]]:
"""
Generate a user-friendly title and metadata for a source based on its content.
@@ -203,8 +205,11 @@ def generate_source_title_and_metadata(
# Limit content for prompt
sample_content = content[:3000] if len(content) > 3000 else content
# Use display name if available for better context
source_context = source_display_name if source_display_name else source_id
prompt = f"""Based on this content from {source_id}, generate a concise, descriptive title (3-6 words) that captures what this source is about:
prompt = f"""Based on this content from {source_context}, generate a concise, descriptive title (3-6 words) that captures what this source is about:
{sample_content}
@@ -230,12 +235,12 @@ Provide only the title, nothing else."""
except Exception as e:
search_logger.error(f"Error generating title for {source_id}: {e}")
# Build metadata - determine source_type from source_id pattern
source_type = "file" if source_id.startswith("file_") else "url"
# Build metadata - source_type will be determined by caller based on actual URL
# Default to "url" but this should be overridden by the caller
metadata = {
"knowledge_type": knowledge_type,
"tags": tags or [],
"source_type": source_type,
"source_type": "url", # Default, should be overridden by caller based on actual URL
"auto_generated": True
}
@@ -252,6 +257,8 @@ def update_source_info(
tags: list[str] | None = None,
update_frequency: int = 7,
original_url: str | None = None,
source_url: str | None = None,
source_display_name: str | None = None,
):
"""
Update or insert source information in the sources table.
@@ -279,7 +286,14 @@ def update_source_info(
search_logger.info(f"Preserving existing title for {source_id}: {existing_title}")
# Update metadata while preserving title
source_type = "file" if source_id.startswith("file_") else "url"
# Determine source_type based on source_url or original_url
if source_url and source_url.startswith("file://"):
source_type = "file"
elif original_url and original_url.startswith("file://"):
source_type = "file"
else:
source_type = "url"
metadata = {
"knowledge_type": knowledge_type,
"tags": tags or [],
@@ -292,14 +306,22 @@ def update_source_info(
metadata["original_url"] = original_url
# Update existing source (preserving title)
update_data = {
"summary": summary,
"total_word_count": word_count,
"metadata": metadata,
"updated_at": "now()",
}
# Add new fields if provided
if source_url:
update_data["source_url"] = source_url
if source_display_name:
update_data["source_display_name"] = source_display_name
result = (
client.table("archon_sources")
.update({
"summary": summary,
"total_word_count": word_count,
"metadata": metadata,
"updated_at": "now()",
})
.update(update_data)
.eq("source_id", source_id)
.execute()
)
@@ -308,10 +330,38 @@ def update_source_info(
f"Updated source {source_id} while preserving title: {existing_title}"
)
else:
# New source - generate title and metadata
title, metadata = generate_source_title_and_metadata(
source_id, content, knowledge_type, tags
)
# New source - use display name as title if available, otherwise generate
if source_display_name:
# Use the display name directly as the title (truncated to prevent DB issues)
title = source_display_name[:100].strip()
# Determine source_type based on source_url or original_url
if source_url and source_url.startswith("file://"):
source_type = "file"
elif original_url and original_url.startswith("file://"):
source_type = "file"
else:
source_type = "url"
metadata = {
"knowledge_type": knowledge_type,
"tags": tags or [],
"source_type": source_type,
"auto_generated": False,
}
else:
# Fallback to AI generation only if no display name
title, metadata = generate_source_title_and_metadata(
source_id, content, knowledge_type, tags, None, source_display_name
)
# Override the source_type from AI with actual URL-based determination
if source_url and source_url.startswith("file://"):
metadata["source_type"] = "file"
elif original_url and original_url.startswith("file://"):
metadata["source_type"] = "file"
else:
metadata["source_type"] = "url"
# Add update_frequency and original_url to metadata
metadata["update_frequency"] = update_frequency
@@ -319,15 +369,23 @@ def update_source_info(
metadata["original_url"] = original_url
search_logger.info(f"Creating new source {source_id} with knowledge_type={knowledge_type}")
# Insert new source
client.table("archon_sources").insert({
# Use upsert to avoid race conditions with concurrent crawls
upsert_data = {
"source_id": source_id,
"title": title,
"summary": summary,
"total_word_count": word_count,
"metadata": metadata,
}).execute()
search_logger.info(f"Created new source {source_id} with title: {title}")
}
# Add new fields if provided
if source_url:
upsert_data["source_url"] = source_url
if source_display_name:
upsert_data["source_display_name"] = source_display_name
client.table("archon_sources").upsert(upsert_data).execute()
search_logger.info(f"Created/updated source {source_id} with title: {title}")
except Exception as e:
search_logger.error(f"Error updating source {source_id}: {e}")

View File

@@ -0,0 +1,413 @@
"""
Test async execution of extract_source_summary and update_source_info.
This test ensures that synchronous functions extract_source_summary and
update_source_info are properly executed in thread pools to avoid blocking
the async event loop.
"""
import asyncio
import time
from unittest.mock import Mock, AsyncMock, patch
import pytest
from src.server.services.crawling.document_storage_operations import DocumentStorageOperations
class TestAsyncSourceSummary:
"""Test that extract_source_summary and update_source_info don't block the async event loop."""
@pytest.mark.asyncio
async def test_extract_summary_runs_in_thread(self):
"""Test that extract_source_summary is executed in a thread pool."""
# Create mock supabase client
mock_supabase = Mock()
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Track when extract_source_summary is called
summary_call_times = []
original_summary_result = "Test summary from AI"
def slow_extract_summary(source_id, content):
"""Simulate a slow synchronous function that would block the event loop."""
summary_call_times.append(time.time())
# Simulate a blocking operation (like an API call)
time.sleep(0.1) # This would block the event loop if not run in thread
return original_summary_result
# Mock the storage service
doc_storage.doc_storage_service.smart_chunk_text = Mock(
return_value=["chunk1", "chunk2"]
)
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
side_effect=slow_extract_summary):
with patch('src.server.services.crawling.document_storage_operations.update_source_info'):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error'):
# Create test metadata
all_metadatas = [
{"source_id": "test123", "word_count": 100},
{"source_id": "test123", "word_count": 150},
]
all_contents = ["chunk1", "chunk2"]
source_word_counts = {"test123": 250}
request = {"knowledge_type": "documentation"}
# Track async execution
start_time = time.time()
# This should not block despite the sleep in extract_summary
await doc_storage._create_source_records(
all_metadatas,
all_contents,
source_word_counts,
request,
"https://example.com",
"Example Site"
)
end_time = time.time()
# Verify that extract_source_summary was called
assert len(summary_call_times) == 1, "extract_source_summary should be called once"
# The async function should complete without blocking
# Even though extract_summary sleeps for 0.1s, the async function
# should not be blocked since it runs in a thread
total_time = end_time - start_time
# We can't guarantee exact timing, but it should complete
# without throwing a timeout error
assert total_time < 1.0, "Should complete in reasonable time"
@pytest.mark.asyncio
async def test_extract_summary_error_handling(self):
"""Test that errors in extract_source_summary are handled correctly."""
mock_supabase = Mock()
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Mock to raise an exception
def failing_extract_summary(source_id, content):
raise RuntimeError("AI service unavailable")
doc_storage.doc_storage_service.smart_chunk_text = Mock(
return_value=["chunk1"]
)
error_messages = []
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
side_effect=failing_extract_summary):
with patch('src.server.services.crawling.document_storage_operations.update_source_info') as mock_update:
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error') as mock_error:
mock_error.side_effect = lambda msg: error_messages.append(msg)
all_metadatas = [{"source_id": "test456", "word_count": 100}]
all_contents = ["chunk1"]
source_word_counts = {"test456": 100}
request = {}
await doc_storage._create_source_records(
all_metadatas,
all_contents,
source_word_counts,
request,
None,
None
)
# Verify error was logged
assert len(error_messages) == 1
assert "Failed to generate AI summary" in error_messages[0]
assert "AI service unavailable" in error_messages[0]
# Verify fallback summary was used
mock_update.assert_called_once()
call_args = mock_update.call_args
assert call_args.kwargs["summary"] == "Documentation from test456 - 1 pages crawled"
@pytest.mark.asyncio
async def test_multiple_sources_concurrent_summaries(self):
"""Test that multiple source summaries are generated concurrently."""
mock_supabase = Mock()
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Track concurrent executions
execution_order = []
def track_extract_summary(source_id, content):
execution_order.append(f"start_{source_id}")
time.sleep(0.05) # Simulate work
execution_order.append(f"end_{source_id}")
return f"Summary for {source_id}"
doc_storage.doc_storage_service.smart_chunk_text = Mock(
return_value=["chunk"]
)
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
side_effect=track_extract_summary):
with patch('src.server.services.crawling.document_storage_operations.update_source_info'):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
# Create metadata for multiple sources
all_metadatas = [
{"source_id": "source1", "word_count": 100},
{"source_id": "source2", "word_count": 150},
{"source_id": "source3", "word_count": 200},
]
all_contents = ["chunk1", "chunk2", "chunk3"]
source_word_counts = {
"source1": 100,
"source2": 150,
"source3": 200,
}
request = {}
await doc_storage._create_source_records(
all_metadatas,
all_contents,
source_word_counts,
request,
None,
None
)
# With threading, sources are processed sequentially in the loop
# but the extract_summary calls happen in threads
assert len(execution_order) == 6 # 3 sources * 2 events each
# Verify all sources were processed
processed_sources = set()
for event in execution_order:
if event.startswith("start_"):
processed_sources.add(event.replace("start_", ""))
assert processed_sources == {"source1", "source2", "source3"}
@pytest.mark.asyncio
async def test_thread_safety_with_variables(self):
"""Test that variables are properly passed to thread execution."""
mock_supabase = Mock()
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Track what gets passed to extract_summary
captured_calls = []
def capture_extract_summary(source_id, content):
captured_calls.append({
"source_id": source_id,
"content_len": len(content),
"content_preview": content[:50] if content else ""
})
return f"Summary for {source_id}"
doc_storage.doc_storage_service.smart_chunk_text = Mock(
return_value=["This is chunk one with some content",
"This is chunk two with more content"]
)
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
side_effect=capture_extract_summary):
with patch('src.server.services.crawling.document_storage_operations.update_source_info'):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
all_metadatas = [
{"source_id": "test789", "word_count": 100},
{"source_id": "test789", "word_count": 150},
]
all_contents = [
"This is chunk one with some content",
"This is chunk two with more content"
]
source_word_counts = {"test789": 250}
request = {}
await doc_storage._create_source_records(
all_metadatas,
all_contents,
source_word_counts,
request,
None,
None
)
# Verify the correct values were passed to the thread
assert len(captured_calls) == 1
call = captured_calls[0]
assert call["source_id"] == "test789"
assert call["content_len"] > 0
# Combined content should start with space + first chunk
assert "This is chunk one" in call["content_preview"]
@pytest.mark.asyncio
async def test_update_source_info_runs_in_thread(self):
"""Test that update_source_info is executed in a thread pool."""
mock_supabase = Mock()
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Track when update_source_info is called
update_call_times = []
def slow_update_source_info(**kwargs):
"""Simulate a slow synchronous database operation."""
update_call_times.append(time.time())
# Simulate a blocking database operation
time.sleep(0.1) # This would block the event loop if not run in thread
return None # update_source_info doesn't return anything
doc_storage.doc_storage_service.smart_chunk_text = Mock(
return_value=["chunk1"]
)
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
return_value="Test summary"):
with patch('src.server.services.crawling.document_storage_operations.update_source_info',
side_effect=slow_update_source_info):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error'):
all_metadatas = [{"source_id": "test_update", "word_count": 100}]
all_contents = ["chunk1"]
source_word_counts = {"test_update": 100}
request = {"knowledge_type": "documentation", "tags": ["test"]}
start_time = time.time()
# This should not block despite the sleep in update_source_info
await doc_storage._create_source_records(
all_metadatas,
all_contents,
source_word_counts,
request,
"https://example.com",
"Example Site"
)
end_time = time.time()
# Verify that update_source_info was called
assert len(update_call_times) == 1, "update_source_info should be called once"
# The async function should complete without blocking
total_time = end_time - start_time
assert total_time < 1.0, "Should complete in reasonable time"
@pytest.mark.asyncio
async def test_update_source_info_error_handling(self):
"""Test that errors in update_source_info trigger fallback correctly."""
mock_supabase = Mock()
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Mock to raise an exception
def failing_update_source_info(**kwargs):
raise RuntimeError("Database connection failed")
doc_storage.doc_storage_service.smart_chunk_text = Mock(
return_value=["chunk1"]
)
error_messages = []
fallback_called = False
def track_fallback_upsert(data):
nonlocal fallback_called
fallback_called = True
return Mock(execute=Mock())
mock_supabase.table.return_value.upsert.side_effect = track_fallback_upsert
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
return_value="Test summary"):
with patch('src.server.services.crawling.document_storage_operations.update_source_info',
side_effect=failing_update_source_info):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_error') as mock_error:
mock_error.side_effect = lambda msg: error_messages.append(msg)
all_metadatas = [{"source_id": "test_fail", "word_count": 100}]
all_contents = ["chunk1"]
source_word_counts = {"test_fail": 100}
request = {"knowledge_type": "technical", "tags": ["test"]}
await doc_storage._create_source_records(
all_metadatas,
all_contents,
source_word_counts,
request,
"https://example.com",
"Example Site"
)
# Verify error was logged
assert any("Failed to create/update source record" in msg for msg in error_messages)
assert any("Database connection failed" in msg for msg in error_messages)
# Verify fallback was attempted
assert fallback_called, "Fallback upsert should be called"
@pytest.mark.asyncio
async def test_update_source_info_preserves_kwargs(self):
"""Test that all kwargs are properly passed to update_source_info in thread."""
mock_supabase = Mock()
mock_supabase.table.return_value.upsert.return_value.execute.return_value = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Track what gets passed to update_source_info
captured_kwargs = {}
def capture_update_source_info(**kwargs):
captured_kwargs.update(kwargs)
return None
doc_storage.doc_storage_service.smart_chunk_text = Mock(
return_value=["chunk content"]
)
with patch('src.server.services.crawling.document_storage_operations.extract_source_summary',
return_value="Generated summary"):
with patch('src.server.services.crawling.document_storage_operations.update_source_info',
side_effect=capture_update_source_info):
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
all_metadatas = [{"source_id": "test_kwargs", "word_count": 250}]
all_contents = ["chunk content"]
source_word_counts = {"test_kwargs": 250}
request = {
"knowledge_type": "api_reference",
"tags": ["api", "docs"],
"url": "https://original.url/crawl"
}
await doc_storage._create_source_records(
all_metadatas,
all_contents,
source_word_counts,
request,
"https://source.url",
"Source Display Name"
)
# Verify all kwargs were passed correctly
assert captured_kwargs["client"] == mock_supabase
assert captured_kwargs["source_id"] == "test_kwargs"
assert captured_kwargs["summary"] == "Generated summary"
assert captured_kwargs["word_count"] == 250
assert "chunk content" in captured_kwargs["content"]
assert captured_kwargs["knowledge_type"] == "api_reference"
assert captured_kwargs["tags"] == ["api", "docs"]
assert captured_kwargs["update_frequency"] == 0
assert captured_kwargs["original_url"] == "https://original.url/crawl"
assert captured_kwargs["source_url"] == "https://source.url"
assert captured_kwargs["source_display_name"] == "Source Display Name"

View File

@@ -0,0 +1,184 @@
"""
Test that code extraction uses the correct source_id.
This test ensures that the fix for using hash-based source_ids
instead of domain-based source_ids works correctly.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from src.server.services.crawling.code_extraction_service import CodeExtractionService
from src.server.services.crawling.document_storage_operations import DocumentStorageOperations
class TestCodeExtractionSourceId:
"""Test that code extraction properly uses the provided source_id."""
@pytest.mark.asyncio
async def test_code_extraction_uses_provided_source_id(self):
"""Test that code extraction uses the hash-based source_id, not domain."""
# Create mock supabase client
mock_supabase = Mock()
mock_supabase.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
# Create service instance
code_service = CodeExtractionService(mock_supabase)
# Track what gets passed to the internal extraction method
extracted_blocks = []
async def mock_extract_blocks(crawl_results, source_id, progress_callback=None, start=0, end=100):
# Simulate finding code blocks and verify source_id is passed correctly
for doc in crawl_results:
extracted_blocks.append({
"block": {"code": "print('hello')", "language": "python"},
"source_url": doc["url"],
"source_id": source_id # This should be the provided source_id
})
return extracted_blocks
code_service._extract_code_blocks_from_documents = mock_extract_blocks
code_service._generate_code_summaries = AsyncMock(return_value=[{"summary": "Test code"}])
code_service._prepare_code_examples_for_storage = Mock(return_value=[
{"source_id": extracted_blocks[0]["source_id"] if extracted_blocks else None}
])
code_service._store_code_examples = AsyncMock(return_value=1)
# Test data
crawl_results = [
{
"url": "https://docs.mem0.ai/example",
"markdown": "```python\nprint('hello')\n```"
}
]
url_to_full_document = {
"https://docs.mem0.ai/example": "Full content with code"
}
# The correct hash-based source_id
correct_source_id = "393224e227ba92eb"
# Call the method with the correct source_id
result = await code_service.extract_and_store_code_examples(
crawl_results,
url_to_full_document,
correct_source_id,
None,
0,
100
)
# Verify that extracted blocks use the correct source_id
assert len(extracted_blocks) > 0, "Should have extracted at least one code block"
for block in extracted_blocks:
# Check that it's using the hash-based source_id, not the domain
assert block["source_id"] == correct_source_id, \
f"Should use hash-based source_id '{correct_source_id}', not domain"
assert block["source_id"] != "docs.mem0.ai", \
"Should NOT use domain-based source_id"
@pytest.mark.asyncio
async def test_document_storage_passes_source_id(self):
"""Test that DocumentStorageOperations passes source_id to code extraction."""
# Create mock supabase client
mock_supabase = Mock()
# Create DocumentStorageOperations instance
doc_storage = DocumentStorageOperations(mock_supabase)
# Mock the code extraction service
mock_extract = AsyncMock(return_value=5)
doc_storage.code_extraction_service.extract_and_store_code_examples = mock_extract
# Test data
crawl_results = [{"url": "https://example.com", "markdown": "test"}]
url_to_full_document = {"https://example.com": "test content"}
source_id = "abc123def456"
# Call the wrapper method
result = await doc_storage.extract_and_store_code_examples(
crawl_results,
url_to_full_document,
source_id,
None,
0,
100
)
# Verify the correct source_id was passed
mock_extract.assert_called_once_with(
crawl_results,
url_to_full_document,
source_id, # This should be the third argument
None,
0,
100
)
assert result == 5
@pytest.mark.asyncio
async def test_no_domain_extraction_from_url(self):
"""Test that we're NOT extracting domain from URL anymore."""
mock_supabase = Mock()
mock_supabase.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
code_service = CodeExtractionService(mock_supabase)
# Patch internal methods
code_service._get_setting = AsyncMock(return_value=True)
# Create a mock that will track what source_id is used
source_ids_seen = []
original_extract = code_service._extract_code_blocks_from_documents
async def track_source_id(crawl_results, source_id, progress_callback=None, start=0, end=100):
source_ids_seen.append(source_id)
return [] # Return empty list to skip further processing
code_service._extract_code_blocks_from_documents = track_source_id
# Test with various URLs that would produce different domains
test_cases = [
("https://github.com/example/repo", "github123abc"),
("https://docs.python.org/guide", "python456def"),
("https://api.openai.com/v1", "openai789ghi"),
]
for url, expected_source_id in test_cases:
source_ids_seen.clear()
crawl_results = [{"url": url, "markdown": "# Test"}]
url_to_full_document = {url: "Full content"}
await code_service.extract_and_store_code_examples(
crawl_results,
url_to_full_document,
expected_source_id,
None,
0,
100
)
# Verify the provided source_id was used
assert len(source_ids_seen) == 1
assert source_ids_seen[0] == expected_source_id
# Verify it's NOT the domain
assert "github.com" not in source_ids_seen[0]
assert "python.org" not in source_ids_seen[0]
assert "openai.com" not in source_ids_seen[0]
def test_urlparse_not_imported(self):
"""Test that urlparse is not imported in code_extraction_service."""
import src.server.services.crawling.code_extraction_service as module
# Check that urlparse is not in the module's namespace
assert not hasattr(module, 'urlparse'), \
"urlparse should not be imported in code_extraction_service"
# Check the module's actual imports
import inspect
source = inspect.getsource(module)
assert "from urllib.parse import urlparse" not in source, \
"Should not import urlparse since we don't extract domain from URL anymore"

View File

@@ -0,0 +1,205 @@
"""
Test document storage metrics calculation.
This test ensures that avg_chunks_per_doc is calculated correctly
and handles edge cases like empty documents.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from src.server.services.crawling.document_storage_operations import DocumentStorageOperations
class TestDocumentStorageMetrics:
"""Test metrics calculation in document storage operations."""
@pytest.mark.asyncio
async def test_avg_chunks_calculation_with_empty_docs(self):
"""Test that avg_chunks_per_doc handles empty documents correctly."""
# Create mock supabase client
mock_supabase = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Mock the storage service
doc_storage.doc_storage_service.smart_chunk_text = Mock(
side_effect=lambda text, chunk_size: ["chunk1", "chunk2"] if text else []
)
# Mock internal methods
doc_storage._create_source_records = AsyncMock()
# Track what gets logged
logged_messages = []
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log:
mock_log.side_effect = lambda msg: logged_messages.append(msg)
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
# Test data with mix of empty and non-empty documents
crawl_results = [
{"url": "https://example.com/page1", "markdown": "Content 1"},
{"url": "https://example.com/page2", "markdown": ""}, # Empty
{"url": "https://example.com/page3", "markdown": "Content 3"},
{"url": "https://example.com/page4", "markdown": ""}, # Empty
{"url": "https://example.com/page5", "markdown": "Content 5"},
]
result = await doc_storage.process_and_store_documents(
crawl_results=crawl_results,
request={},
crawl_type="test",
original_source_id="test123",
source_url="https://example.com",
source_display_name="Example"
)
# Find the metrics log message
metrics_log = None
for msg in logged_messages:
if "Document storage | processed=" in msg:
metrics_log = msg
break
assert metrics_log is not None, "Should log metrics"
# Verify metrics are correct
# 3 documents processed (non-empty), 5 total, 6 chunks (2 per doc), avg = 2.0
assert "processed=3/5" in metrics_log, "Should show 3 processed out of 5 total"
assert "chunks=6" in metrics_log, "Should have 6 chunks total"
assert "avg_chunks_per_doc=2.0" in metrics_log, "Average should be 2.0 (6/3)"
@pytest.mark.asyncio
async def test_avg_chunks_all_empty_docs(self):
"""Test that avg_chunks_per_doc handles all empty documents without division by zero."""
mock_supabase = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Mock the storage service
doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=[])
doc_storage._create_source_records = AsyncMock()
logged_messages = []
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log:
mock_log.side_effect = lambda msg: logged_messages.append(msg)
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
# All documents are empty
crawl_results = [
{"url": "https://example.com/page1", "markdown": ""},
{"url": "https://example.com/page2", "markdown": ""},
{"url": "https://example.com/page3", "markdown": ""},
]
result = await doc_storage.process_and_store_documents(
crawl_results=crawl_results,
request={},
crawl_type="test",
original_source_id="test456",
source_url="https://example.com",
source_display_name="Example"
)
# Find the metrics log
metrics_log = None
for msg in logged_messages:
if "Document storage | processed=" in msg:
metrics_log = msg
break
assert metrics_log is not None, "Should log metrics even with no processed docs"
# Should show 0 processed, 0 chunks, 0.0 average (no division by zero)
assert "processed=0/3" in metrics_log, "Should show 0 processed out of 3 total"
assert "chunks=0" in metrics_log, "Should have 0 chunks"
assert "avg_chunks_per_doc=0.0" in metrics_log, "Average should be 0.0 (no division by zero)"
@pytest.mark.asyncio
async def test_avg_chunks_single_doc(self):
"""Test avg_chunks_per_doc with a single document."""
mock_supabase = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Mock to return 5 chunks for content
doc_storage.doc_storage_service.smart_chunk_text = Mock(
return_value=["chunk1", "chunk2", "chunk3", "chunk4", "chunk5"]
)
doc_storage._create_source_records = AsyncMock()
logged_messages = []
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info') as mock_log:
mock_log.side_effect = lambda msg: logged_messages.append(msg)
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
crawl_results = [
{"url": "https://example.com/page", "markdown": "Long content here..."},
]
result = await doc_storage.process_and_store_documents(
crawl_results=crawl_results,
request={},
crawl_type="test",
original_source_id="test789",
source_url="https://example.com",
source_display_name="Example"
)
# Find metrics log
metrics_log = None
for msg in logged_messages:
if "Document storage | processed=" in msg:
metrics_log = msg
break
assert metrics_log is not None
assert "processed=1/1" in metrics_log, "Should show 1 processed out of 1 total"
assert "chunks=5" in metrics_log, "Should have 5 chunks"
assert "avg_chunks_per_doc=5.0" in metrics_log, "Average should be 5.0"
@pytest.mark.asyncio
async def test_processed_count_accuracy(self):
"""Test that processed_docs count is accurate."""
mock_supabase = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Track which documents are chunked
chunked_urls = []
def mock_chunk(text, chunk_size):
if text:
return ["chunk"]
return []
doc_storage.doc_storage_service.smart_chunk_text = Mock(side_effect=mock_chunk)
doc_storage._create_source_records = AsyncMock()
with patch('src.server.services.crawling.document_storage_operations.safe_logfire_info'):
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
# Mix of documents with various content states
crawl_results = [
{"url": "https://example.com/1", "markdown": "Content"},
{"url": "https://example.com/2", "markdown": ""}, # Empty markdown
{"url": "https://example.com/3", "markdown": None}, # None markdown
{"url": "https://example.com/4", "markdown": "More content"},
{"url": "https://example.com/5"}, # Missing markdown key
{"url": "https://example.com/6", "markdown": " "}, # Whitespace (counts as content)
]
result = await doc_storage.process_and_store_documents(
crawl_results=crawl_results,
request={},
crawl_type="test",
original_source_id="test999",
source_url="https://example.com",
source_display_name="Example"
)
# Should process documents 1, 4, and 6 (has content including whitespace)
assert result["chunk_count"] == 3, "Should have 3 chunks (one per processed doc)"
# Check url_to_full_document only has processed docs
assert len(result["url_to_full_document"]) == 3
assert "https://example.com/1" in result["url_to_full_document"]
assert "https://example.com/4" in result["url_to_full_document"]
assert "https://example.com/6" in result["url_to_full_document"]

View File

@@ -0,0 +1,357 @@
"""
Test Suite for Source ID Architecture Refactor
Tests the new unique source ID generation and display name extraction
to ensure the race condition fix works correctly.
"""
import time
from concurrent.futures import ThreadPoolExecutor
# Import the URLHandler class
from src.server.services.crawling.helpers.url_handler import URLHandler
class TestSourceIDGeneration:
"""Test the unique source ID generation."""
def test_unique_id_generation_basic(self):
"""Test basic unique ID generation."""
handler = URLHandler()
# Test various URLs
test_urls = [
"https://github.com/microsoft/typescript",
"https://github.com/facebook/react",
"https://docs.python.org/3/",
"https://fastapi.tiangolo.com/",
"https://pydantic.dev/",
]
source_ids = []
for url in test_urls:
source_id = handler.generate_unique_source_id(url)
source_ids.append(source_id)
# Check that ID is a 16-character hex string
assert len(source_id) == 16, f"ID should be 16 chars, got {len(source_id)}"
assert all(c in '0123456789abcdef' for c in source_id), f"ID should be hex: {source_id}"
# All IDs should be unique
assert len(set(source_ids)) == len(source_ids), "All source IDs should be unique"
def test_same_domain_different_ids(self):
"""Test that same domain with different paths generates different IDs."""
handler = URLHandler()
# Multiple GitHub repos (same domain, different paths)
github_urls = [
"https://github.com/owner1/repo1",
"https://github.com/owner1/repo2",
"https://github.com/owner2/repo1",
]
ids = [handler.generate_unique_source_id(url) for url in github_urls]
# All should be unique despite same domain
assert len(set(ids)) == len(ids), "Same domain should generate different IDs for different URLs"
def test_id_consistency(self):
"""Test that the same URL always generates the same ID."""
handler = URLHandler()
url = "https://github.com/microsoft/typescript"
# Generate ID multiple times
ids = [handler.generate_unique_source_id(url) for _ in range(5)]
# All should be identical
assert len(set(ids)) == 1, f"Same URL should always generate same ID, got: {set(ids)}"
assert ids[0] == ids[4], "First and last ID should match"
def test_url_normalization(self):
"""Test that URL normalization works correctly."""
handler = URLHandler()
# These should all generate the same ID (after normalization)
url_variations = [
"https://github.com/Microsoft/TypeScript",
"HTTPS://GITHUB.COM/MICROSOFT/TYPESCRIPT",
"https://GitHub.com/Microsoft/TypeScript",
]
ids = [handler.generate_unique_source_id(url) for url in url_variations]
# All normalized versions should generate the same ID
assert len(set(ids)) == 1, f"Normalized URLs should generate same ID, got: {set(ids)}"
def test_concurrent_crawl_simulation(self):
"""Simulate concurrent crawls to verify no race conditions."""
handler = URLHandler()
# URLs that would previously conflict
concurrent_urls = [
"https://github.com/coleam00/archon",
"https://github.com/microsoft/typescript",
"https://github.com/facebook/react",
"https://github.com/vercel/next.js",
"https://github.com/vuejs/vue",
]
def generate_id(url):
"""Simulate a crawl generating an ID."""
time.sleep(0.001) # Simulate some processing time
return handler.generate_unique_source_id(url)
# Run concurrent ID generation
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(generate_id, url) for url in concurrent_urls]
source_ids = [future.result() for future in futures]
# All IDs should be unique
assert len(set(source_ids)) == len(source_ids), "Concurrent crawls should generate unique IDs"
def test_error_handling(self):
"""Test error handling for edge cases."""
handler = URLHandler()
# Test various edge cases
edge_cases = [
"", # Empty string
"not-a-url", # Invalid URL
"https://", # Incomplete URL
None, # None should be handled gracefully in real code
]
for url in edge_cases:
if url is None:
continue # Skip None for this test
# Should not raise exception
source_id = handler.generate_unique_source_id(url)
assert source_id is not None, f"Should generate ID even for edge case: {url}"
assert len(source_id) == 16, f"Edge case should still generate 16-char ID: {url}"
class TestDisplayNameExtraction:
"""Test the human-readable display name extraction."""
def test_github_display_names(self):
"""Test GitHub repository display name extraction."""
handler = URLHandler()
test_cases = [
("https://github.com/microsoft/typescript", "GitHub - microsoft/typescript"),
("https://github.com/facebook/react", "GitHub - facebook/react"),
("https://github.com/vercel/next.js", "GitHub - vercel/next.js"),
("https://github.com/owner", "GitHub - owner"),
("https://github.com/", "GitHub"),
]
for url, expected in test_cases:
display_name = handler.extract_display_name(url)
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
def test_documentation_display_names(self):
"""Test documentation site display name extraction."""
handler = URLHandler()
test_cases = [
("https://docs.python.org/3/", "Python Documentation"),
("https://docs.djangoproject.com/", "Djangoproject Documentation"),
("https://fastapi.tiangolo.com/", "FastAPI Documentation"),
("https://pydantic.dev/", "Pydantic Documentation"),
("https://numpy.org/doc/", "NumPy Documentation"),
("https://pandas.pydata.org/", "Pandas Documentation"),
("https://project.readthedocs.io/", "Project Docs"),
]
for url, expected in test_cases:
display_name = handler.extract_display_name(url)
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
def test_api_display_names(self):
"""Test API endpoint display name extraction."""
handler = URLHandler()
test_cases = [
("https://api.github.com/", "GitHub API"),
("https://api.openai.com/v1/", "Openai API"),
("https://example.com/api/v2/", "Example"),
]
for url, expected in test_cases:
display_name = handler.extract_display_name(url)
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
def test_generic_display_names(self):
"""Test generic website display name extraction."""
handler = URLHandler()
test_cases = [
("https://example.com/", "Example"),
("https://my-site.org/", "My Site"),
("https://test_project.io/", "Test Project"),
("https://some.subdomain.example.com/", "Some Subdomain Example"),
]
for url, expected in test_cases:
display_name = handler.extract_display_name(url)
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
def test_edge_case_display_names(self):
"""Test edge cases for display name extraction."""
handler = URLHandler()
# Edge cases
test_cases = [
("", ""), # Empty URL
("not-a-url", "not-a-url"), # Invalid URL
("/local/file/path", "Local: path"), # Local file path
("https://", "https://"), # Incomplete URL
]
for url, expected_contains in test_cases:
display_name = handler.extract_display_name(url)
assert expected_contains in display_name or display_name == expected_contains, \
f"Edge case {url} handling failed: {display_name}"
def test_special_file_display_names(self):
"""Test that special files like llms.txt and sitemap.xml are properly displayed."""
handler = URLHandler()
test_cases = [
# llms.txt files
("https://docs.mem0.ai/llms-full.txt", "Mem0 - Llms.Txt"),
("https://example.com/llms.txt", "Example - Llms.Txt"),
("https://api.example.com/llms.txt", "Example API"), # API takes precedence
# sitemap.xml files
("https://mem0.ai/sitemap.xml", "Mem0 - Sitemap.Xml"),
("https://docs.example.com/sitemap.xml", "Example - Sitemap.Xml"),
("https://example.org/sitemap.xml", "Example - Sitemap.Xml"),
# Regular .txt files on docs sites
("https://docs.example.com/readme.txt", "Example - Readme.Txt"),
# Non-special files should not get special treatment
("https://docs.example.com/guide", "Example Documentation"),
("https://example.com/page.html", "Example - Page.Html"), # Path gets added for single file
]
for url, expected in test_cases:
display_name = handler.extract_display_name(url)
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
def test_git_extension_removal(self):
"""Test that .git extension is removed from GitHub repos."""
handler = URLHandler()
test_cases = [
("https://github.com/owner/repo.git", "GitHub - owner/repo"),
("https://github.com/owner/repo", "GitHub - owner/repo"),
]
for url, expected in test_cases:
display_name = handler.extract_display_name(url)
assert display_name == expected, f"URL {url} should display as '{expected}', got '{display_name}'"
class TestRaceConditionFix:
"""Test that the race condition is actually fixed."""
def test_no_domain_conflicts(self):
"""Test that multiple sources from same domain don't conflict."""
handler = URLHandler()
# These would all have source_id = "github.com" in the old system
github_urls = [
"https://github.com/microsoft/typescript",
"https://github.com/microsoft/vscode",
"https://github.com/facebook/react",
"https://github.com/vercel/next.js",
"https://github.com/vuejs/vue",
]
source_ids = [handler.generate_unique_source_id(url) for url in github_urls]
# All should be unique
assert len(set(source_ids)) == len(source_ids), \
"Race condition not fixed: duplicate source IDs for same domain"
# None should be just "github.com"
for source_id in source_ids:
assert source_id != "github.com", \
"Source ID should not be just the domain"
def test_hash_properties(self):
"""Test that the hash has good properties."""
handler = URLHandler()
# Similar URLs should still generate very different hashes
url1 = "https://github.com/owner/repo1"
url2 = "https://github.com/owner/repo2" # Only differs by one character
id1 = handler.generate_unique_source_id(url1)
id2 = handler.generate_unique_source_id(url2)
# IDs should be completely different (good hash distribution)
matching_chars = sum(1 for a, b in zip(id1, id2) if a == b)
assert matching_chars < 8, \
f"Similar URLs should generate very different hashes, {matching_chars}/16 chars match"
class TestIntegration:
"""Integration tests for the complete source ID system."""
def test_full_source_creation_flow(self):
"""Test the complete flow of creating a source with all fields."""
handler = URLHandler()
url = "https://github.com/microsoft/typescript"
# Generate all source fields
source_id = handler.generate_unique_source_id(url)
source_display_name = handler.extract_display_name(url)
source_url = url
# Verify all fields are populated correctly
assert len(source_id) == 16, "Source ID should be 16 characters"
assert source_display_name == "GitHub - microsoft/typescript", \
f"Display name incorrect: {source_display_name}"
assert source_url == url, "Source URL should match original"
# Simulate database record
source_record = {
'source_id': source_id,
'source_url': source_url,
'source_display_name': source_display_name,
'title': None, # Generated later
'summary': None, # Generated later
'metadata': {}
}
# Verify record structure
assert 'source_id' in source_record
assert 'source_url' in source_record
assert 'source_display_name' in source_record
def test_backward_compatibility(self):
"""Test that the system handles existing sources gracefully."""
handler = URLHandler()
# Simulate an existing source with old-style source_id
existing_source = {
'source_id': 'github.com', # Old style - just domain
'source_url': None, # Not populated in old system
'source_display_name': None, # Not populated in old system
}
# The migration should handle this by backfilling
# source_url and source_display_name with source_id value
migrated_source = {
'source_id': 'github.com',
'source_url': 'github.com', # Backfilled
'source_display_name': 'github.com', # Backfilled
}
assert migrated_source['source_url'] is not None
assert migrated_source['source_display_name'] is not None

View File

@@ -0,0 +1,269 @@
"""
Test race condition handling in source creation.
This test ensures that concurrent source creation attempts
don't fail with PRIMARY KEY violations.
"""
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch
import pytest
from src.server.services.source_management_service import update_source_info
class TestSourceRaceCondition:
"""Test that concurrent source creation handles race conditions properly."""
def test_concurrent_source_creation_no_race(self):
"""Test that concurrent attempts to create the same source don't fail."""
# Track successful operations
successful_creates = []
failed_creates = []
def mock_execute():
"""Mock execute that simulates database operation."""
return Mock(data=[])
def track_upsert(data):
"""Track upsert calls."""
successful_creates.append(data["source_id"])
return Mock(execute=mock_execute)
# Mock Supabase client
mock_client = Mock()
# Mock the SELECT (existing source check) - always returns empty
mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
# Mock the UPSERT operation
mock_client.table.return_value.upsert = track_upsert
def create_source(thread_id):
"""Simulate creating a source from a thread."""
try:
update_source_info(
client=mock_client,
source_id="test_source_123",
summary=f"Summary from thread {thread_id}",
word_count=100,
content=f"Content from thread {thread_id}",
knowledge_type="documentation",
tags=["test"],
update_frequency=0,
source_url="https://example.com",
source_display_name=f"Example Site {thread_id}" # Will be used as title
)
except Exception as e:
failed_creates.append((thread_id, str(e)))
# Run 5 threads concurrently trying to create the same source
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for i in range(5):
futures.append(executor.submit(create_source, i))
# Wait for all to complete
for future in futures:
future.result()
# All should succeed (no failures due to PRIMARY KEY violation)
assert len(failed_creates) == 0, f"Some creates failed: {failed_creates}"
assert len(successful_creates) == 5, "All 5 attempts should succeed"
assert all(sid == "test_source_123" for sid in successful_creates)
def test_upsert_vs_insert_behavior(self):
"""Test that upsert is used instead of insert for new sources."""
mock_client = Mock()
# Track which method is called
methods_called = []
def track_insert(data):
methods_called.append("insert")
# Simulate PRIMARY KEY violation
raise Exception("duplicate key value violates unique constraint")
def track_upsert(data):
methods_called.append("upsert")
return Mock(execute=Mock(return_value=Mock(data=[])))
# Source doesn't exist
mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
# Set up mocks
mock_client.table.return_value.insert = track_insert
mock_client.table.return_value.upsert = track_upsert
update_source_info(
client=mock_client,
source_id="new_source",
summary="Test summary",
word_count=100,
content="Test content",
knowledge_type="documentation",
source_display_name="Test Display Name" # Will be used as title
)
# Should use upsert, not insert
assert "upsert" in methods_called, "Should use upsert for new sources"
assert "insert" not in methods_called, "Should not use insert to avoid race conditions"
def test_existing_source_uses_update(self):
"""Test that existing sources still use UPDATE (not affected by change)."""
mock_client = Mock()
methods_called = []
def track_update(data):
methods_called.append("update")
return Mock(eq=Mock(return_value=Mock(execute=Mock(return_value=Mock(data=[])))))
def track_upsert(data):
methods_called.append("upsert")
return Mock(execute=Mock(return_value=Mock(data=[])))
# Source exists
existing_source = {
"source_id": "existing_source",
"title": "Existing Title",
"metadata": {"knowledge_type": "api"}
}
mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = [existing_source]
# Set up mocks
mock_client.table.return_value.update = track_update
mock_client.table.return_value.upsert = track_upsert
update_source_info(
client=mock_client,
source_id="existing_source",
summary="Updated summary",
word_count=200,
content="Updated content",
knowledge_type="documentation"
)
# Should use update for existing sources
assert "update" in methods_called, "Should use update for existing sources"
assert "upsert" not in methods_called, "Should not use upsert for existing sources"
@pytest.mark.asyncio
async def test_async_concurrent_creation(self):
"""Test concurrent source creation in async context."""
mock_client = Mock()
# Track operations
operations = []
def track_upsert(data):
operations.append(("upsert", data["source_id"]))
return Mock(execute=Mock(return_value=Mock(data=[])))
# No existing sources
mock_client.table.return_value.select.return_value.eq.return_value.execute.return_value.data = []
mock_client.table.return_value.upsert = track_upsert
async def create_source_async(task_id):
"""Async wrapper for source creation."""
await asyncio.to_thread(
update_source_info,
client=mock_client,
source_id=f"async_source_{task_id % 2}", # Only 2 unique sources
summary=f"Summary {task_id}",
word_count=100,
content=f"Content {task_id}",
knowledge_type="documentation"
)
# Create 10 tasks, but only 2 unique source_ids
tasks = [create_source_async(i) for i in range(10)]
await asyncio.gather(*tasks)
# All operations should succeed
assert len(operations) == 10, "All 10 operations should complete"
# Check that we tried to upsert the two sources multiple times
source_0_count = sum(1 for op, sid in operations if sid == "async_source_0")
source_1_count = sum(1 for op, sid in operations if sid == "async_source_1")
assert source_0_count == 5, "async_source_0 should be upserted 5 times"
assert source_1_count == 5, "async_source_1 should be upserted 5 times"
def test_race_condition_with_delay(self):
"""Test race condition with simulated delay between check and create."""
import time
mock_client = Mock()
# Track timing of operations
check_times = []
create_times = []
source_created = threading.Event()
def delayed_select(*args):
"""Return a mock that simulates SELECT with delay."""
mock_select = Mock()
def eq_mock(*args):
mock_eq = Mock()
mock_eq.execute = lambda: delayed_check()
return mock_eq
mock_select.eq = eq_mock
return mock_select
def delayed_check():
"""Simulate SELECT execution with delay."""
check_times.append(time.time())
result = Mock()
# First thread doesn't see the source
if not source_created.is_set():
time.sleep(0.01) # Small delay to let both threads check
result.data = []
else:
# Subsequent checks would see it (but we use upsert so this doesn't matter)
result.data = [{"source_id": "race_source", "title": "Existing", "metadata": {}}]
return result
def track_upsert(data):
"""Track upsert and set event."""
create_times.append(time.time())
source_created.set()
return Mock(execute=Mock(return_value=Mock(data=[])))
# Set up table mock to return our custom select mock
mock_client.table.return_value.select = delayed_select
mock_client.table.return_value.upsert = track_upsert
errors = []
def create_with_error_tracking(thread_id):
try:
update_source_info(
client=mock_client,
source_id="race_source",
summary="Race summary",
word_count=100,
content="Race content",
knowledge_type="documentation",
source_display_name="Race Display Name" # Will be used as title
)
except Exception as e:
errors.append((thread_id, str(e)))
# Run 2 threads that will both check before either creates
with ThreadPoolExecutor(max_workers=2) as executor:
futures = [
executor.submit(create_with_error_tracking, 1),
executor.submit(create_with_error_tracking, 2)
]
for future in futures:
future.result()
# Both should succeed with upsert (no errors)
assert len(errors) == 0, f"No errors should occur with upsert: {errors}"
assert len(check_times) == 2, "Both threads should check"
assert len(create_times) == 2, "Both threads should attempt create/upsert"

View File

@@ -0,0 +1,124 @@
"""
Test that source_url parameter is not shadowed by document URLs.
This test ensures that the original crawl URL (e.g., sitemap URL)
is correctly passed to _create_source_records and not overwritten
by individual document URLs during processing.
"""
import pytest
from unittest.mock import Mock, AsyncMock, MagicMock, patch
from src.server.services.crawling.document_storage_operations import DocumentStorageOperations
class TestSourceUrlShadowing:
"""Test that source_url parameter is preserved correctly."""
@pytest.mark.asyncio
async def test_source_url_not_shadowed(self):
"""Test that the original source_url is passed to _create_source_records."""
# Create mock supabase client
mock_supabase = Mock()
# Create DocumentStorageOperations instance
doc_storage = DocumentStorageOperations(mock_supabase)
# Mock the storage service
doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=["chunk1", "chunk2"])
# Track what gets passed to _create_source_records
captured_source_url = None
async def mock_create_source_records(all_metadatas, all_contents, source_word_counts,
request, source_url, source_display_name):
nonlocal captured_source_url
captured_source_url = source_url
doc_storage._create_source_records = mock_create_source_records
# Mock add_documents_to_supabase
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase') as mock_add:
mock_add.return_value = None
# Test data - simulating a sitemap crawl
original_source_url = "https://mem0.ai/sitemap.xml"
crawl_results = [
{
"url": "https://mem0.ai/page1",
"markdown": "Content of page 1",
"title": "Page 1"
},
{
"url": "https://mem0.ai/page2",
"markdown": "Content of page 2",
"title": "Page 2"
},
{
"url": "https://mem0.ai/models/openai-o3", # Last document URL
"markdown": "Content of models page",
"title": "Models"
}
]
request = {"knowledge_type": "documentation", "tags": []}
# Call the method
result = await doc_storage.process_and_store_documents(
crawl_results=crawl_results,
request=request,
crawl_type="sitemap",
original_source_id="test123",
progress_callback=None,
cancellation_check=None,
source_url=original_source_url, # This should NOT be overwritten
source_display_name="Test Sitemap"
)
# Verify the original source_url was preserved
assert captured_source_url == original_source_url, \
f"source_url should be '{original_source_url}', not '{captured_source_url}'"
# Verify it's NOT the last document's URL
assert captured_source_url != "https://mem0.ai/models/openai-o3", \
"source_url should NOT be overwritten with the last document's URL"
# Verify url_to_full_document has correct URLs
assert "https://mem0.ai/page1" in result["url_to_full_document"]
assert "https://mem0.ai/page2" in result["url_to_full_document"]
assert "https://mem0.ai/models/openai-o3" in result["url_to_full_document"]
@pytest.mark.asyncio
async def test_metadata_uses_document_urls(self):
"""Test that metadata correctly uses individual document URLs."""
mock_supabase = Mock()
doc_storage = DocumentStorageOperations(mock_supabase)
# Mock the storage service
doc_storage.doc_storage_service.smart_chunk_text = Mock(return_value=["chunk1"])
# Capture metadata
captured_metadatas = None
async def mock_create_source_records(all_metadatas, all_contents, source_word_counts,
request, source_url, source_display_name):
nonlocal captured_metadatas
captured_metadatas = all_metadatas
doc_storage._create_source_records = mock_create_source_records
with patch('src.server.services.crawling.document_storage_operations.add_documents_to_supabase'):
crawl_results = [
{"url": "https://example.com/doc1", "markdown": "Doc 1"},
{"url": "https://example.com/doc2", "markdown": "Doc 2"}
]
await doc_storage.process_and_store_documents(
crawl_results=crawl_results,
request={},
crawl_type="normal",
original_source_id="test456",
source_url="https://example.com",
source_display_name="Example"
)
# Each metadata should have the correct document URL
assert captured_metadatas[0]["url"] == "https://example.com/doc1"
assert captured_metadatas[1]["url"] == "https://example.com/doc2"

View File

@@ -0,0 +1,222 @@
"""
Test URL canonicalization in source ID generation.
This test ensures that URLs are properly normalized before hashing
to prevent duplicate sources from URL variations.
"""
import pytest
from src.server.services.crawling.helpers.url_handler import URLHandler
class TestURLCanonicalization:
"""Test that URL canonicalization works correctly for source ID generation."""
def test_trailing_slash_normalization(self):
"""Test that trailing slashes are handled consistently."""
handler = URLHandler()
# These should generate the same ID
url1 = "https://example.com/path"
url2 = "https://example.com/path/"
id1 = handler.generate_unique_source_id(url1)
id2 = handler.generate_unique_source_id(url2)
assert id1 == id2, "URLs with/without trailing slash should generate same ID"
# Root path should keep its slash
root1 = "https://example.com"
root2 = "https://example.com/"
root_id1 = handler.generate_unique_source_id(root1)
root_id2 = handler.generate_unique_source_id(root2)
# These should be the same (both normalize to https://example.com/)
assert root_id1 == root_id2, "Root URLs should normalize consistently"
def test_fragment_removal(self):
"""Test that URL fragments are removed."""
handler = URLHandler()
urls = [
"https://example.com/page",
"https://example.com/page#section1",
"https://example.com/page#section2",
"https://example.com/page#",
]
ids = [handler.generate_unique_source_id(url) for url in urls]
# All should generate the same ID
assert len(set(ids)) == 1, "URLs with different fragments should generate same ID"
def test_tracking_param_removal(self):
"""Test that tracking parameters are removed."""
handler = URLHandler()
# URL without tracking params
clean_url = "https://example.com/page?important=value"
# URLs with various tracking params
tracked_urls = [
"https://example.com/page?important=value&utm_source=google",
"https://example.com/page?utm_medium=email&important=value",
"https://example.com/page?important=value&fbclid=abc123",
"https://example.com/page?gclid=xyz&important=value&utm_campaign=test",
"https://example.com/page?important=value&ref=homepage",
"https://example.com/page?source=newsletter&important=value",
]
clean_id = handler.generate_unique_source_id(clean_url)
tracked_ids = [handler.generate_unique_source_id(url) for url in tracked_urls]
# All tracked URLs should generate the same ID as the clean URL
for tracked_id in tracked_ids:
assert tracked_id == clean_id, "URLs with tracking params should match clean URL"
def test_query_param_sorting(self):
"""Test that query parameters are sorted for consistency."""
handler = URLHandler()
urls = [
"https://example.com/page?a=1&b=2&c=3",
"https://example.com/page?c=3&a=1&b=2",
"https://example.com/page?b=2&c=3&a=1",
]
ids = [handler.generate_unique_source_id(url) for url in urls]
# All should generate the same ID
assert len(set(ids)) == 1, "URLs with reordered query params should generate same ID"
def test_default_port_removal(self):
"""Test that default ports are removed."""
handler = URLHandler()
# HTTP default port (80)
http_urls = [
"http://example.com/page",
"http://example.com:80/page",
]
http_ids = [handler.generate_unique_source_id(url) for url in http_urls]
assert len(set(http_ids)) == 1, "HTTP URLs with/without :80 should generate same ID"
# HTTPS default port (443)
https_urls = [
"https://example.com/page",
"https://example.com:443/page",
]
https_ids = [handler.generate_unique_source_id(url) for url in https_urls]
assert len(set(https_ids)) == 1, "HTTPS URLs with/without :443 should generate same ID"
# Non-default ports should be preserved
url1 = "https://example.com:8080/page"
url2 = "https://example.com:9090/page"
id1 = handler.generate_unique_source_id(url1)
id2 = handler.generate_unique_source_id(url2)
assert id1 != id2, "URLs with different non-default ports should generate different IDs"
def test_case_normalization(self):
"""Test that scheme and domain are lowercased."""
handler = URLHandler()
urls = [
"https://example.com/Path/To/Page",
"HTTPS://EXAMPLE.COM/Path/To/Page",
"https://Example.Com/Path/To/Page",
"HTTPs://example.COM/Path/To/Page",
]
ids = [handler.generate_unique_source_id(url) for url in urls]
# All should generate the same ID (path case is preserved)
assert len(set(ids)) == 1, "URLs with different case in scheme/domain should generate same ID"
# But different paths should generate different IDs
path_urls = [
"https://example.com/path",
"https://example.com/Path",
"https://example.com/PATH",
]
path_ids = [handler.generate_unique_source_id(url) for url in path_urls]
# These should be different (path case matters)
assert len(set(path_ids)) == 3, "URLs with different path case should generate different IDs"
def test_complex_canonicalization(self):
"""Test complex URL with multiple normalizations needed."""
handler = URLHandler()
urls = [
"https://example.com/page",
"HTTPS://EXAMPLE.COM:443/page/",
"https://Example.com/page#section",
"https://example.com/page/?utm_source=test",
"https://example.com:443/page?utm_campaign=abc#footer",
]
ids = [handler.generate_unique_source_id(url) for url in urls]
# All should generate the same ID
assert len(set(ids)) == 1, "Complex URLs should normalize to same ID"
def test_edge_cases(self):
"""Test edge cases and error handling."""
handler = URLHandler()
# Empty URL
empty_id = handler.generate_unique_source_id("")
assert len(empty_id) == 16, "Empty URL should still generate valid ID"
# Invalid URL
invalid_id = handler.generate_unique_source_id("not-a-url")
assert len(invalid_id) == 16, "Invalid URL should still generate valid ID"
# URL with special characters
special_url = "https://example.com/page?key=value%20with%20spaces"
special_id = handler.generate_unique_source_id(special_url)
assert len(special_id) == 16, "URL with encoded chars should generate valid ID"
# Very long URL
long_url = "https://example.com/" + "a" * 1000
long_id = handler.generate_unique_source_id(long_url)
assert len(long_id) == 16, "Long URL should generate valid ID"
def test_preserves_important_params(self):
"""Test that non-tracking params are preserved."""
handler = URLHandler()
# These have different important params, should be different
url1 = "https://api.example.com/v1/users?page=1"
url2 = "https://api.example.com/v1/users?page=2"
id1 = handler.generate_unique_source_id(url1)
id2 = handler.generate_unique_source_id(url2)
assert id1 != id2, "URLs with different important params should generate different IDs"
# But tracking params should still be removed
url3 = "https://api.example.com/v1/users?page=1&utm_source=docs"
id3 = handler.generate_unique_source_id(url3)
assert id3 == id1, "Adding tracking params shouldn't change ID"
def test_local_file_paths(self):
"""Test handling of local file paths."""
handler = URLHandler()
# File URLs
file_url = "file:///Users/test/document.pdf"
file_id = handler.generate_unique_source_id(file_url)
assert len(file_id) == 16, "File URL should generate valid ID"
# Relative paths
relative_path = "../documents/file.txt"
relative_id = handler.generate_unique_source_id(relative_path)
assert len(relative_id) == 16, "Relative path should generate valid ID"