The New Archon (Beta) - The Operating System for AI Coding Assistants!

This commit is contained in:
Cole Medin
2025-08-13 07:58:24 -05:00
parent 13e1fc6a0e
commit 59084036f6
603 changed files with 131376 additions and 417 deletions

1
python/tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Simplified test suite for Archon - Essential tests only."""

124
python/tests/conftest.py Normal file
View File

@@ -0,0 +1,124 @@
"""Simple test configuration for Archon - Essential tests only."""
import os
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
# Set test environment
os.environ["TEST_MODE"] = "true"
os.environ["TESTING"] = "true"
# Set fake database credentials to prevent connection attempts
os.environ["SUPABASE_URL"] = "https://test.supabase.co"
os.environ["SUPABASE_SERVICE_KEY"] = "test-key"
# Set required port environment variables for ServiceDiscovery
os.environ.setdefault("ARCHON_SERVER_PORT", "8181")
os.environ.setdefault("ARCHON_MCP_PORT", "8051")
os.environ.setdefault("ARCHON_AGENTS_PORT", "8052")
@pytest.fixture(autouse=True)
def prevent_real_db_calls():
"""Automatically prevent any real database calls in all tests."""
with patch("supabase.create_client") as mock_create:
# Make create_client raise an error if called without our mock
mock_create.side_effect = Exception("Real database calls are not allowed in tests!")
yield
@pytest.fixture
def mock_supabase_client():
"""Mock Supabase client for testing."""
mock_client = MagicMock()
# Mock table operations with chaining support
mock_table = MagicMock()
mock_select = MagicMock()
mock_insert = MagicMock()
mock_update = MagicMock()
mock_delete = MagicMock()
# Setup method chaining for select
mock_select.execute.return_value.data = []
mock_select.eq.return_value = mock_select
mock_select.neq.return_value = mock_select
mock_select.order.return_value = mock_select
mock_select.limit.return_value = mock_select
mock_table.select.return_value = mock_select
# Setup method chaining for insert
mock_insert.execute.return_value.data = [{"id": "test-id"}]
mock_table.insert.return_value = mock_insert
# Setup method chaining for update
mock_update.execute.return_value.data = [{"id": "test-id"}]
mock_update.eq.return_value = mock_update
mock_table.update.return_value = mock_update
# Setup method chaining for delete
mock_delete.execute.return_value.data = []
mock_delete.eq.return_value = mock_delete
mock_table.delete.return_value = mock_delete
# Make table() return the mock table
mock_client.table.return_value = mock_table
# Mock auth operations
mock_client.auth = MagicMock()
mock_client.auth.get_user.return_value = None
# Mock storage operations
mock_client.storage = MagicMock()
return mock_client
@pytest.fixture
def client(mock_supabase_client):
"""FastAPI test client with mocked database."""
# Patch all the ways Supabase client can be created
with patch(
"src.server.services.client_manager.create_client", return_value=mock_supabase_client
):
with patch(
"src.server.services.credential_service.create_client",
return_value=mock_supabase_client,
):
with patch(
"src.server.services.client_manager.get_supabase_client",
return_value=mock_supabase_client,
):
with patch("supabase.create_client", return_value=mock_supabase_client):
# Import app after patching to ensure mocks are used
from src.server.main import app
return TestClient(app)
@pytest.fixture
def test_project():
"""Simple test project data."""
return {"title": "Test Project", "description": "A test project for essential tests"}
@pytest.fixture
def test_task():
"""Simple test task data."""
return {
"title": "Test Task",
"description": "A test task for essential tests",
"status": "todo",
"assignee": "User",
}
@pytest.fixture
def test_knowledge_item():
"""Simple test knowledge item data."""
return {
"url": "https://example.com/test",
"title": "Test Knowledge Item",
"content": "This is test content for knowledge base",
"source_id": "test-source",
}

View File

@@ -0,0 +1,113 @@
"""Essential API tests - Focus on core functionality that must work."""
def test_health_endpoint(client):
"""Test that health endpoint returns OK status."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert data["status"] in ["healthy", "initializing"]
def test_create_project(client, test_project, mock_supabase_client):
"""Test creating a new project via API."""
# Set up mock to return a project
mock_supabase_client.table.return_value.insert.return_value.execute.return_value.data = [
{
"id": "test-project-id",
"title": test_project["title"],
"description": test_project["description"],
}
]
response = client.post("/api/projects", json=test_project)
# Should succeed with mocked data
assert response.status_code in [200, 201, 422, 500] # Allow various responses
# If successful, check response format
if response.status_code in [200, 201]:
data = response.json()
# Check response format - at least one of these should be present
assert (
"title" in data
or "id" in data
or "progress_id" in data
or "status" in data
or "message" in data
)
def test_list_projects(client, mock_supabase_client):
"""Test listing projects endpoint exists and responds."""
# Set up mock to return empty list (no projects)
mock_supabase_client.table.return_value.select.return_value.execute.return_value.data = []
response = client.get("/api/projects")
assert response.status_code in [200, 404, 422, 500] # Allow various responses
# If successful, response should be JSON (list or dict)
if response.status_code == 200:
data = response.json()
assert isinstance(data, (list, dict))
def test_create_task(client, test_task):
"""Test task creation endpoint exists."""
# Try the tasks endpoint directly
response = client.post("/api/tasks", json=test_task)
# Accept various status codes - endpoint exists
assert response.status_code in [200, 201, 400, 422, 405]
def test_list_tasks(client):
"""Test tasks listing endpoint exists."""
response = client.get("/api/tasks")
# Accept 200, 400, 422, or 500 - endpoint exists
assert response.status_code in [200, 400, 422, 500]
def test_start_crawl(client):
"""Test crawl endpoint exists and validates input."""
crawl_request = {"url": "https://example.com", "max_depth": 2, "max_pages": 10}
response = client.post("/api/knowledge/crawl", json=crawl_request)
# Accept various status codes - endpoint exists and processes request
assert response.status_code in [200, 201, 400, 404, 422, 500]
def test_search_knowledge(client):
"""Test knowledge search endpoint exists."""
response = client.post("/api/knowledge/search", json={"query": "test"})
# Accept various status codes - endpoint exists
assert response.status_code in [200, 400, 404, 422, 500]
def test_websocket_connection(client):
"""Test WebSocket/Socket.IO endpoint exists."""
response = client.get("/socket.io/")
# Socket.IO returns specific status codes
assert response.status_code in [200, 400, 404]
def test_authentication(client):
"""Test that API handles auth headers gracefully."""
# Test with no auth header
response = client.get("/api/projects")
assert response.status_code in [200, 401, 403, 500] # 500 is OK in test environment
# Test with invalid auth header
headers = {"Authorization": "Bearer invalid-token"}
response = client.get("/api/projects", headers=headers)
assert response.status_code in [200, 401, 403, 500] # 500 is OK in test environment
def test_error_handling(client):
"""Test API returns proper error responses."""
# Test non-existent endpoint
response = client.get("/api/nonexistent")
assert response.status_code == 404
# Test invalid JSON
response = client.post("/api/projects", data="invalid json")
assert response.status_code in [400, 422]

View File

@@ -0,0 +1,509 @@
"""
Comprehensive Tests for Async Background Task Manager
Tests the pure async background task manager after removal of ThreadPoolExecutor.
Focuses on async task execution, concurrency control, and progress tracking.
"""
import asyncio
from typing import Any
from unittest.mock import AsyncMock
import pytest
from src.server.services.background_task_manager import (
BackgroundTaskManager,
cleanup_task_manager,
get_task_manager,
)
class TestAsyncBackgroundTaskManager:
"""Test suite for async background task manager"""
@pytest.fixture
def task_manager(self):
"""Create a fresh task manager instance for each test"""
return BackgroundTaskManager(max_concurrent_tasks=5)
@pytest.fixture
def mock_progress_callback(self):
"""Mock progress callback function"""
return AsyncMock()
@pytest.mark.asyncio
async def test_task_manager_initialization(self, task_manager):
"""Test task manager initialization"""
assert task_manager.max_concurrent_tasks == 5
assert len(task_manager.active_tasks) == 0
assert len(task_manager.task_metadata) == 0
assert task_manager._task_semaphore._value == 5
@pytest.mark.asyncio
async def test_simple_async_task_execution(self, task_manager, mock_progress_callback):
"""Test execution of a simple async task"""
async def simple_task(message: str):
await asyncio.sleep(0.01) # Simulate async work
return f"Task completed: {message}"
task_id = await task_manager.submit_task(
simple_task, ("Hello World",), progress_callback=mock_progress_callback
)
# Wait for task completion
await asyncio.sleep(0.05)
# Check task status
status = await task_manager.get_task_status(task_id)
assert status["status"] == "complete"
assert status["progress"] == 100
assert status["result"] == "Task completed: Hello World"
# Verify progress callback was called
assert mock_progress_callback.call_count >= 1
@pytest.mark.asyncio
async def test_task_with_error(self, task_manager, mock_progress_callback):
"""Test handling of task that raises an exception"""
async def failing_task():
await asyncio.sleep(0.01)
raise ValueError("Task failed intentionally")
task_id = await task_manager.submit_task(
failing_task, (), progress_callback=mock_progress_callback
)
# Wait for task to fail
await asyncio.sleep(0.05)
# Check task status
status = await task_manager.get_task_status(task_id)
assert status["status"] == "error"
assert status["progress"] == -1
assert "error" in status
assert "Task failed intentionally" in status["error"]
# Verify error was reported via progress callback
error_call = None
for call in mock_progress_callback.call_args_list:
if len(call[0]) >= 2 and call[0][1].get("status") == "error":
error_call = call
break
assert error_call is not None
assert "Task failed intentionally" in error_call[0][1]["error"]
@pytest.mark.asyncio
async def test_concurrent_task_execution(self, task_manager):
"""Test execution of multiple concurrent tasks"""
async def numbered_task(number: int):
await asyncio.sleep(0.01)
return f"Task {number} completed"
# Submit 5 tasks simultaneously
task_ids = []
for i in range(5):
task_id = await task_manager.submit_task(numbered_task, (i,), task_id=f"task-{i}")
task_ids.append(task_id)
# Wait for all tasks to complete
await asyncio.sleep(0.05)
# Check all tasks completed successfully
for i, task_id in enumerate(task_ids):
status = await task_manager.get_task_status(task_id)
assert status["status"] == "complete"
assert status["result"] == f"Task {i} completed"
@pytest.mark.asyncio
async def test_concurrency_limit(self, task_manager):
"""Test that concurrency is limited by semaphore"""
# Use a task manager with limit of 2
limited_manager = BackgroundTaskManager(max_concurrent_tasks=2)
running_tasks = []
completed_tasks = []
async def long_running_task(task_id: int):
running_tasks.append(task_id)
await asyncio.sleep(0.05) # Long enough to test concurrency
completed_tasks.append(task_id)
return f"Task {task_id} completed"
# Submit 4 tasks
task_ids = []
for i in range(4):
task_id = await limited_manager.submit_task(
long_running_task, (i,), task_id=f"concurrent-task-{i}"
)
task_ids.append(task_id)
# Wait a bit and check that only 2 tasks are running
await asyncio.sleep(0.01)
assert len(running_tasks) <= 2
# Wait for all to complete
await asyncio.sleep(0.1)
assert len(completed_tasks) == 4
# Clean up
await limited_manager.cleanup()
@pytest.mark.asyncio
async def test_task_cancellation(self, task_manager):
"""Test cancellation of running task"""
async def long_task():
try:
await asyncio.sleep(1.0) # Long enough to be cancelled
return "Should not complete"
except asyncio.CancelledError:
raise # Re-raise to properly handle cancellation
task_id = await task_manager.submit_task(long_task, (), task_id="cancellable-task")
# Wait a bit, then cancel
await asyncio.sleep(0.01)
cancelled = await task_manager.cancel_task(task_id)
assert cancelled is True
# Check task status
await asyncio.sleep(0.01)
status = await task_manager.get_task_status(task_id)
assert status["status"] == "cancelled"
@pytest.mark.asyncio
async def test_task_not_found(self, task_manager):
"""Test getting status of non-existent task"""
status = await task_manager.get_task_status("non-existent-task")
assert status["error"] == "Task not found"
@pytest.mark.asyncio
async def test_cancel_non_existent_task(self, task_manager):
"""Test cancelling non-existent task"""
cancelled = await task_manager.cancel_task("non-existent-task")
assert cancelled is False
@pytest.mark.asyncio
async def test_progress_callback_execution(self, task_manager):
"""Test that progress callback is properly executed"""
progress_updates = []
async def mock_progress_callback(task_id: str, update: dict[str, Any]):
progress_updates.append((task_id, update))
async def simple_task():
await asyncio.sleep(0.01)
return "completed"
task_id = await task_manager.submit_task(
simple_task, (), task_id="progress-test-task", progress_callback=mock_progress_callback
)
# Wait for completion
await asyncio.sleep(0.05)
# Should have at least one progress update (completion)
assert len(progress_updates) >= 1
# Check that task_id matches
assert all(update[0] == task_id for update in progress_updates)
# Check for completion update
completion_updates = [
update for update in progress_updates if update[1].get("status") == "complete"
]
assert len(completion_updates) >= 1
assert completion_updates[0][1]["percentage"] == 100
@pytest.mark.asyncio
async def test_progress_callback_error_handling(self, task_manager):
"""Test that task continues even if progress callback fails"""
async def failing_progress_callback(task_id: str, update: dict[str, Any]):
raise Exception("Progress callback failed")
async def simple_task():
await asyncio.sleep(0.01)
return "Task completed despite callback failure"
task_id = await task_manager.submit_task(
simple_task, (), progress_callback=failing_progress_callback
)
# Wait for completion
await asyncio.sleep(0.05)
# Task should still complete successfully
status = await task_manager.get_task_status(task_id)
assert status["status"] == "complete"
assert status["result"] == "Task completed despite callback failure"
@pytest.mark.asyncio
async def test_task_metadata_tracking(self, task_manager):
"""Test that task metadata is properly tracked"""
async def simple_task():
await asyncio.sleep(0.01)
return "result"
task_id = await task_manager.submit_task(simple_task, (), task_id="metadata-test")
# Check initial metadata
initial_status = await task_manager.get_task_status(task_id)
assert initial_status["status"] == "running"
assert "created_at" in initial_status
assert initial_status["progress"] == 0
# Wait for completion
await asyncio.sleep(0.05)
# Check final metadata
final_status = await task_manager.get_task_status(task_id)
assert final_status["status"] == "complete"
assert final_status["progress"] == 100
assert final_status["result"] == "result"
@pytest.mark.asyncio
async def test_cleanup_active_tasks(self, task_manager):
"""Test cleanup cancels active tasks"""
async def long_running_task():
try:
await asyncio.sleep(1.0)
return "Should not complete"
except asyncio.CancelledError:
raise
# Submit multiple long-running tasks
task_ids = []
for i in range(3):
task_id = await task_manager.submit_task(
long_running_task, (), task_id=f"cleanup-test-{i}"
)
task_ids.append(task_id)
# Verify tasks are active
await asyncio.sleep(0.01)
assert len(task_manager.active_tasks) == 3
# Cleanup
await task_manager.cleanup()
# Verify all tasks were cancelled and cleaned up
assert len(task_manager.active_tasks) == 0
assert len(task_manager.task_metadata) == 0
@pytest.mark.asyncio
async def test_completed_task_status_after_removal(self, task_manager):
"""Test getting status of completed task after it's removed from active_tasks"""
async def quick_task():
return "quick result"
task_id = await task_manager.submit_task(quick_task, (), task_id="quick-test")
# Wait for completion and removal from active_tasks
await asyncio.sleep(0.05)
# Should still be able to get status from metadata
status = await task_manager.get_task_status(task_id)
assert status["status"] == "complete"
assert status["result"] == "quick result"
def test_set_main_loop_deprecated(self, task_manager):
"""Test that set_main_loop is deprecated but doesn't break"""
# Should not raise an exception but may log a warning
import asyncio
loop = asyncio.new_event_loop()
task_manager.set_main_loop(loop)
loop.close()
class TestGlobalTaskManager:
"""Test the global task manager functions"""
def test_get_task_manager_singleton(self):
"""Test that get_task_manager returns singleton"""
manager1 = get_task_manager()
manager2 = get_task_manager()
assert manager1 is manager2
@pytest.mark.asyncio
async def test_cleanup_task_manager(self):
"""Test cleanup of global task manager"""
# Get the global manager
manager = get_task_manager()
assert manager is not None
# Add a task to make it interesting
async def test_task():
return "test"
task_id = await manager.submit_task(test_task, ())
await asyncio.sleep(0.01)
# Cleanup
await cleanup_task_manager()
# Verify it was cleaned up - getting a new one should be different
new_manager = get_task_manager()
assert new_manager is not manager
class TestAsyncTaskPatterns:
"""Test various async task patterns and edge cases"""
@pytest.fixture
def task_manager(self):
return BackgroundTaskManager(max_concurrent_tasks=3)
@pytest.mark.asyncio
async def test_nested_async_calls(self, task_manager):
"""Test tasks that make nested async calls"""
async def nested_task():
async def inner_task():
await asyncio.sleep(0.01)
return "inner result"
result = await inner_task()
return f"outer: {result}"
task_id = await task_manager.submit_task(nested_task, ())
await asyncio.sleep(0.05)
status = await task_manager.get_task_status(task_id)
assert status["status"] == "complete"
assert status["result"] == "outer: inner result"
@pytest.mark.asyncio
async def test_task_with_async_context_manager(self, task_manager):
"""Test tasks that use async context managers"""
class AsyncResource:
def __init__(self):
self.entered = False
self.exited = False
async def __aenter__(self):
await asyncio.sleep(0.001)
self.entered = True
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await asyncio.sleep(0.001)
self.exited = True
resource = AsyncResource()
async def context_manager_task():
async with resource:
await asyncio.sleep(0.01)
return "context manager used"
task_id = await task_manager.submit_task(context_manager_task, ())
await asyncio.sleep(0.05)
status = await task_manager.get_task_status(task_id)
assert status["status"] == "complete"
assert status["result"] == "context manager used"
assert resource.entered
assert resource.exited
@pytest.mark.asyncio
async def test_task_cancellation_propagation(self, task_manager):
"""Test that cancellation properly propagates through nested calls"""
cancelled_flags = []
async def cancellable_inner():
try:
await asyncio.sleep(1.0)
return "should not complete"
except asyncio.CancelledError:
cancelled_flags.append("inner")
raise
async def cancellable_outer():
try:
result = await cancellable_inner()
return f"outer: {result}"
except asyncio.CancelledError:
cancelled_flags.append("outer")
raise
task_id = await task_manager.submit_task(cancellable_outer, ())
await asyncio.sleep(0.01)
# Cancel the task
cancelled = await task_manager.cancel_task(task_id)
assert cancelled
await asyncio.sleep(0.01)
# Both inner and outer should have been cancelled
assert "inner" in cancelled_flags
assert "outer" in cancelled_flags
@pytest.mark.asyncio
async def test_high_concurrency_stress_test(self, task_manager):
"""Stress test with many concurrent tasks"""
async def stress_task(task_num: int):
await asyncio.sleep(0.001 * (task_num % 10)) # Vary sleep time
return f"stress-{task_num}"
# Submit many tasks
task_ids = []
num_tasks = 20
for i in range(num_tasks):
task_id = await task_manager.submit_task(stress_task, (i,), task_id=f"stress-{i}")
task_ids.append(task_id)
# Wait for all to complete
await asyncio.sleep(0.5)
# Verify all completed successfully
for i, task_id in enumerate(task_ids):
status = await task_manager.get_task_status(task_id)
assert status["status"] == "complete"
assert status["result"] == f"stress-{i}"
@pytest.mark.asyncio
async def test_task_execution_order_with_semaphore(self, task_manager):
"""Test that semaphore properly controls execution order"""
# Use manager with limit of 2
limited_manager = BackgroundTaskManager(max_concurrent_tasks=2)
execution_order = []
async def ordered_task(task_id: int):
execution_order.append(f"start-{task_id}")
await asyncio.sleep(0.02)
execution_order.append(f"end-{task_id}")
return task_id
# Submit 4 tasks
task_ids = []
for i in range(4):
task_id = await limited_manager.submit_task(ordered_task, (i,), task_id=f"order-{i}")
task_ids.append(task_id)
# Wait for completion
await asyncio.sleep(0.2)
# Verify execution pattern - should see at most 2 concurrent executions
starts_before_ends = 0
for i, event in enumerate(execution_order):
if event.startswith("start-"):
# Count how many starts we've seen before the first end
starts_seen = sum(1 for e in execution_order[: i + 1] if e.startswith("start-"))
ends_seen = sum(1 for e in execution_order[: i + 1] if e.startswith("end-"))
concurrent = starts_seen - ends_seen
assert concurrent <= 2 # Should never exceed semaphore limit
await limited_manager.cleanup()

