mirror of
https://github.com/coleam00/Archon.git
synced 2025-12-27 04:00:29 -05:00
The New Archon (Beta) - The Operating System for AI Coding Assistants!
This commit is contained in:
1
python/tests/__init__.py
Normal file
1
python/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Simplified test suite for Archon - Essential tests only."""
|
||||
124
python/tests/conftest.py
Normal file
124
python/tests/conftest.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Simple test configuration for Archon - Essential tests only."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Set test environment
|
||||
os.environ["TEST_MODE"] = "true"
|
||||
os.environ["TESTING"] = "true"
|
||||
# Set fake database credentials to prevent connection attempts
|
||||
os.environ["SUPABASE_URL"] = "https://test.supabase.co"
|
||||
os.environ["SUPABASE_SERVICE_KEY"] = "test-key"
|
||||
# Set required port environment variables for ServiceDiscovery
|
||||
os.environ.setdefault("ARCHON_SERVER_PORT", "8181")
|
||||
os.environ.setdefault("ARCHON_MCP_PORT", "8051")
|
||||
os.environ.setdefault("ARCHON_AGENTS_PORT", "8052")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def prevent_real_db_calls():
|
||||
"""Automatically prevent any real database calls in all tests."""
|
||||
with patch("supabase.create_client") as mock_create:
|
||||
# Make create_client raise an error if called without our mock
|
||||
mock_create.side_effect = Exception("Real database calls are not allowed in tests!")
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase_client():
|
||||
"""Mock Supabase client for testing."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock table operations with chaining support
|
||||
mock_table = MagicMock()
|
||||
mock_select = MagicMock()
|
||||
mock_insert = MagicMock()
|
||||
mock_update = MagicMock()
|
||||
mock_delete = MagicMock()
|
||||
|
||||
# Setup method chaining for select
|
||||
mock_select.execute.return_value.data = []
|
||||
mock_select.eq.return_value = mock_select
|
||||
mock_select.neq.return_value = mock_select
|
||||
mock_select.order.return_value = mock_select
|
||||
mock_select.limit.return_value = mock_select
|
||||
mock_table.select.return_value = mock_select
|
||||
|
||||
# Setup method chaining for insert
|
||||
mock_insert.execute.return_value.data = [{"id": "test-id"}]
|
||||
mock_table.insert.return_value = mock_insert
|
||||
|
||||
# Setup method chaining for update
|
||||
mock_update.execute.return_value.data = [{"id": "test-id"}]
|
||||
mock_update.eq.return_value = mock_update
|
||||
mock_table.update.return_value = mock_update
|
||||
|
||||
# Setup method chaining for delete
|
||||
mock_delete.execute.return_value.data = []
|
||||
mock_delete.eq.return_value = mock_delete
|
||||
mock_table.delete.return_value = mock_delete
|
||||
|
||||
# Make table() return the mock table
|
||||
mock_client.table.return_value = mock_table
|
||||
|
||||
# Mock auth operations
|
||||
mock_client.auth = MagicMock()
|
||||
mock_client.auth.get_user.return_value = None
|
||||
|
||||
# Mock storage operations
|
||||
mock_client.storage = MagicMock()
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_supabase_client):
|
||||
"""FastAPI test client with mocked database."""
|
||||
# Patch all the ways Supabase client can be created
|
||||
with patch(
|
||||
"src.server.services.client_manager.create_client", return_value=mock_supabase_client
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.credential_service.create_client",
|
||||
return_value=mock_supabase_client,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.client_manager.get_supabase_client",
|
||||
return_value=mock_supabase_client,
|
||||
):
|
||||
with patch("supabase.create_client", return_value=mock_supabase_client):
|
||||
# Import app after patching to ensure mocks are used
|
||||
from src.server.main import app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_project():
|
||||
"""Simple test project data."""
|
||||
return {"title": "Test Project", "description": "A test project for essential tests"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_task():
|
||||
"""Simple test task data."""
|
||||
return {
|
||||
"title": "Test Task",
|
||||
"description": "A test task for essential tests",
|
||||
"status": "todo",
|
||||
"assignee": "User",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_knowledge_item():
|
||||
"""Simple test knowledge item data."""
|
||||
return {
|
||||
"url": "https://example.com/test",
|
||||
"title": "Test Knowledge Item",
|
||||
"content": "This is test content for knowledge base",
|
||||
"source_id": "test-source",
|
||||
}
|
||||
113
python/tests/test_api_essentials.py
Normal file
113
python/tests/test_api_essentials.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Essential API tests - Focus on core functionality that must work."""
|
||||
|
||||
|
||||
def test_health_endpoint(client):
|
||||
"""Test that health endpoint returns OK status."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert data["status"] in ["healthy", "initializing"]
|
||||
|
||||
|
||||
def test_create_project(client, test_project, mock_supabase_client):
|
||||
"""Test creating a new project via API."""
|
||||
# Set up mock to return a project
|
||||
mock_supabase_client.table.return_value.insert.return_value.execute.return_value.data = [
|
||||
{
|
||||
"id": "test-project-id",
|
||||
"title": test_project["title"],
|
||||
"description": test_project["description"],
|
||||
}
|
||||
]
|
||||
|
||||
response = client.post("/api/projects", json=test_project)
|
||||
# Should succeed with mocked data
|
||||
assert response.status_code in [200, 201, 422, 500] # Allow various responses
|
||||
|
||||
# If successful, check response format
|
||||
if response.status_code in [200, 201]:
|
||||
data = response.json()
|
||||
# Check response format - at least one of these should be present
|
||||
assert (
|
||||
"title" in data
|
||||
or "id" in data
|
||||
or "progress_id" in data
|
||||
or "status" in data
|
||||
or "message" in data
|
||||
)
|
||||
|
||||
|
||||
def test_list_projects(client, mock_supabase_client):
|
||||
"""Test listing projects endpoint exists and responds."""
|
||||
# Set up mock to return empty list (no projects)
|
||||
mock_supabase_client.table.return_value.select.return_value.execute.return_value.data = []
|
||||
|
||||
response = client.get("/api/projects")
|
||||
assert response.status_code in [200, 404, 422, 500] # Allow various responses
|
||||
|
||||
# If successful, response should be JSON (list or dict)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert isinstance(data, (list, dict))
|
||||
|
||||
|
||||
def test_create_task(client, test_task):
|
||||
"""Test task creation endpoint exists."""
|
||||
# Try the tasks endpoint directly
|
||||
response = client.post("/api/tasks", json=test_task)
|
||||
# Accept various status codes - endpoint exists
|
||||
assert response.status_code in [200, 201, 400, 422, 405]
|
||||
|
||||
|
||||
def test_list_tasks(client):
|
||||
"""Test tasks listing endpoint exists."""
|
||||
response = client.get("/api/tasks")
|
||||
# Accept 200, 400, 422, or 500 - endpoint exists
|
||||
assert response.status_code in [200, 400, 422, 500]
|
||||
|
||||
|
||||
def test_start_crawl(client):
|
||||
"""Test crawl endpoint exists and validates input."""
|
||||
crawl_request = {"url": "https://example.com", "max_depth": 2, "max_pages": 10}
|
||||
|
||||
response = client.post("/api/knowledge/crawl", json=crawl_request)
|
||||
# Accept various status codes - endpoint exists and processes request
|
||||
assert response.status_code in [200, 201, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_search_knowledge(client):
|
||||
"""Test knowledge search endpoint exists."""
|
||||
response = client.post("/api/knowledge/search", json={"query": "test"})
|
||||
# Accept various status codes - endpoint exists
|
||||
assert response.status_code in [200, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_websocket_connection(client):
|
||||
"""Test WebSocket/Socket.IO endpoint exists."""
|
||||
response = client.get("/socket.io/")
|
||||
# Socket.IO returns specific status codes
|
||||
assert response.status_code in [200, 400, 404]
|
||||
|
||||
|
||||
def test_authentication(client):
|
||||
"""Test that API handles auth headers gracefully."""
|
||||
# Test with no auth header
|
||||
response = client.get("/api/projects")
|
||||
assert response.status_code in [200, 401, 403, 500] # 500 is OK in test environment
|
||||
|
||||
# Test with invalid auth header
|
||||
headers = {"Authorization": "Bearer invalid-token"}
|
||||
response = client.get("/api/projects", headers=headers)
|
||||
assert response.status_code in [200, 401, 403, 500] # 500 is OK in test environment
|
||||
|
||||
|
||||
def test_error_handling(client):
|
||||
"""Test API returns proper error responses."""
|
||||
# Test non-existent endpoint
|
||||
response = client.get("/api/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test invalid JSON
|
||||
response = client.post("/api/projects", data="invalid json")
|
||||
assert response.status_code in [400, 422]
|
||||
509
python/tests/test_async_background_task_manager.py
Normal file
509
python/tests/test_async_background_task_manager.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""
|
||||
Comprehensive Tests for Async Background Task Manager
|
||||
|
||||
Tests the pure async background task manager after removal of ThreadPoolExecutor.
|
||||
Focuses on async task execution, concurrency control, and progress tracking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.services.background_task_manager import (
|
||||
BackgroundTaskManager,
|
||||
cleanup_task_manager,
|
||||
get_task_manager,
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncBackgroundTaskManager:
|
||||
"""Test suite for async background task manager"""
|
||||
|
||||
@pytest.fixture
|
||||
def task_manager(self):
|
||||
"""Create a fresh task manager instance for each test"""
|
||||
return BackgroundTaskManager(max_concurrent_tasks=5)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_progress_callback(self):
|
||||
"""Mock progress callback function"""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_manager_initialization(self, task_manager):
|
||||
"""Test task manager initialization"""
|
||||
assert task_manager.max_concurrent_tasks == 5
|
||||
assert len(task_manager.active_tasks) == 0
|
||||
assert len(task_manager.task_metadata) == 0
|
||||
assert task_manager._task_semaphore._value == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_async_task_execution(self, task_manager, mock_progress_callback):
|
||||
"""Test execution of a simple async task"""
|
||||
|
||||
async def simple_task(message: str):
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return f"Task completed: {message}"
|
||||
|
||||
task_id = await task_manager.submit_task(
|
||||
simple_task, ("Hello World",), progress_callback=mock_progress_callback
|
||||
)
|
||||
|
||||
# Wait for task completion
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Check task status
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "complete"
|
||||
assert status["progress"] == 100
|
||||
assert status["result"] == "Task completed: Hello World"
|
||||
|
||||
# Verify progress callback was called
|
||||
assert mock_progress_callback.call_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_with_error(self, task_manager, mock_progress_callback):
|
||||
"""Test handling of task that raises an exception"""
|
||||
|
||||
async def failing_task():
|
||||
await asyncio.sleep(0.01)
|
||||
raise ValueError("Task failed intentionally")
|
||||
|
||||
task_id = await task_manager.submit_task(
|
||||
failing_task, (), progress_callback=mock_progress_callback
|
||||
)
|
||||
|
||||
# Wait for task to fail
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Check task status
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "error"
|
||||
assert status["progress"] == -1
|
||||
assert "error" in status
|
||||
assert "Task failed intentionally" in status["error"]
|
||||
|
||||
# Verify error was reported via progress callback
|
||||
error_call = None
|
||||
for call in mock_progress_callback.call_args_list:
|
||||
if len(call[0]) >= 2 and call[0][1].get("status") == "error":
|
||||
error_call = call
|
||||
break
|
||||
|
||||
assert error_call is not None
|
||||
assert "Task failed intentionally" in error_call[0][1]["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_task_execution(self, task_manager):
|
||||
"""Test execution of multiple concurrent tasks"""
|
||||
|
||||
async def numbered_task(number: int):
|
||||
await asyncio.sleep(0.01)
|
||||
return f"Task {number} completed"
|
||||
|
||||
# Submit 5 tasks simultaneously
|
||||
task_ids = []
|
||||
for i in range(5):
|
||||
task_id = await task_manager.submit_task(numbered_task, (i,), task_id=f"task-{i}")
|
||||
task_ids.append(task_id)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Check all tasks completed successfully
|
||||
for i, task_id in enumerate(task_ids):
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "complete"
|
||||
assert status["result"] == f"Task {i} completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_limit(self, task_manager):
|
||||
"""Test that concurrency is limited by semaphore"""
|
||||
# Use a task manager with limit of 2
|
||||
limited_manager = BackgroundTaskManager(max_concurrent_tasks=2)
|
||||
|
||||
running_tasks = []
|
||||
completed_tasks = []
|
||||
|
||||
async def long_running_task(task_id: int):
|
||||
running_tasks.append(task_id)
|
||||
await asyncio.sleep(0.05) # Long enough to test concurrency
|
||||
completed_tasks.append(task_id)
|
||||
return f"Task {task_id} completed"
|
||||
|
||||
# Submit 4 tasks
|
||||
task_ids = []
|
||||
for i in range(4):
|
||||
task_id = await limited_manager.submit_task(
|
||||
long_running_task, (i,), task_id=f"concurrent-task-{i}"
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
|
||||
# Wait a bit and check that only 2 tasks are running
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(running_tasks) <= 2
|
||||
|
||||
# Wait for all to complete
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(completed_tasks) == 4
|
||||
|
||||
# Clean up
|
||||
await limited_manager.cleanup()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cancellation(self, task_manager):
|
||||
"""Test cancellation of running task"""
|
||||
|
||||
async def long_task():
|
||||
try:
|
||||
await asyncio.sleep(1.0) # Long enough to be cancelled
|
||||
return "Should not complete"
|
||||
except asyncio.CancelledError:
|
||||
raise # Re-raise to properly handle cancellation
|
||||
|
||||
task_id = await task_manager.submit_task(long_task, (), task_id="cancellable-task")
|
||||
|
||||
# Wait a bit, then cancel
|
||||
await asyncio.sleep(0.01)
|
||||
cancelled = await task_manager.cancel_task(task_id)
|
||||
assert cancelled is True
|
||||
|
||||
# Check task status
|
||||
await asyncio.sleep(0.01)
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "cancelled"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_not_found(self, task_manager):
|
||||
"""Test getting status of non-existent task"""
|
||||
status = await task_manager.get_task_status("non-existent-task")
|
||||
assert status["error"] == "Task not found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_non_existent_task(self, task_manager):
|
||||
"""Test cancelling non-existent task"""
|
||||
cancelled = await task_manager.cancel_task("non-existent-task")
|
||||
assert cancelled is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_callback_execution(self, task_manager):
|
||||
"""Test that progress callback is properly executed"""
|
||||
progress_updates = []
|
||||
|
||||
async def mock_progress_callback(task_id: str, update: dict[str, Any]):
|
||||
progress_updates.append((task_id, update))
|
||||
|
||||
async def simple_task():
|
||||
await asyncio.sleep(0.01)
|
||||
return "completed"
|
||||
|
||||
task_id = await task_manager.submit_task(
|
||||
simple_task, (), task_id="progress-test-task", progress_callback=mock_progress_callback
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Should have at least one progress update (completion)
|
||||
assert len(progress_updates) >= 1
|
||||
|
||||
# Check that task_id matches
|
||||
assert all(update[0] == task_id for update in progress_updates)
|
||||
|
||||
# Check for completion update
|
||||
completion_updates = [
|
||||
update for update in progress_updates if update[1].get("status") == "complete"
|
||||
]
|
||||
assert len(completion_updates) >= 1
|
||||
assert completion_updates[0][1]["percentage"] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_callback_error_handling(self, task_manager):
|
||||
"""Test that task continues even if progress callback fails"""
|
||||
|
||||
async def failing_progress_callback(task_id: str, update: dict[str, Any]):
|
||||
raise Exception("Progress callback failed")
|
||||
|
||||
async def simple_task():
|
||||
await asyncio.sleep(0.01)
|
||||
return "Task completed despite callback failure"
|
||||
|
||||
task_id = await task_manager.submit_task(
|
||||
simple_task, (), progress_callback=failing_progress_callback
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Task should still complete successfully
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "complete"
|
||||
assert status["result"] == "Task completed despite callback failure"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_metadata_tracking(self, task_manager):
|
||||
"""Test that task metadata is properly tracked"""
|
||||
|
||||
async def simple_task():
|
||||
await asyncio.sleep(0.01)
|
||||
return "result"
|
||||
|
||||
task_id = await task_manager.submit_task(simple_task, (), task_id="metadata-test")
|
||||
|
||||
# Check initial metadata
|
||||
initial_status = await task_manager.get_task_status(task_id)
|
||||
assert initial_status["status"] == "running"
|
||||
assert "created_at" in initial_status
|
||||
assert initial_status["progress"] == 0
|
||||
|
||||
# Wait for completion
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Check final metadata
|
||||
final_status = await task_manager.get_task_status(task_id)
|
||||
assert final_status["status"] == "complete"
|
||||
assert final_status["progress"] == 100
|
||||
assert final_status["result"] == "result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_active_tasks(self, task_manager):
|
||||
"""Test cleanup cancels active tasks"""
|
||||
|
||||
async def long_running_task():
|
||||
try:
|
||||
await asyncio.sleep(1.0)
|
||||
return "Should not complete"
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
# Submit multiple long-running tasks
|
||||
task_ids = []
|
||||
for i in range(3):
|
||||
task_id = await task_manager.submit_task(
|
||||
long_running_task, (), task_id=f"cleanup-test-{i}"
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
|
||||
# Verify tasks are active
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(task_manager.active_tasks) == 3
|
||||
|
||||
# Cleanup
|
||||
await task_manager.cleanup()
|
||||
|
||||
# Verify all tasks were cancelled and cleaned up
|
||||
assert len(task_manager.active_tasks) == 0
|
||||
assert len(task_manager.task_metadata) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completed_task_status_after_removal(self, task_manager):
|
||||
"""Test getting status of completed task after it's removed from active_tasks"""
|
||||
|
||||
async def quick_task():
|
||||
return "quick result"
|
||||
|
||||
task_id = await task_manager.submit_task(quick_task, (), task_id="quick-test")
|
||||
|
||||
# Wait for completion and removal from active_tasks
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Should still be able to get status from metadata
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "complete"
|
||||
assert status["result"] == "quick result"
|
||||
|
||||
def test_set_main_loop_deprecated(self, task_manager):
|
||||
"""Test that set_main_loop is deprecated but doesn't break"""
|
||||
# Should not raise an exception but may log a warning
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
task_manager.set_main_loop(loop)
|
||||
loop.close()
|
||||
|
||||
|
||||
class TestGlobalTaskManager:
|
||||
"""Test the global task manager functions"""
|
||||
|
||||
def test_get_task_manager_singleton(self):
|
||||
"""Test that get_task_manager returns singleton"""
|
||||
manager1 = get_task_manager()
|
||||
manager2 = get_task_manager()
|
||||
assert manager1 is manager2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_task_manager(self):
|
||||
"""Test cleanup of global task manager"""
|
||||
# Get the global manager
|
||||
manager = get_task_manager()
|
||||
assert manager is not None
|
||||
|
||||
# Add a task to make it interesting
|
||||
async def test_task():
|
||||
return "test"
|
||||
|
||||
task_id = await manager.submit_task(test_task, ())
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Cleanup
|
||||
await cleanup_task_manager()
|
||||
|
||||
# Verify it was cleaned up - getting a new one should be different
|
||||
new_manager = get_task_manager()
|
||||
assert new_manager is not manager
|
||||
|
||||
|
||||
class TestAsyncTaskPatterns:
|
||||
"""Test various async task patterns and edge cases"""
|
||||
|
||||
@pytest.fixture
|
||||
def task_manager(self):
|
||||
return BackgroundTaskManager(max_concurrent_tasks=3)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_async_calls(self, task_manager):
|
||||
"""Test tasks that make nested async calls"""
|
||||
|
||||
async def nested_task():
|
||||
async def inner_task():
|
||||
await asyncio.sleep(0.01)
|
||||
return "inner result"
|
||||
|
||||
result = await inner_task()
|
||||
return f"outer: {result}"
|
||||
|
||||
task_id = await task_manager.submit_task(nested_task, ())
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "complete"
|
||||
assert status["result"] == "outer: inner result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_with_async_context_manager(self, task_manager):
|
||||
"""Test tasks that use async context managers"""
|
||||
|
||||
class AsyncResource:
|
||||
def __init__(self):
|
||||
self.entered = False
|
||||
self.exited = False
|
||||
|
||||
async def __aenter__(self):
|
||||
await asyncio.sleep(0.001)
|
||||
self.entered = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await asyncio.sleep(0.001)
|
||||
self.exited = True
|
||||
|
||||
resource = AsyncResource()
|
||||
|
||||
async def context_manager_task():
|
||||
async with resource:
|
||||
await asyncio.sleep(0.01)
|
||||
return "context manager used"
|
||||
|
||||
task_id = await task_manager.submit_task(context_manager_task, ())
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "complete"
|
||||
assert status["result"] == "context manager used"
|
||||
assert resource.entered
|
||||
assert resource.exited
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cancellation_propagation(self, task_manager):
|
||||
"""Test that cancellation properly propagates through nested calls"""
|
||||
cancelled_flags = []
|
||||
|
||||
async def cancellable_inner():
|
||||
try:
|
||||
await asyncio.sleep(1.0)
|
||||
return "should not complete"
|
||||
except asyncio.CancelledError:
|
||||
cancelled_flags.append("inner")
|
||||
raise
|
||||
|
||||
async def cancellable_outer():
|
||||
try:
|
||||
result = await cancellable_inner()
|
||||
return f"outer: {result}"
|
||||
except asyncio.CancelledError:
|
||||
cancelled_flags.append("outer")
|
||||
raise
|
||||
|
||||
task_id = await task_manager.submit_task(cancellable_outer, ())
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Cancel the task
|
||||
cancelled = await task_manager.cancel_task(task_id)
|
||||
assert cancelled
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Both inner and outer should have been cancelled
|
||||
assert "inner" in cancelled_flags
|
||||
assert "outer" in cancelled_flags
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_concurrency_stress_test(self, task_manager):
|
||||
"""Stress test with many concurrent tasks"""
|
||||
|
||||
async def stress_task(task_num: int):
|
||||
await asyncio.sleep(0.001 * (task_num % 10)) # Vary sleep time
|
||||
return f"stress-{task_num}"
|
||||
|
||||
# Submit many tasks
|
||||
task_ids = []
|
||||
num_tasks = 20
|
||||
|
||||
for i in range(num_tasks):
|
||||
task_id = await task_manager.submit_task(stress_task, (i,), task_id=f"stress-{i}")
|
||||
task_ids.append(task_id)
|
||||
|
||||
# Wait for all to complete
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Verify all completed successfully
|
||||
for i, task_id in enumerate(task_ids):
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
assert status["status"] == "complete"
|
||||
assert status["result"] == f"stress-{i}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_execution_order_with_semaphore(self, task_manager):
|
||||
"""Test that semaphore properly controls execution order"""
|
||||
# Use manager with limit of 2
|
||||
limited_manager = BackgroundTaskManager(max_concurrent_tasks=2)
|
||||
execution_order = []
|
||||
|
||||
async def ordered_task(task_id: int):
|
||||
execution_order.append(f"start-{task_id}")
|
||||
await asyncio.sleep(0.02)
|
||||
execution_order.append(f"end-{task_id}")
|
||||
return task_id
|
||||
|
||||
# Submit 4 tasks
|
||||
task_ids = []
|
||||
for i in range(4):
|
||||
task_id = await limited_manager.submit_task(ordered_task, (i,), task_id=f"order-{i}")
|
||||
task_ids.append(task_id)
|
||||
|
||||
# Wait for completion
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Verify execution pattern - should see at most 2 concurrent executions
|
||||
starts_before_ends = 0
|
||||
for i, event in enumerate(execution_order):
|
||||
if event.startswith("start-"):
|
||||
# Count how many starts we've seen before the first end
|
||||
starts_seen = sum(1 for e in execution_order[: i + 1] if e.startswith("start-"))
|
||||
ends_seen = sum(1 for e in execution_order[: i + 1] if e.startswith("end-"))
|
||||
concurrent = starts_seen - ends_seen
|
||||
assert concurrent <= 2 # Should never exceed semaphore limit
|
||||
|
||||
await limited_manager.cleanup()
|
||||
414
python/tests/test_async_credential_service.py
Normal file
414
python/tests/test_async_credential_service.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
Comprehensive Tests for Async Credential Service
|
||||
|
||||
Tests the credential service async functions after sync function removal.
|
||||
Covers credential storage, retrieval, encryption/decryption, and caching.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.services.credential_service import (
|
||||
credential_service,
|
||||
get_credential,
|
||||
initialize_credentials,
|
||||
set_credential,
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncCredentialService:
|
||||
"""Test suite for async credential service functions"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_credential_service(self):
|
||||
"""Setup clean credential service for each test"""
|
||||
# Clear cache and reset state
|
||||
credential_service._cache.clear()
|
||||
credential_service._cache_initialized = False
|
||||
yield
|
||||
# Cleanup after test
|
||||
credential_service._cache.clear()
|
||||
credential_service._cache_initialized = False
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase_client(self):
|
||||
"""Mock Supabase client"""
|
||||
mock_client = MagicMock()
|
||||
mock_table = MagicMock()
|
||||
mock_client.table.return_value = mock_table
|
||||
return mock_client, mock_table
|
||||
|
||||
@pytest.fixture
|
||||
def sample_credentials_data(self):
|
||||
"""Sample credentials data from database"""
|
||||
return [
|
||||
{
|
||||
"id": 1,
|
||||
"key": "OPENAI_API_KEY",
|
||||
"encrypted_value": "encrypted_openai_key",
|
||||
"value": None,
|
||||
"is_encrypted": True,
|
||||
"category": "api_keys",
|
||||
"description": "OpenAI API key for LLM access",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"key": "MODEL_CHOICE",
|
||||
"value": "gpt-4.1-nano",
|
||||
"encrypted_value": None,
|
||||
"is_encrypted": False,
|
||||
"category": "rag_strategy",
|
||||
"description": "Default model choice",
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"key": "MAX_TOKENS",
|
||||
"value": "1000",
|
||||
"encrypted_value": None,
|
||||
"is_encrypted": False,
|
||||
"category": "rag_strategy",
|
||||
"description": "Maximum tokens per request",
|
||||
},
|
||||
]
|
||||
|
||||
def test_deprecated_functions_removed(self):
|
||||
"""Test that deprecated sync functions are no longer available"""
|
||||
import src.server.services.credential_service as cred_module
|
||||
|
||||
# The sync function should no longer exist
|
||||
assert not hasattr(cred_module, "get_credential_sync")
|
||||
|
||||
# The async versions should be the primary functions
|
||||
assert hasattr(cred_module, "get_credential")
|
||||
assert hasattr(cred_module, "set_credential")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credential_from_cache(self):
|
||||
"""Test getting credential from initialized cache"""
|
||||
# Setup cache
|
||||
credential_service._cache = {"TEST_KEY": "test_value", "NUMERIC_KEY": "123"}
|
||||
credential_service._cache_initialized = True
|
||||
|
||||
result = await get_credential("TEST_KEY", "default")
|
||||
assert result == "test_value"
|
||||
|
||||
result = await get_credential("NUMERIC_KEY", "default")
|
||||
assert result == "123"
|
||||
|
||||
result = await get_credential("MISSING_KEY", "default_value")
|
||||
assert result == "default_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credential_encrypted_value(self):
|
||||
"""Test getting encrypted credential"""
|
||||
# Setup cache with encrypted value
|
||||
encrypted_data = {"encrypted_value": "encrypted_test_value", "is_encrypted": True}
|
||||
credential_service._cache = {"SECRET_KEY": encrypted_data}
|
||||
credential_service._cache_initialized = True
|
||||
|
||||
with patch.object(credential_service, "_decrypt_value", return_value="decrypted_value"):
|
||||
result = await get_credential("SECRET_KEY", "default")
|
||||
assert result == "decrypted_value"
|
||||
credential_service._decrypt_value.assert_called_once_with("encrypted_test_value")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credential_cache_not_initialized(self, mock_supabase_client):
|
||||
"""Test getting credential when cache is not initialized"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Mock database response for load_all_credentials (gets ALL settings)
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
{
|
||||
"key": "TEST_KEY",
|
||||
"value": "db_value",
|
||||
"encrypted_value": None,
|
||||
"is_encrypted": False,
|
||||
"category": "test",
|
||||
"description": "Test key",
|
||||
}
|
||||
]
|
||||
mock_table.select().execute.return_value = mock_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
result = await credential_service.get_credential("TEST_KEY", "default")
|
||||
assert result == "db_value"
|
||||
|
||||
# Should have called database to load all credentials
|
||||
mock_table.select.assert_called_with("*")
|
||||
# Should have called execute on the query
|
||||
assert mock_table.select().execute.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credential_not_found_in_db(self, mock_supabase_client):
|
||||
"""Test getting credential that doesn't exist in database"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Mock empty database response
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = []
|
||||
mock_table.select().eq().execute.return_value = mock_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
result = await credential_service.get_credential("MISSING_KEY", "default_value")
|
||||
assert result == "default_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_credential_new(self, mock_supabase_client):
|
||||
"""Test setting a new credential"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Mock successful insert
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [{"id": 1, "key": "NEW_KEY", "value": "new_value"}]
|
||||
mock_table.insert().execute.return_value = mock_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
result = await set_credential("NEW_KEY", "new_value", is_encrypted=False)
|
||||
assert result is True
|
||||
|
||||
# Should have attempted insert
|
||||
mock_table.insert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_credential_encrypted(self, mock_supabase_client):
|
||||
"""Test setting an encrypted credential"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Mock successful insert
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [{"id": 1, "key": "SECRET_KEY"}]
|
||||
mock_table.insert().execute.return_value = mock_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
with patch.object(credential_service, "_encrypt_value", return_value="encrypted_value"):
|
||||
result = await set_credential("SECRET_KEY", "secret_value", is_encrypted=True)
|
||||
assert result is True
|
||||
|
||||
# Should have encrypted the value
|
||||
credential_service._encrypt_value.assert_called_once_with("secret_value")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_all_credentials(self, mock_supabase_client, sample_credentials_data):
|
||||
"""Test loading all credentials from database"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Mock database response
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = sample_credentials_data
|
||||
mock_table.select().execute.return_value = mock_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
result = await credential_service.load_all_credentials()
|
||||
|
||||
# Should have loaded credentials into cache
|
||||
assert credential_service._cache_initialized is True
|
||||
assert "OPENAI_API_KEY" in credential_service._cache
|
||||
assert "MODEL_CHOICE" in credential_service._cache
|
||||
assert "MAX_TOKENS" in credential_service._cache
|
||||
|
||||
# Should have stored encrypted values as dict objects (not decrypted yet)
|
||||
openai_key_cache = credential_service._cache["OPENAI_API_KEY"]
|
||||
assert isinstance(openai_key_cache, dict)
|
||||
assert openai_key_cache["encrypted_value"] == "encrypted_openai_key"
|
||||
assert openai_key_cache["is_encrypted"] is True
|
||||
|
||||
# Plain text values should be stored directly
|
||||
assert credential_service._cache["MODEL_CHOICE"] == "gpt-4.1-nano"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credentials_by_category(self, mock_supabase_client):
|
||||
"""Test getting credentials filtered by category"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Mock database response for rag_strategy category
|
||||
rag_data = [
|
||||
{
|
||||
"key": "MODEL_CHOICE",
|
||||
"value": "gpt-4.1-nano",
|
||||
"is_encrypted": False,
|
||||
"description": "Model choice",
|
||||
},
|
||||
{
|
||||
"key": "MAX_TOKENS",
|
||||
"value": "1000",
|
||||
"is_encrypted": False,
|
||||
"description": "Max tokens",
|
||||
},
|
||||
]
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = rag_data
|
||||
mock_table.select().eq().execute.return_value = mock_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
result = await credential_service.get_credentials_by_category("rag_strategy")
|
||||
|
||||
# Should only return rag_strategy credentials
|
||||
assert "MODEL_CHOICE" in result
|
||||
assert "MAX_TOKENS" in result
|
||||
assert result["MODEL_CHOICE"] == "gpt-4.1-nano"
|
||||
assert result["MAX_TOKENS"] == "1000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_provider_llm(self, mock_supabase_client):
|
||||
"""Test getting active LLM provider configuration"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Setup cache directly instead of mocking complex database responses
|
||||
credential_service._cache = {
|
||||
"LLM_PROVIDER": "openai",
|
||||
"MODEL_CHOICE": "gpt-4.1-nano",
|
||||
"OPENAI_API_KEY": {
|
||||
"encrypted_value": "encrypted_key",
|
||||
"is_encrypted": True,
|
||||
"category": "api_keys",
|
||||
"description": "API key",
|
||||
},
|
||||
}
|
||||
credential_service._cache_initialized = True
|
||||
|
||||
# Mock rag_strategy category response
|
||||
rag_response = MagicMock()
|
||||
rag_response.data = [
|
||||
{
|
||||
"key": "LLM_PROVIDER",
|
||||
"value": "openai",
|
||||
"is_encrypted": False,
|
||||
"description": "LLM provider",
|
||||
},
|
||||
{
|
||||
"key": "MODEL_CHOICE",
|
||||
"value": "gpt-4.1-nano",
|
||||
"is_encrypted": False,
|
||||
"description": "Model choice",
|
||||
},
|
||||
]
|
||||
mock_table.select().eq().execute.return_value = rag_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
with patch.object(credential_service, "_decrypt_value", return_value="decrypted_key"):
|
||||
result = await credential_service.get_active_provider("llm")
|
||||
|
||||
assert result["provider"] == "openai"
|
||||
assert result["api_key"] == "decrypted_key"
|
||||
assert result["chat_model"] == "gpt-4.1-nano"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_provider_basic(self, mock_supabase_client):
|
||||
"""Test basic provider configuration retrieval"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Simple mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = []
|
||||
mock_table.select().eq().execute.return_value = mock_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
result = await credential_service.get_active_provider("llm")
|
||||
# Should return default values when no settings found
|
||||
assert "provider" in result
|
||||
assert "api_key" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_credentials(self, mock_supabase_client, sample_credentials_data):
|
||||
"""Test initialize_credentials function"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Mock database response
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = sample_credentials_data
|
||||
mock_table.select().execute.return_value = mock_response
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
with patch.object(credential_service, "_decrypt_value", return_value="decrypted_key"):
|
||||
with patch.dict(os.environ, {}, clear=True): # Clear environment
|
||||
await initialize_credentials()
|
||||
|
||||
# Should have loaded credentials
|
||||
assert credential_service._cache_initialized is True
|
||||
|
||||
# Should have set infrastructure env vars (like OPENAI_API_KEY)
|
||||
# Note: This tests the logic, actual env var setting depends on implementation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_database_failure(self, mock_supabase_client):
|
||||
"""Test error handling when database fails"""
|
||||
mock_client, mock_table = mock_supabase_client
|
||||
|
||||
# Mock database error
|
||||
mock_table.select().eq().execute.side_effect = Exception("Database connection failed")
|
||||
|
||||
with patch.object(credential_service, "_get_supabase_client", return_value=mock_client):
|
||||
result = await credential_service.get_credential("TEST_KEY", "default_value")
|
||||
assert result == "default_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encryption_decryption_error_handling(self):
|
||||
"""Test error handling for encryption/decryption failures"""
|
||||
# Setup cache with encrypted value that fails to decrypt
|
||||
encrypted_data = {"encrypted_value": "corrupted_encrypted_value", "is_encrypted": True}
|
||||
credential_service._cache = {"CORRUPTED_KEY": encrypted_data}
|
||||
credential_service._cache_initialized = True
|
||||
|
||||
with patch.object(
|
||||
credential_service, "_decrypt_value", side_effect=Exception("Decryption failed")
|
||||
):
|
||||
# Should fall back to default when decryption fails
|
||||
result = await credential_service.get_credential("CORRUPTED_KEY", "fallback_value")
|
||||
assert result == "fallback_value"
|
||||
|
||||
def test_direct_cache_access_fallback(self):
|
||||
"""Test direct cache access pattern used in converted sync functions"""
|
||||
# Setup cache
|
||||
credential_service._cache = {
|
||||
"MODEL_CHOICE": "gpt-4.1-nano",
|
||||
"OPENAI_API_KEY": {"encrypted_value": "encrypted_key", "is_encrypted": True},
|
||||
}
|
||||
credential_service._cache_initialized = True
|
||||
|
||||
# Test simple cache access
|
||||
if credential_service._cache_initialized and "MODEL_CHOICE" in credential_service._cache:
|
||||
result = credential_service._cache["MODEL_CHOICE"]
|
||||
assert result == "gpt-4.1-nano"
|
||||
|
||||
# Test encrypted value access
|
||||
if credential_service._cache_initialized and "OPENAI_API_KEY" in credential_service._cache:
|
||||
cached_key = credential_service._cache["OPENAI_API_KEY"]
|
||||
if isinstance(cached_key, dict) and cached_key.get("is_encrypted"):
|
||||
# Would need to call credential_service._decrypt_value(cached_key["encrypted_value"])
|
||||
assert cached_key["encrypted_value"] == "encrypted_key"
|
||||
assert cached_key["is_encrypted"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_access(self):
|
||||
"""Test concurrent access to credential service"""
|
||||
credential_service._cache = {"SHARED_KEY": "shared_value"}
|
||||
credential_service._cache_initialized = True
|
||||
|
||||
async def get_credential_task():
|
||||
return await get_credential("SHARED_KEY", "default")
|
||||
|
||||
# Run multiple concurrent requests
|
||||
tasks = [get_credential_task() for _ in range(10)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All should return the same value
|
||||
assert all(result == "shared_value" for result in results)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_persistence(self):
|
||||
"""Test that cache persists across calls"""
|
||||
credential_service._cache = {"PERSISTENT_KEY": "persistent_value"}
|
||||
credential_service._cache_initialized = True
|
||||
|
||||
# First call
|
||||
result1 = await get_credential("PERSISTENT_KEY", "default")
|
||||
assert result1 == "persistent_value"
|
||||
|
||||
# Second call should use same cache
|
||||
result2 = await get_credential("PERSISTENT_KEY", "default")
|
||||
assert result2 == "persistent_value"
|
||||
assert result1 == result2
|
||||
469
python/tests/test_async_embedding_service.py
Normal file
469
python/tests/test_async_embedding_service.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Comprehensive Tests for Async Embedding Service
|
||||
|
||||
Tests all aspects of the async embedding service after sync function removal.
|
||||
Covers both success and error scenarios with thorough edge case testing.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from src.server.services.embeddings.embedding_exceptions import (
|
||||
EmbeddingAPIError,
|
||||
)
|
||||
from src.server.services.embeddings.embedding_service import (
|
||||
EmbeddingBatchResult,
|
||||
create_embedding,
|
||||
create_embeddings_batch,
|
||||
)
|
||||
|
||||
|
||||
class AsyncContextManager:
|
||||
"""Helper class for properly mocking async context managers"""
|
||||
|
||||
def __init__(self, return_value):
|
||||
self.return_value = return_value
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.return_value
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class TestAsyncEmbeddingService:
|
||||
"""Test suite for async embedding service functions"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client(self):
|
||||
"""Mock LLM client for testing"""
|
||||
mock_client = MagicMock()
|
||||
mock_embeddings = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3] + [0.0] * 1533) # 1536 dimensions
|
||||
]
|
||||
mock_embeddings.create = AsyncMock(return_value=mock_response)
|
||||
mock_client.embeddings = mock_embeddings
|
||||
return mock_client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_threading_service(self):
|
||||
"""Mock threading service for testing"""
|
||||
mock_service = MagicMock()
|
||||
# Create a proper async context manager
|
||||
rate_limit_ctx = AsyncContextManager(None)
|
||||
mock_service.rate_limited_operation.return_value = rate_limit_ctx
|
||||
return mock_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embedding_success(self, mock_llm_client, mock_threading_service):
|
||||
"""Test successful single embedding creation"""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
# Mock credential service properly
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "10"}
|
||||
)
|
||||
|
||||
# Setup proper async context manager
|
||||
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
|
||||
|
||||
result = await create_embedding("test text")
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1536
|
||||
assert result[0] == 0.1
|
||||
assert result[1] == 0.2
|
||||
assert result[2] == 0.3
|
||||
|
||||
# Verify API was called correctly
|
||||
mock_llm_client.embeddings.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embedding_empty_text(self, mock_llm_client, mock_threading_service):
|
||||
"""Test embedding creation with empty text"""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "10"}
|
||||
)
|
||||
|
||||
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
|
||||
|
||||
result = await create_embedding("")
|
||||
|
||||
# Should still work with empty text
|
||||
assert len(result) == 1536
|
||||
mock_llm_client.embeddings.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embedding_api_error_raises_exception(self, mock_threading_service):
|
||||
"""Test embedding creation with API error - should raise exception"""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "10"}
|
||||
)
|
||||
|
||||
# Setup client to raise an error
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create = AsyncMock(
|
||||
side_effect=Exception("API Error")
|
||||
)
|
||||
mock_get_client.return_value = AsyncContextManager(mock_client)
|
||||
|
||||
# Should raise exception now instead of returning zero embeddings
|
||||
with pytest.raises(EmbeddingAPIError):
|
||||
await create_embedding("test text")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embeddings_batch_success(self, mock_llm_client, mock_threading_service):
|
||||
"""Test successful batch embedding creation"""
|
||||
# Setup mock response for multiple embeddings
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3] + [0.0] * 1533),
|
||||
MagicMock(embedding=[0.4, 0.5, 0.6] + [0.0] * 1533),
|
||||
]
|
||||
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "10"}
|
||||
)
|
||||
|
||||
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
|
||||
|
||||
result = await create_embeddings_batch(["text1", "text2"])
|
||||
|
||||
# Verify the result is EmbeddingBatchResult
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
assert result.success_count == 2
|
||||
assert result.failure_count == 0
|
||||
assert len(result.embeddings) == 2
|
||||
assert len(result.embeddings[0]) == 1536
|
||||
assert len(result.embeddings[1]) == 1536
|
||||
assert result.embeddings[0][0] == 0.1
|
||||
assert result.embeddings[1][0] == 0.4
|
||||
|
||||
mock_llm_client.embeddings.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embeddings_batch_empty_list(self):
|
||||
"""Test batch embedding with empty list"""
|
||||
result = await create_embeddings_batch([])
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
assert result.success_count == 0
|
||||
assert result.failure_count == 0
|
||||
assert result.embeddings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embeddings_batch_rate_limit_error(self, mock_threading_service):
|
||||
"""Test batch embedding with rate limit error"""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "10"}
|
||||
)
|
||||
|
||||
# Setup client to raise rate limit error (not quota)
|
||||
mock_client = MagicMock()
|
||||
# Create a proper RateLimitError with required attributes
|
||||
error = openai.RateLimitError(
|
||||
"Rate limit exceeded",
|
||||
response=MagicMock(),
|
||||
body={"error": {"message": "Rate limit exceeded"}},
|
||||
)
|
||||
mock_client.embeddings.create = AsyncMock(side_effect=error)
|
||||
mock_get_client.return_value = AsyncContextManager(mock_client)
|
||||
|
||||
result = await create_embeddings_batch(["text1", "text2"])
|
||||
|
||||
# Should return result with failures, not zero embeddings
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
assert result.success_count == 0
|
||||
assert result.failure_count == 2
|
||||
assert len(result.embeddings) == 0
|
||||
assert len(result.failed_items) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embeddings_batch_quota_exhausted(self, mock_threading_service):
|
||||
"""Test batch embedding with quota exhausted error"""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "10"}
|
||||
)
|
||||
|
||||
# Setup client to raise quota exhausted error
|
||||
mock_client = MagicMock()
|
||||
error = openai.RateLimitError(
|
||||
"insufficient_quota",
|
||||
response=MagicMock(),
|
||||
body={"error": {"message": "insufficient_quota"}},
|
||||
)
|
||||
mock_client.embeddings.create = AsyncMock(side_effect=error)
|
||||
mock_get_client.return_value = AsyncContextManager(mock_client)
|
||||
|
||||
# Mock progress callback
|
||||
progress_callback = AsyncMock()
|
||||
|
||||
result = await create_embeddings_batch(
|
||||
["text1", "text2"], progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# Should return result with failures, not zero embeddings
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
assert result.success_count == 0
|
||||
assert result.failure_count == 2
|
||||
assert len(result.embeddings) == 0
|
||||
assert len(result.failed_items) == 2
|
||||
# Verify quota exhausted is in error messages
|
||||
assert any("quota" in item["error"].lower() for item in result.failed_items)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embeddings_batch_with_websocket_progress(
|
||||
self, mock_llm_client, mock_threading_service
|
||||
):
|
||||
"""Test batch embedding with WebSocket progress updates"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
|
||||
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "1"}
|
||||
)
|
||||
|
||||
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
|
||||
|
||||
# Mock WebSocket
|
||||
mock_websocket = MagicMock()
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
|
||||
result = await create_embeddings_batch(["text1"], websocket=mock_websocket)
|
||||
|
||||
# Verify result is correct
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
assert result.success_count == 1
|
||||
|
||||
# Verify WebSocket was called
|
||||
mock_websocket.send_json.assert_called()
|
||||
call_args = mock_websocket.send_json.call_args[0][0]
|
||||
assert call_args["type"] == "embedding_progress"
|
||||
assert "processed" in call_args
|
||||
assert "total" in call_args
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embeddings_batch_with_progress_callback(
|
||||
self, mock_llm_client, mock_threading_service
|
||||
):
|
||||
"""Test batch embedding with progress callback"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
|
||||
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "1"}
|
||||
)
|
||||
|
||||
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
|
||||
|
||||
# Mock progress callback
|
||||
progress_callback = AsyncMock()
|
||||
|
||||
result = await create_embeddings_batch(
|
||||
["text1"], progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
assert result.success_count == 1
|
||||
|
||||
# Verify progress callback was called
|
||||
progress_callback.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_override(self, mock_llm_client, mock_threading_service):
|
||||
"""Test that provider override parameter is properly passed through"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock(embedding=[0.1] * 1536)]
|
||||
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model"
|
||||
) as mock_get_model:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "10"}
|
||||
)
|
||||
mock_get_model.return_value = "custom-model"
|
||||
|
||||
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
|
||||
|
||||
await create_embedding("test text", provider="custom-provider")
|
||||
|
||||
# Verify provider was passed to get_llm_client
|
||||
mock_get_client.assert_called_with(
|
||||
provider="custom-provider", use_embedding_provider=True
|
||||
)
|
||||
mock_get_model.assert_called_with(provider="custom-provider")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_embeddings_batch_large_batch_splitting(
|
||||
self, mock_llm_client, mock_threading_service
|
||||
):
|
||||
"""Test that large batches are properly split according to batch size settings"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1] * 1536) for _ in range(2)
|
||||
] # 2 embeddings per call
|
||||
mock_llm_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_threading_service",
|
||||
return_value=mock_threading_service,
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_get_client:
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service"
|
||||
) as mock_cred:
|
||||
# Set batch size to 2
|
||||
mock_cred.get_credentials_by_category = AsyncMock(
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "2"}
|
||||
)
|
||||
|
||||
mock_get_client.return_value = AsyncContextManager(mock_llm_client)
|
||||
|
||||
# Test with 5 texts (should require 3 API calls: 2+2+1)
|
||||
texts = ["text1", "text2", "text3", "text4", "text5"]
|
||||
result = await create_embeddings_batch(texts)
|
||||
|
||||
# Should have made 3 API calls due to batching
|
||||
assert mock_llm_client.embeddings.create.call_count == 3
|
||||
|
||||
# Result should be EmbeddingBatchResult
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
# Should have 5 embeddings total (for 5 input texts)
|
||||
# Even though mock returns 2 per call, we only process as many as we requested
|
||||
assert result.success_count == 5
|
||||
assert len(result.embeddings) == 5
|
||||
assert result.texts_processed == texts
|
||||
474
python/tests/test_async_llm_provider_service.py
Normal file
474
python/tests/test_async_llm_provider_service.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
Comprehensive Tests for Async LLM Provider Service
|
||||
|
||||
Tests all aspects of the async LLM provider service after sync function removal.
|
||||
Covers different providers (OpenAI, Ollama, Google) and error scenarios.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.services.llm_provider_service import (
|
||||
_get_cached_settings,
|
||||
_set_cached_settings,
|
||||
get_embedding_model,
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
|
||||
class AsyncContextManager:
|
||||
"""Helper class for properly mocking async context managers"""
|
||||
|
||||
def __init__(self, return_value):
|
||||
self.return_value = return_value
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.return_value
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class TestAsyncLLMProviderService:
|
||||
"""Test suite for async LLM provider service functions"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache(self):
|
||||
"""Clear cache before each test"""
|
||||
import src.server.services.llm_provider_service as llm_module
|
||||
|
||||
llm_module._settings_cache.clear()
|
||||
yield
|
||||
llm_module._settings_cache.clear()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credential_service(self):
|
||||
"""Mock credential service"""
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_active_provider = AsyncMock()
|
||||
mock_service.get_credentials_by_category = AsyncMock()
|
||||
mock_service._get_provider_api_key = AsyncMock()
|
||||
mock_service._get_provider_base_url = MagicMock()
|
||||
return mock_service
|
||||
|
||||
@pytest.fixture
|
||||
def openai_provider_config(self):
|
||||
"""Standard OpenAI provider config"""
|
||||
return {
|
||||
"provider": "openai",
|
||||
"api_key": "test-openai-key",
|
||||
"base_url": None,
|
||||
"chat_model": "gpt-4.1-nano",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_provider_config(self):
|
||||
"""Standard Ollama provider config"""
|
||||
return {
|
||||
"provider": "ollama",
|
||||
"api_key": "ollama",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"chat_model": "llama2",
|
||||
"embedding_model": "nomic-embed-text",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def google_provider_config(self):
|
||||
"""Standard Google provider config"""
|
||||
return {
|
||||
"provider": "google",
|
||||
"api_key": "test-google-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"chat_model": "gemini-pro",
|
||||
"embedding_model": "text-embedding-004",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_openai_success(
|
||||
self, mock_credential_service, openai_provider_config
|
||||
):
|
||||
"""Test successful OpenAI client creation"""
|
||||
mock_credential_service.get_active_provider.return_value = openai_provider_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client() as client:
|
||||
assert client == mock_client
|
||||
mock_openai.assert_called_once_with(api_key="test-openai-key")
|
||||
|
||||
# Verify provider config was fetched
|
||||
mock_credential_service.get_active_provider.assert_called_once_with("llm")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_ollama_success(
|
||||
self, mock_credential_service, ollama_provider_config
|
||||
):
|
||||
"""Test successful Ollama client creation"""
|
||||
mock_credential_service.get_active_provider.return_value = ollama_provider_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client() as client:
|
||||
assert client == mock_client
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="ollama", base_url="http://localhost:11434/v1"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_google_success(
|
||||
self, mock_credential_service, google_provider_config
|
||||
):
|
||||
"""Test successful Google client creation"""
|
||||
mock_credential_service.get_active_provider.return_value = google_provider_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client() as client:
|
||||
assert client == mock_client
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-google-key",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_with_provider_override(self, mock_credential_service):
|
||||
"""Test client creation with explicit provider override (OpenAI)"""
|
||||
mock_credential_service._get_provider_api_key.return_value = "override-key"
|
||||
mock_credential_service.get_credentials_by_category.return_value = {"LLM_BASE_URL": ""}
|
||||
mock_credential_service._get_provider_base_url.return_value = None
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client(provider="openai") as client:
|
||||
assert client == mock_client
|
||||
mock_openai.assert_called_once_with(api_key="override-key")
|
||||
|
||||
# Verify explicit provider API key was requested
|
||||
mock_credential_service._get_provider_api_key.assert_called_once_with("openai")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_use_embedding_provider(self, mock_credential_service):
|
||||
"""Test client creation with embedding provider preference"""
|
||||
embedding_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "embedding-key",
|
||||
"base_url": None,
|
||||
"chat_model": "gpt-4",
|
||||
"embedding_model": "text-embedding-3-large",
|
||||
}
|
||||
mock_credential_service.get_active_provider.return_value = embedding_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async with get_llm_client(use_embedding_provider=True) as client:
|
||||
assert client == mock_client
|
||||
mock_openai.assert_called_once_with(api_key="embedding-key")
|
||||
|
||||
# Verify embedding provider was requested
|
||||
mock_credential_service.get_active_provider.assert_called_once_with("embedding")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_missing_openai_key(self, mock_credential_service):
|
||||
"""Test error handling when OpenAI API key is missing"""
|
||||
config_without_key = {
|
||||
"provider": "openai",
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
"chat_model": "gpt-4",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
}
|
||||
mock_credential_service.get_active_provider.return_value = config_without_key
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with pytest.raises(ValueError, match="OpenAI API key not found"):
|
||||
async with get_llm_client():
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_missing_google_key(self, mock_credential_service):
|
||||
"""Test error handling when Google API key is missing"""
|
||||
config_without_key = {
|
||||
"provider": "google",
|
||||
"api_key": None,
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"chat_model": "gemini-pro",
|
||||
"embedding_model": "text-embedding-004",
|
||||
}
|
||||
mock_credential_service.get_active_provider.return_value = config_without_key
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with pytest.raises(ValueError, match="Google API key not found"):
|
||||
async with get_llm_client():
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_unsupported_provider_error(self, mock_credential_service):
|
||||
"""Test error when unsupported provider is configured"""
|
||||
unsupported_config = {
|
||||
"provider": "unsupported",
|
||||
"api_key": "some-key",
|
||||
"base_url": None,
|
||||
"chat_model": "some-model",
|
||||
"embedding_model": "",
|
||||
}
|
||||
mock_credential_service.get_active_provider.return_value = unsupported_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider: unsupported"):
|
||||
async with get_llm_client():
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_with_unsupported_provider_override(self, mock_credential_service):
|
||||
"""Test error when unsupported provider is explicitly requested"""
|
||||
mock_credential_service._get_provider_api_key.return_value = "some-key"
|
||||
mock_credential_service.get_credentials_by_category.return_value = {}
|
||||
mock_credential_service._get_provider_base_url.return_value = None
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider: custom-unsupported"):
|
||||
async with get_llm_client(provider="custom-unsupported"):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_model_openai_success(
|
||||
self, mock_credential_service, openai_provider_config
|
||||
):
|
||||
"""Test getting embedding model for OpenAI provider"""
|
||||
mock_credential_service.get_active_provider.return_value = openai_provider_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
model = await get_embedding_model()
|
||||
assert model == "text-embedding-3-small"
|
||||
|
||||
mock_credential_service.get_active_provider.assert_called_once_with("embedding")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_model_ollama_success(
|
||||
self, mock_credential_service, ollama_provider_config
|
||||
):
|
||||
"""Test getting embedding model for Ollama provider"""
|
||||
mock_credential_service.get_active_provider.return_value = ollama_provider_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
model = await get_embedding_model()
|
||||
assert model == "nomic-embed-text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_model_google_success(
|
||||
self, mock_credential_service, google_provider_config
|
||||
):
|
||||
"""Test getting embedding model for Google provider"""
|
||||
mock_credential_service.get_active_provider.return_value = google_provider_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
model = await get_embedding_model()
|
||||
assert model == "text-embedding-004"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_model_with_provider_override(self, mock_credential_service):
|
||||
"""Test getting embedding model with provider override"""
|
||||
rag_settings = {"EMBEDDING_MODEL": "custom-embedding-model"}
|
||||
mock_credential_service.get_credentials_by_category.return_value = rag_settings
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
model = await get_embedding_model(provider="custom-provider")
|
||||
assert model == "custom-embedding-model"
|
||||
|
||||
mock_credential_service.get_credentials_by_category.assert_called_once_with(
|
||||
"rag_strategy"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_model_custom_model_override(self, mock_credential_service):
|
||||
"""Test custom embedding model override"""
|
||||
config_with_custom = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"base_url": None,
|
||||
"chat_model": "gpt-4",
|
||||
"embedding_model": "text-embedding-custom-large",
|
||||
}
|
||||
mock_credential_service.get_active_provider.return_value = config_with_custom
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
model = await get_embedding_model()
|
||||
assert model == "text-embedding-custom-large"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_model_error_fallback(self, mock_credential_service):
|
||||
"""Test fallback when error occurs getting embedding model"""
|
||||
mock_credential_service.get_active_provider.side_effect = Exception("Database error")
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
model = await get_embedding_model()
|
||||
# Should fallback to OpenAI default
|
||||
assert model == "text-embedding-3-small"
|
||||
|
||||
def test_cache_functionality(self):
|
||||
"""Test settings cache functionality"""
|
||||
# Test setting and getting cache
|
||||
test_value = {"test": "data"}
|
||||
_set_cached_settings("test_key", test_value)
|
||||
|
||||
cached_result = _get_cached_settings("test_key")
|
||||
assert cached_result == test_value
|
||||
|
||||
# Test cache expiry (would require time manipulation in real test)
|
||||
# For now just test that non-existent key returns None
|
||||
assert _get_cached_settings("non_existent") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_usage_in_get_llm_client(
|
||||
self, mock_credential_service, openai_provider_config
|
||||
):
|
||||
"""Test that cache is used to avoid repeated credential service calls"""
|
||||
mock_credential_service.get_active_provider.return_value = openai_provider_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
# First call should hit the credential service
|
||||
async with get_llm_client():
|
||||
pass
|
||||
|
||||
# Second call should use cache
|
||||
async with get_llm_client():
|
||||
pass
|
||||
|
||||
# Should only call get_active_provider once due to caching
|
||||
assert mock_credential_service.get_active_provider.call_count == 1
|
||||
|
||||
def test_deprecated_functions_removed(self):
|
||||
"""Test that deprecated sync functions are no longer available"""
|
||||
import src.server.services.llm_provider_service as llm_module
|
||||
|
||||
# These functions should no longer exist
|
||||
assert not hasattr(llm_module, "get_llm_client_sync")
|
||||
assert not hasattr(llm_module, "get_embedding_model_sync")
|
||||
assert not hasattr(llm_module, "_get_active_provider_sync")
|
||||
|
||||
# The async versions should be the primary functions
|
||||
assert hasattr(llm_module, "get_llm_client")
|
||||
assert hasattr(llm_module, "get_embedding_model")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manager_cleanup(self, mock_credential_service, openai_provider_config):
|
||||
"""Test that async context manager properly handles cleanup"""
|
||||
mock_credential_service.get_active_provider.return_value = openai_provider_config
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
client_ref = None
|
||||
async with get_llm_client() as client:
|
||||
client_ref = client
|
||||
assert client == mock_client
|
||||
|
||||
# After context manager exits, should still have reference to client
|
||||
assert client_ref == mock_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_providers_in_sequence(self, mock_credential_service):
|
||||
"""Test creating clients for different providers in sequence"""
|
||||
configs = [
|
||||
{"provider": "openai", "api_key": "openai-key", "base_url": None},
|
||||
{"provider": "ollama", "api_key": "ollama", "base_url": "http://localhost:11434/v1"},
|
||||
{
|
||||
"provider": "google",
|
||||
"api_key": "google-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.credential_service", mock_credential_service
|
||||
):
|
||||
with patch(
|
||||
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
for config in configs:
|
||||
# Clear cache between tests to force fresh credential service calls
|
||||
import src.server.services.llm_provider_service as llm_module
|
||||
|
||||
llm_module._settings_cache.clear()
|
||||
|
||||
mock_credential_service.get_active_provider.return_value = config
|
||||
|
||||
async with get_llm_client() as client:
|
||||
assert client == mock_client
|
||||
|
||||
# Should have been called once for each provider
|
||||
assert mock_credential_service.get_active_provider.call_count == 3
|
||||
98
python/tests/test_business_logic.py
Normal file
98
python/tests/test_business_logic.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Business logic tests - Test core business rules and logic."""
|
||||
|
||||
|
||||
def test_task_status_transitions(client):
|
||||
"""Test task status update endpoint."""
|
||||
# Test status update endpoint exists
|
||||
response = client.patch("/api/tasks/test-id", json={"status": "doing"})
|
||||
assert response.status_code in [200, 400, 404, 405, 422, 500]
|
||||
|
||||
|
||||
def test_progress_calculation(client):
|
||||
"""Test project progress endpoint."""
|
||||
response = client.get("/api/projects/test-id/progress")
|
||||
assert response.status_code in [200, 404, 500]
|
||||
|
||||
|
||||
def test_rate_limiting(client):
|
||||
"""Test that API handles multiple requests gracefully."""
|
||||
# Make several requests
|
||||
for i in range(5):
|
||||
response = client.get("/api/projects")
|
||||
assert response.status_code in [200, 429, 500] # 500 is OK in test environment
|
||||
|
||||
|
||||
def test_data_validation(client):
|
||||
"""Test input validation on project creation."""
|
||||
# Empty title
|
||||
response = client.post("/api/projects", json={"title": ""})
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
# Missing required fields
|
||||
response = client.post("/api/projects", json={})
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
# Valid data
|
||||
response = client.post("/api/projects", json={"title": "Valid Project"})
|
||||
assert response.status_code in [200, 201, 422]
|
||||
|
||||
|
||||
def test_permission_checks(client):
|
||||
"""Test authentication on protected endpoints."""
|
||||
# Delete without auth
|
||||
response = client.delete("/api/projects/test-id")
|
||||
assert response.status_code in [200, 204, 401, 403, 404, 500]
|
||||
|
||||
|
||||
def test_crawl_depth_limits(client):
|
||||
"""Test crawl depth validation."""
|
||||
# Too deep
|
||||
response = client.post(
|
||||
"/api/knowledge/crawl", json={"url": "https://example.com", "max_depth": 100}
|
||||
)
|
||||
assert response.status_code in [200, 400, 404, 422]
|
||||
|
||||
# Valid depth
|
||||
response = client.post(
|
||||
"/api/knowledge/crawl", json={"url": "https://example.com", "max_depth": 2}
|
||||
)
|
||||
assert response.status_code in [200, 201, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_document_chunking(client):
|
||||
"""Test document chunking endpoint."""
|
||||
response = client.post(
|
||||
"/api/knowledge/documents/chunk", json={"content": "x" * 1000, "chunk_size": 500}
|
||||
)
|
||||
assert response.status_code in [200, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_embedding_generation(client):
|
||||
"""Test embedding generation endpoint."""
|
||||
response = client.post("/api/knowledge/embeddings", json={"texts": ["Test text for embedding"]})
|
||||
assert response.status_code in [200, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_source_management(client):
|
||||
"""Test knowledge source management."""
|
||||
# Create source
|
||||
response = client.post(
|
||||
"/api/knowledge/sources",
|
||||
json={"name": "Test Source", "url": "https://example.com", "type": "documentation"},
|
||||
)
|
||||
assert response.status_code in [200, 201, 400, 404, 422, 500]
|
||||
|
||||
# List sources
|
||||
response = client.get("/api/knowledge/sources")
|
||||
assert response.status_code in [200, 404, 500]
|
||||
|
||||
|
||||
def test_version_control(client):
|
||||
"""Test document versioning."""
|
||||
# Create version
|
||||
response = client.post("/api/documents/test-id/versions", json={"content": "Version 1 content"})
|
||||
assert response.status_code in [200, 201, 404, 422, 500]
|
||||
|
||||
# List versions
|
||||
response = client.get("/api/documents/test-id/versions")
|
||||
assert response.status_code in [200, 404, 500]
|
||||
476
python/tests/test_crawl_orchestration_isolated.py
Normal file
476
python/tests/test_crawl_orchestration_isolated.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
Isolated Tests for Async Crawl Orchestration Service
|
||||
|
||||
Tests core functionality without circular import dependencies.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class MockCrawlOrchestrationService:
|
||||
"""Mock version of CrawlOrchestrationService for isolated testing"""
|
||||
|
||||
def __init__(self, crawler=None, supabase_client=None, progress_id=None):
|
||||
self.crawler = crawler
|
||||
self.supabase_client = supabase_client
|
||||
self.progress_id = progress_id
|
||||
self.progress_state = {}
|
||||
self._cancelled = False
|
||||
|
||||
def cancel(self):
|
||||
self._cancelled = True
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
def _check_cancellation(self):
|
||||
if self._cancelled:
|
||||
raise Exception("CrawlCancelledException: Operation was cancelled")
|
||||
|
||||
def _is_documentation_site(self, url: str) -> bool:
|
||||
"""Simple documentation site detection"""
|
||||
doc_indicators = ["/docs/", "docs.", ".readthedocs.io", "/documentation/"]
|
||||
return any(indicator in url.lower() for indicator in doc_indicators)
|
||||
|
||||
async def _create_crawl_progress_callback(self, base_status: str):
|
||||
"""Create async progress callback"""
|
||||
|
||||
async def callback(status: str, percentage: int, message: str, **kwargs):
|
||||
if self.progress_id:
|
||||
self.progress_state.update({
|
||||
"status": status,
|
||||
"percentage": percentage,
|
||||
"log": message,
|
||||
})
|
||||
|
||||
return callback
|
||||
|
||||
async def _crawl_by_url_type(self, url: str, request: dict[str, Any]) -> tuple:
|
||||
"""Mock URL type detection and crawling"""
|
||||
# Mock different URL types
|
||||
if url.endswith(".txt"):
|
||||
return [{"url": url, "markdown": "Text content", "title": "Text File"}], "text_file"
|
||||
elif "sitemap" in url:
|
||||
return [
|
||||
{"url": f"{url}/page1", "markdown": "Page 1 content", "title": "Page 1"},
|
||||
{"url": f"{url}/page2", "markdown": "Page 2 content", "title": "Page 2"},
|
||||
], "sitemap"
|
||||
else:
|
||||
return [{"url": url, "markdown": "Web content", "title": "Web Page"}], "webpage"
|
||||
|
||||
async def _process_and_store_documents(
|
||||
self,
|
||||
crawl_results: list[dict],
|
||||
request: dict[str, Any],
|
||||
crawl_type: str,
|
||||
original_source_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Mock document processing and storage"""
|
||||
# Check for cancellation
|
||||
self._check_cancellation()
|
||||
|
||||
# Simulate chunking
|
||||
chunk_count = len(crawl_results) * 3 # Assume 3 chunks per document
|
||||
total_word_count = chunk_count * 50 # Assume 50 words per chunk
|
||||
|
||||
# Build url_to_full_document mapping
|
||||
url_to_full_document = {}
|
||||
for doc in crawl_results:
|
||||
url_to_full_document[doc["url"]] = doc.get("markdown", "")
|
||||
|
||||
return {
|
||||
"chunk_count": chunk_count,
|
||||
"total_word_count": total_word_count,
|
||||
"url_to_full_document": url_to_full_document,
|
||||
}
|
||||
|
||||
async def _extract_and_store_code_examples(
|
||||
self, crawl_results: list[dict], url_to_full_document: dict[str, str]
|
||||
) -> int:
|
||||
"""Mock code examples extraction"""
|
||||
# Count code blocks in markdown
|
||||
code_examples = 0
|
||||
for doc in crawl_results:
|
||||
content = doc.get("markdown", "")
|
||||
code_examples += content.count("```")
|
||||
return code_examples // 2 # Each code block has opening and closing
|
||||
|
||||
async def _async_orchestrate_crawl(
|
||||
self, request: dict[str, Any], task_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Mock async orchestration"""
|
||||
try:
|
||||
self._check_cancellation()
|
||||
|
||||
url = str(request.get("url", ""))
|
||||
|
||||
# Mock crawl by URL type
|
||||
crawl_results, crawl_type = await self._crawl_by_url_type(url, request)
|
||||
|
||||
self._check_cancellation()
|
||||
|
||||
if not crawl_results:
|
||||
raise ValueError("No content was crawled from the provided URL")
|
||||
|
||||
# Mock document processing
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
source_id = parsed_url.netloc or parsed_url.path
|
||||
|
||||
storage_results = await self._process_and_store_documents(
|
||||
crawl_results, request, crawl_type, source_id
|
||||
)
|
||||
|
||||
self._check_cancellation()
|
||||
|
||||
# Mock code extraction
|
||||
code_examples_count = 0
|
||||
if request.get("enable_code_extraction", False):
|
||||
code_examples_count = await self._extract_and_store_code_examples(
|
||||
crawl_results, storage_results.get("url_to_full_document", {})
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"crawl_type": crawl_type,
|
||||
"chunk_count": storage_results["chunk_count"],
|
||||
"total_word_count": storage_results["total_word_count"],
|
||||
"code_examples_stored": code_examples_count,
|
||||
"processed_pages": len(crawl_results),
|
||||
"total_pages": len(crawl_results),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "CrawlCancelledException" in error_msg:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"cancelled": True,
|
||||
"chunk_count": 0,
|
||||
"code_examples_stored": 0,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"cancelled": False,
|
||||
"chunk_count": 0,
|
||||
"code_examples_stored": 0,
|
||||
}
|
||||
|
||||
async def orchestrate_crawl(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Mock main orchestration entry point"""
|
||||
import uuid
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Start async orchestration task (would normally be background)
|
||||
result = await self._async_orchestrate_crawl(request, task_id)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "started" if result.get("success") else "failed",
|
||||
"message": f"Crawl operation for {request.get('url')}",
|
||||
"progress_id": self.progress_id,
|
||||
}
|
||||
|
||||
|
||||
class TestAsyncCrawlOrchestration:
|
||||
"""Test suite for async crawl orchestration behavior"""
|
||||
|
||||
@pytest.fixture
|
||||
def orchestration_service(self):
|
||||
"""Create mock orchestration service"""
|
||||
return MockCrawlOrchestrationService(
|
||||
crawler=MagicMock(), supabase_client=MagicMock(), progress_id="test-progress-123"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_request(self):
|
||||
"""Sample crawl request"""
|
||||
return {
|
||||
"url": "https://example.com/docs",
|
||||
"max_depth": 2,
|
||||
"knowledge_type": "technical",
|
||||
"tags": ["test"],
|
||||
"enable_code_extraction": True,
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_orchestrate_crawl_success(self, orchestration_service, sample_request):
|
||||
"""Test successful async orchestration"""
|
||||
result = await orchestration_service._async_orchestrate_crawl(sample_request, "task-123")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["crawl_type"] == "webpage"
|
||||
assert result["chunk_count"] > 0
|
||||
assert result["total_word_count"] > 0
|
||||
assert result["processed_pages"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_orchestrate_crawl_with_code_extraction(self, orchestration_service):
|
||||
"""Test orchestration with code extraction enabled"""
|
||||
request = {"url": "https://docs.example.com/api", "enable_code_extraction": True}
|
||||
|
||||
result = await orchestration_service._async_orchestrate_crawl(request, "task-456")
|
||||
|
||||
assert result["success"] is True
|
||||
assert "code_examples_stored" in result
|
||||
assert result["code_examples_stored"] >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawl_by_url_type_text_file(self, orchestration_service):
|
||||
"""Test text file URL type detection"""
|
||||
crawl_results, crawl_type = await orchestration_service._crawl_by_url_type(
|
||||
"https://example.com/readme.txt", {"max_depth": 1}
|
||||
)
|
||||
|
||||
assert crawl_type == "text_file"
|
||||
assert len(crawl_results) == 1
|
||||
assert crawl_results[0]["url"] == "https://example.com/readme.txt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawl_by_url_type_sitemap(self, orchestration_service):
|
||||
"""Test sitemap URL type detection"""
|
||||
crawl_results, crawl_type = await orchestration_service._crawl_by_url_type(
|
||||
"https://example.com/sitemap.xml", {"max_depth": 2}
|
||||
)
|
||||
|
||||
assert crawl_type == "sitemap"
|
||||
assert len(crawl_results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawl_by_url_type_regular_webpage(self, orchestration_service):
|
||||
"""Test regular webpage crawling"""
|
||||
crawl_results, crawl_type = await orchestration_service._crawl_by_url_type(
|
||||
"https://example.com/blog/post", {"max_depth": 1}
|
||||
)
|
||||
|
||||
assert crawl_type == "webpage"
|
||||
assert len(crawl_results) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_and_store_documents(self, orchestration_service):
|
||||
"""Test document processing and storage"""
|
||||
crawl_results = [
|
||||
{"url": "https://example.com/page1", "markdown": "Content 1", "title": "Page 1"},
|
||||
{"url": "https://example.com/page2", "markdown": "Content 2", "title": "Page 2"},
|
||||
]
|
||||
|
||||
request = {"knowledge_type": "technical", "tags": ["test"]}
|
||||
|
||||
result = await orchestration_service._process_and_store_documents(
|
||||
crawl_results, request, "webpage", "example.com"
|
||||
)
|
||||
|
||||
assert "chunk_count" in result
|
||||
assert "total_word_count" in result
|
||||
assert "url_to_full_document" in result
|
||||
assert result["chunk_count"] == 6 # 2 docs * 3 chunks each
|
||||
assert len(result["url_to_full_document"]) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_and_store_code_examples(self, orchestration_service):
|
||||
"""Test code examples extraction"""
|
||||
crawl_results = [
|
||||
{
|
||||
"url": "https://example.com/api",
|
||||
"markdown": '# API\n\n```python\ndef hello():\n return "world"\n```\n\n```javascript\nconsole.log("hello");\n```',
|
||||
"title": "API Docs",
|
||||
}
|
||||
]
|
||||
|
||||
url_to_full_document = {"https://example.com/api": crawl_results[0]["markdown"]}
|
||||
|
||||
result = await orchestration_service._extract_and_store_code_examples(
|
||||
crawl_results, url_to_full_document
|
||||
)
|
||||
|
||||
assert result == 2 # Two code blocks found
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_during_orchestration(self, orchestration_service, sample_request):
|
||||
"""Test cancellation handling"""
|
||||
# Cancel before starting
|
||||
orchestration_service.cancel()
|
||||
|
||||
result = await orchestration_service._async_orchestrate_crawl(sample_request, "task-cancel")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["cancelled"] is True
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_during_document_processing(self, orchestration_service):
|
||||
"""Test cancellation during document processing"""
|
||||
crawl_results = [{"url": "https://example.com", "markdown": "Content"}]
|
||||
request = {"knowledge_type": "technical"}
|
||||
|
||||
# Cancel during processing
|
||||
orchestration_service.cancel()
|
||||
|
||||
with pytest.raises(Exception, match="CrawlCancelledException"):
|
||||
await orchestration_service._process_and_store_documents(
|
||||
crawl_results, request, "webpage", "example.com"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_orchestration(self, orchestration_service):
|
||||
"""Test error handling during orchestration"""
|
||||
|
||||
# Override the method to raise an error
|
||||
async def failing_crawl_by_url_type(url, request):
|
||||
raise ValueError("Simulated crawl failure")
|
||||
|
||||
orchestration_service._crawl_by_url_type = failing_crawl_by_url_type
|
||||
|
||||
request = {"url": "https://example.com", "enable_code_extraction": False}
|
||||
|
||||
result = await orchestration_service._async_orchestrate_crawl(request, "task-error")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["cancelled"] is False
|
||||
assert "error" in result
|
||||
|
||||
def test_documentation_site_detection(self, orchestration_service):
|
||||
"""Test documentation site URL detection"""
|
||||
# Test documentation sites
|
||||
assert orchestration_service._is_documentation_site("https://docs.python.org")
|
||||
assert orchestration_service._is_documentation_site(
|
||||
"https://react.dev/docs/getting-started"
|
||||
)
|
||||
assert orchestration_service._is_documentation_site(
|
||||
"https://project.readthedocs.io/en/latest/"
|
||||
)
|
||||
assert orchestration_service._is_documentation_site("https://example.com/documentation/api")
|
||||
|
||||
# Test non-documentation sites
|
||||
assert not orchestration_service._is_documentation_site("https://github.com/user/repo")
|
||||
assert not orchestration_service._is_documentation_site("https://example.com/blog")
|
||||
assert not orchestration_service._is_documentation_site("https://news.example.com")
|
||||
|
||||
def test_cancellation_functionality(self, orchestration_service):
|
||||
"""Test cancellation state management"""
|
||||
# Initially not cancelled
|
||||
assert not orchestration_service.is_cancelled()
|
||||
|
||||
# Cancel and verify
|
||||
orchestration_service.cancel()
|
||||
assert orchestration_service.is_cancelled()
|
||||
|
||||
# Check cancellation raises exception
|
||||
with pytest.raises(Exception, match="CrawlCancelledException"):
|
||||
orchestration_service._check_cancellation()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_callback_creation(self, orchestration_service):
|
||||
"""Test progress callback functionality"""
|
||||
callback = await orchestration_service._create_crawl_progress_callback("crawling")
|
||||
|
||||
# Execute callback
|
||||
await callback("test_status", 50, "Test message")
|
||||
|
||||
# Verify progress state was updated
|
||||
assert orchestration_service.progress_state["status"] == "test_status"
|
||||
assert orchestration_service.progress_state["percentage"] == 50
|
||||
assert orchestration_service.progress_state["log"] == "Test message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_orchestrate_crawl_entry_point(self, orchestration_service, sample_request):
|
||||
"""Test main orchestration entry point"""
|
||||
result = await orchestration_service.orchestrate_crawl(sample_request)
|
||||
|
||||
assert "task_id" in result
|
||||
assert "status" in result
|
||||
assert "progress_id" in result
|
||||
assert result["progress_id"] == "test-progress-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_operations(self):
|
||||
"""Test multiple concurrent orchestrations"""
|
||||
service1 = MockCrawlOrchestrationService(progress_id="progress-1")
|
||||
service2 = MockCrawlOrchestrationService(progress_id="progress-2")
|
||||
|
||||
request1 = {"url": "https://site1.com", "enable_code_extraction": False}
|
||||
request2 = {"url": "https://site2.com", "enable_code_extraction": True}
|
||||
|
||||
# Run concurrently
|
||||
results = await asyncio.gather(
|
||||
service1._async_orchestrate_crawl(request1, "task-1"),
|
||||
service2._async_orchestrate_crawl(request2, "task-2"),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(result["success"] for result in results)
|
||||
assert results[0]["code_examples_stored"] == 0 # Code extraction disabled
|
||||
assert results[1]["code_examples_stored"] >= 0 # Code extraction enabled
|
||||
|
||||
|
||||
class TestAsyncBehaviors:
|
||||
"""Test async-specific behaviors and patterns"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_method_chaining(self):
|
||||
"""Test that async methods properly chain together"""
|
||||
service = MockCrawlOrchestrationService()
|
||||
|
||||
# This chain should complete without blocking
|
||||
crawl_results, crawl_type = await service._crawl_by_url_type(
|
||||
"https://example.com", {"max_depth": 1}
|
||||
)
|
||||
|
||||
storage_results = await service._process_and_store_documents(
|
||||
crawl_results, {"knowledge_type": "technical"}, crawl_type, "example.com"
|
||||
)
|
||||
|
||||
code_count = await service._extract_and_store_code_examples(
|
||||
crawl_results, storage_results["url_to_full_document"]
|
||||
)
|
||||
|
||||
# All operations should complete successfully
|
||||
assert crawl_type == "webpage"
|
||||
assert storage_results["chunk_count"] > 0
|
||||
assert code_count >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asyncio_cancellation_propagation(self):
|
||||
"""Test that asyncio cancellation properly propagates"""
|
||||
service = MockCrawlOrchestrationService()
|
||||
|
||||
async def long_running_operation():
|
||||
await asyncio.sleep(0.1) # Simulate work
|
||||
return await service._async_orchestrate_crawl(
|
||||
{"url": "https://example.com"}, "task-123"
|
||||
)
|
||||
|
||||
# Start task and cancel it
|
||||
task = asyncio.create_task(long_running_operation())
|
||||
await asyncio.sleep(0.01) # Let it start
|
||||
task.cancel()
|
||||
|
||||
# Should raise CancelledError
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_blocking_operations(self):
|
||||
"""Test that operations don't block the event loop"""
|
||||
service = MockCrawlOrchestrationService()
|
||||
|
||||
# Start multiple operations concurrently
|
||||
tasks = []
|
||||
for i in range(5):
|
||||
task = service._async_orchestrate_crawl({"url": f"https://example{i}.com"}, f"task-{i}")
|
||||
tasks.append(task)
|
||||
|
||||
# All should complete without blocking
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 5
|
||||
assert all(result["success"] for result in results)
|
||||
332
python/tests/test_embedding_service_no_zeros.py
Normal file
332
python/tests/test_embedding_service_no_zeros.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Tests for embedding service to ensure no zero embeddings are returned.
|
||||
|
||||
These tests verify that the embedding service raises appropriate exceptions
|
||||
instead of returning zero embeddings, following the "fail fast and loud" principle.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from src.server.services.embeddings.embedding_exceptions import (
|
||||
EmbeddingAPIError,
|
||||
EmbeddingQuotaExhaustedError,
|
||||
EmbeddingRateLimitError,
|
||||
)
|
||||
from src.server.services.embeddings.embedding_service import (
|
||||
EmbeddingBatchResult,
|
||||
create_embedding,
|
||||
create_embeddings_batch,
|
||||
)
|
||||
|
||||
|
||||
class TestNoZeroEmbeddings:
|
||||
"""Test that no zero embeddings are ever returned."""
|
||||
|
||||
# Note: Removed test_sync_from_async_context_raises_exception
|
||||
# as sync versions no longer exist - everything is async-only now
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_quota_exhausted_returns_failure(self) -> None:
|
||||
"""Test that quota exhaustion returns failure result instead of zeros."""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_client:
|
||||
# Mock the client to raise quota error
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
|
||||
"insufficient_quota: You have exceeded your quota", response=Mock(), body=None
|
||||
)
|
||||
mock_client.return_value = mock_ctx
|
||||
|
||||
# Single embedding still raises for backward compatibility
|
||||
with pytest.raises(EmbeddingQuotaExhaustedError) as exc_info:
|
||||
await create_embedding("test text")
|
||||
|
||||
assert "quota exhausted" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_rate_limit_raises_exception(self) -> None:
|
||||
"""Test that rate limit errors raise exception after retries."""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_client:
|
||||
# Mock the client to raise rate limit error
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
|
||||
"rate_limit_exceeded: Too many requests", response=Mock(), body=None
|
||||
)
|
||||
mock_client.return_value = mock_ctx
|
||||
|
||||
with pytest.raises(EmbeddingRateLimitError) as exc_info:
|
||||
await create_embedding("test text")
|
||||
|
||||
assert "rate limit" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_api_error_raises_exception(self) -> None:
|
||||
"""Test that API errors raise exception instead of returning zeros."""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_client:
|
||||
# Mock the client to raise generic error
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = Exception(
|
||||
"Network error"
|
||||
)
|
||||
mock_client.return_value = mock_ctx
|
||||
|
||||
with pytest.raises(EmbeddingAPIError) as exc_info:
|
||||
await create_embedding("test text")
|
||||
|
||||
assert "failed to create embedding" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_handles_partial_failures(self) -> None:
|
||||
"""Test that batch processing can handle partial failures gracefully."""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_client:
|
||||
# Mock successful response for first batch, failure for second
|
||||
mock_ctx = AsyncMock()
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1] * 1536), Mock(embedding=[0.2] * 1536)]
|
||||
|
||||
# First call succeeds, second fails
|
||||
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = [
|
||||
mock_response,
|
||||
Exception("API Error"),
|
||||
]
|
||||
mock_client.return_value = mock_ctx
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
new_callable=AsyncMock,
|
||||
return_value="text-embedding-ada-002",
|
||||
):
|
||||
# Mock credential service to return batch size of 2
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"EMBEDDING_BATCH_SIZE": "2"},
|
||||
):
|
||||
# Process 4 texts (batch size will be 2)
|
||||
texts = ["text1", "text2", "text3", "text4"]
|
||||
result = await create_embeddings_batch(texts)
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
assert result.success_count == 2 # First batch succeeded
|
||||
assert result.failure_count == 2 # Second batch failed
|
||||
assert len(result.embeddings) == 2
|
||||
assert len(result.failed_items) == 2
|
||||
|
||||
# Verify no zero embeddings were created
|
||||
for embedding in result.embeddings:
|
||||
assert not all(v == 0.0 for v in embedding)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configurable_embedding_dimensions(self) -> None:
|
||||
"""Test that embedding dimensions can be configured via settings."""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_client:
|
||||
# Mock successful response
|
||||
mock_ctx = AsyncMock()
|
||||
mock_create = AsyncMock()
|
||||
mock_ctx.__aenter__.return_value.embeddings.create = mock_create
|
||||
|
||||
# Setup mock response
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1] * 3072)] # Different dimensions
|
||||
mock_create.return_value = mock_response
|
||||
mock_client.return_value = mock_ctx
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
new_callable=AsyncMock,
|
||||
return_value="text-embedding-3-large",
|
||||
):
|
||||
# Mock credential service to return custom dimensions
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"EMBEDDING_DIMENSIONS": "3072"},
|
||||
):
|
||||
result = await create_embeddings_batch(["test text"])
|
||||
|
||||
# Verify the dimensions parameter was passed correctly
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
assert call_args.kwargs["dimensions"] == 3072
|
||||
|
||||
# Verify result
|
||||
assert result.success_count == 1
|
||||
assert len(result.embeddings[0]) == 3072
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_embedding_dimensions(self) -> None:
|
||||
"""Test that default dimensions (1536) are used when not configured."""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_client:
|
||||
# Mock successful response
|
||||
mock_ctx = AsyncMock()
|
||||
mock_create = AsyncMock()
|
||||
mock_ctx.__aenter__.return_value.embeddings.create = mock_create
|
||||
|
||||
# Setup mock response with default dimensions
|
||||
mock_response = Mock()
|
||||
mock_response.data = [Mock(embedding=[0.1] * 1536)]
|
||||
mock_create.return_value = mock_response
|
||||
mock_client.return_value = mock_ctx
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
new_callable=AsyncMock,
|
||||
return_value="text-embedding-3-small",
|
||||
):
|
||||
# Mock credential service to return empty settings (no dimensions specified)
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.credential_service.get_credentials_by_category",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
result = await create_embeddings_batch(["test text"])
|
||||
|
||||
# Verify the default dimensions parameter was used
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
assert call_args.kwargs["dimensions"] == 1536
|
||||
|
||||
# Verify result
|
||||
assert result.success_count == 1
|
||||
assert len(result.embeddings[0]) == 1536
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_quota_exhausted_stops_process(self) -> None:
|
||||
"""Test that quota exhaustion stops processing remaining batches."""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_client:
|
||||
# Mock quota exhaustion
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = openai.RateLimitError(
|
||||
"insufficient_quota: Quota exceeded", response=Mock(), body=None
|
||||
)
|
||||
mock_client.return_value = mock_ctx
|
||||
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_embedding_model",
|
||||
new_callable=AsyncMock,
|
||||
return_value="text-embedding-ada-002",
|
||||
):
|
||||
texts = ["text1", "text2", "text3", "text4"]
|
||||
result = await create_embeddings_batch(texts)
|
||||
|
||||
# All should fail due to quota
|
||||
assert result.success_count == 0
|
||||
assert result.failure_count == 4
|
||||
assert len(result.embeddings) == 0
|
||||
assert all("quota" in item["error"].lower() for item in result.failed_items)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_zero_vectors_in_results(self) -> None:
|
||||
"""Test that no function ever returns a zero vector [0.0] * 1536."""
|
||||
# This is a meta-test to ensure our implementation never creates zero vectors
|
||||
|
||||
# Helper to check if a value is a zero embedding
|
||||
def is_zero_embedding(value):
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
if len(value) != 1536:
|
||||
return False
|
||||
return all(v == 0.0 for v in value)
|
||||
|
||||
# Test data that should never produce zero embeddings
|
||||
test_text = "This is a test"
|
||||
|
||||
# Test: Batch function with error should return failure result, not zeros
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.get_llm_client"
|
||||
) as mock_client:
|
||||
# Mock the client to raise an error
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__.return_value.embeddings.create.side_effect = Exception("Test error")
|
||||
mock_client.return_value = mock_ctx
|
||||
|
||||
result = await create_embeddings_batch([test_text])
|
||||
# Should return result with failures, not zeros
|
||||
assert isinstance(result, EmbeddingBatchResult)
|
||||
assert len(result.embeddings) == 0
|
||||
assert result.failure_count == 1
|
||||
# Verify no zero embeddings in the result
|
||||
for embedding in result.embeddings:
|
||||
assert not is_zero_embedding(embedding)
|
||||
|
||||
|
||||
class TestEmbeddingBatchResult:
|
||||
"""Test the EmbeddingBatchResult dataclass."""
|
||||
|
||||
def test_batch_result_initialization(self) -> None:
|
||||
"""Test that EmbeddingBatchResult initializes correctly."""
|
||||
result = EmbeddingBatchResult()
|
||||
assert result.success_count == 0
|
||||
assert result.failure_count == 0
|
||||
assert result.embeddings == []
|
||||
assert result.failed_items == []
|
||||
assert not result.has_failures
|
||||
|
||||
def test_batch_result_add_success(self) -> None:
|
||||
"""Test adding successful embeddings."""
|
||||
result = EmbeddingBatchResult()
|
||||
embedding = [0.1] * 1536
|
||||
text = "test text"
|
||||
|
||||
result.add_success(embedding, text)
|
||||
|
||||
assert result.success_count == 1
|
||||
assert result.failure_count == 0
|
||||
assert len(result.embeddings) == 1
|
||||
assert result.embeddings[0] == embedding
|
||||
assert result.texts_processed[0] == text
|
||||
assert not result.has_failures
|
||||
|
||||
def test_batch_result_add_failure(self) -> None:
|
||||
"""Test adding failed items."""
|
||||
result = EmbeddingBatchResult()
|
||||
error = EmbeddingAPIError("Test error", text_preview="test")
|
||||
|
||||
result.add_failure("test text", error, batch_index=0)
|
||||
|
||||
assert result.success_count == 0
|
||||
assert result.failure_count == 1
|
||||
assert len(result.failed_items) == 1
|
||||
assert result.has_failures
|
||||
|
||||
failed_item = result.failed_items[0]
|
||||
assert failed_item["error"] == "Test error"
|
||||
assert failed_item["error_type"] == "EmbeddingAPIError"
|
||||
# batch_index comes from the error's to_dict() method which includes it
|
||||
assert "batch_index" in failed_item # Just check it exists
|
||||
|
||||
def test_batch_result_mixed_results(self) -> None:
|
||||
"""Test batch result with both successes and failures."""
|
||||
result = EmbeddingBatchResult()
|
||||
|
||||
# Add successes
|
||||
result.add_success([0.1] * 1536, "text1")
|
||||
result.add_success([0.2] * 1536, "text2")
|
||||
|
||||
# Add failures
|
||||
result.add_failure("text3", Exception("Error 1"), 1)
|
||||
result.add_failure("text4", Exception("Error 2"), 1)
|
||||
|
||||
assert result.success_count == 2
|
||||
assert result.failure_count == 2
|
||||
assert result.total_requested == 4
|
||||
assert result.has_failures
|
||||
assert len(result.embeddings) == 2
|
||||
assert len(result.failed_items) == 2
|
||||
213
python/tests/test_keyword_extraction.py
Normal file
213
python/tests/test_keyword_extraction.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
Tests for keyword extraction and improved hybrid search
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.services.search.keyword_extractor import (
|
||||
KeywordExtractor,
|
||||
build_search_terms,
|
||||
extract_keywords,
|
||||
)
|
||||
|
||||
|
||||
class TestKeywordExtractor:
|
||||
"""Test keyword extraction functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return KeywordExtractor()
|
||||
|
||||
def test_simple_keyword_extraction(self, extractor):
|
||||
"""Test extraction from simple queries"""
|
||||
query = "Supabase authentication"
|
||||
keywords = extractor.extract_keywords(query)
|
||||
|
||||
assert "supabase" in keywords
|
||||
assert "authentication" in keywords
|
||||
assert len(keywords) >= 2
|
||||
|
||||
def test_complex_query_extraction(self, extractor):
|
||||
"""Test extraction from complex queries"""
|
||||
query = "Supabase auth flow best practices"
|
||||
keywords = extractor.extract_keywords(query)
|
||||
|
||||
assert "supabase" in keywords
|
||||
assert "auth" in keywords
|
||||
assert "flow" in keywords
|
||||
assert "best_practices" in keywords or "practices" in keywords
|
||||
|
||||
def test_stop_word_filtering(self, extractor):
|
||||
"""Test that stop words are filtered out"""
|
||||
query = "How to use the React component with the database"
|
||||
keywords = extractor.extract_keywords(query)
|
||||
|
||||
# Stop words should be filtered
|
||||
assert "how" not in keywords
|
||||
assert "to" not in keywords
|
||||
assert "the" not in keywords
|
||||
assert "with" not in keywords
|
||||
|
||||
# Technical terms should remain
|
||||
assert "react" in keywords
|
||||
assert "component" in keywords
|
||||
assert "database" in keywords
|
||||
|
||||
def test_technical_terms_preserved(self, extractor):
|
||||
"""Test that technical terms are preserved"""
|
||||
query = "PostgreSQL full-text search with Python API"
|
||||
keywords = extractor.extract_keywords(query)
|
||||
|
||||
assert "postgresql" in keywords or "postgres" in keywords
|
||||
assert "python" in keywords
|
||||
assert "api" in keywords
|
||||
|
||||
def test_compound_terms(self, extractor):
|
||||
"""Test compound term detection"""
|
||||
query = "best practices for real-time websocket connections"
|
||||
keywords = extractor.extract_keywords(query)
|
||||
|
||||
# Should detect compound terms
|
||||
assert "best_practices" in keywords
|
||||
assert "realtime" in keywords or "real-time" in keywords
|
||||
assert "websocket" in keywords
|
||||
|
||||
def test_empty_query(self, extractor):
|
||||
"""Test handling of empty query"""
|
||||
keywords = extractor.extract_keywords("")
|
||||
assert keywords == []
|
||||
|
||||
def test_query_with_only_stopwords(self, extractor):
|
||||
"""Test query with only stop words"""
|
||||
query = "the and with for in"
|
||||
keywords = extractor.extract_keywords(query)
|
||||
assert keywords == []
|
||||
|
||||
def test_keyword_prioritization(self, extractor):
|
||||
"""Test that keywords are properly prioritized"""
|
||||
query = "Python Python Django REST API framework Python"
|
||||
keywords = extractor.extract_keywords(query)
|
||||
|
||||
# Python appears 3 times, should be prioritized
|
||||
assert keywords[0] == "python"
|
||||
|
||||
# Technical terms should be high priority
|
||||
assert "django" in keywords[:3]
|
||||
assert "api" in keywords[:5] # API should be in top 5
|
||||
|
||||
def test_max_keywords_limit(self, extractor):
|
||||
"""Test that max_keywords parameter is respected"""
|
||||
query = "Python Django Flask FastAPI React Vue Angular TypeScript JavaScript HTML CSS"
|
||||
keywords = extractor.extract_keywords(query, max_keywords=5)
|
||||
|
||||
assert len(keywords) <= 5
|
||||
# Most important terms should be included
|
||||
assert "python" in keywords
|
||||
assert "django" in keywords
|
||||
|
||||
def test_min_length_filtering(self, extractor):
|
||||
"""Test minimum length filtering"""
|
||||
query = "a b c API JWT DB SQL"
|
||||
keywords = extractor.extract_keywords(query, min_length=3)
|
||||
|
||||
# Single letters should be filtered
|
||||
assert "a" not in keywords
|
||||
assert "b" not in keywords
|
||||
assert "c" not in keywords
|
||||
|
||||
# 3+ letter terms should remain
|
||||
assert "api" in keywords
|
||||
assert "jwt" in keywords
|
||||
assert "sql" in keywords
|
||||
|
||||
|
||||
class TestSearchTermBuilder:
|
||||
"""Test search term building with variations"""
|
||||
|
||||
def test_plural_variations(self):
|
||||
"""Test plural/singular variations"""
|
||||
keywords = ["functions", "class", "error"]
|
||||
terms = build_search_terms(keywords)
|
||||
|
||||
# Should include singular of "functions"
|
||||
assert "function" in terms
|
||||
# Should include plural of "class"
|
||||
assert "classes" in terms
|
||||
# Should include plural of "error"
|
||||
assert "errors" in terms
|
||||
|
||||
def test_verb_variations(self):
|
||||
"""Test verb form variations"""
|
||||
keywords = ["creating", "updated", "testing"]
|
||||
terms = build_search_terms(keywords)
|
||||
|
||||
# Should generate base forms
|
||||
assert "create" in terms or "creat" in terms
|
||||
assert "update" in terms or "updat" in terms
|
||||
assert "test" in terms
|
||||
|
||||
def test_no_duplicates(self):
|
||||
"""Test that duplicates are removed"""
|
||||
keywords = ["test", "tests", "testing"]
|
||||
terms = build_search_terms(keywords)
|
||||
|
||||
# Should have unique terms only
|
||||
assert len(terms) == len(set(terms))
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for keyword extraction in search context"""
|
||||
|
||||
def test_real_world_query_1(self):
|
||||
"""Test with real-world query example 1"""
|
||||
query = "How to implement JWT authentication in FastAPI with Supabase"
|
||||
keywords = extract_keywords(query)
|
||||
|
||||
# Should extract the key technical terms
|
||||
assert "jwt" in keywords
|
||||
assert "authentication" in keywords
|
||||
assert "fastapi" in keywords
|
||||
assert "supabase" in keywords
|
||||
|
||||
# Should not include generic words (implement is now filtered as technical stop word)
|
||||
assert "how" not in keywords
|
||||
assert "to" not in keywords
|
||||
|
||||
def test_real_world_query_2(self):
|
||||
"""Test with real-world query example 2"""
|
||||
query = "PostgreSQL full text search vs Elasticsearch performance comparison"
|
||||
keywords = extract_keywords(query)
|
||||
|
||||
assert "postgresql" in keywords or "postgres" in keywords
|
||||
assert "elasticsearch" in keywords
|
||||
assert "performance" in keywords
|
||||
assert "comparison" in keywords
|
||||
|
||||
# Should handle "full text" as compound or separate
|
||||
assert "fulltext" in keywords or ("full" in keywords and "text" in keywords)
|
||||
|
||||
def test_real_world_query_3(self):
|
||||
"""Test with real-world query example 3"""
|
||||
query = "debugging async await issues in Node.js Express middleware"
|
||||
keywords = extract_keywords(query)
|
||||
|
||||
assert "debugging" in keywords or "debug" in keywords
|
||||
assert "async" in keywords
|
||||
assert "await" in keywords
|
||||
assert "express" in keywords
|
||||
assert "middleware" in keywords
|
||||
|
||||
# Node.js might be split
|
||||
assert "nodejs" in keywords or "node" in keywords
|
||||
|
||||
def test_code_related_query(self):
|
||||
"""Test with code-related query"""
|
||||
query = "TypeError cannot read property undefined JavaScript React hooks"
|
||||
keywords = extract_keywords(query)
|
||||
|
||||
assert "typeerror" in keywords or "type" in keywords
|
||||
assert "property" in keywords
|
||||
assert "undefined" in keywords
|
||||
assert "javascript" in keywords
|
||||
assert "react" in keywords
|
||||
assert "hooks" in keywords
|
||||
216
python/tests/test_port_configuration.py
Normal file
216
python/tests/test_port_configuration.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Tests for port configuration requirements.
|
||||
|
||||
This test file verifies that all services properly require environment variables
|
||||
for port configuration and fail with clear error messages when not set.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPortConfiguration:
|
||||
"""Test that services require port environment variables."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Save original environment variables before each test."""
|
||||
self.original_env = os.environ.copy()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original environment variables after each test."""
|
||||
os.environ.clear()
|
||||
os.environ.update(self.original_env)
|
||||
|
||||
def test_service_discovery_requires_all_ports(self):
|
||||
"""Test that ServiceDiscovery requires all port environment variables."""
|
||||
# Clear port environment variables
|
||||
for key in ["ARCHON_SERVER_PORT", "ARCHON_MCP_PORT", "ARCHON_AGENTS_PORT"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Import should fail without environment variables
|
||||
with pytest.raises(ValueError, match="ARCHON_SERVER_PORT environment variable is required"):
|
||||
from src.server.config.service_discovery import ServiceDiscovery
|
||||
|
||||
ServiceDiscovery()
|
||||
|
||||
def test_service_discovery_requires_mcp_port(self):
|
||||
"""Test that ServiceDiscovery requires MCP port."""
|
||||
os.environ["ARCHON_SERVER_PORT"] = "8181"
|
||||
os.environ.pop("ARCHON_MCP_PORT", None)
|
||||
os.environ["ARCHON_AGENTS_PORT"] = "8052"
|
||||
|
||||
with pytest.raises(ValueError, match="ARCHON_MCP_PORT environment variable is required"):
|
||||
from src.server.config.service_discovery import ServiceDiscovery
|
||||
|
||||
ServiceDiscovery()
|
||||
|
||||
def test_service_discovery_requires_agents_port(self):
|
||||
"""Test that ServiceDiscovery requires agents port."""
|
||||
os.environ["ARCHON_SERVER_PORT"] = "8181"
|
||||
os.environ["ARCHON_MCP_PORT"] = "8051"
|
||||
os.environ.pop("ARCHON_AGENTS_PORT", None)
|
||||
|
||||
with pytest.raises(ValueError, match="ARCHON_AGENTS_PORT environment variable is required"):
|
||||
from src.server.config.service_discovery import ServiceDiscovery
|
||||
|
||||
ServiceDiscovery()
|
||||
|
||||
def test_service_discovery_with_all_ports(self):
|
||||
"""Test that ServiceDiscovery works with all ports set."""
|
||||
os.environ["ARCHON_SERVER_PORT"] = "9191"
|
||||
os.environ["ARCHON_MCP_PORT"] = "9051"
|
||||
os.environ["ARCHON_AGENTS_PORT"] = "9052"
|
||||
|
||||
from src.server.config.service_discovery import ServiceDiscovery
|
||||
|
||||
sd = ServiceDiscovery()
|
||||
|
||||
assert sd.DEFAULT_PORTS["api"] == 9191
|
||||
assert sd.DEFAULT_PORTS["mcp"] == 9051
|
||||
assert sd.DEFAULT_PORTS["agents"] == 9052
|
||||
|
||||
def test_mcp_server_requires_port(self):
|
||||
"""Test that MCP server requires ARCHON_MCP_PORT."""
|
||||
os.environ.pop("ARCHON_MCP_PORT", None)
|
||||
|
||||
# We can't directly import mcp_server.py as it will raise at module level
|
||||
# So we test the specific logic
|
||||
with pytest.raises(ValueError, match="ARCHON_MCP_PORT environment variable is required"):
|
||||
mcp_port = os.getenv("ARCHON_MCP_PORT")
|
||||
if not mcp_port:
|
||||
raise ValueError(
|
||||
"ARCHON_MCP_PORT environment variable is required. "
|
||||
"Please set it in your .env file or environment. "
|
||||
"Default value: 8051"
|
||||
)
|
||||
|
||||
def test_main_server_requires_port(self):
|
||||
"""Test that main server requires ARCHON_SERVER_PORT when run directly."""
|
||||
os.environ.pop("ARCHON_SERVER_PORT", None)
|
||||
|
||||
# Test the logic that would be in main.py
|
||||
with pytest.raises(ValueError, match="ARCHON_SERVER_PORT environment variable is required"):
|
||||
server_port = os.getenv("ARCHON_SERVER_PORT")
|
||||
if not server_port:
|
||||
raise ValueError(
|
||||
"ARCHON_SERVER_PORT environment variable is required. "
|
||||
"Please set it in your .env file or environment. "
|
||||
"Default value: 8181"
|
||||
)
|
||||
|
||||
def test_agents_server_requires_port(self):
|
||||
"""Test that agents server requires ARCHON_AGENTS_PORT."""
|
||||
os.environ.pop("ARCHON_AGENTS_PORT", None)
|
||||
|
||||
# Test the logic that would be in agents/server.py
|
||||
with pytest.raises(ValueError, match="ARCHON_AGENTS_PORT environment variable is required"):
|
||||
agents_port = os.getenv("ARCHON_AGENTS_PORT")
|
||||
if not agents_port:
|
||||
raise ValueError(
|
||||
"ARCHON_AGENTS_PORT environment variable is required. "
|
||||
"Please set it in your .env file or environment. "
|
||||
"Default value: 8052"
|
||||
)
|
||||
|
||||
def test_agent_chat_api_requires_agents_port(self):
|
||||
"""Test that agent_chat_api requires ARCHON_AGENTS_PORT for service calls."""
|
||||
os.environ.pop("ARCHON_AGENTS_PORT", None)
|
||||
|
||||
# Test the logic that would be in agent_chat_api
|
||||
with pytest.raises(ValueError, match="ARCHON_AGENTS_PORT environment variable is required"):
|
||||
agents_port = os.getenv("ARCHON_AGENTS_PORT")
|
||||
if not agents_port:
|
||||
raise ValueError(
|
||||
"ARCHON_AGENTS_PORT environment variable is required. "
|
||||
"Please set it in your .env file or environment."
|
||||
)
|
||||
|
||||
def test_config_requires_port_or_archon_mcp_port(self):
|
||||
"""Test that config.py requires PORT or ARCHON_MCP_PORT."""
|
||||
from src.server.config.config import ConfigurationError
|
||||
|
||||
os.environ.pop("PORT", None)
|
||||
os.environ.pop("ARCHON_MCP_PORT", None)
|
||||
|
||||
# Test the logic from config.py
|
||||
with pytest.raises(
|
||||
ConfigurationError, match="PORT or ARCHON_MCP_PORT environment variable is required"
|
||||
):
|
||||
port_str = os.getenv("PORT")
|
||||
if not port_str:
|
||||
port_str = os.getenv("ARCHON_MCP_PORT")
|
||||
if not port_str:
|
||||
raise ConfigurationError(
|
||||
"PORT or ARCHON_MCP_PORT environment variable is required. "
|
||||
"Please set it in your .env file or environment. "
|
||||
"Default value: 8051"
|
||||
)
|
||||
|
||||
def test_custom_port_values(self):
|
||||
"""Test that services use custom port values when set."""
|
||||
# Set custom ports
|
||||
os.environ["ARCHON_SERVER_PORT"] = "9999"
|
||||
os.environ["ARCHON_MCP_PORT"] = "8888"
|
||||
os.environ["ARCHON_AGENTS_PORT"] = "7777"
|
||||
|
||||
from src.server.config.service_discovery import ServiceDiscovery
|
||||
|
||||
sd = ServiceDiscovery()
|
||||
|
||||
# Verify custom ports are used
|
||||
assert sd.DEFAULT_PORTS["api"] == 9999
|
||||
assert sd.DEFAULT_PORTS["mcp"] == 8888
|
||||
assert sd.DEFAULT_PORTS["agents"] == 7777
|
||||
|
||||
# Verify service URLs use custom ports
|
||||
if not sd.is_docker:
|
||||
assert sd.get_service_url("api") == "http://localhost:9999"
|
||||
assert sd.get_service_url("mcp") == "http://localhost:8888"
|
||||
assert sd.get_service_url("agents") == "http://localhost:7777"
|
||||
|
||||
|
||||
class TestPortValidation:
|
||||
"""Test port validation logic."""
|
||||
|
||||
def test_invalid_port_values(self):
|
||||
"""Test that invalid port values are rejected."""
|
||||
os.environ["ARCHON_SERVER_PORT"] = "not-a-number"
|
||||
os.environ["ARCHON_MCP_PORT"] = "8051"
|
||||
os.environ["ARCHON_AGENTS_PORT"] = "8052"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
from src.server.config.service_discovery import ServiceDiscovery
|
||||
|
||||
ServiceDiscovery()
|
||||
|
||||
def test_port_out_of_range(self):
|
||||
"""Test that port values must be valid port numbers."""
|
||||
test_cases = [
|
||||
("0", False), # Port 0 is reserved
|
||||
("1", True), # Valid
|
||||
("65535", True), # Maximum valid port
|
||||
("65536", False), # Too high
|
||||
("-1", False), # Negative
|
||||
]
|
||||
|
||||
for port_value, should_succeed in test_cases:
|
||||
os.environ["ARCHON_SERVER_PORT"] = port_value
|
||||
os.environ["ARCHON_MCP_PORT"] = "8051"
|
||||
os.environ["ARCHON_AGENTS_PORT"] = "8052"
|
||||
|
||||
if should_succeed:
|
||||
# Should not raise
|
||||
from src.server.config.service_discovery import ServiceDiscovery
|
||||
|
||||
sd = ServiceDiscovery()
|
||||
assert sd.DEFAULT_PORTS["api"] == int(port_value)
|
||||
else:
|
||||
# Should raise for invalid ports
|
||||
with pytest.raises((ValueError, AssertionError)):
|
||||
from src.server.config.service_discovery import ServiceDiscovery
|
||||
|
||||
sd = ServiceDiscovery()
|
||||
# Additional validation might be needed
|
||||
port = int(port_value)
|
||||
assert 1 <= port <= 65535, f"Port {port} out of valid range"
|
||||
436
python/tests/test_rag_simple.py
Normal file
436
python/tests/test_rag_simple.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
Simple, Fast RAG Tests
|
||||
|
||||
Focused tests that avoid complex initialization and database calls.
|
||||
These tests verify the core RAG functionality without heavy dependencies.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Set test environment variables
|
||||
os.environ.update({
|
||||
"SUPABASE_URL": "http://test.supabase.co",
|
||||
"SUPABASE_SERVICE_KEY": "test_key",
|
||||
"OPENAI_API_KEY": "test_openai_key",
|
||||
"USE_HYBRID_SEARCH": "false",
|
||||
"USE_RERANKING": "false",
|
||||
"USE_AGENTIC_RAG": "false",
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase():
|
||||
"""Mock supabase client"""
|
||||
client = MagicMock()
|
||||
client.rpc.return_value.execute.return_value.data = []
|
||||
client.from_.return_value.select.return_value.limit.return_value.execute.return_value.data = []
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_service(mock_supabase):
|
||||
"""Create RAGService with mocked dependencies"""
|
||||
with patch("src.server.utils.get_supabase_client", return_value=mock_supabase):
|
||||
with patch("src.server.services.credential_service.credential_service"):
|
||||
from src.server.services.search.rag_service import RAGService
|
||||
|
||||
service = RAGService(supabase_client=mock_supabase)
|
||||
return service
|
||||
|
||||
|
||||
class TestRAGServiceCore:
|
||||
"""Core RAGService functionality tests"""
|
||||
|
||||
def test_initialization(self, rag_service):
|
||||
"""Test RAGService initializes correctly"""
|
||||
assert rag_service is not None
|
||||
assert hasattr(rag_service, "search_documents")
|
||||
assert hasattr(rag_service, "search_code_examples")
|
||||
assert hasattr(rag_service, "perform_rag_query")
|
||||
|
||||
def test_settings_methods(self, rag_service):
|
||||
"""Test settings retrieval methods"""
|
||||
# Test string setting
|
||||
result = rag_service.get_setting("TEST_SETTING", "default")
|
||||
assert isinstance(result, str)
|
||||
|
||||
# Test boolean setting
|
||||
result = rag_service.get_bool_setting("TEST_BOOL", False)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
class TestRAGServiceSearch:
|
||||
"""Search functionality tests"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_vector_search(self, rag_service, mock_supabase):
|
||||
"""Test basic vector search functionality"""
|
||||
# Mock the RPC response
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
{
|
||||
"id": "1",
|
||||
"content": "Test content",
|
||||
"similarity": 0.8,
|
||||
"metadata": {},
|
||||
"url": "test.com",
|
||||
}
|
||||
]
|
||||
mock_supabase.rpc.return_value.execute.return_value = mock_response
|
||||
|
||||
# Test the search
|
||||
query_embedding = [0.1] * 1536
|
||||
results = await rag_service.base_strategy.vector_search(
|
||||
query_embedding=query_embedding, match_count=5
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "Test content"
|
||||
|
||||
# Verify RPC was called correctly
|
||||
mock_supabase.rpc.assert_called_once()
|
||||
call_args = mock_supabase.rpc.call_args[0]
|
||||
assert call_args[0] == "match_archon_crawled_pages"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_documents_with_embedding(self, rag_service):
|
||||
"""Test document search with mocked embedding"""
|
||||
# Patch at the module level where it's called from RAGService
|
||||
with (
|
||||
patch("src.server.services.search.rag_service.create_embedding") as mock_embed,
|
||||
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_search.return_value = [{"content": "Test result", "similarity": 0.9}]
|
||||
|
||||
# Test search
|
||||
results = await rag_service.search_documents(query="test query", match_count=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 1
|
||||
mock_embed.assert_called_once_with("test query")
|
||||
mock_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_rag_query_basic(self, rag_service):
|
||||
"""Test complete RAG query pipeline"""
|
||||
with patch.object(rag_service, "search_documents") as mock_search:
|
||||
mock_search.return_value = [
|
||||
{"id": "1", "content": "Test content", "similarity": 0.8, "metadata": {}}
|
||||
]
|
||||
|
||||
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
|
||||
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["content"] == "Test content"
|
||||
assert result["query"] == "test query"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_code_examples_delegation(self, rag_service):
|
||||
"""Test code examples search delegates to agentic strategy"""
|
||||
with patch.object(rag_service.agentic_strategy, "search_code_examples") as mock_agentic:
|
||||
mock_agentic.return_value = [
|
||||
{"content": "def test(): pass", "summary": "Test function", "url": "test.py"}
|
||||
]
|
||||
|
||||
results = await rag_service.search_code_examples(query="test function", match_count=10)
|
||||
|
||||
assert isinstance(results, list)
|
||||
mock_agentic.assert_called_once()
|
||||
|
||||
|
||||
class TestHybridSearchCore:
|
||||
"""Basic hybrid search tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def hybrid_strategy(self, mock_supabase):
|
||||
"""Create hybrid search strategy"""
|
||||
from src.server.services.search.base_search_strategy import BaseSearchStrategy
|
||||
from src.server.services.search.hybrid_search_strategy import HybridSearchStrategy
|
||||
|
||||
base_strategy = BaseSearchStrategy(mock_supabase)
|
||||
return HybridSearchStrategy(mock_supabase, base_strategy)
|
||||
|
||||
def test_initialization(self, hybrid_strategy):
|
||||
"""Test hybrid strategy initializes"""
|
||||
assert hybrid_strategy is not None
|
||||
assert hasattr(hybrid_strategy, "search_documents_hybrid")
|
||||
assert hasattr(hybrid_strategy, "_merge_search_results")
|
||||
|
||||
def test_merge_results_functionality(self, hybrid_strategy):
|
||||
"""Test result merging logic"""
|
||||
vector_results = [
|
||||
{
|
||||
"id": "1",
|
||||
"content": "Vector result",
|
||||
"similarity": 0.9,
|
||||
"url": "test1.com",
|
||||
"chunk_number": 1,
|
||||
"metadata": {},
|
||||
"source_id": "src1",
|
||||
}
|
||||
]
|
||||
keyword_results = [
|
||||
{
|
||||
"id": "2",
|
||||
"content": "Keyword result",
|
||||
"url": "test2.com",
|
||||
"chunk_number": 1,
|
||||
"metadata": {},
|
||||
"source_id": "src2",
|
||||
}
|
||||
]
|
||||
|
||||
merged = hybrid_strategy._merge_search_results(
|
||||
vector_results, keyword_results, match_count=5
|
||||
)
|
||||
|
||||
assert isinstance(merged, list)
|
||||
assert len(merged) <= 5
|
||||
|
||||
|
||||
class TestRerankingCore:
|
||||
"""Basic reranking tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def reranking_strategy(self):
|
||||
"""Create reranking strategy"""
|
||||
from src.server.services.search.reranking_strategy import RerankingStrategy
|
||||
|
||||
return RerankingStrategy()
|
||||
|
||||
def test_initialization(self, reranking_strategy):
|
||||
"""Test reranking strategy initializes"""
|
||||
assert reranking_strategy is not None
|
||||
assert hasattr(reranking_strategy, "rerank_results")
|
||||
assert hasattr(reranking_strategy, "is_available")
|
||||
|
||||
def test_availability_check(self, reranking_strategy):
|
||||
"""Test model availability checking"""
|
||||
availability = reranking_strategy.is_available()
|
||||
assert isinstance(availability, bool)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_with_no_model(self, reranking_strategy):
|
||||
"""Test reranking when no model is available"""
|
||||
# Force model to be None
|
||||
reranking_strategy.model = None
|
||||
|
||||
original_results = [{"content": "Test content", "score": 0.8}]
|
||||
|
||||
result = await reranking_strategy.rerank_results(
|
||||
query="test query", results=original_results
|
||||
)
|
||||
|
||||
# Should return original results when no model
|
||||
assert result == original_results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_with_mock_model(self, reranking_strategy):
|
||||
"""Test reranking with a mocked model"""
|
||||
# Create a mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict.return_value = [0.95, 0.85, 0.75] # Mock rerank scores
|
||||
reranking_strategy.model = mock_model
|
||||
|
||||
original_results = [
|
||||
{"content": "Content 1", "similarity": 0.8},
|
||||
{"content": "Content 2", "similarity": 0.7},
|
||||
{"content": "Content 3", "similarity": 0.9},
|
||||
]
|
||||
|
||||
result = await reranking_strategy.rerank_results(
|
||||
query="test query", results=original_results
|
||||
)
|
||||
|
||||
# Should return reranked results
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
|
||||
# Results should be sorted by rerank_score
|
||||
scores = [r.get("rerank_score", 0) for r in result]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
# Highest rerank score should be first
|
||||
assert result[0]["rerank_score"] == 0.95
|
||||
|
||||
|
||||
class TestAgenticRAGCore:
|
||||
"""Basic agentic RAG tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def agentic_strategy(self, mock_supabase):
|
||||
"""Create agentic RAG strategy"""
|
||||
from src.server.services.search.agentic_rag_strategy import AgenticRAGStrategy
|
||||
from src.server.services.search.base_search_strategy import BaseSearchStrategy
|
||||
|
||||
base_strategy = BaseSearchStrategy(mock_supabase)
|
||||
return AgenticRAGStrategy(mock_supabase, base_strategy)
|
||||
|
||||
def test_initialization(self, agentic_strategy):
|
||||
"""Test agentic strategy initializes"""
|
||||
assert agentic_strategy is not None
|
||||
assert hasattr(agentic_strategy, "search_code_examples")
|
||||
assert hasattr(agentic_strategy, "is_enabled")
|
||||
|
||||
def test_query_enhancement(self, agentic_strategy):
|
||||
"""Test code query enhancement"""
|
||||
original_query = "python function"
|
||||
analysis = agentic_strategy.analyze_code_query(original_query)
|
||||
|
||||
assert isinstance(analysis, dict)
|
||||
assert "is_code_query" in analysis
|
||||
assert "confidence" in analysis
|
||||
assert "languages" in analysis
|
||||
assert analysis["is_code_query"] is True
|
||||
assert "python" in analysis["languages"]
|
||||
|
||||
|
||||
class TestRAGIntegrationSimple:
|
||||
"""Simple integration tests"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling(self, rag_service):
|
||||
"""Test error handling in RAG pipeline"""
|
||||
with patch.object(rag_service, "search_documents") as mock_search:
|
||||
# Simulate an error
|
||||
mock_search.side_effect = Exception("Test error")
|
||||
|
||||
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
|
||||
|
||||
assert success is False
|
||||
assert "error" in result
|
||||
assert result["error"] == "Test error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_handling(self, rag_service):
|
||||
"""Test handling of empty search results"""
|
||||
with patch.object(rag_service, "search_documents") as mock_search:
|
||||
mock_search.return_value = []
|
||||
|
||||
success, result = await rag_service.perform_rag_query(
|
||||
query="empty query", match_count=5
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_rag_pipeline_with_reranking(self, rag_service, mock_supabase):
|
||||
"""Test complete RAG pipeline with reranking enabled"""
|
||||
# Create a mock reranking model
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict.return_value = [0.95, 0.85, 0.75]
|
||||
|
||||
# Initialize RAG service with reranking
|
||||
from src.server.services.search.reranking_strategy import RerankingStrategy
|
||||
|
||||
reranking_strategy = RerankingStrategy()
|
||||
reranking_strategy.model = mock_model
|
||||
rag_service.reranking_strategy = reranking_strategy
|
||||
|
||||
with (
|
||||
patch.object(rag_service, "search_documents") as mock_search,
|
||||
patch.object(rag_service, "get_bool_setting") as mock_settings,
|
||||
):
|
||||
# Enable reranking
|
||||
mock_settings.return_value = True
|
||||
|
||||
# Mock search results
|
||||
mock_search.return_value = [
|
||||
{"id": "1", "content": "Result 1", "similarity": 0.8, "metadata": {}},
|
||||
{"id": "2", "content": "Result 2", "similarity": 0.7, "metadata": {}},
|
||||
{"id": "3", "content": "Result 3", "similarity": 0.9, "metadata": {}},
|
||||
]
|
||||
|
||||
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
|
||||
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
# Verify reranking was applied
|
||||
assert result["reranking_applied"] is True
|
||||
|
||||
# Results should be sorted by rerank_score
|
||||
results = result["results"]
|
||||
rerank_scores = [r.get("rerank_score", 0) for r in results]
|
||||
assert rerank_scores == sorted(rerank_scores, reverse=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hybrid_search_integration(self, rag_service):
|
||||
"""Test RAG with hybrid search enabled"""
|
||||
with (
|
||||
patch("src.server.services.search.rag_service.create_embedding") as mock_embed,
|
||||
patch.object(rag_service.hybrid_strategy, "search_documents_hybrid") as mock_hybrid,
|
||||
patch.object(rag_service, "get_bool_setting") as mock_settings,
|
||||
):
|
||||
# Mock embedding and enable hybrid search
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_settings.return_value = True
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_hybrid.return_value = [
|
||||
{
|
||||
"id": "1",
|
||||
"content": "Hybrid result",
|
||||
"similarity": 0.9,
|
||||
"metadata": {},
|
||||
"match_type": "hybrid",
|
||||
}
|
||||
]
|
||||
|
||||
results = await rag_service.search_documents(
|
||||
query="test query", use_hybrid_search=True, match_count=5
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "Hybrid result"
|
||||
mock_hybrid.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_search_with_agentic_rag(self, rag_service):
|
||||
"""Test code search using agentic RAG"""
|
||||
with (
|
||||
patch.object(rag_service.agentic_strategy, "is_enabled") as mock_enabled,
|
||||
patch.object(rag_service.agentic_strategy, "search_code_examples") as mock_agentic,
|
||||
patch.object(rag_service, "get_bool_setting") as mock_settings,
|
||||
):
|
||||
# Enable agentic RAG
|
||||
mock_enabled.return_value = True
|
||||
mock_settings.return_value = False # Disable hybrid search for this test
|
||||
|
||||
# Mock agentic search results
|
||||
mock_agentic.return_value = [
|
||||
{
|
||||
"content": 'def example_function():\\n return "Hello"',
|
||||
"summary": "Example function that returns greeting",
|
||||
"url": "example.py",
|
||||
"metadata": {"language": "python"},
|
||||
}
|
||||
]
|
||||
|
||||
success, result = await rag_service.search_code_examples_service(
|
||||
query="python greeting function", match_count=10
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
code_result = result["results"][0]
|
||||
assert "def example_function" in code_result["code"]
|
||||
assert code_result["summary"] == "Example function that returns greeting"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
524
python/tests/test_rag_strategies.py
Normal file
524
python/tests/test_rag_strategies.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""
|
||||
Tests for RAG Strategies and Search Functionality
|
||||
|
||||
Tests RAGService class, hybrid search, agentic RAG, reranking, and other advanced RAG features.
|
||||
Updated to match current async-only architecture.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Mock problematic imports at module level
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"SUPABASE_URL": "http://test.supabase.co",
|
||||
"SUPABASE_SERVICE_KEY": "test_key",
|
||||
"OPENAI_API_KEY": "test_openai_key",
|
||||
},
|
||||
):
|
||||
# Mock credential service to prevent database calls
|
||||
with patch("src.server.services.credential_service.credential_service") as mock_cred:
|
||||
mock_cred._cache_initialized = False
|
||||
mock_cred.get_setting.return_value = "false"
|
||||
mock_cred.get_bool_setting.return_value = False
|
||||
|
||||
# Mock supabase client creation
|
||||
with patch("src.server.utils.get_supabase_client") as mock_supabase:
|
||||
mock_client = MagicMock()
|
||||
mock_supabase.return_value = mock_client
|
||||
|
||||
# Mock embedding service to prevent API calls
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.create_embedding"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
|
||||
# Test RAGService core functionality
|
||||
class TestRAGService:
|
||||
"""Test core RAGService functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase_client(self):
|
||||
"""Mock Supabase client"""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def rag_service(self, mock_supabase_client):
|
||||
"""Create RAGService instance"""
|
||||
from src.server.services.search import RAGService
|
||||
|
||||
return RAGService(supabase_client=mock_supabase_client)
|
||||
|
||||
def test_rag_service_initialization(self, rag_service):
|
||||
"""Test RAGService initializes correctly"""
|
||||
assert rag_service is not None
|
||||
assert hasattr(rag_service, "search_documents")
|
||||
assert hasattr(rag_service, "search_code_examples")
|
||||
assert hasattr(rag_service, "perform_rag_query")
|
||||
|
||||
def test_get_setting(self, rag_service):
|
||||
"""Test settings retrieval"""
|
||||
with patch.dict("os.environ", {"USE_HYBRID_SEARCH": "true"}):
|
||||
result = rag_service.get_setting("USE_HYBRID_SEARCH", "false")
|
||||
assert result == "true"
|
||||
|
||||
def test_get_bool_setting(self, rag_service):
|
||||
"""Test boolean settings retrieval"""
|
||||
with patch.dict("os.environ", {"USE_RERANKING": "true"}):
|
||||
result = rag_service.get_bool_setting("USE_RERANKING", False)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_code_examples(self, rag_service):
|
||||
"""Test code examples search"""
|
||||
with patch.object(
|
||||
rag_service.agentic_strategy, "search_code_examples"
|
||||
) as mock_agentic_search:
|
||||
# Mock agentic search results
|
||||
mock_agentic_search.return_value = [
|
||||
{
|
||||
"content": "def example():\n pass",
|
||||
"summary": "Python function example",
|
||||
"url": "test.py",
|
||||
"metadata": {"language": "python"},
|
||||
}
|
||||
]
|
||||
|
||||
result = await rag_service.search_code_examples(
|
||||
query="python function example", match_count=5
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
mock_agentic_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_rag_query(self, rag_service):
|
||||
"""Test complete RAG query flow"""
|
||||
# Create a mock reranking strategy if it doesn't exist
|
||||
if rag_service.reranking_strategy is None:
|
||||
from unittest.mock import Mock
|
||||
|
||||
rag_service.reranking_strategy = Mock()
|
||||
rag_service.reranking_strategy.rerank_results = AsyncMock()
|
||||
|
||||
with (
|
||||
patch.object(rag_service, "search_documents") as mock_search,
|
||||
patch.object(rag_service.reranking_strategy, "rerank_results") as mock_rerank,
|
||||
):
|
||||
mock_search.return_value = [{"content": "Relevant content", "score": 0.90}]
|
||||
mock_rerank.return_value = [{"content": "Relevant content", "score": 0.95}]
|
||||
|
||||
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
|
||||
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
assert isinstance(result["results"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_results(self, rag_service):
|
||||
"""Test result reranking via strategy"""
|
||||
from src.server.services.search import RerankingStrategy
|
||||
|
||||
# Create a mock reranking strategy
|
||||
mock_strategy = MagicMock(spec=RerankingStrategy)
|
||||
mock_strategy.rerank_results = AsyncMock(
|
||||
return_value=[{"content": "Reranked content", "score": 0.98}]
|
||||
)
|
||||
|
||||
# Assign the mock strategy to the service
|
||||
rag_service.reranking_strategy = mock_strategy
|
||||
|
||||
original_results = [{"content": "Original content", "score": 0.80}]
|
||||
|
||||
# Call the strategy directly (as the service now does internally)
|
||||
result = await rag_service.reranking_strategy.rerank_results(
|
||||
query="test query", results=original_results
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["content"] == "Reranked content"
|
||||
|
||||
|
||||
class TestHybridSearchStrategy:
|
||||
"""Test hybrid search strategy implementation"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase_client(self):
|
||||
"""Mock Supabase client"""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def hybrid_strategy(self, mock_supabase_client):
|
||||
"""Create HybridSearchStrategy instance"""
|
||||
from src.server.services.search import HybridSearchStrategy
|
||||
from src.server.services.search.base_search_strategy import BaseSearchStrategy
|
||||
|
||||
base_strategy = BaseSearchStrategy(mock_supabase_client)
|
||||
return HybridSearchStrategy(mock_supabase_client, base_strategy)
|
||||
|
||||
def test_hybrid_strategy_initialization(self, hybrid_strategy):
|
||||
"""Test HybridSearchStrategy initializes correctly"""
|
||||
assert hybrid_strategy is not None
|
||||
assert hasattr(hybrid_strategy, "search_documents_hybrid")
|
||||
assert hasattr(hybrid_strategy, "search_code_examples_hybrid")
|
||||
|
||||
def test_merge_search_results(self, hybrid_strategy):
|
||||
"""Test search result merging"""
|
||||
vector_results = [
|
||||
{
|
||||
"id": "1",
|
||||
"content": "Vector result 1",
|
||||
"score": 0.9,
|
||||
"url": "url1",
|
||||
"chunk_number": 1,
|
||||
"metadata": {},
|
||||
"source_id": "source1",
|
||||
"similarity": 0.9,
|
||||
}
|
||||
]
|
||||
keyword_results = [
|
||||
{
|
||||
"id": "2",
|
||||
"content": "Keyword result 1",
|
||||
"score": 0.8,
|
||||
"url": "url2",
|
||||
"chunk_number": 1,
|
||||
"metadata": {},
|
||||
"source_id": "source2",
|
||||
}
|
||||
]
|
||||
|
||||
merged = hybrid_strategy._merge_search_results(
|
||||
vector_results, keyword_results, match_count=5
|
||||
)
|
||||
|
||||
assert isinstance(merged, list)
|
||||
assert len(merged) <= 5
|
||||
# Should contain results from both sources
|
||||
if merged:
|
||||
assert any("Vector result" in str(r) or "Keyword result" in str(r) for r in merged)
|
||||
|
||||
|
||||
class TestRerankingStrategy:
|
||||
"""Test reranking strategy implementation"""
|
||||
|
||||
@pytest.fixture
|
||||
def reranking_strategy(self):
|
||||
"""Create RerankingStrategy instance"""
|
||||
from src.server.services.search import RerankingStrategy
|
||||
|
||||
return RerankingStrategy()
|
||||
|
||||
def test_reranking_strategy_initialization(self, reranking_strategy):
|
||||
"""Test RerankingStrategy initializes correctly"""
|
||||
assert reranking_strategy is not None
|
||||
assert hasattr(reranking_strategy, "rerank_results")
|
||||
assert hasattr(reranking_strategy, "is_available")
|
||||
|
||||
def test_model_availability_check(self, reranking_strategy):
|
||||
"""Test model availability checking"""
|
||||
# This should not crash even if model not available
|
||||
availability = reranking_strategy.is_available()
|
||||
assert isinstance(availability, bool)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_results_no_model(self, reranking_strategy):
|
||||
"""Test reranking when model not available"""
|
||||
with patch.object(reranking_strategy, "is_available") as mock_available:
|
||||
mock_available.return_value = False
|
||||
|
||||
original_results = [{"content": "Test content", "score": 0.8}]
|
||||
|
||||
result = await reranking_strategy.rerank_results(
|
||||
query="test query", results=original_results
|
||||
)
|
||||
|
||||
# Should return original results when model not available
|
||||
assert result == original_results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_results_with_model(self, reranking_strategy):
|
||||
"""Test reranking when model is available"""
|
||||
with (
|
||||
patch.object(reranking_strategy, "is_available") as mock_available,
|
||||
patch.object(reranking_strategy, "model") as mock_model,
|
||||
):
|
||||
mock_available.return_value = True
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.predict.return_value = [0.95, 0.85] # Mock scores
|
||||
mock_model = mock_model_instance
|
||||
reranking_strategy.model = mock_model_instance
|
||||
|
||||
original_results = [
|
||||
{"content": "Content 1", "score": 0.8},
|
||||
{"content": "Content 2", "score": 0.7},
|
||||
]
|
||||
|
||||
result = await reranking_strategy.rerank_results(
|
||||
query="test query", results=original_results
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) <= len(original_results)
|
||||
|
||||
|
||||
class TestAgenticRAGStrategy:
|
||||
"""Test agentic RAG strategy implementation"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase_client(self):
|
||||
"""Mock Supabase client"""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def agentic_strategy(self, mock_supabase_client):
|
||||
"""Create AgenticRAGStrategy instance"""
|
||||
from src.server.services.search import AgenticRAGStrategy
|
||||
from src.server.services.search.base_search_strategy import BaseSearchStrategy
|
||||
|
||||
base_strategy = BaseSearchStrategy(mock_supabase_client)
|
||||
return AgenticRAGStrategy(mock_supabase_client, base_strategy)
|
||||
|
||||
def test_agentic_strategy_initialization(self, agentic_strategy):
|
||||
"""Test AgenticRAGStrategy initializes correctly"""
|
||||
assert agentic_strategy is not None
|
||||
# Check for expected methods
|
||||
methods = dir(agentic_strategy)
|
||||
assert any("search" in method.lower() for method in methods)
|
||||
|
||||
|
||||
class TestRAGIntegration:
|
||||
"""Integration tests for RAG strategies working together"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase_client(self):
|
||||
"""Mock Supabase client"""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def rag_service(self, mock_supabase_client):
|
||||
"""Create RAGService instance"""
|
||||
from src.server.services.search import RAGService
|
||||
|
||||
return RAGService(supabase_client=mock_supabase_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_rag_pipeline(self, rag_service):
|
||||
"""Test complete RAG pipeline with all strategies"""
|
||||
# Create a mock reranking strategy if it doesn't exist
|
||||
if rag_service.reranking_strategy is None:
|
||||
from unittest.mock import Mock
|
||||
|
||||
rag_service.reranking_strategy = Mock()
|
||||
rag_service.reranking_strategy.rerank_results = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"src.server.services.embeddings.embedding_service.create_embedding"
|
||||
) as mock_embedding,
|
||||
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
|
||||
patch.object(rag_service, "get_bool_setting") as mock_settings,
|
||||
patch.object(rag_service.reranking_strategy, "rerank_results") as mock_rerank,
|
||||
):
|
||||
# Mock embedding creation
|
||||
mock_embedding.return_value = [0.1] * 1536
|
||||
|
||||
# Enable all strategies
|
||||
mock_settings.side_effect = lambda key, default: True
|
||||
|
||||
mock_search.return_value = [
|
||||
{"content": "Test result 1", "similarity": 0.9, "id": "1", "metadata": {}},
|
||||
{"content": "Test result 2", "similarity": 0.8, "id": "2", "metadata": {}},
|
||||
]
|
||||
|
||||
mock_rerank.return_value = [
|
||||
{"content": "Reranked result", "similarity": 0.95, "id": "1", "metadata": {}}
|
||||
]
|
||||
|
||||
success, result = await rag_service.perform_rag_query(
|
||||
query="complex technical query", match_count=10
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
assert isinstance(result["results"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_rag_pipeline(self, rag_service):
|
||||
"""Test error handling when strategies fail"""
|
||||
with patch(
|
||||
"src.server.services.embeddings.embedding_service.create_embedding"
|
||||
) as mock_embedding:
|
||||
# Simulate embedding failure (returns None)
|
||||
mock_embedding.return_value = None
|
||||
|
||||
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
|
||||
|
||||
# Should handle gracefully by returning empty results
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 0 # Empty results due to embedding failure
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_handling(self, rag_service):
|
||||
"""Test handling of empty search results"""
|
||||
with (
|
||||
patch(
|
||||
"src.server.services.embeddings.embedding_service.create_embedding"
|
||||
) as mock_embedding,
|
||||
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
|
||||
):
|
||||
# Mock embedding creation
|
||||
mock_embedding.return_value = [0.1] * 1536
|
||||
mock_search.return_value = [] # Empty results
|
||||
|
||||
success, result = await rag_service.perform_rag_query(
|
||||
query="nonexistent query", match_count=5
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 0
|
||||
|
||||
|
||||
class TestRAGPerformance:
|
||||
"""Test RAG performance and optimization features"""
|
||||
|
||||
@pytest.fixture
|
||||
def rag_service(self):
|
||||
"""Create RAGService instance"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from src.server.services.search import RAGService
|
||||
|
||||
mock_client = MagicMock()
|
||||
return RAGService(supabase_client=mock_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_rag_queries(self, rag_service):
|
||||
"""Test multiple concurrent RAG queries"""
|
||||
with (
|
||||
patch(
|
||||
"src.server.services.embeddings.embedding_service.create_embedding"
|
||||
) as mock_embedding,
|
||||
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
|
||||
):
|
||||
# Mock embedding creation
|
||||
mock_embedding.return_value = [0.1] * 1536
|
||||
|
||||
mock_search.return_value = [
|
||||
{
|
||||
"content": "Result for concurrent test",
|
||||
"similarity": 0.9,
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Run multiple queries concurrently
|
||||
queries = ["query 1", "query 2", "query 3"]
|
||||
tasks = [rag_service.perform_rag_query(query, match_count=3) for query in queries]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# All should complete successfully
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
if isinstance(result, tuple):
|
||||
success, data = result
|
||||
assert success is True or isinstance(data, dict)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_result_set_handling(self, rag_service):
|
||||
"""Test handling of large result sets"""
|
||||
with (
|
||||
patch(
|
||||
"src.server.services.embeddings.embedding_service.create_embedding"
|
||||
) as mock_embedding,
|
||||
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
|
||||
):
|
||||
# Mock embedding creation
|
||||
mock_embedding.return_value = [0.1] * 1536
|
||||
|
||||
# Create large result set, but limit to match_count
|
||||
large_results = [
|
||||
{
|
||||
"content": f"Result {i}",
|
||||
"similarity": 0.9 - (i * 0.01),
|
||||
"id": str(i),
|
||||
"metadata": {},
|
||||
}
|
||||
for i in range(50) # Only return up to match_count results
|
||||
]
|
||||
mock_search.return_value = large_results
|
||||
|
||||
success, result = await rag_service.perform_rag_query(
|
||||
query="large query", match_count=50
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert "results" in result
|
||||
# Should respect match_count limit
|
||||
assert len(result["results"]) <= 50
|
||||
|
||||
|
||||
class TestRAGConfiguration:
|
||||
"""Test RAG configuration and settings"""
|
||||
|
||||
@pytest.fixture
|
||||
def rag_service(self):
|
||||
"""Create RAGService instance"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from src.server.services.search import RAGService
|
||||
|
||||
mock_client = MagicMock()
|
||||
return RAGService(supabase_client=mock_client)
|
||||
|
||||
def test_environment_variable_settings(self, rag_service):
|
||||
"""Test reading settings from environment variables"""
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{"USE_HYBRID_SEARCH": "true", "USE_RERANKING": "false", "USE_AGENTIC_RAG": "true"},
|
||||
):
|
||||
assert rag_service.get_bool_setting("USE_HYBRID_SEARCH") is True
|
||||
assert rag_service.get_bool_setting("USE_RERANKING") is False
|
||||
assert rag_service.get_bool_setting("USE_AGENTIC_RAG") is True
|
||||
|
||||
def test_default_settings(self, rag_service):
|
||||
"""Test default settings when environment variables not set"""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
assert rag_service.get_bool_setting("NONEXISTENT_SETTING", True) is True
|
||||
assert rag_service.get_bool_setting("NONEXISTENT_SETTING", False) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_conditional_execution(self, rag_service):
|
||||
"""Test that strategies only execute when enabled"""
|
||||
with (
|
||||
patch(
|
||||
"src.server.services.embeddings.embedding_service.create_embedding"
|
||||
) as mock_embedding,
|
||||
patch.object(rag_service.base_strategy, "vector_search") as mock_search,
|
||||
patch.object(rag_service, "get_bool_setting") as mock_setting,
|
||||
):
|
||||
# Mock embedding creation
|
||||
mock_embedding.return_value = [0.1] * 1536
|
||||
|
||||
mock_search.return_value = [
|
||||
{"content": "test", "similarity": 0.9, "id": "1", "metadata": {}}
|
||||
]
|
||||
|
||||
# Disable all strategies
|
||||
mock_setting.return_value = False
|
||||
|
||||
success, result = await rag_service.perform_rag_query(query="test query", match_count=5)
|
||||
|
||||
assert success is True
|
||||
# Should still return results from basic search
|
||||
assert "results" in result
|
||||
95
python/tests/test_service_integration.py
Normal file
95
python/tests/test_service_integration.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Service integration tests - Test core service interactions."""
|
||||
|
||||
|
||||
def test_project_with_tasks_flow(client):
|
||||
"""Test creating a project and adding tasks."""
|
||||
# Create project
|
||||
project_response = client.post("/api/projects", json={"title": "Test Project"})
|
||||
assert project_response.status_code in [200, 201, 422]
|
||||
|
||||
# List projects to verify
|
||||
list_response = client.get("/api/projects")
|
||||
assert list_response.status_code in [200, 500] # 500 is OK in test environment
|
||||
|
||||
|
||||
def test_crawl_to_knowledge_flow(client):
|
||||
"""Test crawling workflow."""
|
||||
# Start crawl
|
||||
crawl_data = {"url": "https://example.com", "max_depth": 1, "max_pages": 5}
|
||||
response = client.post("/api/knowledge/crawl", json=crawl_data)
|
||||
assert response.status_code in [200, 201, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_document_storage_flow(client):
|
||||
"""Test document upload endpoint."""
|
||||
# Test multipart form upload
|
||||
files = {"file": ("test.txt", b"Test content", "text/plain")}
|
||||
response = client.post("/api/knowledge/documents", files=files)
|
||||
assert response.status_code in [200, 201, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_code_extraction_flow(client):
|
||||
"""Test code extraction endpoint."""
|
||||
response = client.post(
|
||||
"/api/knowledge/extract-code", json={"document_id": "test-doc-id", "languages": ["python"]}
|
||||
)
|
||||
assert response.status_code in [200, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_search_and_retrieve_flow(client):
|
||||
"""Test search functionality."""
|
||||
# Search
|
||||
search_response = client.post("/api/knowledge/search", json={"query": "test"})
|
||||
assert search_response.status_code in [200, 400, 404, 422, 500]
|
||||
|
||||
# Get specific item (might not exist)
|
||||
item_response = client.get("/api/knowledge/items/test-id")
|
||||
assert item_response.status_code in [200, 404, 500]
|
||||
|
||||
|
||||
def test_mcp_tool_execution(client):
|
||||
"""Test MCP tool execution endpoint."""
|
||||
response = client.post("/api/mcp/tools/execute", json={"tool": "test_tool", "params": {}})
|
||||
assert response.status_code in [200, 400, 404, 422, 500]
|
||||
|
||||
|
||||
def test_socket_io_events(client):
|
||||
"""Test Socket.IO connectivity."""
|
||||
# Just verify the endpoint exists
|
||||
response = client.get("/socket.io/")
|
||||
assert response.status_code in [200, 400, 404]
|
||||
|
||||
|
||||
def test_background_task_progress(client):
|
||||
"""Test background task tracking."""
|
||||
# Check if task progress endpoint exists
|
||||
response = client.get("/api/tasks/test-task-id/progress")
|
||||
assert response.status_code in [200, 404, 500]
|
||||
|
||||
|
||||
def test_database_operations(client):
|
||||
"""Test pagination and filtering."""
|
||||
# Test with query params
|
||||
response = client.get("/api/projects?limit=10&offset=0")
|
||||
assert response.status_code in [200, 500] # 500 is OK in test environment
|
||||
|
||||
# Test filtering
|
||||
response = client.get("/api/tasks?status=todo")
|
||||
assert response.status_code in [200, 400, 422, 500]
|
||||
|
||||
|
||||
def test_concurrent_operations(client):
|
||||
"""Test API handles concurrent requests."""
|
||||
import concurrent.futures
|
||||
|
||||
def make_request():
|
||||
return client.get("/api/projects")
|
||||
|
||||
# Make 3 concurrent requests
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [executor.submit(make_request) for _ in range(3)]
|
||||
results = [f.result() for f in futures]
|
||||
|
||||
# All should succeed or fail with 500 in test environment
|
||||
for result in results:
|
||||
assert result.status_code in [200, 500] # 500 is OK in test environment
|
||||
58
python/tests/test_settings_api.py
Normal file
58
python/tests/test_settings_api.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Simple tests for settings API credential handling.
|
||||
Focus on critical paths for optional settings with defaults.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def test_optional_setting_returns_default(client, mock_supabase_client):
|
||||
"""Test that optional settings return default values with is_default flag."""
|
||||
# Mock the entire credential_service instance
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_credential = AsyncMock(return_value=None)
|
||||
|
||||
with patch("src.server.api_routes.settings_api.credential_service", mock_service):
|
||||
response = client.get("/api/credentials/DISCONNECT_SCREEN_ENABLED")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["key"] == "DISCONNECT_SCREEN_ENABLED"
|
||||
assert data["value"] == "true"
|
||||
assert data["is_default"] is True
|
||||
assert "category" in data
|
||||
assert "description" in data
|
||||
|
||||
|
||||
def test_unknown_credential_returns_404(client, mock_supabase_client):
|
||||
"""Test that unknown credentials still return 404."""
|
||||
# Mock the entire credential_service instance
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_credential = AsyncMock(return_value=None)
|
||||
|
||||
with patch("src.server.api_routes.settings_api.credential_service", mock_service):
|
||||
response = client.get("/api/credentials/UNKNOWN_KEY_THAT_DOES_NOT_EXIST")
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data["detail"]
|
||||
assert "not found" in data["detail"]["error"].lower()
|
||||
|
||||
|
||||
def test_existing_credential_returns_normally(client, mock_supabase_client):
|
||||
"""Test that existing credentials return without default flag."""
|
||||
mock_value = "user_configured_value"
|
||||
# Mock the entire credential_service instance
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_credential = AsyncMock(return_value=mock_value)
|
||||
|
||||
with patch("src.server.api_routes.settings_api.credential_service", mock_service):
|
||||
response = client.get("/api/credentials/SOME_EXISTING_KEY")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["key"] == "SOME_EXISTING_KEY"
|
||||
assert data["value"] == "user_configured_value"
|
||||
assert data["is_encrypted"] is False
|
||||
# Should not have is_default flag for real credentials
|
||||
assert "is_default" not in data
|
||||
Reference in New Issue
Block a user