mirror of
https://github.com/coleam00/Archon.git
synced 2026-01-01 20:28:43 -05:00
644 lines
23 KiB
Python
644 lines
23 KiB
Python
"""
|
|
Threading Service for Archon
|
|
|
|
This service provides comprehensive threading patterns for high-performance AI operations
|
|
while maintaining WebSocket connection health and system stability.
|
|
|
|
Based on proven patterns from crawl4ai_mcp.py architecture.
|
|
"""
|
|
|
|
import asyncio
|
|
import gc
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
from collections.abc import Callable
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass, field
|
|
|
|
# Removed direct logging import - using unified config
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
import psutil
|
|
from fastapi import WebSocket
|
|
|
|
from ..config.logfire_config import get_logger
|
|
|
|
# Get logger for this module
|
|
logfire_logger = get_logger("threading")
|
|
|
|
|
|
class ProcessingMode(str, Enum):
|
|
"""Processing modes for different workload types"""
|
|
|
|
CPU_INTENSIVE = "cpu_intensive" # AI summaries, embeddings, heavy computation
|
|
IO_BOUND = "io_bound" # Database operations, file I/O
|
|
NETWORK_BOUND = "network_bound" # External API calls, web requests
|
|
WEBSOCKET_SAFE = "websocket_safe" # Operations that need to yield for WebSocket health
|
|
|
|
|
|
@dataclass
|
|
class RateLimitConfig:
|
|
"""Configuration for rate limiting"""
|
|
|
|
tokens_per_minute: int = 200_000 # OpenAI embedding limit
|
|
requests_per_minute: int = 3000 # Request rate limit
|
|
max_concurrent: int = 2 # Concurrent request limit
|
|
backoff_multiplier: float = 1.5 # Exponential backoff multiplier
|
|
max_backoff: float = 60.0 # Maximum backoff delay in seconds
|
|
|
|
|
|
@dataclass
|
|
class SystemMetrics:
|
|
"""Current system performance metrics"""
|
|
|
|
memory_percent: float
|
|
cpu_percent: float
|
|
available_memory_gb: float
|
|
active_threads: int
|
|
timestamp: float = field(default_factory=time.time)
|
|
|
|
|
|
@dataclass
|
|
class ThreadingConfig:
|
|
"""Configuration for threading behavior"""
|
|
|
|
base_workers: int = 4
|
|
max_workers: int = 16
|
|
memory_threshold: float = 0.8
|
|
cpu_threshold: float = 0.9
|
|
batch_size: int = 15
|
|
yield_interval: float = 0.1 # How often to yield for WebSocket health
|
|
health_check_interval: float = 30 # System health check frequency
|
|
|
|
|
|
class RateLimiter:
|
|
"""Thread-safe rate limiter with token bucket algorithm"""
|
|
|
|
def __init__(self, config: RateLimitConfig):
|
|
self.config = config
|
|
self.request_times = deque()
|
|
self.token_usage = deque()
|
|
self.semaphore = asyncio.Semaphore(config.max_concurrent)
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def acquire(self, estimated_tokens: int = 8000) -> bool:
|
|
"""Acquire permission to make API call with token awareness"""
|
|
async with self._lock:
|
|
now = time.time()
|
|
|
|
# Clean old entries
|
|
self._clean_old_entries(now)
|
|
|
|
# Check if we can make the request
|
|
if not self._can_make_request(estimated_tokens):
|
|
wait_time = self._calculate_wait_time(estimated_tokens)
|
|
if wait_time > 0:
|
|
logfire_logger.info(
|
|
f"Rate limiting: waiting {wait_time:.1f}s",
|
|
extra={
|
|
"tokens": estimated_tokens,
|
|
"current_usage": self._get_current_usage(),
|
|
}
|
|
)
|
|
await asyncio.sleep(wait_time)
|
|
return await self.acquire(estimated_tokens)
|
|
return False
|
|
|
|
# Record the request
|
|
self.request_times.append(now)
|
|
self.token_usage.append((now, estimated_tokens))
|
|
return True
|
|
|
|
def _can_make_request(self, estimated_tokens: int) -> bool:
|
|
"""Check if request can be made within limits"""
|
|
# Check request rate limit
|
|
if len(self.request_times) >= self.config.requests_per_minute:
|
|
return False
|
|
|
|
# Check token usage limit
|
|
current_tokens = sum(tokens for _, tokens in self.token_usage)
|
|
if current_tokens + estimated_tokens > self.config.tokens_per_minute:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _clean_old_entries(self, current_time: float):
|
|
"""Remove entries older than 1 minute"""
|
|
cutoff_time = current_time - 60
|
|
|
|
while self.request_times and self.request_times[0] < cutoff_time:
|
|
self.request_times.popleft()
|
|
|
|
while self.token_usage and self.token_usage[0][0] < cutoff_time:
|
|
self.token_usage.popleft()
|
|
|
|
def _calculate_wait_time(self, estimated_tokens: int) -> float:
|
|
"""Calculate how long to wait before retrying"""
|
|
if not self.request_times:
|
|
return 0
|
|
|
|
oldest_request = self.request_times[0]
|
|
time_since_oldest = time.time() - oldest_request
|
|
|
|
if time_since_oldest < 60:
|
|
return 60 - time_since_oldest + 0.1
|
|
|
|
return 0
|
|
|
|
def _get_current_usage(self) -> dict[str, int]:
|
|
"""Get current usage statistics"""
|
|
current_tokens = sum(tokens for _, tokens in self.token_usage)
|
|
return {
|
|
"requests": len(self.request_times),
|
|
"tokens": current_tokens,
|
|
"max_requests": self.config.requests_per_minute,
|
|
"max_tokens": self.config.tokens_per_minute,
|
|
}
|
|
|
|
|
|
class MemoryAdaptiveDispatcher:
|
|
"""Dynamically adjust concurrency based on memory usage"""
|
|
|
|
def __init__(self, config: ThreadingConfig):
|
|
self.config = config
|
|
self.current_workers = config.base_workers
|
|
self.last_metrics = None
|
|
|
|
def get_system_metrics(self) -> SystemMetrics:
|
|
"""Get current system performance metrics"""
|
|
memory = psutil.virtual_memory()
|
|
cpu_percent = psutil.cpu_percent(interval=1)
|
|
active_threads = threading.active_count()
|
|
|
|
return SystemMetrics(
|
|
memory_percent=memory.percent,
|
|
cpu_percent=cpu_percent,
|
|
available_memory_gb=memory.available / (1024**3),
|
|
active_threads=active_threads,
|
|
)
|
|
|
|
def calculate_optimal_workers(self, mode: ProcessingMode = ProcessingMode.CPU_INTENSIVE) -> int:
|
|
"""Calculate optimal worker count based on system load and processing mode"""
|
|
metrics = self.get_system_metrics()
|
|
self.last_metrics = metrics
|
|
|
|
# Base worker count depends on processing mode
|
|
if mode == ProcessingMode.CPU_INTENSIVE:
|
|
base = min(self.config.base_workers, psutil.cpu_count())
|
|
elif mode == ProcessingMode.IO_BOUND:
|
|
base = self.config.base_workers * 2
|
|
elif mode == ProcessingMode.NETWORK_BOUND:
|
|
base = self.config.base_workers
|
|
else: # WEBSOCKET_SAFE
|
|
base = max(1, self.config.base_workers // 2)
|
|
|
|
# Adjust based on system load
|
|
if metrics.memory_percent > self.config.memory_threshold * 100:
|
|
# Reduce workers when memory is high
|
|
workers = max(1, base // 2)
|
|
logfire_logger.warning(
|
|
"High memory usage detected, reducing workers",
|
|
extra={
|
|
"memory_percent": metrics.memory_percent,
|
|
"workers": workers,
|
|
}
|
|
)
|
|
elif metrics.cpu_percent > self.config.cpu_threshold * 100:
|
|
# Reduce workers when CPU is high
|
|
workers = max(1, base // 2)
|
|
logfire_logger.warning(
|
|
"High CPU usage detected, reducing workers",
|
|
extra={
|
|
"cpu_percent": metrics.cpu_percent,
|
|
"workers": workers,
|
|
}
|
|
)
|
|
elif metrics.memory_percent < 50 and metrics.cpu_percent < 50:
|
|
# Increase workers when resources are available
|
|
workers = min(self.config.max_workers, base * 2)
|
|
else:
|
|
# Use base worker count
|
|
workers = base
|
|
|
|
self.current_workers = workers
|
|
return workers
|
|
|
|
async def process_with_adaptive_concurrency(
|
|
self,
|
|
items: list[Any],
|
|
process_func: Callable,
|
|
mode: ProcessingMode = ProcessingMode.CPU_INTENSIVE,
|
|
websocket: WebSocket | None = None,
|
|
progress_callback: Callable | None = None,
|
|
enable_worker_tracking: bool = False,
|
|
) -> list[Any]:
|
|
"""Process items with adaptive concurrency control"""
|
|
|
|
if not items:
|
|
return []
|
|
|
|
optimal_workers = self.calculate_optimal_workers(mode)
|
|
semaphore = asyncio.Semaphore(optimal_workers)
|
|
|
|
logfire_logger.info(
|
|
"Starting adaptive processing",
|
|
extra={
|
|
"items_count": len(items),
|
|
"workers": optimal_workers,
|
|
"mode": mode,
|
|
"memory_percent": self.last_metrics.memory_percent,
|
|
"cpu_percent": self.last_metrics.cpu_percent,
|
|
}
|
|
)
|
|
|
|
# Track active workers
|
|
active_workers = {}
|
|
worker_counter = 0
|
|
completed_count = 0
|
|
lock = asyncio.Lock()
|
|
|
|
async def process_single(item: Any, index: int) -> Any:
|
|
nonlocal worker_counter, completed_count
|
|
|
|
# Assign worker ID
|
|
worker_id = None
|
|
async with lock:
|
|
for i in range(1, optimal_workers + 1):
|
|
if i not in active_workers:
|
|
worker_id = i
|
|
active_workers[worker_id] = index
|
|
break
|
|
|
|
async with semaphore:
|
|
try:
|
|
# Report worker started
|
|
if progress_callback and worker_id:
|
|
await progress_callback({
|
|
"type": "worker_started",
|
|
"worker_id": worker_id,
|
|
"item_index": index,
|
|
"total_items": len(items),
|
|
"message": f"Worker {worker_id} processing item {index + 1}",
|
|
})
|
|
|
|
# For CPU-intensive work, run in thread pool
|
|
if mode == ProcessingMode.CPU_INTENSIVE:
|
|
loop = asyncio.get_event_loop()
|
|
result = await loop.run_in_executor(None, process_func, item)
|
|
else:
|
|
# For other modes, run directly (assumed to be async)
|
|
if asyncio.iscoroutinefunction(process_func):
|
|
result = await process_func(item)
|
|
else:
|
|
result = process_func(item)
|
|
|
|
# Update completed count
|
|
async with lock:
|
|
completed_count += 1
|
|
if worker_id in active_workers:
|
|
del active_workers[worker_id]
|
|
|
|
# Progress reporting with worker info
|
|
if progress_callback:
|
|
await progress_callback({
|
|
"type": "worker_completed",
|
|
"worker_id": worker_id,
|
|
"item_index": index,
|
|
"completed_count": completed_count,
|
|
"total_items": len(items),
|
|
"message": f"Worker {worker_id} completed item {index + 1}",
|
|
})
|
|
|
|
# WebSocket health check
|
|
if websocket and mode == ProcessingMode.WEBSOCKET_SAFE:
|
|
if index % 10 == 0: # Every 10 items
|
|
await asyncio.sleep(self.config.yield_interval)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
# Clean up worker on error
|
|
async with lock:
|
|
if worker_id and worker_id in active_workers:
|
|
del active_workers[worker_id]
|
|
|
|
logfire_logger.error(
|
|
f"Processing failed for item {index}",
|
|
extra={"error": str(e), "item_index": index}
|
|
)
|
|
return None
|
|
|
|
# Create tasks for all items
|
|
tasks = [process_single(item, idx) for idx, item in enumerate(items)]
|
|
|
|
# Execute with controlled concurrency
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Filter out failed results and exceptions
|
|
successful_results = [r for r in results if r is not None and not isinstance(r, Exception)]
|
|
|
|
success_rate = len(successful_results) / len(items) * 100
|
|
logfire_logger.info(
|
|
"Adaptive processing completed",
|
|
extra={
|
|
"total_items": len(items),
|
|
"successful": len(successful_results),
|
|
"success_rate": f"{success_rate:.1f}%",
|
|
"workers_used": optimal_workers,
|
|
}
|
|
)
|
|
|
|
return successful_results
|
|
|
|
|
|
class WebSocketSafeProcessor:
|
|
"""WebSocket-safe processing with progress updates"""
|
|
|
|
def __init__(self, config: ThreadingConfig):
|
|
self.config = config
|
|
self.active_connections: list[WebSocket] = []
|
|
|
|
async def connect(self, websocket: WebSocket):
|
|
"""Connect WebSocket client"""
|
|
await websocket.accept()
|
|
self.active_connections.append(websocket)
|
|
logfire_logger.info(
|
|
"WebSocket client connected",
|
|
extra={"total_connections": len(self.active_connections)}
|
|
)
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
"""Disconnect WebSocket client"""
|
|
if websocket in self.active_connections:
|
|
self.active_connections.remove(websocket)
|
|
logfire_logger.info(
|
|
"WebSocket client disconnected",
|
|
extra={"remaining_connections": len(self.active_connections)}
|
|
)
|
|
|
|
async def broadcast_progress(self, message: dict[str, Any]):
|
|
"""Broadcast progress to all connected clients"""
|
|
if not self.active_connections:
|
|
return
|
|
|
|
# Send to all clients concurrently
|
|
tasks = []
|
|
for connection in self.active_connections.copy():
|
|
try:
|
|
task = connection.send_json(message)
|
|
tasks.append(task)
|
|
except Exception:
|
|
self.disconnect(connection)
|
|
|
|
if tasks:
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
async def process_with_progress(
|
|
self,
|
|
items: list[Any],
|
|
process_func: Callable,
|
|
operation_name: str = "processing",
|
|
batch_size: int | None = None,
|
|
) -> list[Any]:
|
|
"""Process items with WebSocket progress updates"""
|
|
|
|
if not items:
|
|
return []
|
|
|
|
batch_size = batch_size or self.config.batch_size
|
|
total_items = len(items)
|
|
results = []
|
|
|
|
for batch_start in range(0, total_items, batch_size):
|
|
batch_end = min(batch_start + batch_size, total_items)
|
|
batch = items[batch_start:batch_end]
|
|
|
|
# Process batch
|
|
for i, item in enumerate(batch):
|
|
if asyncio.iscoroutinefunction(process_func):
|
|
result = await process_func(item)
|
|
else:
|
|
# Run in thread pool for CPU-intensive work
|
|
loop = asyncio.get_event_loop()
|
|
result = await loop.run_in_executor(None, process_func, item)
|
|
|
|
results.append(result)
|
|
|
|
# Calculate progress
|
|
items_processed = batch_start + i + 1
|
|
progress = (items_processed / total_items) * 100
|
|
|
|
# Broadcast progress
|
|
await self.broadcast_progress({
|
|
"type": "progress",
|
|
"operation": operation_name,
|
|
"progress": progress,
|
|
"processed": items_processed,
|
|
"total": total_items,
|
|
"batch": f"Batch {batch_start // batch_size + 1}",
|
|
"current_item": str(getattr(item, "id", i)),
|
|
})
|
|
|
|
# Yield control for WebSocket health
|
|
await asyncio.sleep(self.config.yield_interval)
|
|
|
|
# Final completion message
|
|
await self.broadcast_progress({
|
|
"type": "complete",
|
|
"operation": operation_name,
|
|
"total_processed": len(results),
|
|
"success_rate": f"{len(results) / total_items * 100:.1f}%",
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
class ThreadingService:
|
|
"""Main threading service that coordinates all threading operations"""
|
|
|
|
def __init__(
|
|
self,
|
|
threading_config: ThreadingConfig | None = None,
|
|
rate_limit_config: RateLimitConfig | None = None,
|
|
):
|
|
self.config = threading_config or ThreadingConfig()
|
|
self.rate_limiter = RateLimiter(rate_limit_config or RateLimitConfig())
|
|
self.memory_dispatcher = MemoryAdaptiveDispatcher(self.config)
|
|
self.websocket_processor = WebSocketSafeProcessor(self.config)
|
|
|
|
# Thread pools for different workload types
|
|
self.cpu_executor = ThreadPoolExecutor(
|
|
max_workers=self.config.max_workers, thread_name_prefix="archon-cpu"
|
|
)
|
|
self.io_executor = ThreadPoolExecutor(
|
|
max_workers=self.config.max_workers * 2, thread_name_prefix="archon-io"
|
|
)
|
|
|
|
self._running = False
|
|
self._health_check_task = None
|
|
|
|
async def start(self):
|
|
"""Start the threading service"""
|
|
if self._running:
|
|
return
|
|
|
|
self._running = True
|
|
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
|
logfire_logger.info("Threading service started", extra={"config": self.config.__dict__})
|
|
|
|
async def stop(self):
|
|
"""Stop the threading service"""
|
|
if not self._running:
|
|
return
|
|
|
|
self._running = False
|
|
|
|
if self._health_check_task:
|
|
self._health_check_task.cancel()
|
|
try:
|
|
await self._health_check_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Shutdown thread pools
|
|
self.cpu_executor.shutdown(wait=True)
|
|
self.io_executor.shutdown(wait=True)
|
|
|
|
logfire_logger.info("Threading service stopped")
|
|
|
|
@asynccontextmanager
|
|
async def rate_limited_operation(self, estimated_tokens: int = 8000):
|
|
"""Context manager for rate-limited operations"""
|
|
async with self.rate_limiter.semaphore:
|
|
can_proceed = await self.rate_limiter.acquire(estimated_tokens)
|
|
if not can_proceed:
|
|
raise Exception("Rate limit exceeded")
|
|
|
|
start_time = time.time()
|
|
try:
|
|
yield
|
|
finally:
|
|
duration = time.time() - start_time
|
|
logfire_logger.debug(
|
|
"Rate limited operation completed",
|
|
extra={"duration": duration, "tokens": estimated_tokens},
|
|
)
|
|
|
|
async def run_cpu_intensive(self, func: Callable, *args, **kwargs) -> Any:
|
|
"""Run CPU-intensive function in thread pool"""
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(self.cpu_executor, func, *args, **kwargs)
|
|
|
|
async def run_io_bound(self, func: Callable, *args, **kwargs) -> Any:
|
|
"""Run I/O-bound function in thread pool"""
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(self.io_executor, func, *args, **kwargs)
|
|
|
|
async def batch_process(
|
|
self,
|
|
items: list[Any],
|
|
process_func: Callable,
|
|
mode: ProcessingMode = ProcessingMode.CPU_INTENSIVE,
|
|
websocket: WebSocket | None = None,
|
|
progress_callback: Callable | None = None,
|
|
enable_worker_tracking: bool = False,
|
|
) -> list[Any]:
|
|
"""Process items in batches with optimal threading"""
|
|
return await self.memory_dispatcher.process_with_adaptive_concurrency(
|
|
items=items,
|
|
process_func=process_func,
|
|
mode=mode,
|
|
websocket=websocket,
|
|
progress_callback=progress_callback,
|
|
enable_worker_tracking=enable_worker_tracking,
|
|
)
|
|
|
|
async def websocket_safe_process(
|
|
self, items: list[Any], process_func: Callable, operation_name: str = "processing"
|
|
) -> list[Any]:
|
|
"""Process items with WebSocket safety guarantees"""
|
|
return await self.websocket_processor.process_with_progress(
|
|
items=items, process_func=process_func, operation_name=operation_name
|
|
)
|
|
|
|
def get_system_metrics(self) -> SystemMetrics:
|
|
"""Get current system performance metrics"""
|
|
return self.memory_dispatcher.get_system_metrics()
|
|
|
|
async def _health_check_loop(self):
|
|
"""Monitor system health and adjust threading parameters"""
|
|
while self._running:
|
|
try:
|
|
metrics = self.get_system_metrics()
|
|
|
|
# Log system metrics
|
|
logfire_logger.info(
|
|
"System health check",
|
|
extra={
|
|
"memory_percent": metrics.memory_percent,
|
|
"cpu_percent": metrics.cpu_percent,
|
|
"available_memory_gb": metrics.available_memory_gb,
|
|
"active_threads": metrics.active_threads,
|
|
"active_websockets": len(self.websocket_processor.active_connections),
|
|
}
|
|
)
|
|
|
|
# Alert on critical thresholds
|
|
if metrics.memory_percent > 90:
|
|
logfire_logger.warning(
|
|
"Critical memory usage",
|
|
extra={"memory_percent": metrics.memory_percent}
|
|
)
|
|
# Force garbage collection
|
|
gc.collect()
|
|
|
|
if metrics.cpu_percent > 95:
|
|
logfire_logger.warning(
|
|
"Critical CPU usage", extra={"cpu_percent": metrics.cpu_percent}
|
|
)
|
|
|
|
# Check for memory leaks (too many threads)
|
|
if metrics.active_threads > self.config.max_workers * 3:
|
|
logfire_logger.warning(
|
|
"High thread count detected",
|
|
extra={
|
|
"active_threads": metrics.active_threads,
|
|
"max_expected": self.config.max_workers * 3,
|
|
}
|
|
)
|
|
|
|
await asyncio.sleep(self.config.health_check_interval)
|
|
|
|
except Exception as e:
|
|
logfire_logger.error("Health check failed", extra={"error": str(e)})
|
|
await asyncio.sleep(self.config.health_check_interval)
|
|
|
|
|
|
# Global threading service instance
|
|
_threading_service: ThreadingService | None = None
|
|
|
|
|
|
def get_threading_service() -> ThreadingService:
|
|
"""Get the global threading service instance"""
|
|
global _threading_service
|
|
if _threading_service is None:
|
|
_threading_service = ThreadingService()
|
|
return _threading_service
|
|
|
|
|
|
async def start_threading_service() -> ThreadingService:
|
|
"""Start the global threading service"""
|
|
service = get_threading_service()
|
|
await service.start()
|
|
return service
|
|
|
|
|
|
async def stop_threading_service():
|
|
"""Stop the global threading service"""
|
|
global _threading_service
|
|
if _threading_service:
|
|
await _threading_service.stop()
|
|
_threading_service = None |