View File

@@ -0,0 +1,414 @@
"""
Comprehensive Tests for Async Credential Service
Tests the credential service async functions after sync function removal.
Covers credential storage, retrieval, encryption/decryption, and caching.
"""
import asyncio
import os
from unittest.mock import MagicMock, patch
import pytest
from src.server.services.credential_service import (
credential_service,
get_credential,
initialize_credentials,
set_credential,
)
class TestAsyncCredentialService:
"""Test suite for async credential service functions"""
@pytest.fixture(autouse=True)
def setup_credential_service(self):
"""Setup clean credential service for each test"""
# Clear cache and reset state
credential_service._cache.clear()
credential_service._cache_initialized = False
yield
# Cleanup after test
credential_service._cache.clear()
credential_service._cache_initialized = False
@pytest.fixture
def mock_supabase_client(self):
"""Mock Supabase client"""
mock_client = MagicMock()
mock_table = MagicMock()
mock_client.table.return_value = mock_table
return mock_client, mock_table
@pytest.fixture
def sample_credentials_data(self):
"""Sample credentials data from database"""
return [
{
"id": 1,
"key": "OPENAI_API_KEY",
"encrypted_value": "encrypted_openai_key",
"value": None,
"is_encrypted": True,
"category": "api_keys",
"description": "OpenAI API key for LLM access",
},
{
"id": 2,
"key": "MODEL_CHOICE",
"value": "gpt-4.1-nano",
"encrypted_value": None,
"is_encrypted": False,
"category": "rag_strategy",
"description": "Default model choice",
},
{
"id": 3,
"key": "MAX_TOKENS",
"value": "1000",
"encrypted_value": None,
"is_encrypted": False,
"category": "rag_strategy",
"description": "Maximum tokens per request",
},
]
def test_deprecated_functions_removed(self):
"""Test that deprecated sync functions are no longer available"""
import src.server.services.credential_service as cred_module
# The sync function should no longer exist
assert not hasattr(cred_module, "get_credential_sync")
# The async versions should be the primary functions
assert hasattr(cred_module, "get_credential")
assert hasattr(cred_module, "set_credential")
@pytest.mark.asyncio
async def test_get_credential_from_cache(self):
"""Test getting credential from initialized cache"""
# Setup cache
credential_service._cache = {"TEST_KEY": "test_value", "NUMERIC_KEY": "123"}
credential_service._cache_initialized = True
result = await get_credential("TEST_KEY", "default")
assert result == "test_value"
result = await get_credential("NUMERIC_KEY", "default")
assert result == "123"
result = await get_credential("MISSING_KEY", "default_value")
assert result == "default_value"
@pytest.mark.asyncio
async def test_get_credential_encrypted_value(self):
"""Test getting encrypted credential"""
# Setup cache with encrypted value
encrypted_data = {"encrypted_value": "encrypted_test_value", "is_encrypted": True}
credential_service._cache = {"SECRET_KEY": encrypted_data}
credential_service._cache_initialized = True
with patch.object(credential_service, "_decrypt_value", return_value="decrypted_value"):
result = await get_credential("SECRET_KEY", "default")
assert result == "decrypted_value"
credential_service._decrypt_value.assert_called_once_with("encrypted_test_value")
@pytest.mark.asyncio
async def test_get_credential_cache_not_initialized(self, mock_supabase_client):
"""Test getting credential when cache is not initialized"""
mock_client, mock_table = mock_supabase_client
# Mock database response for load_all_credentials (gets ALL settings)
mock_response = MagicMock()
mock_response.data = [
{
"key": "TEST_KEY",
"value": "db_value",
"encrypted_value": None,
"is_encrypted": False,
"category": "test",
"description": "Test key",
}
]
mock_table.select().execute.return_value = mock_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
result = await credential_service.get_credential("TEST_KEY", "default")
assert result == "db_value"
# Should have called database to load all credentials
mock_table.select.assert_called_with("*")
# Should have called execute on the query
assert mock_table.select().execute.called
@pytest.mark.asyncio
async def test_get_credential_not_found_in_db(self, mock_supabase_client):
"""Test getting credential that doesn't exist in database"""
mock_client, mock_table = mock_supabase_client
# Mock empty database response
mock_response = MagicMock()
mock_response.data = []
mock_table.select().eq().execute.return_value = mock_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
result = await credential_service.get_credential("MISSING_KEY", "default_value")
assert result == "default_value"
@pytest.mark.asyncio
async def test_set_credential_new(self, mock_supabase_client):
"""Test setting a new credential"""
mock_client, mock_table = mock_supabase_client
# Mock successful insert
mock_response = MagicMock()
mock_response.data = [{"id": 1, "key": "NEW_KEY", "value": "new_value"}]
mock_table.insert().execute.return_value = mock_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
result = await set_credential("NEW_KEY", "new_value", is_encrypted=False)
assert result is True
# Should have attempted insert
mock_table.insert.assert_called_once()
@pytest.mark.asyncio
async def test_set_credential_encrypted(self, mock_supabase_client):
"""Test setting an encrypted credential"""
mock_client, mock_table = mock_supabase_client
# Mock successful insert
mock_response = MagicMock()
mock_response.data = [{"id": 1, "key": "SECRET_KEY"}]
mock_table.insert().execute.return_value = mock_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
with patch.object(credential_service, "_encrypt_value", return_value="encrypted_value"):
result = await set_credential("SECRET_KEY", "secret_value", is_encrypted=True)
assert result is True
# Should have encrypted the value
credential_service._encrypt_value.assert_called_once_with("secret_value")
@pytest.mark.asyncio
async def test_load_all_credentials(self, mock_supabase_client, sample_credentials_data):
"""Test loading all credentials from database"""
mock_client, mock_table = mock_supabase_client
# Mock database response
mock_response = MagicMock()
mock_response.data = sample_credentials_data
mock_table.select().execute.return_value = mock_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
result = await credential_service.load_all_credentials()
# Should have loaded credentials into cache
assert credential_service._cache_initialized is True
assert "OPENAI_API_KEY" in credential_service._cache
assert "MODEL_CHOICE" in credential_service._cache
assert "MAX_TOKENS" in credential_service._cache
# Should have stored encrypted values as dict objects (not decrypted yet)
openai_key_cache = credential_service._cache["OPENAI_API_KEY"]
assert isinstance(openai_key_cache, dict)
assert openai_key_cache["encrypted_value"] == "encrypted_openai_key"
assert openai_key_cache["is_encrypted"] is True
# Plain text values should be stored directly
assert credential_service._cache["MODEL_CHOICE"] == "gpt-4.1-nano"
@pytest.mark.asyncio
async def test_get_credentials_by_category(self, mock_supabase_client):
"""Test getting credentials filtered by category"""
mock_client, mock_table = mock_supabase_client
# Mock database response for rag_strategy category
rag_data = [
{
"key": "MODEL_CHOICE",
"value": "gpt-4.1-nano",
"is_encrypted": False,
"description": "Model choice",
},
{
"key": "MAX_TOKENS",
"value": "1000",
"is_encrypted": False,
"description": "Max tokens",
},
]
mock_response = MagicMock()
mock_response.data = rag_data
mock_table.select().eq().execute.return_value = mock_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
result = await credential_service.get_credentials_by_category("rag_strategy")
# Should only return rag_strategy credentials
assert "MODEL_CHOICE" in result
assert "MAX_TOKENS" in result
assert result["MODEL_CHOICE"] == "gpt-4.1-nano"
assert result["MAX_TOKENS"] == "1000"
@pytest.mark.asyncio
async def test_get_active_provider_llm(self, mock_supabase_client):
"""Test getting active LLM provider configuration"""
mock_client, mock_table = mock_supabase_client
# Setup cache directly instead of mocking complex database responses
credential_service._cache = {
"LLM_PROVIDER": "openai",
"MODEL_CHOICE": "gpt-4.1-nano",
"OPENAI_API_KEY": {
"encrypted_value": "encrypted_key",
"is_encrypted": True,
"category": "api_keys",
"description": "API key",
},
}
credential_service._cache_initialized = True
# Mock rag_strategy category response
rag_response = MagicMock()
rag_response.data = [
{
"key": "LLM_PROVIDER",
"value": "openai",
"is_encrypted": False,
"description": "LLM provider",
},
{
"key": "MODEL_CHOICE",
"value": "gpt-4.1-nano",
"is_encrypted": False,
"description": "Model choice",
},
]
mock_table.select().eq().execute.return_value = rag_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
with patch.object(credential_service, "_decrypt_value", return_value="decrypted_key"):
result = await credential_service.get_active_provider("llm")
assert result["provider"] == "openai"
assert result["api_key"] == "decrypted_key"
assert result["chat_model"] == "gpt-4.1-nano"
@pytest.mark.asyncio
async def test_get_active_provider_basic(self, mock_supabase_client):
"""Test basic provider configuration retrieval"""
mock_client, mock_table = mock_supabase_client
# Simple mock response
mock_response = MagicMock()
mock_response.data = []
mock_table.select().eq().execute.return_value = mock_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
result = await credential_service.get_active_provider("llm")
# Should return default values when no settings found
assert "provider" in result
assert "api_key" in result
@pytest.mark.asyncio
async def test_initialize_credentials(self, mock_supabase_client, sample_credentials_data):
"""Test initialize_credentials function"""
mock_client, mock_table = mock_supabase_client
# Mock database response
mock_response = MagicMock()
mock_response.data = sample_credentials_data
mock_table.select().execute.return_value = mock_response
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
with patch.object(credential_service, "_decrypt_value", return_value="decrypted_key"):
with patch.dict(os.environ, {}, clear=True): # Clear environment
await initialize_credentials()
# Should have loaded credentials
assert credential_service._cache_initialized is True
# Should have set infrastructure env vars (like OPENAI_API_KEY)
# Note: This tests the logic, actual env var setting depends on implementation
@pytest.mark.asyncio
async def test_error_handling_database_failure(self, mock_supabase_client):
"""Test error handling when database fails"""
mock_client, mock_table = mock_supabase_client
# Mock database error
mock_table.select().eq().execute.side_effect = Exception("Database connection failed")
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
result = await credential_service.get_credential("TEST_KEY", "default_value")
assert result == "default_value"
@pytest.mark.asyncio
async def test_encryption_decryption_error_handling(self):
"""Test error handling for encryption/decryption failures"""
# Setup cache with encrypted value that fails to decrypt
encrypted_data = {"encrypted_value": "corrupted_encrypted_value", "is_encrypted": True}
credential_service._cache = {"CORRUPTED_KEY": encrypted_data}
credential_service._cache_initialized = True
with patch.object(
credential_service, "_decrypt_value", side_effect=Exception("Decryption failed")
):
# Should fall back to default when decryption fails
result = await credential_service.get_credential("CORRUPTED_KEY", "fallback_value")
assert result == "fallback_value"
def test_direct_cache_access_fallback(self):
"""Test direct cache access pattern used in converted sync functions"""
# Setup cache
credential_service._cache = {
"MODEL_CHOICE": "gpt-4.1-nano",
"OPENAI_API_KEY": {"encrypted_value": "encrypted_key", "is_encrypted": True},
}
credential_service._cache_initialized = True
# Test simple cache access
if credential_service._cache_initialized and "MODEL_CHOICE" in credential_service._cache:
result = credential_service._cache["MODEL_CHOICE"]
assert result == "gpt-4.1-nano"
# Test encrypted value access
if credential_service._cache_initialized and "OPENAI_API_KEY" in credential_service._cache:
cached_key = credential_service._cache["OPENAI_API_KEY"]
if isinstance(cached_key, dict) and cached_key.get("is_encrypted"):
# Would need to call credential_service._decrypt_value(cached_key["encrypted_value"])
assert cached_key["encrypted_value"] == "encrypted_key"
assert cached_key["is_encrypted"] is True
@pytest.mark.asyncio
async def test_concurrent_access(self):
"""Test concurrent access to credential service"""
credential_service._cache = {"SHARED_KEY": "shared_value"}
credential_service._cache_initialized = True
async def get_credential_task():
return await get_credential("SHARED_KEY", "default")
# Run multiple concurrent requests
tasks = [get_credential_task() for _ in range(10)]
results = await asyncio.gather(*tasks)
# All should return the same value
assert all(result == "shared_value" for result in results)
@pytest.mark.asyncio
async def test_cache_persistence(self):
"""Test that cache persists across calls"""
credential_service._cache = {"PERSISTENT_KEY": "persistent_value"}
credential_service._cache_initialized = True
# First call
result1 = await get_credential("PERSISTENT_KEY", "default")
assert result1 == "persistent_value"
# Second call should use same cache
result2 = await get_credential("PERSISTENT_KEY", "default")
assert result2 == "persistent_value"
assert result1 == result2

