mirror of
https://github.com/coleam00/Archon.git
synced 2025-12-30 21:49:30 -05:00
feat: Add advanced web crawling with domain filtering
- Implement domain filtering for web crawler with whitelist/blacklist support - Add URL pattern matching (glob-style) for include/exclude patterns - Create AdvancedCrawlConfig UI component with collapsible panel - Add domain filter to Knowledge Inspector sidebar for easy filtering - Implement crawl-v2 API endpoint with backward compatibility - Add comprehensive unit tests for domain filtering logic Implements priority-based filtering: 1. Blacklist (excluded_domains) - highest priority 2. Whitelist (allowed_domains) - must match if provided 3. Exclude patterns - glob patterns to exclude 4. Include patterns - glob patterns to include UI improvements: - Advanced configuration section in Add Knowledge dialog - Domain pills in Inspector sidebar showing document distribution - Visual domain indicators on each document - Responsive domain filtering with document counts 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from ..services.search.rag_service import RAGService
|
||||
from ..services.storage import DocumentStorageService
|
||||
from ..utils import get_supabase_client
|
||||
from ..utils.document_processing import extract_text_from_document
|
||||
from ..utils.progress.progress_tracker import ProgressTracker
|
||||
|
||||
# Get logger for this module
|
||||
logger = get_logger(__name__)
|
||||
@@ -855,6 +856,135 @@ async def _perform_crawl_with_progress(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/knowledge-items/crawl-v2")
|
||||
async def crawl_knowledge_item_v2(request: dict):
|
||||
"""
|
||||
Crawl a URL with advanced domain filtering configuration.
|
||||
|
||||
This is version 2 of the crawl endpoint that supports domain filtering.
|
||||
"""
|
||||
# Import CrawlRequestV2 model
|
||||
from ..models.crawl_models import CrawlRequestV2, CrawlConfig
|
||||
|
||||
# Parse and validate request
|
||||
crawl_request = CrawlRequestV2(**request)
|
||||
|
||||
# Validate API key before starting expensive operation
|
||||
logger.info("🔍 About to validate API key for crawl-v2...")
|
||||
provider_config = await credential_service.get_active_provider("embedding")
|
||||
provider = provider_config.get("provider", "openai")
|
||||
await _validate_provider_api_key(provider)
|
||||
logger.info("✅ API key validation completed successfully")
|
||||
|
||||
try:
|
||||
safe_logfire_info(
|
||||
f"Starting knowledge item crawl v2 | url={crawl_request.url} | "
|
||||
f"knowledge_type={crawl_request.knowledge_type} | "
|
||||
f"has_crawl_config={crawl_request.crawl_config is not None}"
|
||||
)
|
||||
|
||||
# Generate unique progress ID
|
||||
progress_id = str(uuid.uuid4())
|
||||
|
||||
# Create progress tracker for HTTP polling
|
||||
tracker = ProgressTracker(progress_id, operation_type="crawl")
|
||||
await tracker.start({
|
||||
"status": "starting",
|
||||
"url": crawl_request.url,
|
||||
"has_filters": crawl_request.crawl_config is not None
|
||||
})
|
||||
|
||||
# Create async task for crawling
|
||||
crawl_task = asyncio.create_task(_run_crawl_v2(request_dict=crawl_request.dict(), progress_id=progress_id))
|
||||
active_crawl_tasks[progress_id] = crawl_task
|
||||
|
||||
safe_logfire_info(
|
||||
f"Crawl v2 task created | progress_id={progress_id} | url={crawl_request.url}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"progressId": progress_id,
|
||||
"message": "Crawl started with domain filtering",
|
||||
"estimatedDuration": "2-10 minutes depending on site size"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
safe_logfire_error(f"Failed to start crawl v2 | error={str(e)}")
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
async def _run_crawl_v2(request_dict: dict, progress_id: str):
|
||||
"""Run the crawl v2 with domain filtering in background."""
|
||||
tracker = ProgressTracker(progress_id, operation_type="crawl")
|
||||
|
||||
try:
|
||||
safe_logfire_info(
|
||||
f"Starting crawl v2 with progress tracking | progress_id={progress_id} | url={request_dict['url']}"
|
||||
)
|
||||
|
||||
# Get crawler from CrawlerManager
|
||||
try:
|
||||
crawler = await get_crawler()
|
||||
if crawler is None:
|
||||
raise Exception("Crawler not available - initialization may have failed")
|
||||
except Exception as e:
|
||||
safe_logfire_error(f"Failed to get crawler | error={str(e)}")
|
||||
await tracker.error(f"Failed to initialize crawler: {str(e)}")
|
||||
return
|
||||
|
||||
supabase_client = get_supabase_client()
|
||||
|
||||
# Extract crawl_config if present
|
||||
crawl_config_dict = request_dict.get("crawl_config")
|
||||
crawl_config = None
|
||||
if crawl_config_dict:
|
||||
from ..models.crawl_models import CrawlConfig
|
||||
crawl_config = CrawlConfig(**crawl_config_dict)
|
||||
|
||||
# Create orchestration service with crawl_config
|
||||
orchestration_service = CrawlingService(
|
||||
crawler,
|
||||
supabase_client,
|
||||
crawl_config=crawl_config
|
||||
)
|
||||
orchestration_service.set_progress_id(progress_id)
|
||||
|
||||
# Add crawl_config to metadata for storage
|
||||
if crawl_config:
|
||||
request_dict["metadata"] = request_dict.get("metadata", {})
|
||||
request_dict["metadata"]["crawl_config"] = crawl_config.dict()
|
||||
|
||||
# Orchestrate the crawl - this returns immediately with task info
|
||||
result = await orchestration_service.orchestrate_crawl(request_dict)
|
||||
|
||||
# Store the actual crawl task for proper cancellation
|
||||
crawl_task = result.get("task")
|
||||
if crawl_task:
|
||||
active_crawl_tasks[progress_id] = crawl_task
|
||||
safe_logfire_info(
|
||||
f"Stored actual crawl v2 task in active_crawl_tasks | progress_id={progress_id}"
|
||||
)
|
||||
else:
|
||||
safe_logfire_error(f"No task returned from orchestrate_crawl v2 | progress_id={progress_id}")
|
||||
|
||||
safe_logfire_info(
|
||||
f"Crawl v2 task started | progress_id={progress_id} | task_id={result.get('task_id')}"
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
safe_logfire_info(f"Crawl v2 cancelled | progress_id={progress_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
safe_logfire_error(f"Crawl v2 task failed | progress_id={progress_id} | error={str(e)}")
|
||||
await tracker.error(str(e))
|
||||
finally:
|
||||
# Clean up task from registry when done
|
||||
if progress_id in active_crawl_tasks:
|
||||
del active_crawl_tasks[progress_id]
|
||||
safe_logfire_info(f"Cleaned up crawl v2 task from registry | progress_id={progress_id}")
|
||||
|
||||
|
||||
@router.post("/documents/upload")
|
||||
async def upload_document(
|
||||
file: UploadFile = File(...),
|
||||
|
||||
0
python/src/server/models/__init__.py
Normal file
0
python/src/server/models/__init__.py
Normal file
63
python/src/server/models/crawl_models.py
Normal file
63
python/src/server/models/crawl_models.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Crawling Models Module
|
||||
|
||||
This module contains Pydantic models for crawling configuration,
|
||||
specifically for domain filtering and URL pattern matching.
|
||||
"""
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class CrawlConfig(BaseModel):
|
||||
"""Configuration for domain filtering during crawl."""
|
||||
|
||||
allowed_domains: list[str] | None = Field(None, description="Whitelist of domains to crawl")
|
||||
excluded_domains: list[str] | None = Field(None, description="Blacklist of domains to exclude")
|
||||
include_patterns: list[str] | None = Field(None, description="URL patterns to include (glob-style)")
|
||||
exclude_patterns: list[str] | None = Field(None, description="URL patterns to exclude (glob-style)")
|
||||
|
||||
@validator("allowed_domains", "excluded_domains", pre=True)
|
||||
def normalize_domains(cls, v):
|
||||
"""Normalize domain formats for consistent matching."""
|
||||
if v is None:
|
||||
return v
|
||||
return [d.lower().strip().replace("http://", "").replace("https://", "").rstrip("/") for d in v]
|
||||
|
||||
@validator("include_patterns", "exclude_patterns", pre=True)
|
||||
def validate_patterns(cls, v):
|
||||
"""Validate URL patterns are valid glob patterns."""
|
||||
if v is None:
|
||||
return v
|
||||
# Ensure patterns are strings and not empty
|
||||
return [p.strip() for p in v if p and isinstance(p, str) and p.strip()]
|
||||
|
||||
|
||||
class CrawlRequestV2(BaseModel):
|
||||
"""Extended crawl request with domain filtering."""
|
||||
|
||||
url: str = Field(..., description="URL to start crawling from")
|
||||
knowledge_type: str | None = Field("technical", description="Type of knowledge (technical/business)")
|
||||
tags: list[str] | None = Field(default_factory=list, description="Tags to apply to crawled content")
|
||||
update_frequency: int | None = Field(None, description="Update frequency in days")
|
||||
max_depth: int | None = Field(3, description="Maximum crawl depth")
|
||||
crawl_config: CrawlConfig | None = Field(None, description="Domain filtering configuration")
|
||||
crawl_options: dict | None = Field(None, description="Additional crawl options")
|
||||
extract_code_examples: bool | None = Field(True, description="Whether to extract code examples")
|
||||
|
||||
@validator("url")
|
||||
def validate_url(cls, v):
|
||||
"""Ensure URL is properly formatted."""
|
||||
if not v or not v.strip():
|
||||
raise ValueError("URL cannot be empty")
|
||||
# Add http:// if no protocol specified
|
||||
if not v.startswith(("http://", "https://")):
|
||||
v = f"https://{v}"
|
||||
return v.strip()
|
||||
|
||||
@validator("knowledge_type")
|
||||
def validate_knowledge_type(cls, v):
|
||||
"""Ensure knowledge type is valid."""
|
||||
if v and v not in ["technical", "business"]:
|
||||
return "technical" # Default to technical if invalid
|
||||
return v or "technical"
|
||||
@@ -12,12 +12,14 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from ...config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info
|
||||
from ...models.crawl_models import CrawlConfig
|
||||
from ...utils import get_supabase_client
|
||||
from ...utils.progress.progress_tracker import ProgressTracker
|
||||
|
||||
# Import strategies
|
||||
# Import operations
|
||||
from .document_storage_operations import DocumentStorageOperations
|
||||
from .domain_filter import DomainFilter
|
||||
from .helpers.site_config import SiteConfig
|
||||
|
||||
# Import helpers
|
||||
@@ -56,7 +58,7 @@ class CrawlingService:
|
||||
Combines functionality from both CrawlingService and CrawlOrchestrationService.
|
||||
"""
|
||||
|
||||
def __init__(self, crawler=None, supabase_client=None, progress_id=None):
|
||||
def __init__(self, crawler=None, supabase_client=None, progress_id=None, crawl_config=None):
|
||||
"""
|
||||
Initialize the crawling service.
|
||||
|
||||
@@ -64,21 +66,24 @@ class CrawlingService:
|
||||
crawler: The Crawl4AI crawler instance
|
||||
supabase_client: The Supabase client for database operations
|
||||
progress_id: Optional progress ID for HTTP polling updates
|
||||
crawl_config: Optional CrawlConfig for domain filtering
|
||||
"""
|
||||
self.crawler = crawler
|
||||
self.supabase_client = supabase_client or get_supabase_client()
|
||||
self.progress_id = progress_id
|
||||
self.progress_tracker = None
|
||||
self.crawl_config = crawl_config
|
||||
|
||||
# Initialize helpers
|
||||
self.url_handler = URLHandler()
|
||||
self.site_config = SiteConfig()
|
||||
self.markdown_generator = self.site_config.get_markdown_generator()
|
||||
self.link_pruning_markdown_generator = self.site_config.get_link_pruning_markdown_generator()
|
||||
self.domain_filter = DomainFilter()
|
||||
|
||||
# Initialize strategies
|
||||
self.batch_strategy = BatchCrawlStrategy(crawler, self.link_pruning_markdown_generator)
|
||||
self.recursive_strategy = RecursiveCrawlStrategy(crawler, self.link_pruning_markdown_generator)
|
||||
self.recursive_strategy = RecursiveCrawlStrategy(crawler, self.link_pruning_markdown_generator, self.domain_filter)
|
||||
self.single_page_strategy = SinglePageCrawlStrategy(crawler, self.markdown_generator)
|
||||
self.sitemap_strategy = SitemapCrawlStrategy()
|
||||
|
||||
@@ -225,6 +230,7 @@ class CrawlingService:
|
||||
max_concurrent,
|
||||
progress_callback,
|
||||
self._check_cancellation, # Pass cancellation check
|
||||
self.crawl_config, # Pass crawl config for domain filtering
|
||||
)
|
||||
|
||||
# Orchestration methods
|
||||
|
||||
169
python/src/server/services/crawling/domain_filter.py
Normal file
169
python/src/server/services/crawling/domain_filter.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Domain Filtering Module
|
||||
|
||||
This module provides domain filtering utilities for web crawling,
|
||||
allowing users to control which domains and URL patterns are crawled.
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ...config.logfire_config import get_logger
|
||||
from ...models.crawl_models import CrawlConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DomainFilter:
|
||||
"""
|
||||
Handles domain and URL pattern filtering for crawl operations.
|
||||
|
||||
Priority order:
|
||||
1. Blacklist (excluded_domains) - always blocks
|
||||
2. Whitelist (allowed_domains) - must match if specified
|
||||
3. Exclude patterns - blocks matching URLs
|
||||
4. Include patterns - must match if specified
|
||||
"""
|
||||
|
||||
def is_url_allowed(self, url: str, base_url: str, config: CrawlConfig | None) -> bool:
|
||||
"""
|
||||
Check if a URL should be crawled based on domain filtering configuration.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
base_url: The base URL of the crawl (for resolving relative URLs)
|
||||
config: The crawl configuration with filtering rules
|
||||
|
||||
Returns:
|
||||
True if the URL should be crawled, False otherwise
|
||||
"""
|
||||
if not config:
|
||||
# No filtering configured, allow all URLs
|
||||
return True
|
||||
|
||||
try:
|
||||
# Parse the URL
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Handle relative URLs by using base URL's domain
|
||||
if not parsed.netloc:
|
||||
base_parsed = urlparse(base_url)
|
||||
domain = base_parsed.netloc.lower()
|
||||
# Construct full URL for pattern matching
|
||||
full_url = f"{base_parsed.scheme}://{base_parsed.netloc}{parsed.path or '/'}"
|
||||
else:
|
||||
domain = parsed.netloc.lower()
|
||||
full_url = url
|
||||
|
||||
# Remove www. prefix for consistent matching
|
||||
normalized_domain = domain.replace("www.", "")
|
||||
|
||||
# PRIORITY 1: Blacklist always wins
|
||||
if config.excluded_domains:
|
||||
for excluded in config.excluded_domains:
|
||||
if self._matches_domain(normalized_domain, excluded):
|
||||
logger.debug(f"URL blocked by excluded domain | url={url} | domain={normalized_domain} | excluded={excluded}")
|
||||
return False
|
||||
|
||||
# PRIORITY 2: If whitelist exists, URL must match
|
||||
if config.allowed_domains:
|
||||
allowed = False
|
||||
for allowed_domain in config.allowed_domains:
|
||||
if self._matches_domain(normalized_domain, allowed_domain):
|
||||
allowed = True
|
||||
break
|
||||
|
||||
if not allowed:
|
||||
logger.debug(f"URL blocked - not in allowed domains | url={url} | domain={normalized_domain}")
|
||||
return False
|
||||
|
||||
# PRIORITY 3: Check exclude patterns (glob-style)
|
||||
if config.exclude_patterns:
|
||||
for pattern in config.exclude_patterns:
|
||||
if fnmatch.fnmatch(full_url, pattern):
|
||||
logger.debug(f"URL blocked by exclude pattern | url={url} | pattern={pattern}")
|
||||
return False
|
||||
|
||||
# PRIORITY 4: Check include patterns if specified
|
||||
if config.include_patterns:
|
||||
matched = False
|
||||
for pattern in config.include_patterns:
|
||||
if fnmatch.fnmatch(full_url, pattern):
|
||||
matched = True
|
||||
break
|
||||
|
||||
if not matched:
|
||||
logger.debug(f"URL blocked - doesn't match include patterns | url={url}")
|
||||
return False
|
||||
|
||||
logger.debug(f"URL allowed | url={url} | domain={normalized_domain}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error filtering URL | url={url} | error={str(e)}")
|
||||
# On error, be conservative and block the URL
|
||||
return False
|
||||
|
||||
def _matches_domain(self, domain: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if a domain matches a pattern.
|
||||
|
||||
Supports:
|
||||
- Exact matches: example.com matches example.com
|
||||
- Subdomain wildcards: *.example.com matches sub.example.com
|
||||
- Subdomain matching: sub.example.com matches sub.example.com and subsub.sub.example.com
|
||||
|
||||
Args:
|
||||
domain: The domain to check (already normalized and lowercase)
|
||||
pattern: The pattern to match against (already normalized and lowercase)
|
||||
|
||||
Returns:
|
||||
True if the domain matches the pattern
|
||||
"""
|
||||
# Remove any remaining protocol or path from pattern
|
||||
pattern = pattern.replace("http://", "").replace("https://", "").split("/")[0]
|
||||
pattern = pattern.replace("www.", "") # Remove www. for consistent matching
|
||||
|
||||
# Exact match
|
||||
if domain == pattern:
|
||||
return True
|
||||
|
||||
# Wildcard subdomain match (*.example.com)
|
||||
if pattern.startswith("*."):
|
||||
base_pattern = pattern[2:] # Remove *.
|
||||
# Check if domain ends with the base pattern and has a subdomain
|
||||
if domain.endswith(base_pattern):
|
||||
# Make sure it's a proper subdomain, not just containing the pattern
|
||||
prefix = domain[:-len(base_pattern)]
|
||||
if prefix and prefix.endswith("."):
|
||||
return True
|
||||
|
||||
# Subdomain match (allow any subdomain of the pattern)
|
||||
# e.g., pattern=example.com should match sub.example.com
|
||||
if domain.endswith(f".{pattern}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_domains_from_urls(self, urls: list[str]) -> set[str]:
|
||||
"""
|
||||
Extract unique domains from a list of URLs.
|
||||
|
||||
Args:
|
||||
urls: List of URLs to extract domains from
|
||||
|
||||
Returns:
|
||||
Set of unique domains (normalized and lowercase)
|
||||
"""
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.netloc:
|
||||
domain = parsed.netloc.lower().replace("www.", "")
|
||||
domains.add(domain)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract domain from URL | url={url} | error={str(e)}")
|
||||
continue
|
||||
|
||||
return domains
|
||||
@@ -21,17 +21,19 @@ logger = get_logger(__name__)
|
||||
class RecursiveCrawlStrategy:
|
||||
"""Strategy for recursive crawling of websites."""
|
||||
|
||||
def __init__(self, crawler, markdown_generator):
|
||||
def __init__(self, crawler, markdown_generator, domain_filter=None):
|
||||
"""
|
||||
Initialize recursive crawl strategy.
|
||||
|
||||
Args:
|
||||
crawler (AsyncWebCrawler): The Crawl4AI crawler instance for web crawling operations
|
||||
markdown_generator (DefaultMarkdownGenerator): The markdown generator instance for converting HTML to markdown
|
||||
domain_filter: Optional DomainFilter instance for URL filtering
|
||||
"""
|
||||
self.crawler = crawler
|
||||
self.markdown_generator = markdown_generator
|
||||
self.url_handler = URLHandler()
|
||||
self.domain_filter = domain_filter
|
||||
|
||||
async def crawl_recursive_with_progress(
|
||||
self,
|
||||
@@ -42,6 +44,7 @@ class RecursiveCrawlStrategy:
|
||||
max_concurrent: int | None = None,
|
||||
progress_callback: Callable[..., Awaitable[None]] | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
crawl_config=None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recursively crawl internal links from start URLs up to a maximum depth with progress reporting.
|
||||
@@ -291,6 +294,13 @@ class RecursiveCrawlStrategy:
|
||||
# Skip binary files and already visited URLs
|
||||
is_binary = self.url_handler.is_binary_file(next_url)
|
||||
if next_url not in visited and not is_binary:
|
||||
# Apply domain filtering if configured
|
||||
if self.domain_filter and crawl_config:
|
||||
base_url = start_urls[0] if start_urls else original_url
|
||||
if not self.domain_filter.is_url_allowed(next_url, base_url, crawl_config):
|
||||
logger.debug(f"Filtering URL based on domain rules: {next_url}")
|
||||
continue
|
||||
|
||||
if next_url not in next_level_urls:
|
||||
next_level_urls.add(next_url)
|
||||
total_discovered += 1 # Increment when we discover a new URL
|
||||
|
||||
0
python/src/server/services/tests/__init__.py
Normal file
0
python/src/server/services/tests/__init__.py
Normal file
204
python/src/server/services/tests/test_domain_filter.py
Normal file
204
python/src/server/services/tests/test_domain_filter.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Unit tests for domain filtering functionality
|
||||
"""
|
||||
|
||||
from src.server.models.crawl_models import CrawlConfig
|
||||
from src.server.services.crawling.domain_filter import DomainFilter
|
||||
|
||||
|
||||
class TestDomainFilter:
|
||||
"""Test suite for DomainFilter class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.filter = DomainFilter()
|
||||
|
||||
def test_no_config_allows_all(self):
|
||||
"""Test that no configuration allows all URLs."""
|
||||
assert self.filter.is_url_allowed("https://example.com/page", "https://example.com", None) is True
|
||||
assert self.filter.is_url_allowed("https://other.com/page", "https://example.com", None) is True
|
||||
|
||||
def test_whitelist_only(self):
|
||||
"""Test whitelist-only configuration."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=["example.com", "docs.example.com"]
|
||||
)
|
||||
|
||||
# Should allow whitelisted domains
|
||||
assert self.filter.is_url_allowed("https://example.com/page", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://docs.example.com/api", "https://example.com", config) is True
|
||||
|
||||
# Should block non-whitelisted domains
|
||||
assert self.filter.is_url_allowed("https://other.com/page", "https://example.com", config) is False
|
||||
assert self.filter.is_url_allowed("https://evil.com", "https://example.com", config) is False
|
||||
|
||||
def test_blacklist_only(self):
|
||||
"""Test blacklist-only configuration."""
|
||||
config = CrawlConfig(
|
||||
excluded_domains=["evil.com", "ads.example.com"]
|
||||
)
|
||||
|
||||
# Should block blacklisted domains
|
||||
assert self.filter.is_url_allowed("https://evil.com/page", "https://example.com", config) is False
|
||||
assert self.filter.is_url_allowed("https://ads.example.com/track", "https://example.com", config) is False
|
||||
|
||||
# Should allow non-blacklisted domains
|
||||
assert self.filter.is_url_allowed("https://example.com/page", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://docs.example.com/api", "https://example.com", config) is True
|
||||
|
||||
def test_blacklist_overrides_whitelist(self):
|
||||
"""Test that blacklist takes priority over whitelist."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=["example.com", "blog.example.com"],
|
||||
excluded_domains=["blog.example.com"]
|
||||
)
|
||||
|
||||
# Blacklist should override whitelist
|
||||
assert self.filter.is_url_allowed("https://blog.example.com/post", "https://example.com", config) is False
|
||||
|
||||
# Non-blacklisted whitelisted domain should work
|
||||
assert self.filter.is_url_allowed("https://example.com/page", "https://example.com", config) is True
|
||||
|
||||
def test_subdomain_matching(self):
|
||||
"""Test subdomain matching patterns."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=["example.com"]
|
||||
)
|
||||
|
||||
# Should match subdomains of allowed domain
|
||||
assert self.filter.is_url_allowed("https://docs.example.com/page", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://api.example.com/v1", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://sub.sub.example.com", "https://example.com", config) is True
|
||||
|
||||
# Should not match different domains
|
||||
assert self.filter.is_url_allowed("https://notexample.com", "https://example.com", config) is False
|
||||
|
||||
def test_wildcard_subdomain_matching(self):
|
||||
"""Test wildcard subdomain patterns."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=["*.example.com"]
|
||||
)
|
||||
|
||||
# Should match subdomains
|
||||
assert self.filter.is_url_allowed("https://docs.example.com/page", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://api.example.com/v1", "https://example.com", config) is True
|
||||
|
||||
# Should NOT match the base domain without subdomain
|
||||
assert self.filter.is_url_allowed("https://example.com/page", "https://example.com", config) is False
|
||||
|
||||
def test_url_patterns_include(self):
|
||||
"""Test include URL patterns."""
|
||||
config = CrawlConfig(
|
||||
include_patterns=["*/api/*", "*/docs/*"]
|
||||
)
|
||||
|
||||
# Should match include patterns
|
||||
assert self.filter.is_url_allowed("https://example.com/api/v1", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://example.com/docs/guide", "https://example.com", config) is True
|
||||
|
||||
# Should not match URLs not in patterns
|
||||
assert self.filter.is_url_allowed("https://example.com/blog/post", "https://example.com", config) is False
|
||||
assert self.filter.is_url_allowed("https://example.com/", "https://example.com", config) is False
|
||||
|
||||
def test_url_patterns_exclude(self):
|
||||
"""Test exclude URL patterns."""
|
||||
config = CrawlConfig(
|
||||
exclude_patterns=["*/private/*", "*.pdf", "*/admin/*"]
|
||||
)
|
||||
|
||||
# Should block excluded patterns
|
||||
assert self.filter.is_url_allowed("https://example.com/private/data", "https://example.com", config) is False
|
||||
assert self.filter.is_url_allowed("https://example.com/file.pdf", "https://example.com", config) is False
|
||||
assert self.filter.is_url_allowed("https://example.com/admin/panel", "https://example.com", config) is False
|
||||
|
||||
# Should allow non-excluded URLs
|
||||
assert self.filter.is_url_allowed("https://example.com/public/page", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://example.com/file.html", "https://example.com", config) is True
|
||||
|
||||
def test_combined_filters(self):
|
||||
"""Test combination of all filter types."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=["example.com", "docs.example.com"],
|
||||
excluded_domains=["ads.example.com"],
|
||||
include_patterns=["*/api/*", "*/guide/*"],
|
||||
exclude_patterns=["*/deprecated/*"]
|
||||
)
|
||||
|
||||
# Should pass all filters
|
||||
assert self.filter.is_url_allowed("https://docs.example.com/api/v2", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://example.com/guide/intro", "https://example.com", config) is True
|
||||
|
||||
# Should fail on blacklist (highest priority)
|
||||
assert self.filter.is_url_allowed("https://ads.example.com/api/track", "https://example.com", config) is False
|
||||
|
||||
# Should fail on not in whitelist
|
||||
assert self.filter.is_url_allowed("https://other.com/api/v1", "https://example.com", config) is False
|
||||
|
||||
# Should fail on exclude pattern
|
||||
assert self.filter.is_url_allowed("https://example.com/api/deprecated/old", "https://example.com", config) is False
|
||||
|
||||
# Should fail on not matching include pattern
|
||||
assert self.filter.is_url_allowed("https://example.com/blog/post", "https://example.com", config) is False
|
||||
|
||||
def test_relative_urls(self):
|
||||
"""Test handling of relative URLs."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=["example.com"]
|
||||
)
|
||||
|
||||
# Relative URLs should use base URL's domain
|
||||
assert self.filter.is_url_allowed("/page/path", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("page.html", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("../other/page", "https://example.com", config) is True
|
||||
|
||||
def test_domain_normalization(self):
|
||||
"""Test that domains are properly normalized."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=["EXAMPLE.COM", "https://docs.example.com/", "www.test.com"]
|
||||
)
|
||||
|
||||
# Should handle different cases and formats
|
||||
assert self.filter.is_url_allowed("https://example.com/page", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://EXAMPLE.COM/PAGE", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://docs.example.com/api", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://www.test.com/page", "https://example.com", config) is True
|
||||
assert self.filter.is_url_allowed("https://test.com/page", "https://example.com", config) is True
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases and error handling."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=["example.com"]
|
||||
)
|
||||
|
||||
# Should handle malformed URLs gracefully
|
||||
assert self.filter.is_url_allowed("not-a-url", "https://example.com", config) is True # Treated as relative
|
||||
assert self.filter.is_url_allowed("", "https://example.com", config) is True # Empty URL
|
||||
assert self.filter.is_url_allowed("//example.com/page", "https://example.com", config) is True # Protocol-relative
|
||||
|
||||
def test_get_domains_from_urls(self):
|
||||
"""Test extracting domains from URL list."""
|
||||
urls = [
|
||||
"https://example.com/page1",
|
||||
"https://docs.example.com/api",
|
||||
"https://example.com/page2",
|
||||
"https://other.com/resource",
|
||||
"https://WWW.TEST.COM/page",
|
||||
"/relative/path", # Should be skipped
|
||||
"invalid-url", # Should be skipped
|
||||
]
|
||||
|
||||
domains = self.filter.get_domains_from_urls(urls)
|
||||
|
||||
assert domains == {"example.com", "docs.example.com", "other.com", "test.com"}
|
||||
|
||||
def test_empty_filter_lists(self):
|
||||
"""Test that empty filter lists behave correctly."""
|
||||
config = CrawlConfig(
|
||||
allowed_domains=[],
|
||||
excluded_domains=[],
|
||||
include_patterns=[],
|
||||
exclude_patterns=[]
|
||||
)
|
||||
|
||||
# Empty lists should be ignored (allow all)
|
||||
assert self.filter.is_url_allowed("https://any.com/page", "https://example.com", config) is True
|
||||
Reference in New Issue
Block a user