""" Threading Service for Archon This service provides comprehensive threading patterns for high-performance AI operations while maintaining WebSocket connection health and system stability. Based on proven patterns from crawl4ai_mcp.py architecture. """ import asyncio import gc import threading import time from collections import deque from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from dataclasses import dataclass, field # Removed direct logging import - using unified config from enum import Enum from typing import Any import psutil from fastapi import WebSocket from ..config.logfire_config import get_logger # Get logger for this module logfire_logger = get_logger("threading") class ProcessingMode(str, Enum): """Processing modes for different workload types""" CPU_INTENSIVE = "cpu_intensive" # AI summaries, embeddings, heavy computation IO_BOUND = "io_bound" # Database operations, file I/O NETWORK_BOUND = "network_bound" # External API calls, web requests WEBSOCKET_SAFE = "websocket_safe" # Operations that need to yield for WebSocket health @dataclass class RateLimitConfig: """Configuration for rate limiting""" tokens_per_minute: int = 200_000 # OpenAI embedding limit requests_per_minute: int = 3000 # Request rate limit max_concurrent: int = 2 # Concurrent request limit backoff_multiplier: float = 1.5 # Exponential backoff multiplier max_backoff: float = 60.0 # Maximum backoff delay in seconds @dataclass class SystemMetrics: """Current system performance metrics""" memory_percent: float cpu_percent: float available_memory_gb: float active_threads: int timestamp: float = field(default_factory=time.time) @dataclass class ThreadingConfig: """Configuration for threading behavior""" base_workers: int = 4 max_workers: int = 16 memory_threshold: float = 0.8 cpu_threshold: float = 0.9 batch_size: int = 15 yield_interval: float = 0.1 # How often to yield for WebSocket health health_check_interval: float = 30 # System health check frequency class RateLimiter: """Thread-safe rate limiter with token bucket algorithm""" def __init__(self, config: RateLimitConfig): self.config = config self.request_times = deque() self.token_usage = deque() self.semaphore = asyncio.Semaphore(config.max_concurrent) self._lock = asyncio.Lock() async def acquire(self, estimated_tokens: int = 8000) -> bool: """Acquire permission to make API call with token awareness""" async with self._lock: now = time.time() # Clean old entries self._clean_old_entries(now) # Check if we can make the request if not self._can_make_request(estimated_tokens): wait_time = self._calculate_wait_time(estimated_tokens) if wait_time > 0: logfire_logger.info( f"Rate limiting: waiting {wait_time:.1f}s", extra={ "tokens": estimated_tokens, "current_usage": self._get_current_usage(), } ) await asyncio.sleep(wait_time) return await self.acquire(estimated_tokens) return False # Record the request self.request_times.append(now) self.token_usage.append((now, estimated_tokens)) return True def _can_make_request(self, estimated_tokens: int) -> bool: """Check if request can be made within limits""" # Check request rate limit if len(self.request_times) >= self.config.requests_per_minute: return False # Check token usage limit current_tokens = sum(tokens for _, tokens in self.token_usage) if current_tokens + estimated_tokens > self.config.tokens_per_minute: return False return True def _clean_old_entries(self, current_time: float): """Remove entries older than 1 minute""" cutoff_time = current_time - 60 while self.request_times and self.request_times[0] < cutoff_time: self.request_times.popleft() while self.token_usage and self.token_usage[0][0] < cutoff_time: self.token_usage.popleft() def _calculate_wait_time(self, estimated_tokens: int) -> float: """Calculate how long to wait before retrying""" if not self.request_times: return 0 oldest_request = self.request_times[0] time_since_oldest = time.time() - oldest_request if time_since_oldest < 60: return 60 - time_since_oldest + 0.1 return 0 def _get_current_usage(self) -> dict[str, int]: """Get current usage statistics""" current_tokens = sum(tokens for _, tokens in self.token_usage) return { "requests": len(self.request_times), "tokens": current_tokens, "max_requests": self.config.requests_per_minute, "max_tokens": self.config.tokens_per_minute, } class MemoryAdaptiveDispatcher: """Dynamically adjust concurrency based on memory usage""" def __init__(self, config: ThreadingConfig): self.config = config self.current_workers = config.base_workers self.last_metrics = None def get_system_metrics(self) -> SystemMetrics: """Get current system performance metrics""" memory = psutil.virtual_memory() cpu_percent = psutil.cpu_percent(interval=1) active_threads = threading.active_count() return SystemMetrics( memory_percent=memory.percent, cpu_percent=cpu_percent, available_memory_gb=memory.available / (1024**3), active_threads=active_threads, ) def calculate_optimal_workers(self, mode: ProcessingMode = ProcessingMode.CPU_INTENSIVE) -> int: """Calculate optimal worker count based on system load and processing mode""" metrics = self.get_system_metrics() self.last_metrics = metrics # Base worker count depends on processing mode if mode == ProcessingMode.CPU_INTENSIVE: base = min(self.config.base_workers, psutil.cpu_count()) elif mode == ProcessingMode.IO_BOUND: base = self.config.base_workers * 2 elif mode == ProcessingMode.NETWORK_BOUND: base = self.config.base_workers else: # WEBSOCKET_SAFE base = max(1, self.config.base_workers // 2) # Adjust based on system load if metrics.memory_percent > self.config.memory_threshold * 100: # Reduce workers when memory is high workers = max(1, base // 2) logfire_logger.warning( "High memory usage detected, reducing workers", extra={ "memory_percent": metrics.memory_percent, "workers": workers, } ) elif metrics.cpu_percent > self.config.cpu_threshold * 100: # Reduce workers when CPU is high workers = max(1, base // 2) logfire_logger.warning( "High CPU usage detected, reducing workers", extra={ "cpu_percent": metrics.cpu_percent, "workers": workers, } ) elif metrics.memory_percent < 50 and metrics.cpu_percent < 50: # Increase workers when resources are available workers = min(self.config.max_workers, base * 2) else: # Use base worker count workers = base self.current_workers = workers return workers async def process_with_adaptive_concurrency( self, items: list[Any], process_func: Callable, mode: ProcessingMode = ProcessingMode.CPU_INTENSIVE, websocket: WebSocket | None = None, progress_callback: Callable | None = None, enable_worker_tracking: bool = False, ) -> list[Any]: """Process items with adaptive concurrency control""" if not items: return [] optimal_workers = self.calculate_optimal_workers(mode) semaphore = asyncio.Semaphore(optimal_workers) logfire_logger.info( "Starting adaptive processing", extra={ "items_count": len(items), "workers": optimal_workers, "mode": mode, "memory_percent": self.last_metrics.memory_percent, "cpu_percent": self.last_metrics.cpu_percent, } ) # Track active workers active_workers = {} worker_counter = 0 completed_count = 0 lock = asyncio.Lock() async def process_single(item: Any, index: int) -> Any: nonlocal worker_counter, completed_count # Assign worker ID worker_id = None async with lock: for i in range(1, optimal_workers + 1): if i not in active_workers: worker_id = i active_workers[worker_id] = index break async with semaphore: try: # Report worker started if progress_callback and worker_id: await progress_callback({ "type": "worker_started", "worker_id": worker_id, "item_index": index, "total_items": len(items), "message": f"Worker {worker_id} processing item {index + 1}", }) # For CPU-intensive work, run in thread pool if mode == ProcessingMode.CPU_INTENSIVE: loop = asyncio.get_event_loop() result = await loop.run_in_executor(None, process_func, item) else: # For other modes, run directly (assumed to be async) if asyncio.iscoroutinefunction(process_func): result = await process_func(item) else: result = process_func(item) # Update completed count async with lock: completed_count += 1 if worker_id in active_workers: del active_workers[worker_id] # Progress reporting with worker info if progress_callback: await progress_callback({ "type": "worker_completed", "worker_id": worker_id, "item_index": index, "completed_count": completed_count, "total_items": len(items), "message": f"Worker {worker_id} completed item {index + 1}", }) # WebSocket health check if websocket and mode == ProcessingMode.WEBSOCKET_SAFE: if index % 10 == 0: # Every 10 items await asyncio.sleep(self.config.yield_interval) return result except Exception as e: # Clean up worker on error async with lock: if worker_id and worker_id in active_workers: del active_workers[worker_id] logfire_logger.error( f"Processing failed for item {index}", extra={"error": str(e), "item_index": index} ) return None # Create tasks for all items tasks = [process_single(item, idx) for idx, item in enumerate(items)] # Execute with controlled concurrency results = await asyncio.gather(*tasks, return_exceptions=True) # Filter out failed results and exceptions successful_results = [r for r in results if r is not None and not isinstance(r, Exception)] success_rate = len(successful_results) / len(items) * 100 logfire_logger.info( "Adaptive processing completed", extra={ "total_items": len(items), "successful": len(successful_results), "success_rate": f"{success_rate:.1f}%", "workers_used": optimal_workers, } ) return successful_results class WebSocketSafeProcessor: """WebSocket-safe processing with progress updates""" def __init__(self, config: ThreadingConfig): self.config = config self.active_connections: list[WebSocket] = [] async def connect(self, websocket: WebSocket): """Connect WebSocket client""" await websocket.accept() self.active_connections.append(websocket) logfire_logger.info( "WebSocket client connected", extra={"total_connections": len(self.active_connections)} ) def disconnect(self, websocket: WebSocket): """Disconnect WebSocket client""" if websocket in self.active_connections: self.active_connections.remove(websocket) logfire_logger.info( "WebSocket client disconnected", extra={"remaining_connections": len(self.active_connections)} ) async def broadcast_progress(self, message: dict[str, Any]): """Broadcast progress to all connected clients""" if not self.active_connections: return # Send to all clients concurrently tasks = [] for connection in self.active_connections.copy(): try: task = connection.send_json(message) tasks.append(task) except Exception: self.disconnect(connection) if tasks: await asyncio.gather(*tasks, return_exceptions=True) async def process_with_progress( self, items: list[Any], process_func: Callable, operation_name: str = "processing", batch_size: int | None = None, ) -> list[Any]: """Process items with WebSocket progress updates""" if not items: return [] batch_size = batch_size or self.config.batch_size total_items = len(items) results = [] for batch_start in range(0, total_items, batch_size): batch_end = min(batch_start + batch_size, total_items) batch = items[batch_start:batch_end] # Process batch for i, item in enumerate(batch): if asyncio.iscoroutinefunction(process_func): result = await process_func(item) else: # Run in thread pool for CPU-intensive work loop = asyncio.get_event_loop() result = await loop.run_in_executor(None, process_func, item) results.append(result) # Calculate progress items_processed = batch_start + i + 1 progress = (items_processed / total_items) * 100 # Broadcast progress await self.broadcast_progress({ "type": "progress", "operation": operation_name, "progress": progress, "processed": items_processed, "total": total_items, "batch": f"Batch {batch_start // batch_size + 1}", "current_item": str(getattr(item, "id", i)), }) # Yield control for WebSocket health await asyncio.sleep(self.config.yield_interval) # Final completion message await self.broadcast_progress({ "type": "complete", "operation": operation_name, "total_processed": len(results), "success_rate": f"{len(results) / total_items * 100:.1f}%", }) return results class ThreadingService: """Main threading service that coordinates all threading operations""" def __init__( self, threading_config: ThreadingConfig | None = None, rate_limit_config: RateLimitConfig | None = None, ): self.config = threading_config or ThreadingConfig() self.rate_limiter = RateLimiter(rate_limit_config or RateLimitConfig()) self.memory_dispatcher = MemoryAdaptiveDispatcher(self.config) self.websocket_processor = WebSocketSafeProcessor(self.config) # Thread pools for different workload types self.cpu_executor = ThreadPoolExecutor( max_workers=self.config.max_workers, thread_name_prefix="archon-cpu" ) self.io_executor = ThreadPoolExecutor( max_workers=self.config.max_workers * 2, thread_name_prefix="archon-io" ) self._running = False self._health_check_task = None async def start(self): """Start the threading service""" if self._running: return self._running = True self._health_check_task = asyncio.create_task(self._health_check_loop()) logfire_logger.info("Threading service started", extra={"config": self.config.__dict__}) async def stop(self): """Stop the threading service""" if not self._running: return self._running = False if self._health_check_task: self._health_check_task.cancel() try: await self._health_check_task except asyncio.CancelledError: pass # Shutdown thread pools self.cpu_executor.shutdown(wait=True) self.io_executor.shutdown(wait=True) logfire_logger.info("Threading service stopped") @asynccontextmanager async def rate_limited_operation(self, estimated_tokens: int = 8000): """Context manager for rate-limited operations""" async with self.rate_limiter.semaphore: can_proceed = await self.rate_limiter.acquire(estimated_tokens) if not can_proceed: raise Exception("Rate limit exceeded") start_time = time.time() try: yield finally: duration = time.time() - start_time logfire_logger.debug( "Rate limited operation completed", extra={"duration": duration, "tokens": estimated_tokens}, ) async def run_cpu_intensive(self, func: Callable, *args, **kwargs) -> Any: """Run CPU-intensive function in thread pool""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self.cpu_executor, func, *args, **kwargs) async def run_io_bound(self, func: Callable, *args, **kwargs) -> Any: """Run I/O-bound function in thread pool""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self.io_executor, func, *args, **kwargs) async def batch_process( self, items: list[Any], process_func: Callable, mode: ProcessingMode = ProcessingMode.CPU_INTENSIVE, websocket: WebSocket | None = None, progress_callback: Callable | None = None, enable_worker_tracking: bool = False, ) -> list[Any]: """Process items in batches with optimal threading""" return await self.memory_dispatcher.process_with_adaptive_concurrency( items=items, process_func=process_func, mode=mode, websocket=websocket, progress_callback=progress_callback, enable_worker_tracking=enable_worker_tracking, ) async def websocket_safe_process( self, items: list[Any], process_func: Callable, operation_name: str = "processing" ) -> list[Any]: """Process items with WebSocket safety guarantees""" return await self.websocket_processor.process_with_progress( items=items, process_func=process_func, operation_name=operation_name ) def get_system_metrics(self) -> SystemMetrics: """Get current system performance metrics""" return self.memory_dispatcher.get_system_metrics() async def _health_check_loop(self): """Monitor system health and adjust threading parameters""" while self._running: try: metrics = self.get_system_metrics() # Log system metrics logfire_logger.info( "System health check", extra={ "memory_percent": metrics.memory_percent, "cpu_percent": metrics.cpu_percent, "available_memory_gb": metrics.available_memory_gb, "active_threads": metrics.active_threads, "active_websockets": len(self.websocket_processor.active_connections), } ) # Alert on critical thresholds if metrics.memory_percent > 90: logfire_logger.warning( "Critical memory usage", extra={"memory_percent": metrics.memory_percent} ) # Force garbage collection gc.collect() if metrics.cpu_percent > 95: logfire_logger.warning( "Critical CPU usage", extra={"cpu_percent": metrics.cpu_percent} ) # Check for memory leaks (too many threads) if metrics.active_threads > self.config.max_workers * 3: logfire_logger.warning( "High thread count detected", extra={ "active_threads": metrics.active_threads, "max_expected": self.config.max_workers * 3, } ) await asyncio.sleep(self.config.health_check_interval) except Exception as e: logfire_logger.error("Health check failed", extra={"error": str(e)}) await asyncio.sleep(self.config.health_check_interval) # Global threading service instance _threading_service: ThreadingService | None = None def get_threading_service() -> ThreadingService: """Get the global threading service instance""" global _threading_service if _threading_service is None: _threading_service = ThreadingService() return _threading_service async def start_threading_service() -> ThreadingService: """Start the global threading service""" service = get_threading_service() await service.start() return service async def stop_threading_service(): """Stop the global threading service""" global _threading_service if _threading_service: await _threading_service.stop() _threading_service = None