View File

@@ -0,0 +1,469 @@
"""
Comprehensive Tests for Async Embedding Service
Tests all aspects of the async embedding service after sync function removal.
Covers both success and error scenarios with thorough edge case testing.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import openai
import pytest
from src.server.services.embeddings.embedding_exceptions import (
EmbeddingAPIError,
)
from src.server.services.embeddings.embedding_service import (
EmbeddingBatchResult,
create_embedding,
create_embeddings_batch,
)
class AsyncContextManager:
"""Helper class for properly mocking async context managers"""
def __init__(self, return_value):
self.return_value = return_value
async def __aenter__(self):
return self.return_value
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
class TestAsyncEmbeddingService:
"""Test suite for async embedding service functions"""
@pytest.fixture
def mock_llm_client(self):
"""Mock LLM client for testing"""
mock_client = MagicMock()
mock_embeddings = MagicMock()
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1, 0.2, 0.3] + [0.0] * 1533) # 1536 dimensions
]
mock_embeddings.create = AsyncMock(return_value=mock_response)
mock_client.embeddings = mock_embeddings
return mock_client
@pytest.fixture
def mock_threading_service(self):
"""Mock threading service for testing"""
mock_service = MagicMock()
# Create a proper async context manager
rate_limit_ctx = AsyncContextManager(None)
mock_service.rate_limited_operation.return_value = rate_limit_ctx
return mock_service
@pytest.mark.asyncio
async def test_create_embedding_success(self, mock_llm_client, mock_threading_service):
"""Test successful single embedding creation"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
# Mock credential service properly
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
# Setup proper async context manager
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
result = await create_embedding("test text")
# Verify the result
assert len(result) == 1536
assert result[0] == 0.1
assert result[1] == 0.2
assert result[2] == 0.3
# Verify API was called correctly
mock_llm_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_create_embedding_empty_text(self, mock_llm_client, mock_threading_service):
"""Test embedding creation with empty text"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
result = await create_embedding("")
# Should still work with empty text
assert len(result) == 1536
mock_llm_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_create_embedding_api_error_raises_exception(self, mock_threading_service):
"""Test embedding creation with API error - should raise exception"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
# Setup client to raise an error
mock_client = MagicMock()
mock_client.embeddings.create = AsyncMock(
side_effect=Exception("API Error")
)
mock_get_client.return_value = AsyncContextManager(mock_client)
# Should raise exception now instead of returning zero embeddings
with pytest.raises(EmbeddingAPIError):
await create_embedding("test text")
@pytest.mark.asyncio
async def test_create_embeddings_batch_success(self, mock_llm_client, mock_threading_service):
"""Test successful batch embedding creation"""
# Setup mock response for multiple embeddings
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1, 0.2, 0.3] + [0.0] * 1533),
MagicMock(embedding=[0.4, 0.5, 0.6] + [0.0] * 1533),
]
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
result = await create_embeddings_batch(["text1", "text2"])
# Verify the result is EmbeddingBatchResult
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 2
assert result.failure_count == 0
assert len(result.embeddings) == 2
assert len(result.embeddings[0]) == 1536
assert len(result.embeddings[1]) == 1536
assert result.embeddings[0][0] == 0.1
assert result.embeddings[1][0] == 0.4
mock_llm_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_create_embeddings_batch_empty_list(self):
"""Test batch embedding with empty list"""
result = await create_embeddings_batch([])
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 0
assert result.failure_count == 0
assert result.embeddings == []
@pytest.mark.asyncio
async def test_create_embeddings_batch_rate_limit_error(self, mock_threading_service):
"""Test batch embedding with rate limit error"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
# Setup client to raise rate limit error (not quota)
mock_client = MagicMock()
# Create a proper RateLimitError with required attributes
error = openai.RateLimitError(
"Rate limit exceeded",
response=MagicMock(),
body={"error": {"message": "Rate limit exceeded"}},
)
mock_client.embeddings.create = AsyncMock(side_effect=error)
mock_get_client.return_value = AsyncContextManager(mock_client)
result = await create_embeddings_batch(["text1", "text2"])
# Should return result with failures, not zero embeddings
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 0
assert result.failure_count == 2
assert len(result.embeddings) == 0
assert len(result.failed_items) == 2
@pytest.mark.asyncio
async def test_create_embeddings_batch_quota_exhausted(self, mock_threading_service):
"""Test batch embedding with quota exhausted error"""
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
# Setup client to raise quota exhausted error
mock_client = MagicMock()
error = openai.RateLimitError(
"insufficient_quota",
response=MagicMock(),
body={"error": {"message": "insufficient_quota"}},
)
mock_client.embeddings.create = AsyncMock(side_effect=error)
mock_get_client.return_value = AsyncContextManager(mock_client)
# Mock progress callback
progress_callback = AsyncMock()
result = await create_embeddings_batch(
["text1", "text2"], progress_callback=progress_callback
)
# Should return result with failures, not zero embeddings
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 0
assert result.failure_count == 2
assert len(result.embeddings) == 0
assert len(result.failed_items) == 2
# Verify quota exhausted is in error messages
assert any("quota" in item["error"].lower() for item in result.failed_items)
@pytest.mark.asyncio
async def test_create_embeddings_batch_with_websocket_progress(
self, mock_llm_client, mock_threading_service
):
"""Test batch embedding with WebSocket progress updates"""
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "1"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
# Mock WebSocket
mock_websocket = MagicMock()
mock_websocket.send_json = AsyncMock()
result = await create_embeddings_batch(["text1"], websocket=mock_websocket)
# Verify result is correct
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 1
# Verify WebSocket was called
mock_websocket.send_json.assert_called()
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "embedding_progress"
assert "processed" in call_args
assert "total" in call_args
@pytest.mark.asyncio
async def test_create_embeddings_batch_with_progress_callback(
self, mock_llm_client, mock_threading_service
):
"""Test batch embedding with progress callback"""
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "1"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
# Mock progress callback
progress_callback = AsyncMock()
result = await create_embeddings_batch(
["text1"], progress_callback=progress_callback
)
# Verify result
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 1
# Verify progress callback was called
progress_callback.assert_called()
@pytest.mark.asyncio
async def test_provider_override(self, mock_llm_client, mock_threading_service):
"""Test that provider override parameter is properly passed through"""
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model"
) as mock_get_model:
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "10"}
)
mock_get_model.return_value = "custom-model"
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
await create_embedding("test text", provider="custom-provider")
# Verify provider was passed to get_llm_client
mock_get_client.assert_called_with(
provider="custom-provider", use_embedding_provider=True
)
mock_get_model.assert_called_with(provider="custom-provider")
@pytest.mark.asyncio
async def test_create_embeddings_batch_large_batch_splitting(
self, mock_llm_client, mock_threading_service
):
"""Test that large batches are properly split according to batch size settings"""
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1] * 1536) for _ in range(2)
] # 2 embeddings per call
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch(
"src.server.services.embeddings.embedding_service.get_threading_service",
return_value=mock_threading_service,
):
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_get_client:
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
return_value="text-embedding-3-small",
):
with patch(
"src.server.services.embeddings.embedding_service.credential_service"
) as mock_cred:
# Set batch size to 2
mock_cred.get_credentials_by_category = AsyncMock(
return_value={"EMBEDDING_BATCH_SIZE": "2"}
)
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
# Test with 5 texts (should require 3 API calls: 2+2+1)
texts = ["text1", "text2", "text3", "text4", "text5"]
result = await create_embeddings_batch(texts)
# Should have made 3 API calls due to batching
assert mock_llm_client.embeddings.create.call_count == 3
# Result should be EmbeddingBatchResult
assert isinstance(result, EmbeddingBatchResult)
# Should have 5 embeddings total (for 5 input texts)
# Even though mock returns 2 per call, we only process as many as we requested
assert result.success_count == 5
assert len(result.embeddings) == 5
assert result.texts_processed == texts

View File

@@ -0,0 +1,474 @@
"""
Comprehensive Tests for Async LLM Provider Service
Tests all aspects of the async LLM provider service after sync function removal.
Covers different providers (OpenAI, Ollama, Google) and error scenarios.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.server.services.llm_provider_service import (
_get_cached_settings,
_set_cached_settings,
get_embedding_model,
get_llm_client,
)
class AsyncContextManager:
"""Helper class for properly mocking async context managers"""
def __init__(self, return_value):
self.return_value = return_value
async def __aenter__(self):
return self.return_value
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
class TestAsyncLLMProviderService:
"""Test suite for async LLM provider service functions"""
@pytest.fixture(autouse=True)
def clear_cache(self):
"""Clear cache before each test"""
import src.server.services.llm_provider_service as llm_module
llm_module._settings_cache.clear()
yield
llm_module._settings_cache.clear()
@pytest.fixture
def mock_credential_service(self):
"""Mock credential service"""
mock_service = MagicMock()
mock_service.get_active_provider = AsyncMock()
mock_service.get_credentials_by_category = AsyncMock()
mock_service._get_provider_api_key = AsyncMock()
mock_service._get_provider_base_url = MagicMock()
return mock_service
@pytest.fixture
def openai_provider_config(self):
"""Standard OpenAI provider config"""
return {
"provider": "openai",
"api_key": "test-openai-key",
"base_url": None,
"chat_model": "gpt-4.1-nano",
"embedding_model": "text-embedding-3-small",
}
@pytest.fixture
def ollama_provider_config(self):
"""Standard Ollama provider config"""
return {
"provider": "ollama",
"api_key": "ollama",
"base_url": "http://localhost:11434/v1",
"chat_model": "llama2",
"embedding_model": "nomic-embed-text",
}
@pytest.fixture
def google_provider_config(self):
"""Standard Google provider config"""
return {
"provider": "google",
"api_key": "test-google-key",
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
"chat_model": "gemini-pro",
"embedding_model": "text-embedding-004",
}
@pytest.mark.asyncio
async def test_get_llm_client_openai_success(
self, mock_credential_service, openai_provider_config
):
"""Test successful OpenAI client creation"""
mock_credential_service.get_active_provider.return_value = openai_provider_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
assert client == mock_client
mock_openai.assert_called_once_with(api_key="test-openai-key")
# Verify provider config was fetched
mock_credential_service.get_active_provider.assert_called_once_with("llm")
@pytest.mark.asyncio
async def test_get_llm_client_ollama_success(
self, mock_credential_service, ollama_provider_config
):
"""Test successful Ollama client creation"""
mock_credential_service.get_active_provider.return_value = ollama_provider_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
assert client == mock_client
mock_openai.assert_called_once_with(
api_key="ollama", base_url="http://localhost:11434/v1"
)
@pytest.mark.asyncio
async def test_get_llm_client_google_success(
self, mock_credential_service, google_provider_config
):
"""Test successful Google client creation"""
mock_credential_service.get_active_provider.return_value = google_provider_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
assert client == mock_client
mock_openai.assert_called_once_with(
api_key="test-google-key",
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
@pytest.mark.asyncio
async def test_get_llm_client_with_provider_override(self, mock_credential_service):
"""Test client creation with explicit provider override (OpenAI)"""
mock_credential_service._get_provider_api_key.return_value = "override-key"
mock_credential_service.get_credentials_by_category.return_value = {"LLM_BASE_URL": ""}
mock_credential_service._get_provider_base_url.return_value = None
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
async with get_llm_client(provider="openai") as client:
assert client == mock_client
mock_openai.assert_called_once_with(api_key="override-key")
# Verify explicit provider API key was requested
mock_credential_service._get_provider_api_key.assert_called_once_with("openai")
@pytest.mark.asyncio
async def test_get_llm_client_use_embedding_provider(self, mock_credential_service):
"""Test client creation with embedding provider preference"""
embedding_config = {
"provider": "openai",
"api_key": "embedding-key",
"base_url": None,
"chat_model": "gpt-4",
"embedding_model": "text-embedding-3-large",
}
mock_credential_service.get_active_provider.return_value = embedding_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
async with get_llm_client(use_embedding_provider=True) as client:
assert client == mock_client
mock_openai.assert_called_once_with(api_key="embedding-key")
# Verify embedding provider was requested
mock_credential_service.get_active_provider.assert_called_once_with("embedding")
@pytest.mark.asyncio
async def test_get_llm_client_missing_openai_key(self, mock_credential_service):
"""Test error handling when OpenAI API key is missing"""
config_without_key = {
"provider": "openai",
"api_key": None,
"base_url": None,
"chat_model": "gpt-4",
"embedding_model": "text-embedding-3-small",
}
mock_credential_service.get_active_provider.return_value = config_without_key
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with pytest.raises(ValueError, match="OpenAI API key not found"):
async with get_llm_client():
pass
@pytest.mark.asyncio
async def test_get_llm_client_missing_google_key(self, mock_credential_service):
"""Test error handling when Google API key is missing"""
config_without_key = {
"provider": "google",
"api_key": None,
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
"chat_model": "gemini-pro",
"embedding_model": "text-embedding-004",
}
mock_credential_service.get_active_provider.return_value = config_without_key
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with pytest.raises(ValueError, match="Google API key not found"):
async with get_llm_client():
pass
@pytest.mark.asyncio
async def test_get_llm_client_unsupported_provider_error(self, mock_credential_service):
"""Test error when unsupported provider is configured"""
unsupported_config = {
"provider": "unsupported",
"api_key": "some-key",
"base_url": None,
"chat_model": "some-model",
"embedding_model": "",
}
mock_credential_service.get_active_provider.return_value = unsupported_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with pytest.raises(ValueError, match="Unsupported LLM provider: unsupported"):
async with get_llm_client():
pass
@pytest.mark.asyncio
async def test_get_llm_client_with_unsupported_provider_override(self, mock_credential_service):
"""Test error when unsupported provider is explicitly requested"""
mock_credential_service._get_provider_api_key.return_value = "some-key"
mock_credential_service.get_credentials_by_category.return_value = {}
mock_credential_service._get_provider_base_url.return_value = None
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with pytest.raises(ValueError, match="Unsupported LLM provider: custom-unsupported"):
async with get_llm_client(provider="custom-unsupported"):
pass
@pytest.mark.asyncio
async def test_get_embedding_model_openai_success(
self, mock_credential_service, openai_provider_config
):
"""Test getting embedding model for OpenAI provider"""
mock_credential_service.get_active_provider.return_value = openai_provider_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
model = await get_embedding_model()
assert model == "text-embedding-3-small"
mock_credential_service.get_active_provider.assert_called_once_with("embedding")
@pytest.mark.asyncio
async def test_get_embedding_model_ollama_success(
self, mock_credential_service, ollama_provider_config
):
"""Test getting embedding model for Ollama provider"""
mock_credential_service.get_active_provider.return_value = ollama_provider_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
model = await get_embedding_model()
assert model == "nomic-embed-text"
@pytest.mark.asyncio
async def test_get_embedding_model_google_success(
self, mock_credential_service, google_provider_config
):
"""Test getting embedding model for Google provider"""
mock_credential_service.get_active_provider.return_value = google_provider_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
model = await get_embedding_model()
assert model == "text-embedding-004"
@pytest.mark.asyncio
async def test_get_embedding_model_with_provider_override(self, mock_credential_service):
"""Test getting embedding model with provider override"""
rag_settings = {"EMBEDDING_MODEL": "custom-embedding-model"}
mock_credential_service.get_credentials_by_category.return_value = rag_settings
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
model = await get_embedding_model(provider="custom-provider")
assert model == "custom-embedding-model"
mock_credential_service.get_credentials_by_category.assert_called_once_with(
"rag_strategy"
)
@pytest.mark.asyncio
async def test_get_embedding_model_custom_model_override(self, mock_credential_service):
"""Test custom embedding model override"""
config_with_custom = {
"provider": "openai",
"api_key": "test-key",
"base_url": None,
"chat_model": "gpt-4",
"embedding_model": "text-embedding-custom-large",
}
mock_credential_service.get_active_provider.return_value = config_with_custom
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
model = await get_embedding_model()
assert model == "text-embedding-custom-large"
@pytest.mark.asyncio
async def test_get_embedding_model_error_fallback(self, mock_credential_service):
"""Test fallback when error occurs getting embedding model"""
mock_credential_service.get_active_provider.side_effect = Exception("Database error")
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
model = await get_embedding_model()
# Should fallback to OpenAI default
assert model == "text-embedding-3-small"
def test_cache_functionality(self):
"""Test settings cache functionality"""
# Test setting and getting cache
test_value = {"test": "data"}
_set_cached_settings("test_key", test_value)
cached_result = _get_cached_settings("test_key")
assert cached_result == test_value
# Test cache expiry (would require time manipulation in real test)
# For now just test that non-existent key returns None
assert _get_cached_settings("non_existent") is None
@pytest.mark.asyncio
async def test_cache_usage_in_get_llm_client(
self, mock_credential_service, openai_provider_config
):
"""Test that cache is used to avoid repeated credential service calls"""
mock_credential_service.get_active_provider.return_value = openai_provider_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
# First call should hit the credential service
async with get_llm_client():
pass
# Second call should use cache
async with get_llm_client():
pass
# Should only call get_active_provider once due to caching
assert mock_credential_service.get_active_provider.call_count == 1
def test_deprecated_functions_removed(self):
"""Test that deprecated sync functions are no longer available"""
import src.server.services.llm_provider_service as llm_module
# These functions should no longer exist
assert not hasattr(llm_module, "get_llm_client_sync")
assert not hasattr(llm_module, "get_embedding_model_sync")
assert not hasattr(llm_module, "_get_active_provider_sync")
# The async versions should be the primary functions
assert hasattr(llm_module, "get_llm_client")
assert hasattr(llm_module, "get_embedding_model")
@pytest.mark.asyncio
async def test_context_manager_cleanup(self, mock_credential_service, openai_provider_config):
"""Test that async context manager properly handles cleanup"""
mock_credential_service.get_active_provider.return_value = openai_provider_config
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
client_ref = None
async with get_llm_client() as client:
client_ref = client
assert client == mock_client
# After context manager exits, should still have reference to client
assert client_ref == mock_client
@pytest.mark.asyncio
async def test_multiple_providers_in_sequence(self, mock_credential_service):
"""Test creating clients for different providers in sequence"""
configs = [
{"provider": "openai", "api_key": "openai-key", "base_url": None},
{"provider": "ollama", "api_key": "ollama", "base_url": "http://localhost:11434/v1"},
{
"provider": "google",
"api_key": "google-key",
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
},
]
with patch(
"src.server.services.llm_provider_service.credential_service", mock_credential_service
):
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
for config in configs:
# Clear cache between tests to force fresh credential service calls
import src.server.services.llm_provider_service as llm_module
llm_module._settings_cache.clear()
mock_credential_service.get_active_provider.return_value = config
async with get_llm_client() as client:
assert client == mock_client
# Should have been called once for each provider
assert mock_credential_service.get_active_provider.call_count == 3

View File

@@ -0,0 +1,98 @@
"""Business logic tests - Test core business rules and logic."""
def test_task_status_transitions(client):
"""Test task status update endpoint."""
# Test status update endpoint exists
response = client.patch("/api/tasks/test-id", json={"status": "doing"})
assert response.status_code in [200, 400, 404, 405, 422, 500]
def test_progress_calculation(client):
"""Test project progress endpoint."""
response = client.get("/api/projects/test-id/progress")
assert response.status_code in [200, 404, 500]
def test_rate_limiting(client):
"""Test that API handles multiple requests gracefully."""
# Make several requests
for i in range(5):
response = client.get("/api/projects")
assert response.status_code in [200, 429, 500] # 500 is OK in test environment
def test_data_validation(client):
"""Test input validation on project creation."""
# Empty title
response = client.post("/api/projects", json={"title": ""})
assert response.status_code in [400, 422]
# Missing required fields
response = client.post("/api/projects", json={})
assert response.status_code in [400, 422]
# Valid data
response = client.post("/api/projects", json={"title": "Valid Project"})
assert response.status_code in [200, 201, 422]
def test_permission_checks(client):
"""Test authentication on protected endpoints."""
# Delete without auth
response = client.delete("/api/projects/test-id")
assert response.status_code in [200, 204, 401, 403, 404, 500]
def test_crawl_depth_limits(client):
"""Test crawl depth validation."""
# Too deep
response = client.post(
"/api/knowledge/crawl", json={"url": "https://example.com", "max_depth": 100}
)
assert response.status_code in [200, 400, 404, 422]
# Valid depth
response = client.post(
"/api/knowledge/crawl", json={"url": "https://example.com", "max_depth": 2}
)
assert response.status_code in [200, 201, 400, 404, 422, 500]
def test_document_chunking(client):
"""Test document chunking endpoint."""
response = client.post(
"/api/knowledge/documents/chunk", json={"content": "x" * 1000, "chunk_size": 500}
)
assert response.status_code in [200, 400, 404, 422, 500]
def test_embedding_generation(client):
"""Test embedding generation endpoint."""
response = client.post("/api/knowledge/embeddings", json={"texts": ["Test text for embedding"]})
assert response.status_code in [200, 400, 404, 422, 500]
def test_source_management(client):
"""Test knowledge source management."""
# Create source
response = client.post(
"/api/knowledge/sources",
json={"name": "Test Source", "url": "https://example.com", "type": "documentation"},
)
assert response.status_code in [200, 201, 400, 404, 422, 500]
# List sources
response = client.get("/api/knowledge/sources")
assert response.status_code in [200, 404, 500]
def test_version_control(client):
"""Test document versioning."""
# Create version
response = client.post("/api/documents/test-id/versions", json={"content": "Version 1 content"})
assert response.status_code in [200, 201, 404, 422, 500]
# List versions
response = client.get("/api/documents/test-id/versions")
assert response.status_code in [200, 404, 500]

View File

@@ -0,0 +1,476 @@
"""
Isolated Tests for Async Crawl Orchestration Service
Tests core functionality without circular import dependencies.
"""
import asyncio
from typing import Any
from unittest.mock import MagicMock
import pytest
class MockCrawlOrchestrationService:
"""Mock version of CrawlOrchestrationService for isolated testing"""
def __init__(self, crawler=None, supabase_client=None, progress_id=None):
self.crawler = crawler
self.supabase_client = supabase_client
self.progress_id = progress_id
self.progress_state = {}
self._cancelled = False
def cancel(self):
self._cancelled = True
def is_cancelled(self) -> bool:
return self._cancelled
def _check_cancellation(self):
if self._cancelled:
raise Exception("CrawlCancelledException: Operation was cancelled")
def _is_documentation_site(self, url: str) -> bool:
"""Simple documentation site detection"""
doc_indicators = ["/docs/", "docs.", ".readthedocs.io", "/documentation/"]
return any(indicator in url.lower() for indicator in doc_indicators)
async def _create_crawl_progress_callback(self, base_status: str):
"""Create async progress callback"""
async def callback(status: str, percentage: int, message: str, **kwargs):
if self.progress_id:
self.progress_state.update({
"status": status,
"percentage": percentage,
"log": message,
})
return callback
async def _crawl_by_url_type(self, url: str, request: dict[str, Any]) -> tuple:
"""Mock URL type detection and crawling"""
# Mock different URL types
if url.endswith(".txt"):
return [{"url": url, "markdown": "Text content", "title": "Text File"}], "text_file"
elif "sitemap" in url:
return [
{"url": f"{url}/page1", "markdown": "Page 1 content", "title": "Page 1"},
{"url": f"{url}/page2", "markdown": "Page 2 content", "title": "Page 2"},
], "sitemap"
else:
return [{"url": url, "markdown": "Web content", "title": "Web Page"}], "webpage"
async def _process_and_store_documents(
self,
crawl_results: list[dict],
request: dict[str, Any],
crawl_type: str,
original_source_id: str,
) -> dict[str, Any]:
"""Mock document processing and storage"""
# Check for cancellation
self._check_cancellation()
# Simulate chunking
chunk_count = len(crawl_results) * 3 # Assume 3 chunks per document
total_word_count = chunk_count * 50 # Assume 50 words per chunk
# Build url_to_full_document mapping
url_to_full_document = {}
for doc in crawl_results:
url_to_full_document[doc["url"]] = doc.get("markdown", "")
return {
"chunk_count": chunk_count,
"total_word_count": total_word_count,
"url_to_full_document": url_to_full_document,
}
async def _extract_and_store_code_examples(
self, crawl_results: list[dict], url_to_full_document: dict[str, str]
) -> int:
"""Mock code examples extraction"""
# Count code blocks in markdown
code_examples = 0
for doc in crawl_results:
content = doc.get("markdown", "")
code_examples += content.count("```")
return code_examples // 2 # Each code block has opening and closing
async def _async_orchestrate_crawl(
self, request: dict[str, Any], task_id: str
) -> dict[str, Any]:
"""Mock async orchestration"""
try:
self._check_cancellation()
url = str(request.get("url", ""))
# Mock crawl by URL type
crawl_results, crawl_type = await self._crawl_by_url_type(url, request)
self._check_cancellation()
if not crawl_results:
raise ValueError("No content was crawled from the provided URL")
# Mock document processing
from urllib.parse import urlparse
parsed_url = urlparse(url)
source_id = parsed_url.netloc or parsed_url.path
storage_results = await self._process_and_store_documents(
crawl_results, request, crawl_type, source_id
)
self._check_cancellation()
# Mock code extraction
code_examples_count = 0
if request.get("enable_code_extraction", False):
code_examples_count = await self._extract_and_store_code_examples(
crawl_results, storage_results.get("url_to_full_document", {})
)
return {
"success": True,
"crawl_type": crawl_type,
"chunk_count": storage_results["chunk_count"],
"total_word_count": storage_results["total_word_count"],
"code_examples_stored": code_examples_count,
"processed_pages": len(crawl_results),
"total_pages": len(crawl_results),
}
except Exception as e:
error_msg = str(e)
if "CrawlCancelledException" in error_msg:
return {
"success": False,
"error": error_msg,
"cancelled": True,
"chunk_count": 0,
"code_examples_stored": 0,
}
else:
return {
"success": False,
"error": error_msg,
"cancelled": False,
"chunk_count": 0,
"code_examples_stored": 0,
}
async def orchestrate_crawl(self, request: dict[str, Any]) -> dict[str, Any]:
"""Mock main orchestration entry point"""
import uuid
task_id = str(uuid.uuid4())
# Start async orchestration task (would normally be background)
result = await self._async_orchestrate_crawl(request, task_id)
return {
"task_id": task_id,
"status": "started" if result.get("success") else "failed",
"message": f"Crawl operation for {request.get('url')}",
"progress_id": self.progress_id,
}
class TestAsyncCrawlOrchestration:
"""Test suite for async crawl orchestration behavior"""
@pytest.fixture
def orchestration_service(self):
"""Create mock orchestration service"""
return MockCrawlOrchestrationService(
crawler=MagicMock(), supabase_client=MagicMock(), progress_id="test-progress-123"
)
@pytest.fixture
def sample_request(self):
"""Sample crawl request"""
return {
"url": "https://example.com/docs",
"max_depth": 2,
"knowledge_type": "technical",
"tags": ["test"],
"enable_code_extraction": True,
}
@pytest.mark.asyncio
async def test_async_orchestrate_crawl_success(self, orchestration_service, sample_request):
"""Test successful async orchestration"""
result = await orchestration_service._async_orchestrate_crawl(sample_request, "task-123")
assert result["success"] is True
assert result["crawl_type"] == "webpage"
assert result["chunk_count"] > 0
assert result["total_word_count"] > 0
assert result["processed_pages"] == 1
@pytest.mark.asyncio
async def test_async_orchestrate_crawl_with_code_extraction(self, orchestration_service):
"""Test orchestration with code extraction enabled"""
request = {"url": "https://docs.example.com/api", "enable_code_extraction": True}
result = await orchestration_service._async_orchestrate_crawl(request, "task-456")
assert result["success"] is True
assert "code_examples_stored" in result
assert result["code_examples_stored"] >= 0
@pytest.mark.asyncio
async def test_crawl_by_url_type_text_file(self, orchestration_service):
"""Test text file URL type detection"""
crawl_results, crawl_type = await orchestration_service._crawl_by_url_type(
"https://example.com/readme.txt", {"max_depth": 1}
)
assert crawl_type == "text_file"
assert len(crawl_results) == 1
assert crawl_results[0]["url"] == "https://example.com/readme.txt"
@pytest.mark.asyncio
async def test_crawl_by_url_type_sitemap(self, orchestration_service):
"""Test sitemap URL type detection"""
crawl_results, crawl_type = await orchestration_service._crawl_by_url_type(
"https://example.com/sitemap.xml", {"max_depth": 2}
)
assert crawl_type == "sitemap"
assert len(crawl_results) == 2
@pytest.mark.asyncio
async def test_crawl_by_url_type_regular_webpage(self, orchestration_service):
"""Test regular webpage crawling"""
crawl_results, crawl_type = await orchestration_service._crawl_by_url_type(
"https://example.com/blog/post", {"max_depth": 1}
)
assert crawl_type == "webpage"
assert len(crawl_results) == 1
@pytest.mark.asyncio
async def test_process_and_store_documents(self, orchestration_service):
"""Test document processing and storage"""
crawl_results = [
{"url": "https://example.com/page1", "markdown": "Content 1", "title": "Page 1"},
{"url": "https://example.com/page2", "markdown": "Content 2", "title": "Page 2"},
]
request = {"knowledge_type": "technical", "tags": ["test"]}
result = await orchestration_service._process_and_store_documents(
crawl_results, request, "webpage", "example.com"
)
assert "chunk_count" in result
assert "total_word_count" in result
assert "url_to_full_document" in result
assert result["chunk_count"] == 6 # 2 docs * 3 chunks each
assert len(result["url_to_full_document"]) == 2
@pytest.mark.asyncio
async def test_extract_and_store_code_examples(self, orchestration_service):
"""Test code examples extraction"""
crawl_results = [
{
"url": "https://example.com/api",
"markdown": '# API\n\n```python\ndef hello():\n return "world"\n```\n\n```javascript\nconsole.log("hello");\n```',
"title": "API Docs",
}
]
url_to_full_document = {"https://example.com/api": crawl_results[0]["markdown"]}
result = await orchestration_service._extract_and_store_code_examples(
crawl_results, url_to_full_document
)
assert result == 2 # Two code blocks found
@pytest.mark.asyncio
async def test_cancellation_during_orchestration(self, orchestration_service, sample_request):
"""Test cancellation handling"""
# Cancel before starting
orchestration_service.cancel()
result = await orchestration_service._async_orchestrate_crawl(sample_request, "task-cancel")
assert result["success"] is False
assert result["cancelled"] is True
assert "error" in result
@pytest.mark.asyncio
async def test_cancellation_during_document_processing(self, orchestration_service):
"""Test cancellation during document processing"""
crawl_results = [{"url": "https://example.com", "markdown": "Content"}]
request = {"knowledge_type": "technical"}
# Cancel during processing
orchestration_service.cancel()
with pytest.raises(Exception, match="CrawlCancelledException"):
await orchestration_service._process_and_store_documents(
crawl_results, request, "webpage", "example.com"
)
@pytest.mark.asyncio
async def test_error_handling_in_orchestration(self, orchestration_service):
"""Test error handling during orchestration"""
# Override the method to raise an error
async def failing_crawl_by_url_type(url, request):
raise ValueError("Simulated crawl failure")
orchestration_service._crawl_by_url_type = failing_crawl_by_url_type
request = {"url": "https://example.com", "enable_code_extraction": False}
result = await orchestration_service._async_orchestrate_crawl(request, "task-error")
assert result["success"] is False
assert result["cancelled"] is False
assert "error" in result
def test_documentation_site_detection(self, orchestration_service):
"""Test documentation site URL detection"""
# Test documentation sites
assert orchestration_service._is_documentation_site("https://docs.python.org")
assert orchestration_service._is_documentation_site(
"https://react.dev/docs/getting-started"
)
assert orchestration_service._is_documentation_site(
"https://project.readthedocs.io/en/latest/"
)
assert orchestration_service._is_documentation_site("https://example.com/documentation/api")
# Test non-documentation sites
assert not orchestration_service._is_documentation_site("https://github.com/user/repo")
assert not orchestration_service._is_documentation_site("https://example.com/blog")
assert not orchestration_service._is_documentation_site("https://news.example.com")
def test_cancellation_functionality(self, orchestration_service):
"""Test cancellation state management"""
# Initially not cancelled
assert not orchestration_service.is_cancelled()
# Cancel and verify
orchestration_service.cancel()
assert orchestration_service.is_cancelled()
# Check cancellation raises exception
with pytest.raises(Exception, match="CrawlCancelledException"):
orchestration_service._check_cancellation()
@pytest.mark.asyncio
async def test_progress_callback_creation(self, orchestration_service):
"""Test progress callback functionality"""
callback = await orchestration_service._create_crawl_progress_callback("crawling")
# Execute callback
await callback("test_status", 50, "Test message")
# Verify progress state was updated
assert orchestration_service.progress_state["status"] == "test_status"
assert orchestration_service.progress_state["percentage"] == 50
assert orchestration_service.progress_state["log"] == "Test message"
@pytest.mark.asyncio
async def test_main_orchestrate_crawl_entry_point(self, orchestration_service, sample_request):
"""Test main orchestration entry point"""
result = await orchestration_service.orchestrate_crawl(sample_request)
assert "task_id" in result
assert "status" in result
assert "progress_id" in result
assert result["progress_id"] == "test-progress-123"
@pytest.mark.asyncio
async def test_concurrent_operations(self):
"""Test multiple concurrent orchestrations"""
service1 = MockCrawlOrchestrationService(progress_id="progress-1")
service2 = MockCrawlOrchestrationService(progress_id="progress-2")
request1 = {"url": "https://site1.com", "enable_code_extraction": False}
request2 = {"url": "https://site2.com", "enable_code_extraction": True}
# Run concurrently
results = await asyncio.gather(
service1._async_orchestrate_crawl(request1, "task-1"),
service2._async_orchestrate_crawl(request2, "task-2"),
)
assert len(results) == 2
assert all(result["success"] for result in results)
assert results[0]["code_examples_stored"] == 0 # Code extraction disabled
assert results[1]["code_examples_stored"] >= 0 # Code extraction enabled
class TestAsyncBehaviors:
"""Test async-specific behaviors and patterns"""
@pytest.mark.asyncio
async def test_async_method_chaining(self):
"""Test that async methods properly chain together"""
service = MockCrawlOrchestrationService()
# This chain should complete without blocking
crawl_results, crawl_type = await service._crawl_by_url_type(
"https://example.com", {"max_depth": 1}
)
storage_results = await service._process_and_store_documents(
crawl_results, {"knowledge_type": "technical"}, crawl_type, "example.com"
)
code_count = await service._extract_and_store_code_examples(
crawl_results, storage_results["url_to_full_document"]
)
# All operations should complete successfully
assert crawl_type == "webpage"
assert storage_results["chunk_count"] > 0
assert code_count >= 0
@pytest.mark.asyncio
async def test_asyncio_cancellation_propagation(self):
"""Test that asyncio cancellation properly propagates"""
service = MockCrawlOrchestrationService()
async def long_running_operation():
await asyncio.sleep(0.1) # Simulate work
return await service._async_orchestrate_crawl(
{"url": "https://example.com"}, "task-123"
)
# Start task and cancel it
task = asyncio.create_task(long_running_operation())
await asyncio.sleep(0.01) # Let it start
task.cancel()
# Should raise CancelledError
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_no_blocking_operations(self):
"""Test that operations don't block the event loop"""
service = MockCrawlOrchestrationService()
# Start multiple operations concurrently
tasks = []
for i in range(5):
task = service._async_orchestrate_crawl({"url": f"https://example{i}.com"}, f"task-{i}")
tasks.append(task)
# All should complete without blocking
results = await asyncio.gather(*tasks)
assert len(results) == 5
assert all(result["success"] for result in results)

View File

@@ -0,0 +1,332 @@
"""
Tests for embedding service to ensure no zero embeddings are returned.
These tests verify that the embedding service raises appropriate exceptions
instead of returning zero embeddings, following the "fail fast and loud" principle.
"""
from unittest.mock import AsyncMock, Mock, patch
import openai
import pytest
from src.server.services.embeddings.embedding_exceptions import (
EmbeddingAPIError,
EmbeddingQuotaExhaustedError,
EmbeddingRateLimitError,
)
from src.server.services.embeddings.embedding_service import (
EmbeddingBatchResult,
create_embedding,
create_embeddings_batch,
)
class TestNoZeroEmbeddings:
"""Test that no zero embeddings are ever returned."""
# Note: Removed test_sync_from_async_context_raises_exception
# as sync versions no longer exist - everything is async-only now
@pytest.mark.asyncio
async def test_async_quota_exhausted_returns_failure(self) -> None:
"""Test that quota exhaustion returns failure result instead of zeros."""
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_client:
# Mock the client to raise quota error
mock_ctx = AsyncMock()
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
"insufficient_quota: You have exceeded your quota", response=Mock(), body=None
)
mock_client.return_value = mock_ctx
# Single embedding still raises for backward compatibility
with pytest.raises(EmbeddingQuotaExhaustedError) as exc_info:
await create_embedding("test text")
assert "quota exhausted" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_async_rate_limit_raises_exception(self) -> None:
"""Test that rate limit errors raise exception after retries."""
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_client:
# Mock the client to raise rate limit error
mock_ctx = AsyncMock()
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
"rate_limit_exceeded: Too many requests", response=Mock(), body=None
)
mock_client.return_value = mock_ctx
with pytest.raises(EmbeddingRateLimitError) as exc_info:
await create_embedding("test text")
assert "rate limit" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_async_api_error_raises_exception(self) -> None:
"""Test that API errors raise exception instead of returning zeros."""
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_client:
# Mock the client to raise generic error
mock_ctx = AsyncMock()
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = Exception(
"Network error"
)
mock_client.return_value = mock_ctx
with pytest.raises(EmbeddingAPIError) as exc_info:
await create_embedding("test text")
assert "failed to create embedding" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_batch_handles_partial_failures(self) -> None:
"""Test that batch processing can handle partial failures gracefully."""
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_client:
# Mock successful response for first batch, failure for second
mock_ctx = AsyncMock()
mock_response = Mock()
mock_response.data = [Mock(embedding=[0.1] * 1536), Mock(embedding=[0.2] * 1536)]
# First call succeeds, second fails
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = [
mock_response,
Exception("API Error"),
]
mock_client.return_value = mock_ctx
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
new_callable=AsyncMock,
return_value="text-embedding-ada-002",
):
# Mock credential service to return batch size of 2
with patch(
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
new_callable=AsyncMock,
return_value={"EMBEDDING_BATCH_SIZE": "2"},
):
# Process 4 texts (batch size will be 2)
texts = ["text1", "text2", "text3", "text4"]
result = await create_embeddings_batch(texts)
# Check result structure
assert isinstance(result, EmbeddingBatchResult)
assert result.success_count == 2 # First batch succeeded
assert result.failure_count == 2 # Second batch failed
assert len(result.embeddings) == 2
assert len(result.failed_items) == 2
# Verify no zero embeddings were created
for embedding in result.embeddings:
assert not all(v == 0.0 for v in embedding)
@pytest.mark.asyncio
async def test_configurable_embedding_dimensions(self) -> None:
"""Test that embedding dimensions can be configured via settings."""
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_client:
# Mock successful response
mock_ctx = AsyncMock()
mock_create = AsyncMock()
mock_ctx.__aenter__.return_value.embeddings.create = mock_create
# Setup mock response
mock_response = Mock()
mock_response.data = [Mock(embedding=[0.1] * 3072)] # Different dimensions
mock_create.return_value = mock_response
mock_client.return_value = mock_ctx
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
new_callable=AsyncMock,
return_value="text-embedding-3-large",
):
# Mock credential service to return custom dimensions
with patch(
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
new_callable=AsyncMock,
return_value={"EMBEDDING_DIMENSIONS": "3072"},
):
result = await create_embeddings_batch(["test text"])
# Verify the dimensions parameter was passed correctly
mock_create.assert_called_once()
call_args = mock_create.call_args
assert call_args.kwargs["dimensions"] == 3072
# Verify result
assert result.success_count == 1
assert len(result.embeddings[0]) == 3072
@pytest.mark.asyncio
async def test_default_embedding_dimensions(self) -> None:
"""Test that default dimensions (1536) are used when not configured."""
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_client:
# Mock successful response
mock_ctx = AsyncMock()
mock_create = AsyncMock()
mock_ctx.__aenter__.return_value.embeddings.create = mock_create
# Setup mock response with default dimensions
mock_response = Mock()
mock_response.data = [Mock(embedding=[0.1] * 1536)]
mock_create.return_value = mock_response
mock_client.return_value = mock_ctx
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
new_callable=AsyncMock,
return_value="text-embedding-3-small",
):
# Mock credential service to return empty settings (no dimensions specified)
with patch(
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
new_callable=AsyncMock,
return_value={},
):
result = await create_embeddings_batch(["test text"])
# Verify the default dimensions parameter was used
mock_create.assert_called_once()
call_args = mock_create.call_args
assert call_args.kwargs["dimensions"] == 1536
# Verify result
assert result.success_count == 1
assert len(result.embeddings[0]) == 1536
@pytest.mark.asyncio
async def test_batch_quota_exhausted_stops_process(self) -> None:
"""Test that quota exhaustion stops processing remaining batches."""
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_client:
# Mock quota exhaustion
mock_ctx = AsyncMock()
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
"insufficient_quota: Quota exceeded", response=Mock(), body=None
)
mock_client.return_value = mock_ctx
with patch(
"src.server.services.embeddings.embedding_service.get_embedding_model",
new_callable=AsyncMock,
return_value="text-embedding-ada-002",
):
texts = ["text1", "text2", "text3", "text4"]
result = await create_embeddings_batch(texts)
# All should fail due to quota
assert result.success_count == 0
assert result.failure_count == 4
assert len(result.embeddings) == 0
assert all("quota" in item["error"].lower() for item in result.failed_items)
@pytest.mark.asyncio
async def test_no_zero_vectors_in_results(self) -> None:
"""Test that no function ever returns a zero vector [0.0] * 1536."""
# This is a meta-test to ensure our implementation never creates zero vectors
# Helper to check if a value is a zero embedding
def is_zero_embedding(value):
if not isinstance(value, list):
return False
if len(value) != 1536:
return False
return all(v == 0.0 for v in value)
# Test data that should never produce zero embeddings
test_text = "This is a test"
# Test: Batch function with error should return failure result, not zeros
with patch(
"src.server.services.embeddings.embedding_service.get_llm_client"
) as mock_client:
# Mock the client to raise an error
mock_ctx = AsyncMock()
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = Exception("Test error")
mock_client.return_value = mock_ctx
result = await create_embeddings_batch([test_text])
# Should return result with failures, not zeros
assert isinstance(result, EmbeddingBatchResult)
assert len(result.embeddings) == 0
assert result.failure_count == 1
# Verify no zero embeddings in the result
for embedding in result.embeddings:
assert not is_zero_embedding(embedding)
class TestEmbeddingBatchResult:
"""Test the EmbeddingBatchResult dataclass."""
def test_batch_result_initialization(self) -> None:
"""Test that EmbeddingBatchResult initializes correctly."""
result = EmbeddingBatchResult()
assert result.success_count == 0
assert result.failure_count == 0
assert result.embeddings == []
assert result.failed_items == []
assert not result.has_failures
def test_batch_result_add_success(self) -> None:
"""Test adding successful embeddings."""
result = EmbeddingBatchResult()
embedding = [0.1] * 1536
text = "test text"
result.add_success(embedding, text)
assert result.success_count == 1
assert result.failure_count == 0
assert len(result.embeddings) == 1
assert result.embeddings[0] == embedding
assert result.texts_processed[0] == text
assert not result.has_failures
def test_batch_result_add_failure(self) -> None:
"""Test adding failed items."""
result = EmbeddingBatchResult()
error = EmbeddingAPIError("Test error", text_preview="test")
result.add_failure("test text", error, batch_index=0)
assert result.success_count == 0
assert result.failure_count == 1
assert len(result.failed_items) == 1
assert result.has_failures
failed_item = result.failed_items[0]
assert failed_item["error"] == "Test error"
assert failed_item["error_type"] == "EmbeddingAPIError"
# batch_index comes from the error's to_dict() method which includes it
assert "batch_index" in failed_item # Just check it exists
def test_batch_result_mixed_results(self) -> None:
"""Test batch result with both successes and failures."""
result = EmbeddingBatchResult()
# Add successes
result.add_success([0.1] * 1536, "text1")
result.add_success([0.2] * 1536, "text2")
# Add failures
result.add_failure("text3", Exception("Error 1"), 1)
result.add_failure("text4", Exception("Error 2"), 1)
assert result.success_count == 2
assert result.failure_count == 2
assert result.total_requested == 4
assert result.has_failures
assert len(result.embeddings) == 2
assert len(result.failed_items) == 2

View File

@@ -0,0 +1,213 @@
"""
Tests for keyword extraction and improved hybrid search
"""
import pytest
from src.server.services.search.keyword_extractor import (
KeywordExtractor,
build_search_terms,
extract_keywords,
)
class TestKeywordExtractor:
"""Test keyword extraction functionality"""
@pytest.fixture
def extractor(self):
return KeywordExtractor()
def test_simple_keyword_extraction(self, extractor):
"""Test extraction from simple queries"""
query = "Supabase authentication"
keywords = extractor.extract_keywords(query)
assert "supabase" in keywords
assert "authentication" in keywords
assert len(keywords) >= 2
def test_complex_query_extraction(self, extractor):
"""Test extraction from complex queries"""
query = "Supabase auth flow best practices"
keywords = extractor.extract_keywords(query)
assert "supabase" in keywords
assert "auth" in keywords
assert "flow" in keywords
assert "best_practices" in keywords or "practices" in keywords
def test_stop_word_filtering(self, extractor):
"""Test that stop words are filtered out"""
query = "How to use the React component with the database"
keywords = extractor.extract_keywords(query)
# Stop words should be filtered
assert "how" not in keywords
assert "to" not in keywords
assert "the" not in keywords
assert "with" not in keywords
# Technical terms should remain
assert "react" in keywords
assert "component" in keywords
assert "database" in keywords
def test_technical_terms_preserved(self, extractor):
"""Test that technical terms are preserved"""
query = "PostgreSQL full-text search with Python API"
keywords = extractor.extract_keywords(query)
assert "postgresql" in keywords or "postgres" in keywords
assert "python" in keywords
assert "api" in keywords
def test_compound_terms(self, extractor):
"""Test compound term detection"""
query = "best practices for real-time websocket connections"
keywords = extractor.extract_keywords(query)
# Should detect compound terms
assert "best_practices" in keywords
assert "realtime" in keywords or "real-time" in keywords
assert "websocket" in keywords
def test_empty_query(self, extractor):
"""Test handling of empty query"""
keywords = extractor.extract_keywords("")
assert keywords == []
def test_query_with_only_stopwords(self, extractor):
"""Test query with only stop words"""
query = "the and with for in"
keywords = extractor.extract_keywords(query)
assert keywords == []
def test_keyword_prioritization(self, extractor):
"""Test that keywords are properly prioritized"""
query = "Python Python Django REST API framework Python"
keywords = extractor.extract_keywords(query)
# Python appears 3 times, should be prioritized
assert keywords[0] == "python"
# Technical terms should be high priority
assert "django" in keywords[:3]
assert "api" in keywords[:5] # API should be in top 5
def test_max_keywords_limit(self, extractor):
"""Test that max_keywords parameter is respected"""
query = "Python Django Flask FastAPI React Vue Angular TypeScript JavaScript HTML CSS"
keywords = extractor.extract_keywords(query, max_keywords=5)
assert len(keywords) <= 5
# Most important terms should be included
assert "python" in keywords
assert "django" in keywords
def test_min_length_filtering(self, extractor):
"""Test minimum length filtering"""
query = "a b c API JWT DB SQL"
keywords = extractor.extract_keywords(query, min_length=3)
# Single letters should be filtered
assert "a" not in keywords
assert "b" not in keywords
assert "c" not in keywords
# 3+ letter terms should remain
assert "api" in keywords
assert "jwt" in keywords
assert "sql" in keywords
class TestSearchTermBuilder:
"""Test search term building with variations"""
def test_plural_variations(self):
"""Test plural/singular variations"""
keywords = ["functions", "class", "error"]
terms = build_search_terms(keywords)
# Should include singular of "functions"
assert "function" in terms
# Should include plural of "class"
assert "classes" in terms
# Should include plural of "error"
assert "errors" in terms
def test_verb_variations(self):
"""Test verb form variations"""
keywords = ["creating", "updated", "testing"]
terms = build_search_terms(keywords)
# Should generate base forms
assert "create" in terms or "creat" in terms
assert "update" in terms or "updat" in terms
assert "test" in terms
def test_no_duplicates(self):
"""Test that duplicates are removed"""
keywords = ["test", "tests", "testing"]
terms = build_search_terms(keywords)
# Should have unique terms only
assert len(terms) == len(set(terms))
class TestIntegration:
"""Integration tests for keyword extraction in search context"""
def test_real_world_query_1(self):
"""Test with real-world query example 1"""
query = "How to implement JWT authentication in FastAPI with Supabase"
keywords = extract_keywords(query)
# Should extract the key technical terms
assert "jwt" in keywords
assert "authentication" in keywords
assert "fastapi" in keywords
assert "supabase" in keywords
# Should not include generic words (implement is now filtered as technical stop word)
assert "how" not in keywords
assert "to" not in keywords
def test_real_world_query_2(self):
"""Test with real-world query example 2"""
query = "PostgreSQL full text search vs Elasticsearch performance comparison"
keywords = extract_keywords(query)
assert "postgresql" in keywords or "postgres" in keywords
assert "elasticsearch" in keywords
assert "performance" in keywords
assert "comparison" in keywords
# Should handle "full text" as compound or separate
assert "fulltext" in keywords or ("full" in keywords and "text" in keywords)
def test_real_world_query_3(self):
"""Test with real-world query example 3"""
query = "debugging async await issues in Node.js Express middleware"
keywords = extract_keywords(query)
assert "debugging" in keywords or "debug" in keywords
assert "async" in keywords
assert "await" in keywords
assert "express" in keywords
assert "middleware" in keywords
# Node.js might be split
assert "nodejs" in keywords or "node" in keywords
def test_code_related_query(self):
"""Test with code-related query"""
query = "TypeError cannot read property undefined JavaScript React hooks"
keywords = extract_keywords(query)
assert "typeerror" in keywords or "type" in keywords
assert "property" in keywords
assert "undefined" in keywords
assert "javascript" in keywords
assert "react" in keywords
assert "hooks" in keywords

View File

@@ -0,0 +1,216 @@
"""
Tests for port configuration requirements.
This test file verifies that all services properly require environment variables
for port configuration and fail with clear error messages when not set.
"""
import os
import pytest
class TestPortConfiguration:
"""Test that services require port environment variables."""
def setup_method(self):
"""Save original environment variables before each test."""
self.original_env = os.environ.copy()
def teardown_method(self):
"""Restore original environment variables after each test."""
os.environ.clear()
os.environ.update(self.original_env)
def test_service_discovery_requires_all_ports(self):
"""Test that ServiceDiscovery requires all port environment variables."""
# Clear port environment variables
for key in ["ARCHON_SERVER_PORT", "ARCHON_MCP_PORT", "ARCHON_AGENTS_PORT"]:
os.environ.pop(key, None)
# Import should fail without environment variables
with pytest.raises(ValueError, match="ARCHON_SERVER_PORT environment variable is required"):
from src.server.config.service_discovery import ServiceDiscovery
ServiceDiscovery()
def test_service_discovery_requires_mcp_port(self):
"""Test that ServiceDiscovery requires MCP port."""
os.environ["ARCHON_SERVER_PORT"] = "8181"
os.environ.pop("ARCHON_MCP_PORT", None)
os.environ["ARCHON_AGENTS_PORT"] = "8052"
with pytest.raises(ValueError, match="ARCHON_MCP_PORT environment variable is required"):
from src.server.config.service_discovery import ServiceDiscovery
ServiceDiscovery()
def test_service_discovery_requires_agents_port(self):
"""Test that ServiceDiscovery requires agents port."""
os.environ["ARCHON_SERVER_PORT"] = "8181"
os.environ["ARCHON_MCP_PORT"] = "8051"
os.environ.pop("ARCHON_AGENTS_PORT", None)
with pytest.raises(ValueError, match="ARCHON_AGENTS_PORT environment variable is required"):
from src.server.config.service_discovery import ServiceDiscovery
ServiceDiscovery()
def test_service_discovery_with_all_ports(self):
"""Test that ServiceDiscovery works with all ports set."""
os.environ["ARCHON_SERVER_PORT"] = "9191"
os.environ["ARCHON_MCP_PORT"] = "9051"
os.environ["ARCHON_AGENTS_PORT"] = "9052"
from src.server.config.service_discovery import ServiceDiscovery
sd = ServiceDiscovery()
assert sd.DEFAULT_PORTS["api"] == 9191
assert sd.DEFAULT_PORTS["mcp"] == 9051
assert sd.DEFAULT_PORTS["agents"] == 9052
def test_mcp_server_requires_port(self):
"""Test that MCP server requires ARCHON_MCP_PORT."""
os.environ.pop("ARCHON_MCP_PORT", None)
# We can't directly import mcp_server.py as it will raise at module level
# So we test the specific logic
with pytest.raises(ValueError, match="ARCHON_MCP_PORT environment variable is required"):
mcp_port = os.getenv("ARCHON_MCP_PORT")
if not mcp_port:
raise ValueError(
"ARCHON_MCP_PORT environment variable is required. "
"Please set it in your .env file or environment. "
"Default value: 8051"
)
def test_main_server_requires_port(self):
"""Test that main server requires ARCHON_SERVER_PORT when run directly."""
os.environ.pop("ARCHON_SERVER_PORT", None)
# Test the logic that would be in main.py
with pytest.raises(ValueError, match="ARCHON_SERVER_PORT environment variable is required"):
server_port = os.getenv("ARCHON_SERVER_PORT")
if not server_port:
raise ValueError(
"ARCHON_SERVER_PORT environment variable is required. "
"Please set it in your .env file or environment. "
"Default value: 8181"
)
def test_agents_server_requires_port(self):
"""Test that agents server requires ARCHON_AGENTS_PORT."""
os.environ.pop("ARCHON_AGENTS_PORT", None)
# Test the logic that would be in agents/server.py
with pytest.raises(ValueError, match="ARCHON_AGENTS_PORT environment variable is required"):
agents_port = os.getenv("ARCHON_AGENTS_PORT")
if not agents_port:
raise ValueError(
"ARCHON_AGENTS_PORT environment variable is required. "
"Please set it in your .env file or environment. "
"Default value: 8052"
)
def test_agent_chat_api_requires_agents_port(self):
"""Test that agent_chat_api requires ARCHON_AGENTS_PORT for service calls."""
os.environ.pop("ARCHON_AGENTS_PORT", None)
# Test the logic that would be in agent_chat_api
with pytest.raises(ValueError, match="ARCHON_AGENTS_PORT environment variable is required"):
agents_port = os.getenv("ARCHON_AGENTS_PORT")
if not agents_port:
raise ValueError(
"ARCHON_AGENTS_PORT environment variable is required. "
"Please set it in your .env file or environment."
)
def test_config_requires_port_or_archon_mcp_port(self):
"""Test that config.py requires PORT or ARCHON_MCP_PORT."""
from src.server.config.config import ConfigurationError
os.environ.pop("PORT", None)
os.environ.pop("ARCHON_MCP_PORT", None)
# Test the logic from config.py
with pytest.raises(
ConfigurationError, match="PORT or ARCHON_MCP_PORT environment variable is required"
):
port_str = os.getenv("PORT")
if not port_str:
port_str = os.getenv("ARCHON_MCP_PORT")
if not port_str:
raise ConfigurationError(
"PORT or ARCHON_MCP_PORT environment variable is required. "
"Please set it in your .env file or environment. "
"Default value: 8051"
)
def test_custom_port_values(self):
"""Test that services use custom port values when set."""
# Set custom ports
os.environ["ARCHON_SERVER_PORT"] = "9999"
os.environ["ARCHON_MCP_PORT"] = "8888"
os.environ["ARCHON_AGENTS_PORT"] = "7777"
from src.server.config.service_discovery import ServiceDiscovery
sd = ServiceDiscovery()
# Verify custom ports are used
assert sd.DEFAULT_PORTS["api"] == 9999
assert sd.DEFAULT_PORTS["mcp"] == 8888
assert sd.DEFAULT_PORTS["agents"] == 7777
# Verify service URLs use custom ports
if not sd.is_docker:
assert sd.get_service_url("api") == "http://localhost:9999"
assert sd.get_service_url("mcp") == "http://localhost:8888"
assert sd.get_service_url("agents") == "http://localhost:7777"
class TestPortValidation:
"""Test port validation logic."""
def test_invalid_port_values(self):
"""Test that invalid port values are rejected."""
os.environ["ARCHON_SERVER_PORT"] = "not-a-number"
os.environ["ARCHON_MCP_PORT"] = "8051"
os.environ["ARCHON_AGENTS_PORT"] = "8052"
with pytest.raises(ValueError):
from src.server.config.service_discovery import ServiceDiscovery
ServiceDiscovery()
def test_port_out_of_range(self):
"""Test that port values must be valid port numbers."""
test_cases = [
("0", False), # Port 0 is reserved
("1", True), # Valid
("65535", True), # Maximum valid port
("65536", False), # Too high
("-1", False), # Negative
]
for port_value, should_succeed in test_cases:
os.environ["ARCHON_SERVER_PORT"] = port_value
os.environ["ARCHON_MCP_PORT"] = "8051"
os.environ["ARCHON_AGENTS_PORT"] = "8052"
if should_succeed:
# Should not raise
from src.server.config.service_discovery import ServiceDiscovery
sd = ServiceDiscovery()
assert sd.DEFAULT_PORTS["api"] == int(port_value)
else:
# Should raise for invalid ports
with pytest.raises((ValueError, AssertionError)):
from src.server.config.service_discovery import ServiceDiscovery
sd = ServiceDiscovery()
# Additional validation might be needed
port = int(port_value)
assert 1 <= port <= 65535, f"Port {port} out of valid range"

View File

@@ -0,0 +1,436 @@
"""
Simple, Fast RAG Tests
Focused tests that avoid complex initialization and database calls.
These tests verify the core RAG functionality without heavy dependencies.
"""
import os
from unittest.mock import MagicMock, patch
import pytest
# Set test environment variables
os.environ.update({
"SUPABASE_URL": "http://test.supabase.co",
"SUPABASE_SERVICE_KEY": "test_key",
"OPENAI_API_KEY": "test_openai_key",
"USE_HYBRID_SEARCH": "false",
"USE_RERANKING": "false",
"USE_AGENTIC_RAG": "false",
})
@pytest.fixture
def mock_supabase():
"""Mock supabase client"""
client = MagicMock()
client.rpc.return_value.execute.return_value.data = []
client.from_.return_value.select.return_value.limit.return_value.execute.return_value.data = []
return client
@pytest.fixture
def rag_service(mock_supabase):
"""Create RAGService with mocked dependencies"""
with patch("src.server.utils.get_supabase_client", return_value=mock_supabase):
with patch("src.server.services.credential_service.credential_service"):
from src.server.services.search.rag_service import RAGService
service = RAGService(supabase_client=mock_supabase)
return service
class TestRAGServiceCore:
"""Core RAGService functionality tests"""
def test_initialization(self, rag_service):
"""Test RAGService initializes correctly"""
assert rag_service is not None
assert hasattr(rag_service, "search_documents")
assert hasattr(rag_service, "search_code_examples")
assert hasattr(rag_service, "perform_rag_query")
def test_settings_methods(self, rag_service):
"""Test settings retrieval methods"""
# Test string setting
result = rag_service.get_setting("TEST_SETTING", "default")
assert isinstance(result, str)
# Test boolean setting
result = rag_service.get_bool_setting("TEST_BOOL", False)
assert isinstance(result, bool)
class TestRAGServiceSearch:
"""Search functionality tests"""
@pytest.mark.asyncio
async def test_basic_vector_search(self, rag_service, mock_supabase):
"""Test basic vector search functionality"""
# Mock the RPC response
mock_response = MagicMock()
mock_response.data = [
{
"id": "1",
"content": "Test content",
"similarity": 0.8,
"metadata": {},
"url": "test.com",
}
]
mock_supabase.rpc.return_value.execute.return_value = mock_response
# Test the search
query_embedding = [0.1] * 1536
results = await rag_service.base_strategy.vector_search(
query_embedding=query_embedding, match_count=5
)
assert isinstance(results, list)
assert len(results) == 1
assert results[0]["content"] == "Test content"
# Verify RPC was called correctly
mock_supabase.rpc.assert_called_once()
call_args = mock_supabase.rpc.call_args[0]
assert call_args[0] == "match_archon_crawled_pages"
@pytest.mark.asyncio
async def test_search_documents_with_embedding(self, rag_service):
"""Test document search with mocked embedding"""
# Patch at the module level where it's called from RAGService
with (
patch("src.server.services.search.rag_service.create_embedding") as mock_embed,
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
):
# Setup mocks
mock_embed.return_value = [0.1] * 1536
mock_search.return_value = [{"content": "Test result", "similarity": 0.9}]
# Test search
results = await rag_service.search_documents(query="test query", match_count=5)
assert isinstance(results, list)
assert len(results) == 1
mock_embed.assert_called_once_with("test query")
mock_search.assert_called_once()
@pytest.mark.asyncio
async def test_perform_rag_query_basic(self, rag_service):
"""Test complete RAG query pipeline"""
with patch.object(rag_service, "search_documents") as mock_search:
mock_search.return_value = [
{"id": "1", "content": "Test content", "similarity": 0.8, "metadata": {}}
]
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
assert success is True
assert "results" in result
assert len(result["results"]) == 1
assert result["results"][0]["content"] == "Test content"
assert result["query"] == "test query"
@pytest.mark.asyncio
async def test_search_code_examples_delegation(self, rag_service):
"""Test code examples search delegates to agentic strategy"""
with patch.object(rag_service.agentic_strategy, "search_code_examples") as mock_agentic:
mock_agentic.return_value = [
{"content": "def test(): pass", "summary": "Test function", "url": "test.py"}
]
results = await rag_service.search_code_examples(query="test function", match_count=10)
assert isinstance(results, list)
mock_agentic.assert_called_once()
class TestHybridSearchCore:
"""Basic hybrid search tests"""
@pytest.fixture
def hybrid_strategy(self, mock_supabase):
"""Create hybrid search strategy"""
from src.server.services.search.base_search_strategy import BaseSearchStrategy
from src.server.services.search.hybrid_search_strategy import HybridSearchStrategy
base_strategy = BaseSearchStrategy(mock_supabase)
return HybridSearchStrategy(mock_supabase, base_strategy)
def test_initialization(self, hybrid_strategy):
"""Test hybrid strategy initializes"""
assert hybrid_strategy is not None
assert hasattr(hybrid_strategy, "search_documents_hybrid")
assert hasattr(hybrid_strategy, "_merge_search_results")
def test_merge_results_functionality(self, hybrid_strategy):
"""Test result merging logic"""
vector_results = [
{
"id": "1",
"content": "Vector result",
"similarity": 0.9,
"url": "test1.com",
"chunk_number": 1,
"metadata": {},
"source_id": "src1",
}
]
keyword_results = [
{
"id": "2",
"content": "Keyword result",
"url": "test2.com",
"chunk_number": 1,
"metadata": {},
"source_id": "src2",
}
]
merged = hybrid_strategy._merge_search_results(
vector_results, keyword_results, match_count=5
)
assert isinstance(merged, list)
assert len(merged) <= 5
class TestRerankingCore:
"""Basic reranking tests"""
@pytest.fixture
def reranking_strategy(self):
"""Create reranking strategy"""
from src.server.services.search.reranking_strategy import RerankingStrategy
return RerankingStrategy()
def test_initialization(self, reranking_strategy):
"""Test reranking strategy initializes"""
assert reranking_strategy is not None
assert hasattr(reranking_strategy, "rerank_results")
assert hasattr(reranking_strategy, "is_available")
def test_availability_check(self, reranking_strategy):
"""Test model availability checking"""
availability = reranking_strategy.is_available()
assert isinstance(availability, bool)
@pytest.mark.asyncio
async def test_rerank_with_no_model(self, reranking_strategy):
"""Test reranking when no model is available"""
# Force model to be None
reranking_strategy.model = None
original_results = [{"content": "Test content", "score": 0.8}]
result = await reranking_strategy.rerank_results(
query="test query", results=original_results
)
# Should return original results when no model
assert result == original_results
@pytest.mark.asyncio
async def test_rerank_with_mock_model(self, reranking_strategy):
"""Test reranking with a mocked model"""
# Create a mock model
mock_model = MagicMock()
mock_model.predict.return_value = [0.95, 0.85, 0.75] # Mock rerank scores
reranking_strategy.model = mock_model
original_results = [
{"content": "Content 1", "similarity": 0.8},
{"content": "Content 2", "similarity": 0.7},
{"content": "Content 3", "similarity": 0.9},
]
result = await reranking_strategy.rerank_results(
query="test query", results=original_results
)
# Should return reranked results
assert isinstance(result, list)
assert len(result) == 3
# Results should be sorted by rerank_score
scores = [r.get("rerank_score", 0) for r in result]
assert scores == sorted(scores, reverse=True)
# Highest rerank score should be first
assert result[0]["rerank_score"] == 0.95
class TestAgenticRAGCore:
"""Basic agentic RAG tests"""
@pytest.fixture
def agentic_strategy(self, mock_supabase):
"""Create agentic RAG strategy"""
from src.server.services.search.agentic_rag_strategy import AgenticRAGStrategy
from src.server.services.search.base_search_strategy import BaseSearchStrategy
base_strategy = BaseSearchStrategy(mock_supabase)
return AgenticRAGStrategy(mock_supabase, base_strategy)
def test_initialization(self, agentic_strategy):
"""Test agentic strategy initializes"""
assert agentic_strategy is not None
assert hasattr(agentic_strategy, "search_code_examples")
assert hasattr(agentic_strategy, "is_enabled")
def test_query_enhancement(self, agentic_strategy):
"""Test code query enhancement"""
original_query = "python function"
analysis = agentic_strategy.analyze_code_query(original_query)
assert isinstance(analysis, dict)
assert "is_code_query" in analysis
assert "confidence" in analysis
assert "languages" in analysis
assert analysis["is_code_query"] is True
assert "python" in analysis["languages"]
class TestRAGIntegrationSimple:
"""Simple integration tests"""
@pytest.mark.asyncio
async def test_error_handling(self, rag_service):
"""Test error handling in RAG pipeline"""
with patch.object(rag_service, "search_documents") as mock_search:
# Simulate an error
mock_search.side_effect = Exception("Test error")
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
assert success is False
assert "error" in result
assert result["error"] == "Test error"
@pytest.mark.asyncio
async def test_empty_results_handling(self, rag_service):
"""Test handling of empty search results"""
with patch.object(rag_service, "search_documents") as mock_search:
mock_search.return_value = []
success, result = await rag_service.perform_rag_query(
query="empty query", match_count=5
)
assert success is True
assert "results" in result
assert len(result["results"]) == 0
@pytest.mark.asyncio
async def test_full_rag_pipeline_with_reranking(self, rag_service, mock_supabase):
"""Test complete RAG pipeline with reranking enabled"""
# Create a mock reranking model
mock_model = MagicMock()
mock_model.predict.return_value = [0.95, 0.85, 0.75]
# Initialize RAG service with reranking
from src.server.services.search.reranking_strategy import RerankingStrategy
reranking_strategy = RerankingStrategy()
reranking_strategy.model = mock_model
rag_service.reranking_strategy = reranking_strategy
with (
patch.object(rag_service, "search_documents") as mock_search,
patch.object(rag_service, "get_bool_setting") as mock_settings,
):
# Enable reranking
mock_settings.return_value = True
# Mock search results
mock_search.return_value = [
{"id": "1", "content": "Result 1", "similarity": 0.8, "metadata": {}},
{"id": "2", "content": "Result 2", "similarity": 0.7, "metadata": {}},
{"id": "3", "content": "Result 3", "similarity": 0.9, "metadata": {}},
]
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
assert success is True
assert "results" in result
assert len(result["results"]) == 3
# Verify reranking was applied
assert result["reranking_applied"] is True
# Results should be sorted by rerank_score
results = result["results"]
rerank_scores = [r.get("rerank_score", 0) for r in results]
assert rerank_scores == sorted(rerank_scores, reverse=True)
@pytest.mark.asyncio
async def test_hybrid_search_integration(self, rag_service):
"""Test RAG with hybrid search enabled"""
with (
patch("src.server.services.search.rag_service.create_embedding") as mock_embed,
patch.object(rag_service.hybrid_strategy, "search_documents_hybrid") as mock_hybrid,
patch.object(rag_service, "get_bool_setting") as mock_settings,
):
# Mock embedding and enable hybrid search
mock_embed.return_value = [0.1] * 1536
mock_settings.return_value = True
# Mock hybrid search results
mock_hybrid.return_value = [
{
"id": "1",
"content": "Hybrid result",
"similarity": 0.9,
"metadata": {},
"match_type": "hybrid",
}
]
results = await rag_service.search_documents(
query="test query", use_hybrid_search=True, match_count=5
)
assert isinstance(results, list)
assert len(results) == 1
assert results[0]["content"] == "Hybrid result"
mock_hybrid.assert_called_once()
@pytest.mark.asyncio
async def test_code_search_with_agentic_rag(self, rag_service):
"""Test code search using agentic RAG"""
with (
patch.object(rag_service.agentic_strategy, "is_enabled") as mock_enabled,
patch.object(rag_service.agentic_strategy, "search_code_examples") as mock_agentic,
patch.object(rag_service, "get_bool_setting") as mock_settings,
):
# Enable agentic RAG
mock_enabled.return_value = True
mock_settings.return_value = False # Disable hybrid search for this test
# Mock agentic search results
mock_agentic.return_value = [
{
"content": 'def example_function():\\n return "Hello"',
"summary": "Example function that returns greeting",
"url": "example.py",
"metadata": {"language": "python"},
}
]
success, result = await rag_service.search_code_examples_service(
query="python greeting function", match_count=10
)
assert success is True
assert "results" in result
assert len(result["results"]) == 1
code_result = result["results"][0]
assert "def example_function" in code_result["code"]
assert code_result["summary"] == "Example function that returns greeting"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,524 @@
"""
Tests for RAG Strategies and Search Functionality
Tests RAGService class, hybrid search, agentic RAG, reranking, and other advanced RAG features.
Updated to match current async-only architecture.
"""
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Mock problematic imports at module level
with patch.dict(
os.environ,
{
"SUPABASE_URL": "http://test.supabase.co",
"SUPABASE_SERVICE_KEY": "test_key",
"OPENAI_API_KEY": "test_openai_key",
},
):
# Mock credential service to prevent database calls
with patch("src.server.services.credential_service.credential_service") as mock_cred:
mock_cred._cache_initialized = False
mock_cred.get_setting.return_value = "false"
mock_cred.get_bool_setting.return_value = False
# Mock supabase client creation
with patch("src.server.utils.get_supabase_client") as mock_supabase:
mock_client = MagicMock()
mock_supabase.return_value = mock_client
# Mock embedding service to prevent API calls
with patch(
"src.server.services.embeddings.embedding_service.create_embedding"
) as mock_embed:
mock_embed.return_value = [0.1] * 1536
# Test RAGService core functionality
class TestRAGService:
"""Test core RAGService functionality"""
@pytest.fixture
def mock_supabase_client(self):
"""Mock Supabase client"""
return MagicMock()
@pytest.fixture
def rag_service(self, mock_supabase_client):
"""Create RAGService instance"""
from src.server.services.search import RAGService
return RAGService(supabase_client=mock_supabase_client)
def test_rag_service_initialization(self, rag_service):
"""Test RAGService initializes correctly"""
assert rag_service is not None
assert hasattr(rag_service, "search_documents")
assert hasattr(rag_service, "search_code_examples")
assert hasattr(rag_service, "perform_rag_query")
def test_get_setting(self, rag_service):
"""Test settings retrieval"""
with patch.dict("os.environ", {"USE_HYBRID_SEARCH": "true"}):
result = rag_service.get_setting("USE_HYBRID_SEARCH", "false")
assert result == "true"
def test_get_bool_setting(self, rag_service):
"""Test boolean settings retrieval"""
with patch.dict("os.environ", {"USE_RERANKING": "true"}):
result = rag_service.get_bool_setting("USE_RERANKING", False)
assert result is True
@pytest.mark.asyncio
async def test_search_code_examples(self, rag_service):
"""Test code examples search"""
with patch.object(
rag_service.agentic_strategy, "search_code_examples"
) as mock_agentic_search:
# Mock agentic search results
mock_agentic_search.return_value = [
{
"content": "def example():\n pass",
"summary": "Python function example",
"url": "test.py",
"metadata": {"language": "python"},
}
]
result = await rag_service.search_code_examples(
query="python function example", match_count=5
)
assert isinstance(result, list)
assert len(result) == 1
mock_agentic_search.assert_called_once()
@pytest.mark.asyncio
async def test_perform_rag_query(self, rag_service):
"""Test complete RAG query flow"""
# Create a mock reranking strategy if it doesn't exist
if rag_service.reranking_strategy is None:
from unittest.mock import Mock
rag_service.reranking_strategy = Mock()
rag_service.reranking_strategy.rerank_results = AsyncMock()
with (
patch.object(rag_service, "search_documents") as mock_search,
patch.object(rag_service.reranking_strategy, "rerank_results") as mock_rerank,
):
mock_search.return_value = [{"content": "Relevant content", "score": 0.90}]
mock_rerank.return_value = [{"content": "Relevant content", "score": 0.95}]
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
assert success is True
assert "results" in result
assert isinstance(result["results"], list)
@pytest.mark.asyncio
async def test_rerank_results(self, rag_service):
"""Test result reranking via strategy"""
from src.server.services.search import RerankingStrategy
# Create a mock reranking strategy
mock_strategy = MagicMock(spec=RerankingStrategy)
mock_strategy.rerank_results = AsyncMock(
return_value=[{"content": "Reranked content", "score": 0.98}]
)
# Assign the mock strategy to the service
rag_service.reranking_strategy = mock_strategy
original_results = [{"content": "Original content", "score": 0.80}]
# Call the strategy directly (as the service now does internally)
result = await rag_service.reranking_strategy.rerank_results(
query="test query", results=original_results
)
assert isinstance(result, list)
assert result[0]["content"] == "Reranked content"
class TestHybridSearchStrategy:
"""Test hybrid search strategy implementation"""
@pytest.fixture
def mock_supabase_client(self):
"""Mock Supabase client"""
return MagicMock()
@pytest.fixture
def hybrid_strategy(self, mock_supabase_client):
"""Create HybridSearchStrategy instance"""
from src.server.services.search import HybridSearchStrategy
from src.server.services.search.base_search_strategy import BaseSearchStrategy
base_strategy = BaseSearchStrategy(mock_supabase_client)
return HybridSearchStrategy(mock_supabase_client, base_strategy)
def test_hybrid_strategy_initialization(self, hybrid_strategy):
"""Test HybridSearchStrategy initializes correctly"""
assert hybrid_strategy is not None
assert hasattr(hybrid_strategy, "search_documents_hybrid")
assert hasattr(hybrid_strategy, "search_code_examples_hybrid")
def test_merge_search_results(self, hybrid_strategy):
"""Test search result merging"""
vector_results = [
{
"id": "1",
"content": "Vector result 1",
"score": 0.9,
"url": "url1",
"chunk_number": 1,
"metadata": {},
"source_id": "source1",
"similarity": 0.9,
}
]
keyword_results = [
{
"id": "2",
"content": "Keyword result 1",
"score": 0.8,
"url": "url2",
"chunk_number": 1,
"metadata": {},
"source_id": "source2",
}
]
merged = hybrid_strategy._merge_search_results(
vector_results, keyword_results, match_count=5
)
assert isinstance(merged, list)
assert len(merged) <= 5
# Should contain results from both sources
if merged:
assert any("Vector result" in str(r) or "Keyword result" in str(r) for r in merged)
class TestRerankingStrategy:
"""Test reranking strategy implementation"""
@pytest.fixture
def reranking_strategy(self):
"""Create RerankingStrategy instance"""
from src.server.services.search import RerankingStrategy
return RerankingStrategy()
def test_reranking_strategy_initialization(self, reranking_strategy):
"""Test RerankingStrategy initializes correctly"""
assert reranking_strategy is not None
assert hasattr(reranking_strategy, "rerank_results")
assert hasattr(reranking_strategy, "is_available")
def test_model_availability_check(self, reranking_strategy):
"""Test model availability checking"""
# This should not crash even if model not available
availability = reranking_strategy.is_available()
assert isinstance(availability, bool)
@pytest.mark.asyncio
async def test_rerank_results_no_model(self, reranking_strategy):
"""Test reranking when model not available"""
with patch.object(reranking_strategy, "is_available") as mock_available:
mock_available.return_value = False
original_results = [{"content": "Test content", "score": 0.8}]
result = await reranking_strategy.rerank_results(
query="test query", results=original_results
)
# Should return original results when model not available
assert result == original_results
@pytest.mark.asyncio
async def test_rerank_results_with_model(self, reranking_strategy):
"""Test reranking when model is available"""
with (
patch.object(reranking_strategy, "is_available") as mock_available,
patch.object(reranking_strategy, "model") as mock_model,
):
mock_available.return_value = True
mock_model_instance = MagicMock()
mock_model_instance.predict.return_value = [0.95, 0.85] # Mock scores
mock_model = mock_model_instance
reranking_strategy.model = mock_model_instance
original_results = [
{"content": "Content 1", "score": 0.8},
{"content": "Content 2", "score": 0.7},
]
result = await reranking_strategy.rerank_results(
query="test query", results=original_results
)
assert isinstance(result, list)
assert len(result) <= len(original_results)
class TestAgenticRAGStrategy:
"""Test agentic RAG strategy implementation"""
@pytest.fixture
def mock_supabase_client(self):
"""Mock Supabase client"""
return MagicMock()
@pytest.fixture
def agentic_strategy(self, mock_supabase_client):
"""Create AgenticRAGStrategy instance"""
from src.server.services.search import AgenticRAGStrategy
from src.server.services.search.base_search_strategy import BaseSearchStrategy
base_strategy = BaseSearchStrategy(mock_supabase_client)
return AgenticRAGStrategy(mock_supabase_client, base_strategy)
def test_agentic_strategy_initialization(self, agentic_strategy):
"""Test AgenticRAGStrategy initializes correctly"""
assert agentic_strategy is not None
# Check for expected methods
methods = dir(agentic_strategy)
assert any("search" in method.lower() for method in methods)
class TestRAGIntegration:
"""Integration tests for RAG strategies working together"""
@pytest.fixture
def mock_supabase_client(self):
"""Mock Supabase client"""
return MagicMock()
@pytest.fixture
def rag_service(self, mock_supabase_client):
"""Create RAGService instance"""
from src.server.services.search import RAGService
return RAGService(supabase_client=mock_supabase_client)
@pytest.mark.asyncio
async def test_full_rag_pipeline(self, rag_service):
"""Test complete RAG pipeline with all strategies"""
# Create a mock reranking strategy if it doesn't exist
if rag_service.reranking_strategy is None:
from unittest.mock import Mock
rag_service.reranking_strategy = Mock()
rag_service.reranking_strategy.rerank_results = AsyncMock()
with (
patch(
"src.server.services.embeddings.embedding_service.create_embedding"
) as mock_embedding,
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
patch.object(rag_service, "get_bool_setting") as mock_settings,
patch.object(rag_service.reranking_strategy, "rerank_results") as mock_rerank,
):
# Mock embedding creation
mock_embedding.return_value = [0.1] * 1536
# Enable all strategies
mock_settings.side_effect = lambda key, default: True
mock_search.return_value = [
{"content": "Test result 1", "similarity": 0.9, "id": "1", "metadata": {}},
{"content": "Test result 2", "similarity": 0.8, "id": "2", "metadata": {}},
]
mock_rerank.return_value = [
{"content": "Reranked result", "similarity": 0.95, "id": "1", "metadata": {}}
]
success, result = await rag_service.perform_rag_query(
query="complex technical query", match_count=10
)
assert success is True
assert "results" in result
assert isinstance(result["results"], list)
@pytest.mark.asyncio
async def test_error_handling_in_rag_pipeline(self, rag_service):
"""Test error handling when strategies fail"""
with patch(
"src.server.services.embeddings.embedding_service.create_embedding"
) as mock_embedding:
# Simulate embedding failure (returns None)
mock_embedding.return_value = None
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
# Should handle gracefully by returning empty results
assert success is True
assert "results" in result
assert len(result["results"]) == 0 # Empty results due to embedding failure
@pytest.mark.asyncio
async def test_empty_results_handling(self, rag_service):
"""Test handling of empty search results"""
with (
patch(
"src.server.services.embeddings.embedding_service.create_embedding"
) as mock_embedding,
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
):
# Mock embedding creation
mock_embedding.return_value = [0.1] * 1536
mock_search.return_value = [] # Empty results
success, result = await rag_service.perform_rag_query(
query="nonexistent query", match_count=5
)
assert success is True
assert "results" in result
assert len(result["results"]) == 0
class TestRAGPerformance:
"""Test RAG performance and optimization features"""
@pytest.fixture
def rag_service(self):
"""Create RAGService instance"""
from unittest.mock import MagicMock
from src.server.services.search import RAGService
mock_client = MagicMock()
return RAGService(supabase_client=mock_client)
@pytest.mark.asyncio
async def test_concurrent_rag_queries(self, rag_service):
"""Test multiple concurrent RAG queries"""
with (
patch(
"src.server.services.embeddings.embedding_service.create_embedding"
) as mock_embedding,
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
):
# Mock embedding creation
mock_embedding.return_value = [0.1] * 1536
mock_search.return_value = [
{
"content": "Result for concurrent test",
"similarity": 0.9,
"id": "1",
"metadata": {},
}
]
# Run multiple queries concurrently
queries = ["query 1", "query 2", "query 3"]
tasks = [rag_service.perform_rag_query(query, match_count=3) for query in queries]
results = await asyncio.gather(*tasks, return_exceptions=True)
# All should complete successfully
assert len(results) == 3
for result in results:
if isinstance(result, tuple):
success, data = result
assert success is True or isinstance(data, dict)
@pytest.mark.asyncio
async def test_large_result_set_handling(self, rag_service):
"""Test handling of large result sets"""
with (
patch(
"src.server.services.embeddings.embedding_service.create_embedding"
) as mock_embedding,
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
):
# Mock embedding creation
mock_embedding.return_value = [0.1] * 1536
# Create large result set, but limit to match_count
large_results = [
{
"content": f"Result {i}",
"similarity": 0.9 - (i * 0.01),
"id": str(i),
"metadata": {},
}
for i in range(50) # Only return up to match_count results
]
mock_search.return_value = large_results
success, result = await rag_service.perform_rag_query(
query="large query", match_count=50
)
assert success is True
assert "results" in result
# Should respect match_count limit
assert len(result["results"]) <= 50
class TestRAGConfiguration:
"""Test RAG configuration and settings"""
@pytest.fixture
def rag_service(self):
"""Create RAGService instance"""
from unittest.mock import MagicMock
from src.server.services.search import RAGService
mock_client = MagicMock()
return RAGService(supabase_client=mock_client)
def test_environment_variable_settings(self, rag_service):
"""Test reading settings from environment variables"""
with patch.dict(
"os.environ",
{"USE_HYBRID_SEARCH": "true", "USE_RERANKING": "false", "USE_AGENTIC_RAG": "true"},
):
assert rag_service.get_bool_setting("USE_HYBRID_SEARCH") is True
assert rag_service.get_bool_setting("USE_RERANKING") is False
assert rag_service.get_bool_setting("USE_AGENTIC_RAG") is True
def test_default_settings(self, rag_service):
"""Test default settings when environment variables not set"""
with patch.dict("os.environ", {}, clear=True):
assert rag_service.get_bool_setting("NONEXISTENT_SETTING", True) is True
assert rag_service.get_bool_setting("NONEXISTENT_SETTING", False) is False
@pytest.mark.asyncio
async def test_strategy_conditional_execution(self, rag_service):
"""Test that strategies only execute when enabled"""
with (
patch(
"src.server.services.embeddings.embedding_service.create_embedding"
) as mock_embedding,
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
patch.object(rag_service, "get_bool_setting") as mock_setting,
):
# Mock embedding creation
mock_embedding.return_value = [0.1] * 1536
mock_search.return_value = [
{"content": "test", "similarity": 0.9, "id": "1", "metadata": {}}
]
# Disable all strategies
mock_setting.return_value = False
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
assert success is True
# Should still return results from basic search
assert "results" in result

View File

@@ -0,0 +1,95 @@
"""Service integration tests - Test core service interactions."""
def test_project_with_tasks_flow(client):
"""Test creating a project and adding tasks."""
# Create project
project_response = client.post("/api/projects", json={"title": "Test Project"})
assert project_response.status_code in [200, 201, 422]
# List projects to verify
list_response = client.get("/api/projects")
assert list_response.status_code in [200, 500] # 500 is OK in test environment
def test_crawl_to_knowledge_flow(client):
"""Test crawling workflow."""
# Start crawl
crawl_data = {"url": "https://example.com", "max_depth": 1, "max_pages": 5}
response = client.post("/api/knowledge/crawl", json=crawl_data)
assert response.status_code in [200, 201, 400, 404, 422, 500]
def test_document_storage_flow(client):
"""Test document upload endpoint."""
# Test multipart form upload
files = {"file": ("test.txt", b"Test content", "text/plain")}
response = client.post("/api/knowledge/documents", files=files)
assert response.status_code in [200, 201, 400, 404, 422, 500]
def test_code_extraction_flow(client):
"""Test code extraction endpoint."""
response = client.post(
"/api/knowledge/extract-code", json={"document_id": "test-doc-id", "languages": ["python"]}
)
assert response.status_code in [200, 400, 404, 422, 500]
def test_search_and_retrieve_flow(client):
"""Test search functionality."""
# Search
search_response = client.post("/api/knowledge/search", json={"query": "test"})
assert search_response.status_code in [200, 400, 404, 422, 500]
# Get specific item (might not exist)
item_response = client.get("/api/knowledge/items/test-id")
assert item_response.status_code in [200, 404, 500]
def test_mcp_tool_execution(client):
"""Test MCP tool execution endpoint."""
response = client.post("/api/mcp/tools/execute", json={"tool": "test_tool", "params": {}})
assert response.status_code in [200, 400, 404, 422, 500]
def test_socket_io_events(client):
"""Test Socket.IO connectivity."""
# Just verify the endpoint exists
response = client.get("/socket.io/")
assert response.status_code in [200, 400, 404]
def test_background_task_progress(client):
"""Test background task tracking."""
# Check if task progress endpoint exists
response = client.get("/api/tasks/test-task-id/progress")
assert response.status_code in [200, 404, 500]
def test_database_operations(client):
"""Test pagination and filtering."""
# Test with query params
response = client.get("/api/projects?limit=10&offset=0")
assert response.status_code in [200, 500] # 500 is OK in test environment
# Test filtering
response = client.get("/api/tasks?status=todo")
assert response.status_code in [200, 400, 422, 500]
def test_concurrent_operations(client):
"""Test API handles concurrent requests."""
import concurrent.futures
def make_request():
return client.get("/api/projects")
# Make 3 concurrent requests
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
futures = [executor.submit(make_request) for _ in range(3)]
results = [f.result() for f in futures]
# All should succeed or fail with 500 in test environment
for result in results:
assert result.status_code in [200, 500] # 500 is OK in test environment

View File

@@ -0,0 +1,58 @@
"""
Simple tests for settings API credential handling.
Focus on critical paths for optional settings with defaults.
"""
from unittest.mock import AsyncMock, MagicMock, patch
def test_optional_setting_returns_default(client, mock_supabase_client):
"""Test that optional settings return default values with is_default flag."""
# Mock the entire credential_service instance
mock_service = MagicMock()
mock_service.get_credential = AsyncMock(return_value=None)
with patch("src.server.api_routes.settings_api.credential_service", mock_service):
response = client.get("/api/credentials/DISCONNECT_SCREEN_ENABLED")
assert response.status_code == 200
data = response.json()
assert data["key"] == "DISCONNECT_SCREEN_ENABLED"
assert data["value"] == "true"
assert data["is_default"] is True
assert "category" in data
assert "description" in data
def test_unknown_credential_returns_404(client, mock_supabase_client):
"""Test that unknown credentials still return 404."""
# Mock the entire credential_service instance
mock_service = MagicMock()
mock_service.get_credential = AsyncMock(return_value=None)
with patch("src.server.api_routes.settings_api.credential_service", mock_service):
response = client.get("/api/credentials/UNKNOWN_KEY_THAT_DOES_NOT_EXIST")
assert response.status_code == 404
data = response.json()
assert "error" in data["detail"]
assert "not found" in data["detail"]["error"].lower()
def test_existing_credential_returns_normally(client, mock_supabase_client):
"""Test that existing credentials return without default flag."""
mock_value = "user_configured_value"
# Mock the entire credential_service instance
mock_service = MagicMock()
mock_service.get_credential = AsyncMock(return_value=mock_value)
with patch("src.server.api_routes.settings_api.credential_service", mock_service):
response = client.get("/api/credentials/SOME_EXISTING_KEY")
assert response.status_code == 200
data = response.json()
assert data["key"] == "SOME_EXISTING_KEY"
assert data["value"] == "user_configured_value"
assert data["is_encrypted"] is False
# Should not have is_default flag for real credentials
assert "is_default" not in data