Merge branch 'main' into feature/automatic-discovery-llms-sitemap-430

This commit is contained in:
leex279
2025-09-22 22:35:36 +02:00
127 changed files with 6076 additions and 2018 deletions

View File

@@ -16,7 +16,6 @@ import os
from urllib.parse import urljoin
import httpx
from mcp.server.fastmcp import Context, FastMCP
# Import service discovery for HTTP communication
@@ -78,15 +77,18 @@ def register_rag_tools(mcp: FastMCP):
@mcp.tool()
async def rag_search_knowledge_base(
ctx: Context, query: str, source_domain: str | None = None, match_count: int = 5
ctx: Context, query: str, source_id: str | None = None, match_count: int = 5
) -> str:
"""
Search knowledge base for relevant content using RAG.
Args:
query: Search query
source_domain: Optional domain filter (e.g., 'docs.anthropic.com').
Note: This is a domain name, not the source_id from get_available_sources.
query: Search query - Keep it SHORT and FOCUSED (2-5 keywords).
Good: "vector search", "authentication JWT", "React hooks"
Bad: "how to implement user authentication with JWT tokens in React with TypeScript and handle refresh tokens"
source_id: Optional source ID filter from rag_get_available_sources().
This is the 'id' field from available sources, NOT a URL or domain name.
Example: "src_1234abcd" not "docs.anthropic.com"
match_count: Max results (default: 5)
Returns:
@@ -102,8 +104,8 @@ def register_rag_tools(mcp: FastMCP):
async with httpx.AsyncClient(timeout=timeout) as client:
request_data = {"query": query, "match_count": match_count}
if source_domain:
request_data["source"] = source_domain
if source_id:
request_data["source"] = source_id
response = await client.post(urljoin(api_url, "/api/rag/query"), json=request_data)
@@ -135,15 +137,18 @@ def register_rag_tools(mcp: FastMCP):
@mcp.tool()
async def rag_search_code_examples(
ctx: Context, query: str, source_domain: str | None = None, match_count: int = 5
ctx: Context, query: str, source_id: str | None = None, match_count: int = 5
) -> str:
"""
Search for relevant code examples in the knowledge base.
Args:
query: Search query
source_domain: Optional domain filter (e.g., 'docs.anthropic.com').
Note: This is a domain name, not the source_id from get_available_sources.
query: Search query - Keep it SHORT and FOCUSED (2-5 keywords).
Good: "React useState", "FastAPI middleware", "vector pgvector"
Bad: "React hooks useState useEffect useContext useReducer useMemo useCallback"
source_id: Optional source ID filter from rag_get_available_sources().
This is the 'id' field from available sources, NOT a URL or domain name.
Example: "src_1234abcd" not "docs.anthropic.com"
match_count: Max results (default: 5)
Returns:
@@ -159,8 +164,8 @@ def register_rag_tools(mcp: FastMCP):
async with httpx.AsyncClient(timeout=timeout) as client:
request_data = {"query": query, "match_count": match_count}
if source_domain:
request_data["source"] = source_domain
if source_id:
request_data["source"] = source_id
# Call the dedicated code examples endpoint
response = await client.post(

View File

@@ -10,8 +10,8 @@ from typing import Any
from urllib.parse import urljoin
import httpx
from mcp.server.fastmcp import Context, FastMCP
from src.mcp_server.utils.error_handling import MCPErrorFormatter
from src.mcp_server.utils.timeout_config import get_default_timeout
from src.server.config.service_discovery import get_api_url
@@ -31,20 +31,20 @@ def truncate_text(text: str, max_length: int = MAX_DESCRIPTION_LENGTH) -> str:
def optimize_task_response(task: dict) -> dict:
"""Optimize task object for MCP response."""
task = task.copy() # Don't modify original
# Truncate description if present
if "description" in task and task["description"]:
task["description"] = truncate_text(task["description"])
# Replace arrays with counts
if "sources" in task and isinstance(task["sources"], list):
task["sources_count"] = len(task["sources"])
del task["sources"]
if "code_examples" in task and isinstance(task["code_examples"], list):
task["code_examples_count"] = len(task["code_examples"])
del task["code_examples"]
return task
@@ -88,12 +88,12 @@ def register_task_tools(mcp: FastMCP):
try:
api_url = get_api_url()
timeout = get_default_timeout()
# Single task get mode
if task_id:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(urljoin(api_url, f"/api/tasks/{task_id}"))
if response.status_code == 200:
task = response.json()
# Don't optimize single task get - return full details
@@ -107,18 +107,18 @@ def register_task_tools(mcp: FastMCP):
)
else:
return MCPErrorFormatter.from_http_error(response, "get task")
# List mode with search and filters
params: dict[str, Any] = {
"page": page,
"per_page": per_page,
"exclude_large_fields": True, # Always exclude large fields in MCP responses
}
# Add search query if provided
if query:
params["q"] = query
if filter_by == "project" and filter_value:
# Use project-specific endpoint for project filtering
url = urljoin(api_url, f"/api/projects/{filter_value}/tasks")
@@ -146,13 +146,13 @@ def register_task_tools(mcp: FastMCP):
# No specific filters - get all tasks
url = urljoin(api_url, "/api/tasks")
params["include_closed"] = include_closed
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(url, params=params)
response.raise_for_status()
result = response.json()
# Normalize response format
if isinstance(result, list):
tasks = result
@@ -176,10 +176,10 @@ def register_task_tools(mcp: FastMCP):
message="Invalid response type from API",
details={"response_type": type(result).__name__},
)
# Optimize task responses
optimized_tasks = [optimize_task_response(task) for task in tasks]
return json.dumps({
"success": True,
"tasks": optimized_tasks,
@@ -187,7 +187,7 @@ def register_task_tools(mcp: FastMCP):
"count": len(optimized_tasks),
"query": query, # Include search query in response
})
except httpx.RequestError as e:
return MCPErrorFormatter.from_exception(
e, "list tasks", {"filter_by": filter_by, "filter_value": filter_value}
@@ -211,13 +211,19 @@ def register_task_tools(mcp: FastMCP):
) -> str:
"""
Manage tasks (consolidated: create/update/delete).
TASK GRANULARITY GUIDANCE:
- For feature-specific projects: Create detailed implementation tasks (setup, implement, test, document)
- For codebase-wide projects: Create feature-level tasks
- Default to more granular tasks when project scope is unclear
- Each task should represent 30 minutes to 4 hours of work
Args:
action: "create" | "update" | "delete"
task_id: Task UUID for update/delete
project_id: Project UUID for create
title: Task title text
description: Detailed task description
description: Detailed task description with clear completion criteria
status: "todo" | "doing" | "review" | "done"
assignee: String name of the assignee. Can be any agent name,
"User" for human assignment, or custom agent identifiers
@@ -228,16 +234,17 @@ def register_task_tools(mcp: FastMCP):
feature: Feature label for grouping
Examples:
manage_task("create", project_id="p-1", title="Fix auth bug", assignee="CodeAnalyzer-v2")
manage_task("create", project_id="p-1", title="Research existing patterns", description="Study codebase for similar implementations")
manage_task("create", project_id="p-1", title="Write unit tests", description="Cover all edge cases with 80% coverage")
manage_task("update", task_id="t-1", status="doing", assignee="User")
manage_task("delete", task_id="t-1")
Returns: {success: bool, task?: object, message: string}
"""
try:
api_url = get_api_url()
timeout = get_default_timeout()
async with httpx.AsyncClient(timeout=timeout) as client:
if action == "create":
if not project_id or not title:
@@ -246,7 +253,7 @@ def register_task_tools(mcp: FastMCP):
"project_id and title required for create",
suggestion="Provide both project_id and title"
)
response = await client.post(
urljoin(api_url, "/api/tasks"),
json={
@@ -260,15 +267,15 @@ def register_task_tools(mcp: FastMCP):
"code_examples": [],
},
)
if response.status_code == 200:
result = response.json()
task = result.get("task")
# Optimize task response
if task:
task = optimize_task_response(task)
return json.dumps({
"success": True,
"task": task,
@@ -277,7 +284,7 @@ def register_task_tools(mcp: FastMCP):
})
else:
return MCPErrorFormatter.from_http_error(response, "create task")
elif action == "update":
if not task_id:
return MCPErrorFormatter.format_error(
@@ -285,7 +292,7 @@ def register_task_tools(mcp: FastMCP):
"task_id required for update",
suggestion="Provide task_id to update"
)
# Build update fields
update_fields = {}
if title is not None:
@@ -300,27 +307,27 @@ def register_task_tools(mcp: FastMCP):
update_fields["task_order"] = task_order
if feature is not None:
update_fields["feature"] = feature
if not update_fields:
return MCPErrorFormatter.format_error(
error_type="validation_error",
message="No fields to update",
suggestion="Provide at least one field to update",
)
response = await client.put(
urljoin(api_url, f"/api/tasks/{task_id}"),
json=update_fields
)
if response.status_code == 200:
result = response.json()
task = result.get("task")
# Optimize task response
if task:
task = optimize_task_response(task)
return json.dumps({
"success": True,
"task": task,
@@ -328,7 +335,7 @@ def register_task_tools(mcp: FastMCP):
})
else:
return MCPErrorFormatter.from_http_error(response, "update task")
elif action == "delete":
if not task_id:
return MCPErrorFormatter.format_error(
@@ -336,11 +343,11 @@ def register_task_tools(mcp: FastMCP):
"task_id required for delete",
suggestion="Provide task_id to delete"
)
response = await client.delete(
urljoin(api_url, f"/api/tasks/{task_id}")
)
if response.status_code == 200:
result = response.json()
return json.dumps({
@@ -349,14 +356,14 @@ def register_task_tools(mcp: FastMCP):
})
else:
return MCPErrorFormatter.from_http_error(response, "delete task")
else:
return MCPErrorFormatter.format_error(
"invalid_action",
f"Unknown action: {action}",
suggestion="Use 'create', 'update', or 'delete'"
)
except httpx.RequestError as e:
return MCPErrorFormatter.from_exception(
e, f"{action} task", {"task_id": task_id, "project_id": project_id}

View File

@@ -194,12 +194,30 @@ MCP_INSTRUCTIONS = """
## 🚨 CRITICAL RULES (ALWAYS FOLLOW)
1. **Task Management**: ALWAYS use Archon MCP tools for task management.
- Combine with your local TODO tools for granular tracking
- First TODO: Update Archon task status
- Last TODO: Update Archon with findings/completion
2. **Research First**: Before implementing, use rag_search_knowledge_base and rag_search_code_examples
3. **Task-Driven Development**: Never code without checking current tasks first
## 🎯 Targeted Documentation Search
When searching specific documentation (very common!):
1. **Get available sources**: `rag_get_available_sources()` - Returns list with id, title, url
2. **Find source ID**: Match user's request to source title (e.g., "PydanticAI docs" -> find ID)
3. **Filter search**: `rag_search_knowledge_base(query="...", source_id="src_xxx", match_count=5)`
Examples:
- User: "Search the Supabase docs for vector functions"
1. Call `rag_get_available_sources()`
2. Find Supabase source ID from results (e.g., "src_abc123")
3. Call `rag_search_knowledge_base(query="vector functions", source_id="src_abc123")`
- User: "Find authentication examples in the MCP documentation"
1. Call `rag_get_available_sources()`
2. Find MCP docs source ID
3. Call `rag_search_code_examples(query="authentication", source_id="src_def456")`
IMPORTANT: Always use source_id (not URLs or domain names) for filtering!
## 📋 Core Workflow
### Task Management Cycle
@@ -215,9 +233,9 @@ MCP_INSTRUCTIONS = """
### Consolidated Task Tools (Optimized ~2 tools from 5)
- `list_tasks(query=None, task_id=None, filter_by=None, filter_value=None, per_page=10)`
- **Consolidated**: list + search + get in one tool
- **NEW**: Search with keyword query parameter
- **NEW**: task_id parameter for getting single task (full details)
- list + search + get in one tool
- Search with keyword query parameter (optional)
- task_id parameter for getting single task (full details)
- Filter by status, project, or assignee
- **Optimized**: Returns truncated descriptions and array counts (lists only)
- **Default**: 10 items per page (was 50)
@@ -231,23 +249,38 @@ MCP_INSTRUCTIONS = """
## 🏗️ Project Management
### Project Tools (Consolidated)
### Project Tools
- `list_projects(project_id=None, query=None, page=1, per_page=10)`
- List all projects, search by query, or get specific project by ID
- `manage_project(action, project_id=None, title=None, description=None, github_repo=None)`
- Actions: "create", "update", "delete"
### Document Tools (Consolidated)
### Document Tools
- `list_documents(project_id, document_id=None, query=None, document_type=None, page=1, per_page=10)`
- List project documents, search, filter by type, or get specific document
- `manage_document(action, project_id, document_id=None, title=None, document_type=None, content=None, ...)`
- Actions: "create", "update", "delete"
## 🔍 Research Patterns
- **Architecture patterns**: `rag_search_knowledge_base(query="[tech] architecture patterns", match_count=5)`
- **Code examples**: `rag_search_code_examples(query="[feature] implementation", match_count=3)`
- **Source discovery**: `rag_get_available_sources()`
- Keep match_count around 3-5 for focused results
### CRITICAL: Keep Queries Short and Focused!
Vector search works best with 2-5 keywords, NOT long sentences or keyword dumps.
✅ GOOD Queries (concise, focused):
- `rag_search_knowledge_base(query="vector search pgvector")`
- `rag_search_code_examples(query="React useState")`
- `rag_search_knowledge_base(query="authentication JWT")`
- `rag_search_code_examples(query="FastAPI middleware")`
❌ BAD Queries (too long, unfocused):
- `rag_search_knowledge_base(query="how to implement vector search with pgvector in PostgreSQL for semantic similarity matching with OpenAI embeddings")`
- `rag_search_code_examples(query="React hooks useState useEffect useContext useReducer useMemo useCallback")`
### Query Construction Tips:
- Extract 2-5 most important keywords from the user's request
- Focus on technical terms and specific technologies
- Omit filler words like "how to", "implement", "create", "example"
- For multi-concept searches, do multiple focused queries instead of one broad query
## 📊 Task Status Flow
`todo` → `doing` → `review` → `done`
@@ -255,25 +288,26 @@ MCP_INSTRUCTIONS = """
- Use 'review' for completed work awaiting validation
- Mark tasks 'done' only after verification
## 💾 Version Management (Consolidated)
- `list_versions(project_id, field_name=None, version_number=None, page=1, per_page=10)`
- List all versions, filter by field, or get specific version
- `manage_version(action, project_id, field_name, version_number=None, content=None, change_summary=None, ...)`
- Actions: "create", "restore"
- Field names: "docs", "features", "data", "prd"
## 📝 Task Granularity Guidelines
## 🎯 Best Practices
1. **Atomic Tasks**: Create tasks that take 1-4 hours
2. **Clear Descriptions**: Include acceptance criteria in task descriptions
3. **Use Features**: Group related tasks with feature labels
4. **Add Sources**: Link relevant documentation to tasks
5. **Track Progress**: Update task status as you work
### Project Scope Determines Task Granularity
## 📊 Optimization Updates
- **Payload Optimization**: Tasks in lists return truncated descriptions (200 chars)
- **Array Counts**: Source/example arrays replaced with counts in list responses
- **Smart Defaults**: Default page size reduced from 50 to 10 items
- **Search Support**: New `query` parameter in list_tasks for keyword search
**For Feature-Specific Projects** (project = single feature):
Create granular implementation tasks:
- "Set up development environment"
- "Install required dependencies"
- "Create database schema"
- "Implement API endpoints"
- "Add frontend components"
- "Write unit tests"
- "Add integration tests"
- "Update documentation"
**For Codebase-Wide Projects** (project = entire application):
Create feature-level tasks:
- "Implement user authentication feature"
- "Add payment processing system"
- "Create admin dashboard"
"""
# Initialize the main FastMCP server with fixed configuration

View File

@@ -14,6 +14,7 @@ from .internal_api import router as internal_router
from .knowledge_api import router as knowledge_router
from .mcp_api import router as mcp_router
from .projects_api import router as projects_router
from .providers_api import router as providers_router
from .settings_api import router as settings_router
__all__ = [
@@ -23,4 +24,5 @@ __all__ = [
"projects_router",
"agent_chat_router",
"internal_router",
"providers_router",
]

View File

@@ -18,6 +18,8 @@ from urllib.parse import urlparse
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from pydantic import BaseModel
# Basic validation - simplified inline version
# Import unified logging
from ..config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info
from ..services.crawler_manager import get_crawler
@@ -62,26 +64,59 @@ async def _validate_provider_api_key(provider: str = None) -> None:
logger.info("🔑 Starting API key validation...")
try:
# Basic provider validation
if not provider:
provider = "openai"
else:
# Simple provider validation
allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
if provider not in allowed_providers:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid provider name",
"message": f"Provider '{provider}' not supported",
"error_type": "validation_error"
}
)
logger.info(f"🔑 Testing {provider.title()} API key with minimal embedding request...")
# Test API key with minimal embedding request - this will fail if key is invalid
from ..services.embeddings.embedding_service import create_embedding
test_result = await create_embedding(text="test")
if not test_result:
logger.error(f"{provider.title()} API key validation failed - no embedding returned")
raise HTTPException(
status_code=401,
detail={
"error": f"Invalid {provider.title()} API key",
"message": f"Please verify your {provider.title()} API key in Settings.",
"error_type": "authentication_failed",
"provider": provider
}
)
# Basic sanitization for logging
safe_provider = provider[:20] # Limit length
logger.info(f"🔑 Testing {safe_provider.title()} API key with minimal embedding request...")
try:
# Test API key with minimal embedding request using provider-scoped configuration
from ..services.embeddings.embedding_service import create_embedding
test_result = await create_embedding(text="test", provider=provider)
if not test_result:
logger.error(
f"{provider.title()} API key validation failed - no embedding returned"
)
raise HTTPException(
status_code=401,
detail={
"error": f"Invalid {provider.title()} API key",
"message": f"Please verify your {provider.title()} API key in Settings.",
"error_type": "authentication_failed",
"provider": provider,
},
)
except Exception as e:
logger.error(
f"{provider.title()} API key validation failed: {e}",
exc_info=True,
)
raise HTTPException(
status_code=401,
detail={
"error": f"Invalid {provider.title()} API key",
"message": f"Please verify your {provider.title()} API key in Settings. Error: {str(e)[:100]}",
"error_type": "authentication_failed",
"provider": provider,
},
)
logger.info(f"{provider.title()} API key validation successful")

View File

@@ -0,0 +1,170 @@
"""
API routes for database migration tracking and management.
"""
from datetime import datetime
import logfire
from fastapi import APIRouter, Header, HTTPException, Response
from pydantic import BaseModel
from ..config.version import ARCHON_VERSION
from ..services.migration_service import migration_service
from ..utils.etag_utils import check_etag, generate_etag
# Response models
class MigrationRecord(BaseModel):
"""Represents an applied migration."""
version: str
migration_name: str
applied_at: datetime
checksum: str | None = None
class PendingMigration(BaseModel):
"""Represents a pending migration."""
version: str
name: str
sql_content: str
file_path: str
checksum: str | None = None
class MigrationStatusResponse(BaseModel):
"""Complete migration status response."""
pending_migrations: list[PendingMigration]
applied_migrations: list[MigrationRecord]
has_pending: bool
bootstrap_required: bool
current_version: str
pending_count: int
applied_count: int
class MigrationHistoryResponse(BaseModel):
"""Migration history response."""
migrations: list[MigrationRecord]
total_count: int
current_version: str
# Create router
router = APIRouter(prefix="/api/migrations", tags=["migrations"])
@router.get("/status", response_model=MigrationStatusResponse)
async def get_migration_status(
response: Response, if_none_match: str | None = Header(None)
):
"""
Get current migration status including pending and applied migrations.
Returns comprehensive migration status with:
- List of pending migrations with SQL content
- List of applied migrations
- Bootstrap flag if migrations table doesn't exist
- Current version information
"""
try:
# Get migration status from service
status = await migration_service.get_migration_status()
# Generate ETag for response
etag = generate_etag(status)
# Check if client has current data
if check_etag(if_none_match, etag):
# Client has current data, return 304
response.status_code = 304
response.headers["ETag"] = f'"{etag}"'
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return Response(status_code=304)
else:
# Client needs new data
response.headers["ETag"] = f'"{etag}"'
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return MigrationStatusResponse(**status)
except Exception as e:
logfire.error(f"Error getting migration status: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get migration status: {str(e)}") from e
@router.get("/history", response_model=MigrationHistoryResponse)
async def get_migration_history(response: Response, if_none_match: str | None = Header(None)):
"""
Get history of applied migrations.
Returns list of all applied migrations sorted by date.
"""
try:
# Get applied migrations from service
applied = await migration_service.get_applied_migrations()
# Format response
history = {
"migrations": [
MigrationRecord(
version=m.version,
migration_name=m.migration_name,
applied_at=m.applied_at,
checksum=m.checksum,
)
for m in applied
],
"total_count": len(applied),
"current_version": ARCHON_VERSION,
}
# Generate ETag for response
etag = generate_etag(history)
# Check if client has current data
if check_etag(if_none_match, etag):
# Client has current data, return 304
response.status_code = 304
response.headers["ETag"] = f'"{etag}"'
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return Response(status_code=304)
else:
# Client needs new data
response.headers["ETag"] = f'"{etag}"'
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return MigrationHistoryResponse(**history)
except Exception as e:
logfire.error(f"Error getting migration history: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get migration history: {str(e)}") from e
@router.get("/pending", response_model=list[PendingMigration])
async def get_pending_migrations():
"""
Get list of pending migrations only.
Returns simplified list of migrations that need to be applied.
"""
try:
# Get pending migrations from service
pending = await migration_service.get_pending_migrations()
# Format response
return [
PendingMigration(
version=m.version,
name=m.name,
sql_content=m.sql_content,
file_path=m.file_path,
checksum=m.checksum,
)
for m in pending
]
except Exception as e:
logfire.error(f"Error getting pending migrations: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get pending migrations: {str(e)}") from e

View File

@@ -0,0 +1,154 @@
"""
Provider status API endpoints for testing connectivity
Handles server-side provider connectivity testing without exposing API keys to frontend.
"""
import httpx
from fastapi import APIRouter, HTTPException, Path
from ..config.logfire_config import logfire
from ..services.credential_service import credential_service
# Provider validation - simplified inline version
router = APIRouter(prefix="/api/providers", tags=["providers"])
async def test_openai_connection(api_key: str) -> bool:
"""Test OpenAI API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {api_key}"}
)
return response.status_code == 200
except Exception as e:
logfire.warning(f"OpenAI connectivity test failed: {e}")
return False
async def test_google_connection(api_key: str) -> bool:
"""Test Google AI API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://generativelanguage.googleapis.com/v1/models",
headers={"x-goog-api-key": api_key}
)
return response.status_code == 200
except Exception:
logfire.warning("Google AI connectivity test failed")
return False
async def test_anthropic_connection(api_key: str) -> bool:
"""Test Anthropic API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://api.anthropic.com/v1/models",
headers={
"x-api-key": api_key,
"anthropic-version": "2023-06-01"
}
)
return response.status_code == 200
except Exception as e:
logfire.warning(f"Anthropic connectivity test failed: {e}")
return False
async def test_openrouter_connection(api_key: str) -> bool:
"""Test OpenRouter API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://openrouter.ai/api/v1/models",
headers={"Authorization": f"Bearer {api_key}"}
)
return response.status_code == 200
except Exception as e:
logfire.warning(f"OpenRouter connectivity test failed: {e}")
return False
async def test_grok_connection(api_key: str) -> bool:
"""Test Grok API connectivity"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://api.x.ai/v1/models",
headers={"Authorization": f"Bearer {api_key}"}
)
return response.status_code == 200
except Exception as e:
logfire.warning(f"Grok connectivity test failed: {e}")
return False
PROVIDER_TESTERS = {
"openai": test_openai_connection,
"google": test_google_connection,
"anthropic": test_anthropic_connection,
"openrouter": test_openrouter_connection,
"grok": test_grok_connection,
}
@router.get("/{provider}/status")
async def get_provider_status(
provider: str = Path(
...,
description="Provider name to test connectivity for",
regex="^[a-z0-9_]+$",
max_length=20
)
):
"""Test provider connectivity using server-side API key (secure)"""
try:
# Basic provider validation
allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
if provider not in allowed_providers:
raise HTTPException(
status_code=400,
detail=f"Invalid provider '{provider}'. Allowed providers: {sorted(allowed_providers)}"
)
# Basic sanitization for logging
safe_provider = provider[:20] # Limit length
logfire.info(f"Testing {safe_provider} connectivity server-side")
if provider not in PROVIDER_TESTERS:
raise HTTPException(
status_code=400,
detail=f"Provider '{provider}' not supported for connectivity testing"
)
# Get API key server-side (never expose to client)
key_name = f"{provider.upper()}_API_KEY"
api_key = await credential_service.get_credential(key_name, decrypt=True)
if not api_key or not isinstance(api_key, str) or not api_key.strip():
logfire.info(f"No API key configured for {safe_provider}")
return {"ok": False, "reason": "no_key"}
# Test connectivity using server-side key
tester = PROVIDER_TESTERS[provider]
is_connected = await tester(api_key)
logfire.info(f"{safe_provider} connectivity test result: {is_connected}")
return {
"ok": is_connected,
"reason": "connected" if is_connected else "connection_failed",
"provider": provider # Echo back validated provider name
}
except HTTPException:
# Re-raise HTTP exceptions (they're already properly formatted)
raise
except Exception as e:
# Basic error sanitization for logging
safe_error = str(e)[:100] # Limit length
logfire.error(f"Error testing {provider[:20]} connectivity: {safe_error}")
raise HTTPException(status_code=500, detail={"error": "Internal server error during connectivity test"})

View File

@@ -0,0 +1,121 @@
"""
API routes for version checking and update management.
"""
from datetime import datetime
from typing import Any
import logfire
from fastapi import APIRouter, Header, HTTPException, Response
from pydantic import BaseModel
from ..config.version import ARCHON_VERSION
from ..services.version_service import version_service
from ..utils.etag_utils import check_etag, generate_etag
# Response models
class ReleaseAsset(BaseModel):
"""Represents a downloadable asset from a release."""
name: str
size: int
download_count: int
browser_download_url: str
content_type: str
class VersionCheckResponse(BaseModel):
"""Version check response with update information."""
current: str
latest: str | None
update_available: bool
release_url: str | None
release_notes: str | None
published_at: datetime | None
check_error: str | None = None
assets: list[dict[str, Any]] | None = None
author: str | None = None
class CurrentVersionResponse(BaseModel):
"""Simple current version response."""
version: str
timestamp: datetime
# Create router
router = APIRouter(prefix="/api/version", tags=["version"])
@router.get("/check", response_model=VersionCheckResponse)
async def check_for_updates(response: Response, if_none_match: str | None = Header(None)):
"""
Check for available Archon updates.
Queries GitHub releases API to determine if a newer version is available.
Results are cached for 1 hour to avoid rate limiting.
Returns:
Version information including current, latest, and update availability
"""
try:
# Get version check results from service
result = await version_service.check_for_updates()
# Generate ETag for response
etag = generate_etag(result)
# Check if client has current data
if check_etag(if_none_match, etag):
# Client has current data, return 304
response.status_code = 304
response.headers["ETag"] = f'"{etag}"'
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return Response(status_code=304)
else:
# Client needs new data
response.headers["ETag"] = f'"{etag}"'
response.headers["Cache-Control"] = "no-cache, must-revalidate"
return VersionCheckResponse(**result)
except Exception as e:
logfire.error(f"Error checking for updates: {e}")
# Return safe response with error
return VersionCheckResponse(
current=ARCHON_VERSION,
latest=None,
update_available=False,
release_url=None,
release_notes=None,
published_at=None,
check_error=str(e),
)
@router.get("/current", response_model=CurrentVersionResponse)
async def get_current_version():
"""
Get the current Archon version.
Simple endpoint that returns the installed version without checking for updates.
"""
return CurrentVersionResponse(version=ARCHON_VERSION, timestamp=datetime.now())
@router.post("/clear-cache")
async def clear_version_cache():
"""
Clear the version check cache.
Forces the next version check to query GitHub API instead of using cached data.
Useful for testing or forcing an immediate update check.
"""
try:
version_service.clear_cache()
return {"message": "Version cache cleared successfully", "success": True}
except Exception as e:
logfire.error(f"Error clearing version cache: {e}")
raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}") from e

View File

@@ -0,0 +1,11 @@
"""
Version configuration for Archon.
"""
# Current version of Archon
# Update this with each release
ARCHON_VERSION = "0.1.0"
# Repository information for GitHub API
GITHUB_REPO_OWNER = "coleam00"
GITHUB_REPO_NAME = "Archon"

View File

@@ -23,9 +23,12 @@ from .api_routes.bug_report_api import router as bug_report_router
from .api_routes.internal_api import router as internal_router
from .api_routes.knowledge_api import router as knowledge_router
from .api_routes.mcp_api import router as mcp_router
from .api_routes.migration_api import router as migration_router
from .api_routes.ollama_api import router as ollama_router
from .api_routes.progress_api import router as progress_router
from .api_routes.projects_api import router as projects_router
from .api_routes.providers_api import router as providers_router
from .api_routes.version_api import router as version_router
# Import modular API routers
from .api_routes.settings_api import router as settings_router
@@ -186,6 +189,9 @@ app.include_router(progress_router)
app.include_router(agent_chat_router)
app.include_router(internal_router)
app.include_router(bug_report_router)
app.include_router(providers_router)
app.include_router(version_router)
app.include_router(migration_router)
# Root endpoint

View File

@@ -139,6 +139,7 @@ class CodeExtractionService:
source_id: str,
progress_callback: Callable | None = None,
cancellation_check: Callable[[], None] | None = None,
provider: str | None = None,
) -> int:
"""
Extract code examples from crawled documents and store them.
@@ -204,7 +205,7 @@ class CodeExtractionService:
# Generate summaries for code blocks
summary_results = await self._generate_code_summaries(
all_code_blocks, summary_callback, cancellation_check
all_code_blocks, summary_callback, cancellation_check, provider
)
# Prepare code examples for storage
@@ -223,7 +224,7 @@ class CodeExtractionService:
# Store code examples in database
return await self._store_code_examples(
storage_data, url_to_full_document, storage_callback
storage_data, url_to_full_document, storage_callback, provider
)
async def _extract_code_blocks_from_documents(
@@ -1523,6 +1524,7 @@ class CodeExtractionService:
all_code_blocks: list[dict[str, Any]],
progress_callback: Callable | None = None,
cancellation_check: Callable[[], None] | None = None,
provider: str | None = None,
) -> list[dict[str, str]]:
"""
Generate summaries for all code blocks.
@@ -1587,7 +1589,7 @@ class CodeExtractionService:
try:
results = await generate_code_summaries_batch(
code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback
code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback, provider=provider
)
# Ensure all results are valid dicts
@@ -1667,6 +1669,7 @@ class CodeExtractionService:
storage_data: dict[str, list[Any]],
url_to_full_document: dict[str, str],
progress_callback: Callable | None = None,
provider: str | None = None,
) -> int:
"""
Store code examples in the database.
@@ -1709,7 +1712,7 @@ class CodeExtractionService:
batch_size=20,
url_to_full_document=url_to_full_document,
progress_callback=storage_progress_callback,
provider=None, # Use configured provider
provider=provider,
)
# Report completion of code extraction/storage phase

View File

@@ -75,10 +75,11 @@ class CrawlingService:
self.url_handler = URLHandler()
self.site_config = SiteConfig()
self.markdown_generator = self.site_config.get_markdown_generator()
self.link_pruning_markdown_generator = self.site_config.get_link_pruning_markdown_generator()
# Initialize strategies
self.batch_strategy = BatchCrawlStrategy(crawler, self.markdown_generator)
self.recursive_strategy = RecursiveCrawlStrategy(crawler, self.markdown_generator)
self.batch_strategy = BatchCrawlStrategy(crawler, self.link_pruning_markdown_generator)
self.recursive_strategy = RecursiveCrawlStrategy(crawler, self.link_pruning_markdown_generator)
self.single_page_strategy = SinglePageCrawlStrategy(crawler, self.markdown_generator)
self.sitemap_strategy = SitemapCrawlStrategy()
@@ -551,12 +552,24 @@ class CrawlingService:
)
try:
# Extract provider from request or use credential service default
provider = request.get("provider")
if not provider:
try:
from ..credential_service import credential_service
provider_config = await credential_service.get_active_provider("llm")
provider = provider_config.get("provider", "openai")
except Exception as e:
logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
provider = "openai"
code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples(
crawl_results,
storage_results["url_to_full_document"],
storage_results["source_id"],
code_progress_callback,
self._check_cancellation,
provider,
)
except RuntimeError as e:
# Code extraction failed, continue crawl with warning

View File

@@ -351,6 +351,7 @@ class DocumentStorageOperations:
source_id: str,
progress_callback: Callable | None = None,
cancellation_check: Callable[[], None] | None = None,
provider: str | None = None,
) -> int:
"""
Extract code examples from crawled documents and store them.
@@ -361,12 +362,13 @@ class DocumentStorageOperations:
source_id: The unique source_id for all documents
progress_callback: Optional callback for progress updates
cancellation_check: Optional function to check for cancellation
provider: Optional LLM provider to use for code summaries
Returns:
Number of code examples stored
"""
result = await self.code_extraction_service.extract_and_store_code_examples(
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider
)
return result

View File

@@ -4,6 +4,7 @@ Site Configuration Helper
Handles site-specific configurations and detection.
"""
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_filter_strategy import PruningContentFilter
from ....config.logfire_config import get_logger
@@ -96,3 +97,33 @@ class SiteConfig:
"code_language_callback": lambda el: el.get('class', '').replace('language-', '') if el else ''
}
)
@staticmethod
def get_link_pruning_markdown_generator():
"""
Get markdown generator for the recursive crawling strategy that cleans up pages crawled.
Returns:
Configured markdown generator
"""
prune_filter = PruningContentFilter(
threshold=0.2,
threshold_type="fixed"
)
return DefaultMarkdownGenerator(
content_source="html", # Use raw HTML to preserve code blocks
content_filter=prune_filter,
options={
"mark_code": True, # Mark code blocks properly
"handle_code_in_pre": True, # Handle <pre><code> tags
"body_width": 0, # No line wrapping
"skip_internal_links": True, # Add to reduce noise
"include_raw_html": False, # Prevent HTML in markdown
"escape": False, # Don't escape special chars in code
"decode_unicode": True, # Decode unicode characters
"strip_empty_lines": False, # Preserve empty lines in code
"preserve_code_formatting": True, # Custom option if supported
"code_language_callback": lambda el: el.get('class', '').replace('language-', '') if el else ''
}
)

View File

@@ -231,12 +231,12 @@ class BatchCrawlStrategy:
raise
processed += 1
if result.success and result.markdown:
if result.success and result.markdown and result.markdown.fit_markdown:
# Map back to original URL
original_url = url_mapping.get(result.url, result.url)
successful_results.append({
"url": original_url,
"markdown": result.markdown,
"markdown": result.markdown.fit_markdown,
"html": result.html, # Use raw HTML
})
else:

View File

@@ -276,10 +276,10 @@ class RecursiveCrawlStrategy:
visited.add(norm_url)
total_processed += 1
if result.success and result.markdown:
if result.success and result.markdown and result.markdown.fit_markdown:
results_all.append({
"url": original_url,
"markdown": result.markdown,
"markdown": result.markdown.fit_markdown,
"html": result.html, # Always use raw HTML for code extraction
})
depth_successful += 1

View File

@@ -36,6 +36,44 @@ class CredentialItem:
description: str | None = None
def _detect_embedding_provider_from_model(embedding_model: str) -> str:
"""
Detect the appropriate embedding provider based on model name.
Args:
embedding_model: The embedding model name
Returns:
Provider name: 'google', 'openai', or 'openai' (default)
"""
if not embedding_model:
return "openai" # Default
model_lower = embedding_model.lower()
# Google embedding models
google_patterns = [
"text-embedding-004",
"text-embedding-005",
"text-multilingual-embedding",
"gemini-embedding",
"multimodalembedding"
]
if any(pattern in model_lower for pattern in google_patterns):
return "google"
# OpenAI embedding models (and default for unknown)
openai_patterns = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large"
]
# Default to OpenAI for OpenAI models or unknown models
return "openai"
class CredentialService:
"""Service for managing application credentials and configuration."""
@@ -239,6 +277,14 @@ class CredentialService:
self._rag_cache_timestamp = None
logger.debug(f"Invalidated RAG settings cache due to update of {key}")
# Also invalidate provider service cache to ensure immediate effect
try:
from .llm_provider_service import clear_provider_cache
clear_provider_cache()
logger.debug("Also cleared LLM provider service cache")
except Exception as e:
logger.warning(f"Failed to clear provider service cache: {e}")
# Also invalidate LLM provider service cache for provider config
try:
from . import llm_provider_service
@@ -281,6 +327,14 @@ class CredentialService:
self._rag_cache_timestamp = None
logger.debug(f"Invalidated RAG settings cache due to deletion of {key}")
# Also invalidate provider service cache to ensure immediate effect
try:
from .llm_provider_service import clear_provider_cache
clear_provider_cache()
logger.debug("Also cleared LLM provider service cache")
except Exception as e:
logger.warning(f"Failed to clear provider service cache: {e}")
# Also invalidate LLM provider service cache for provider config
try:
from . import llm_provider_service
@@ -419,8 +473,33 @@ class CredentialService:
# Get RAG strategy settings (where UI saves provider selection)
rag_settings = await self.get_credentials_by_category("rag_strategy")
# Get the selected provider
provider = rag_settings.get("LLM_PROVIDER", "openai")
# Get the selected provider based on service type
if service_type == "embedding":
# Get the LLM provider setting to determine embedding provider
llm_provider = rag_settings.get("LLM_PROVIDER", "openai")
embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small")
# Determine embedding provider based on LLM provider
if llm_provider == "google":
provider = "google"
elif llm_provider == "ollama":
provider = "ollama"
elif llm_provider == "openrouter":
# OpenRouter supports both OpenAI and Google embedding models
provider = _detect_embedding_provider_from_model(embedding_model)
elif llm_provider in ["anthropic", "grok"]:
# Anthropic and Grok support both OpenAI and Google embedding models
provider = _detect_embedding_provider_from_model(embedding_model)
else:
# Default case (openai, or unknown providers)
provider = "openai"
logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'")
else:
provider = rag_settings.get("LLM_PROVIDER", "openai")
# Ensure provider is a valid string, not a boolean or other type
if not isinstance(provider, str) or provider.lower() in ("true", "false", "none", "null"):
provider = "openai"
# Get API key for this provider
api_key = await self._get_provider_api_key(provider)
@@ -464,6 +543,9 @@ class CredentialService:
key_mapping = {
"openai": "OPENAI_API_KEY",
"google": "GOOGLE_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"grok": "GROK_API_KEY",
"ollama": None, # No API key needed
}
@@ -475,9 +557,15 @@ class CredentialService:
def _get_provider_base_url(self, provider: str, rag_settings: dict) -> str | None:
"""Get base URL for provider."""
if provider == "ollama":
return rag_settings.get("LLM_BASE_URL", "http://localhost:11434/v1")
return rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434/v1")
elif provider == "google":
return "https://generativelanguage.googleapis.com/v1beta/openai/"
elif provider == "openrouter":
return "https://openrouter.ai/api/v1"
elif provider == "anthropic":
return "https://api.anthropic.com/v1"
elif provider == "grok":
return "https://api.x.ai/v1"
return None # Use default for OpenAI
async def set_active_provider(self, provider: str, service_type: str = "llm") -> bool:
@@ -485,7 +573,7 @@ class CredentialService:
try:
# For now, we'll update the RAG strategy settings
return await self.set_credential(
"llm_provider",
"LLM_PROVIDER",
provider,
category="rag_strategy",
description=f"Active {service_type} provider",

View File

@@ -10,7 +10,13 @@ import os
import openai
from ...config.logfire_config import search_logger
from ..llm_provider_service import get_llm_client
from ..credential_service import credential_service
from ..llm_provider_service import (
extract_message_text,
get_llm_client,
prepare_chat_completion_params,
requires_max_completion_tokens,
)
from ..threading_service import get_threading_service
@@ -32,8 +38,6 @@ async def generate_contextual_embedding(
"""
# Model choice is a RAG setting, get from credential service
try:
from ...services.credential_service import credential_service
model_choice = await credential_service.get_credential("MODEL_CHOICE", "gpt-4.1-nano")
except Exception as e:
# Fallback to environment variable or default
@@ -65,20 +69,25 @@ Please give a short succinct context to situate this chunk within the overall do
# Get model from provider configuration
model = await _get_model_choice(provider)
response = await client.chat.completions.create(
model=model,
messages=[
# Prepare parameters and convert max_tokens for GPT-5/reasoning models
params = {
"model": model,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that provides concise contextual information.",
},
{"role": "user", "content": prompt},
],
temperature=0.3,
max_tokens=200,
)
"temperature": 0.3,
"max_tokens": 1200 if requires_max_completion_tokens(model) else 200, # Much more tokens for reasoning models (GPT-5 needs extra for reasoning process)
}
final_params = prepare_chat_completion_params(model, params)
response = await client.chat.completions.create(**final_params)
context = response.choices[0].message.content.strip()
choice = response.choices[0] if response.choices else None
context, _, _ = extract_message_text(choice)
context = context.strip()
contextual_text = f"{context}\n---\n{chunk}"
return contextual_text, True
@@ -111,7 +120,7 @@ async def process_chunk_with_context(
async def _get_model_choice(provider: str | None = None) -> str:
"""Get model choice from credential service."""
"""Get model choice from credential service with centralized defaults."""
from ..credential_service import credential_service
# Get the active provider configuration
@@ -119,31 +128,36 @@ async def _get_model_choice(provider: str | None = None) -> str:
model = provider_config.get("chat_model", "").strip() # Strip whitespace
provider_name = provider_config.get("provider", "openai")
# Handle empty model case - fallback to provider-specific defaults or explicit config
# Handle empty model case - use centralized defaults
if not model:
search_logger.warning(f"chat_model is empty for provider {provider_name}, using fallback logic")
search_logger.warning(f"chat_model is empty for provider {provider_name}, using centralized defaults")
# Special handling for Ollama to check specific credential
if provider_name == "ollama":
# Try to get OLLAMA_CHAT_MODEL specifically
try:
ollama_model = await credential_service.get_credential("OLLAMA_CHAT_MODEL")
if ollama_model and ollama_model.strip():
model = ollama_model.strip()
search_logger.info(f"Using OLLAMA_CHAT_MODEL fallback: {model}")
else:
# Use a sensible Ollama default
# Use default for Ollama
model = "llama3.2:latest"
search_logger.info(f"Using Ollama default model: {model}")
search_logger.info(f"Using Ollama default: {model}")
except Exception as e:
search_logger.error(f"Error getting OLLAMA_CHAT_MODEL: {e}")
model = "llama3.2:latest"
search_logger.info(f"Using Ollama fallback model: {model}")
elif provider_name == "google":
model = "gemini-1.5-flash"
search_logger.info(f"Using Ollama fallback: {model}")
else:
# OpenAI or other providers
model = "gpt-4o-mini"
# Use provider-specific defaults
provider_defaults = {
"openai": "gpt-4o-mini",
"openrouter": "anthropic/claude-3.5-sonnet",
"google": "gemini-1.5-flash",
"anthropic": "claude-3-5-haiku-20241022",
"grok": "grok-3-mini"
}
model = provider_defaults.get(provider_name, "gpt-4o-mini")
search_logger.debug(f"Using default model for provider {provider_name}: {model}")
search_logger.debug(f"Using model from credential service: {model}")
return model
@@ -174,38 +188,48 @@ async def generate_contextual_embeddings_batch(
model_choice = await _get_model_choice(provider)
# Build batch prompt for ALL chunks at once
batch_prompt = (
"Process the following chunks and provide contextual information for each:\\n\\n"
)
batch_prompt = "Process the following chunks and provide contextual information for each:\n\n"
for i, (doc, chunk) in enumerate(zip(full_documents, chunks, strict=False)):
# Use only 2000 chars of document context to save tokens
doc_preview = doc[:2000] if len(doc) > 2000 else doc
batch_prompt += f"CHUNK {i + 1}:\\n"
batch_prompt += f"<document_preview>\\n{doc_preview}\\n</document_preview>\\n"
batch_prompt += f"<chunk>\\n{chunk[:500]}\\n</chunk>\\n\\n" # Limit chunk preview
batch_prompt += f"CHUNK {i + 1}:\n"
batch_prompt += f"<document_preview>\n{doc_preview}\n</document_preview>\n"
batch_prompt += f"<chunk>\n{chunk[:500]}\n</chunk>\n\n" # Limit chunk preview
batch_prompt += "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. Format your response as:\\nCHUNK 1: [context]\\nCHUNK 2: [context]\\netc."
batch_prompt += (
"For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. "
"Format your response as:\nCHUNK 1: [context]\nCHUNK 2: [context]\netc."
)
# Make single API call for ALL chunks
response = await client.chat.completions.create(
model=model_choice,
messages=[
# Prepare parameters and convert max_tokens for GPT-5/reasoning models
batch_params = {
"model": model_choice,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that generates contextual information for document chunks.",
},
{"role": "user", "content": batch_prompt},
],
temperature=0,
max_tokens=100 * len(chunks), # Limit response size
)
"temperature": 0,
"max_tokens": (600 if requires_max_completion_tokens(model_choice) else 100) * len(chunks), # Much more tokens for reasoning models (GPT-5 needs extra reasoning space)
}
final_batch_params = prepare_chat_completion_params(model_choice, batch_params)
response = await client.chat.completions.create(**final_batch_params)
# Parse response
response_text = response.choices[0].message.content
choice = response.choices[0] if response.choices else None
response_text, _, _ = extract_message_text(choice)
if not response_text:
search_logger.error(
"Empty response from LLM when generating contextual embeddings batch"
)
return [(chunk, False) for chunk in chunks]
# Extract contexts from response
lines = response_text.strip().split("\\n")
lines = response_text.strip().split("\n")
chunk_contexts = {}
for line in lines:
@@ -245,4 +269,4 @@ async def generate_contextual_embeddings_batch(
except Exception as e:
search_logger.error(f"Error in contextual embedding batch: {e}")
# Return non-contextual for all chunks
return [(chunk, False) for chunk in chunks]
return [(chunk, False) for chunk in chunks]

View File

@@ -13,7 +13,7 @@ import openai
from ...config.logfire_config import safe_span, search_logger
from ..credential_service import credential_service
from ..llm_provider_service import get_embedding_model, get_llm_client
from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model
from ..threading_service import get_threading_service
from .embedding_exceptions import (
EmbeddingAPIError,
@@ -152,34 +152,56 @@ async def create_embeddings_batch(
if not texts:
return EmbeddingBatchResult()
result = EmbeddingBatchResult()
# Validate that all items in texts are strings
validated_texts = []
for i, text in enumerate(texts):
if not isinstance(text, str):
search_logger.error(
f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
)
# Try to convert to string
try:
validated_texts.append(str(text))
except Exception as e:
search_logger.error(
f"Failed to convert text at index {i} to string: {e}", exc_info=True
)
validated_texts.append("") # Use empty string as fallback
else:
if isinstance(text, str):
validated_texts.append(text)
continue
search_logger.error(
f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
)
try:
converted = str(text)
validated_texts.append(converted)
except Exception as conversion_error:
search_logger.error(
f"Failed to convert text at index {i} to string: {conversion_error}",
exc_info=True,
)
result.add_failure(
repr(text),
EmbeddingAPIError("Invalid text type", original_error=conversion_error),
batch_index=None,
)
texts = validated_texts
result = EmbeddingBatchResult()
threading_service = get_threading_service()
with safe_span(
"create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts)
) as span:
try:
async with get_llm_client(provider=provider, use_embedding_provider=True) as client:
# Intelligent embedding provider routing based on model type
# Get the embedding model first to determine the correct provider
embedding_model = await get_embedding_model(provider=provider)
# Route to correct provider based on model type
if is_google_embedding_model(embedding_model):
embedding_provider = "google"
search_logger.info(f"Routing to Google for embedding model: {embedding_model}")
elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower():
embedding_provider = "openai"
search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}")
else:
# Keep original provider for ollama and other providers
embedding_provider = provider
search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}")
async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client:
# Load batch size and dimensions from settings
try:
rag_settings = await credential_service.get_credentials_by_category(
@@ -220,7 +242,8 @@ async def create_embeddings_batch(
while retry_count < max_retries:
try:
# Create embeddings for this batch
embedding_model = await get_embedding_model(provider=provider)
embedding_model = await get_embedding_model(provider=embedding_provider)
response = await client.embeddings.create(
model=embedding_model,
input=batch,

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,233 @@
"""
Database migration tracking and management service.
"""
import hashlib
from pathlib import Path
from typing import Any
import logfire
from supabase import Client
from .client_manager import get_supabase_client
from ..config.version import ARCHON_VERSION
class MigrationRecord:
"""Represents a migration record from the database."""
def __init__(self, data: dict[str, Any]):
self.id = data.get("id")
self.version = data.get("version")
self.migration_name = data.get("migration_name")
self.applied_at = data.get("applied_at")
self.checksum = data.get("checksum")
class PendingMigration:
"""Represents a pending migration from the filesystem."""
def __init__(self, version: str, name: str, sql_content: str, file_path: str):
self.version = version
self.name = name
self.sql_content = sql_content
self.file_path = file_path
self.checksum = self._calculate_checksum(sql_content)
def _calculate_checksum(self, content: str) -> str:
"""Calculate MD5 checksum of migration content."""
return hashlib.md5(content.encode()).hexdigest()
class MigrationService:
"""Service for managing database migrations."""
def __init__(self):
self._supabase: Client | None = None
# Handle both Docker (/app/migration) and local (./migration) environments
if Path("/app/migration").exists():
self._migrations_dir = Path("/app/migration")
else:
self._migrations_dir = Path("migration")
def _get_supabase_client(self) -> Client:
"""Get or create Supabase client."""
if not self._supabase:
self._supabase = get_supabase_client()
return self._supabase
async def check_migrations_table_exists(self) -> bool:
"""
Check if the archon_migrations table exists in the database.
Returns:
True if table exists, False otherwise
"""
try:
supabase = self._get_supabase_client()
# Query to check if table exists
result = supabase.rpc(
"sql",
{
"query": """
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'archon_migrations'
) as exists
"""
}
).execute()
# Check if result indicates table exists
if result.data and len(result.data) > 0:
return result.data[0].get("exists", False)
return False
except Exception:
# If the SQL function doesn't exist or query fails, try direct query
try:
supabase = self._get_supabase_client()
# Try to select from the table with limit 0
supabase.table("archon_migrations").select("id").limit(0).execute()
return True
except Exception as e:
logfire.info(f"Migrations table does not exist: {e}")
return False
async def get_applied_migrations(self) -> list[MigrationRecord]:
"""
Get list of applied migrations from the database.
Returns:
List of MigrationRecord objects
"""
try:
# Check if table exists first
if not await self.check_migrations_table_exists():
logfire.info("Migrations table does not exist, returning empty list")
return []
supabase = self._get_supabase_client()
result = supabase.table("archon_migrations").select("*").order("applied_at", desc=True).execute()
return [MigrationRecord(row) for row in result.data]
except Exception as e:
logfire.error(f"Error fetching applied migrations: {e}")
# Return empty list if we can't fetch migrations
return []
async def scan_migration_directory(self) -> list[PendingMigration]:
"""
Scan the migration directory for all SQL files.
Returns:
List of PendingMigration objects
"""
migrations = []
if not self._migrations_dir.exists():
logfire.warning(f"Migration directory does not exist: {self._migrations_dir}")
return migrations
# Scan all version directories
for version_dir in sorted(self._migrations_dir.iterdir()):
if not version_dir.is_dir():
continue
version = version_dir.name
# Scan all SQL files in version directory
for sql_file in sorted(version_dir.glob("*.sql")):
try:
# Read SQL content
with open(sql_file, encoding="utf-8") as f:
sql_content = f.read()
# Extract migration name (filename without extension)
migration_name = sql_file.stem
# Create pending migration object
migration = PendingMigration(
version=version,
name=migration_name,
sql_content=sql_content,
file_path=str(sql_file.relative_to(Path.cwd())),
)
migrations.append(migration)
except Exception as e:
logfire.error(f"Error reading migration file {sql_file}: {e}")
return migrations
async def get_pending_migrations(self) -> list[PendingMigration]:
"""
Get list of pending migrations by comparing filesystem with database.
Returns:
List of PendingMigration objects that haven't been applied
"""
# Get all migrations from filesystem
all_migrations = await self.scan_migration_directory()
# Check if migrations table exists
if not await self.check_migrations_table_exists():
# Bootstrap case - all migrations are pending
logfire.info("Migrations table doesn't exist, all migrations are pending")
return all_migrations
# Get applied migrations from database
applied_migrations = await self.get_applied_migrations()
# Create set of applied migration identifiers
applied_set = {(m.version, m.migration_name) for m in applied_migrations}
# Filter out applied migrations
pending = [m for m in all_migrations if (m.version, m.name) not in applied_set]
return pending
async def get_migration_status(self) -> dict[str, Any]:
"""
Get comprehensive migration status.
Returns:
Dictionary with pending and applied migrations info
"""
pending = await self.get_pending_migrations()
applied = await self.get_applied_migrations()
# Check if bootstrap is required
bootstrap_required = not await self.check_migrations_table_exists()
return {
"pending_migrations": [
{
"version": m.version,
"name": m.name,
"sql_content": m.sql_content,
"file_path": m.file_path,
"checksum": m.checksum,
}
for m in pending
],
"applied_migrations": [
{
"version": m.version,
"migration_name": m.migration_name,
"applied_at": m.applied_at,
"checksum": m.checksum,
}
for m in applied
],
"has_pending": len(pending) > 0,
"bootstrap_required": bootstrap_required,
"current_version": ARCHON_VERSION,
"pending_count": len(pending),
"applied_count": len(applied),
}
# Export singleton instance
migration_service = MigrationService()

View File

@@ -2,7 +2,7 @@
Provider Discovery Service
Discovers available models, checks provider health, and provides model specifications
for OpenAI, Google Gemini, Ollama, and Anthropic providers.
for OpenAI, Google Gemini, Ollama, Anthropic, and Grok providers.
"""
import time
@@ -23,7 +23,7 @@ _provider_cache: dict[str, tuple[Any, float]] = {}
_CACHE_TTL_SECONDS = 300 # 5 minutes
# Default Ollama instance URL (configurable via environment/settings)
DEFAULT_OLLAMA_URL = "http://localhost:11434"
DEFAULT_OLLAMA_URL = "http://host.docker.internal:11434"
# Model pattern detection for dynamic capabilities (no hardcoded model names)
CHAT_MODEL_PATTERNS = ["llama", "qwen", "mistral", "codellama", "phi", "gemma", "vicuna", "orca"]
@@ -359,6 +359,36 @@ class ProviderDiscoveryService:
return models
async def discover_grok_models(self, api_key: str) -> list[ModelSpec]:
"""Discover available Grok models."""
cache_key = f"grok_models_{hash(api_key)}"
cached = self._get_cached_result(cache_key)
if cached:
return cached
models = []
try:
# Grok model specifications
model_specs = [
ModelSpec("grok-3-mini", "grok", 32768, True, True, False, None, 0.15, 0.60, "Fast and efficient Grok model"),
ModelSpec("grok-3", "grok", 32768, True, True, False, None, 2.00, 10.00, "Standard Grok model"),
ModelSpec("grok-4", "grok", 32768, True, True, False, None, 5.00, 25.00, "Advanced Grok model"),
ModelSpec("grok-2-vision", "grok", 8192, True, True, True, None, 3.00, 15.00, "Grok model with vision capabilities"),
ModelSpec("grok-2-latest", "grok", 8192, True, True, False, None, 2.00, 10.00, "Latest Grok 2 model"),
]
# Test connectivity - Grok doesn't have a models list endpoint,
# so we'll just return the known models if API key is provided
if api_key:
models = model_specs
self._cache_result(cache_key, models)
logger.info(f"Discovered {len(models)} Grok models")
except Exception as e:
logger.error(f"Error discovering Grok models: {e}")
return models
async def check_provider_health(self, provider: str, config: dict[str, Any]) -> ProviderStatus:
"""Check health and connectivity status of a provider."""
start_time = time.time()
@@ -456,6 +486,23 @@ class ProviderDiscoveryService:
last_checked=time.time()
)
elif provider == "grok":
api_key = config.get("api_key")
if not api_key:
return ProviderStatus(provider, False, None, "API key not configured")
# Grok doesn't have a health check endpoint, so we'll assume it's available
# if API key is provided. In a real implementation, you might want to make a
# small test request to verify the key is valid.
response_time = (time.time() - start_time) * 1000
return ProviderStatus(
provider="grok",
is_available=True,
response_time_ms=response_time,
models_available=5, # Known model count
last_checked=time.time()
)
else:
return ProviderStatus(provider, False, None, f"Unknown provider: {provider}")
@@ -496,6 +543,11 @@ class ProviderDiscoveryService:
if anthropic_key:
providers["anthropic"] = await self.discover_anthropic_models(anthropic_key)
# Grok
grok_key = await credential_service.get_credential("GROK_API_KEY")
if grok_key:
providers["grok"] = await self.discover_grok_models(grok_key)
except Exception as e:
logger.error(f"Error getting all available models: {e}")

View File

@@ -14,7 +14,7 @@ from ...config.logfire_config import get_logger, safe_span
logger = get_logger(__name__)
# Fixed similarity threshold for vector results
SIMILARITY_THRESHOLD = 0.15
SIMILARITY_THRESHOLD = 0.05
class BaseSearchStrategy:

View File

@@ -11,7 +11,7 @@ from supabase import Client
from ..config.logfire_config import get_logger, search_logger
from .client_manager import get_supabase_client
from .llm_provider_service import get_llm_client
from .llm_provider_service import extract_message_text, get_llm_client
logger = get_logger(__name__)
@@ -72,20 +72,21 @@ The above content is from the documentation for '{source_id}'. Please provide a
)
# Extract the generated summary with proper error handling
if not response or not response.choices or len(response.choices) == 0:
search_logger.error(f"Empty or invalid response from LLM for {source_id}")
return default_summary
message_content = response.choices[0].message.content
if message_content is None:
search_logger.error(f"LLM returned None content for {source_id}")
return default_summary
summary = message_content.strip()
# Ensure the summary is not too long
if len(summary) > max_length:
summary = summary[:max_length] + "..."
if not response or not response.choices or len(response.choices) == 0:
search_logger.error(f"Empty or invalid response from LLM for {source_id}")
return default_summary
choice = response.choices[0]
summary_text, _, _ = extract_message_text(choice)
if not summary_text:
search_logger.error(f"LLM returned None content for {source_id}")
return default_summary
summary = summary_text.strip()
# Ensure the summary is not too long
if len(summary) > max_length:
summary = summary[:max_length] + "..."
return summary
@@ -187,7 +188,9 @@ Generate only the title, nothing else."""
],
)
generated_title = response.choices[0].message.content.strip()
choice = response.choices[0]
generated_title, _, _ = extract_message_text(choice)
generated_title = generated_title.strip()
# Clean up the title
generated_title = generated_title.strip("\"'")
if len(generated_title) < 50: # Sanity check

View File

@@ -8,6 +8,7 @@ import asyncio
import json
import os
import re
import time
from collections import defaultdict, deque
from collections.abc import Callable
from difflib import SequenceMatcher
@@ -17,25 +18,104 @@ from urllib.parse import urlparse
from supabase import Client
from ...config.logfire_config import search_logger
from ..credential_service import credential_service
from ..embeddings.contextual_embedding_service import generate_contextual_embeddings_batch
from ..embeddings.embedding_service import create_embeddings_batch
from ..llm_provider_service import (
extract_json_from_reasoning,
extract_message_text,
get_llm_client,
prepare_chat_completion_params,
synthesize_json_from_reasoning,
)
def _get_model_choice() -> str:
"""Get MODEL_CHOICE with direct fallback."""
def _extract_json_payload(raw_response: str, context_code: str = "", language: str = "") -> str:
"""Return the best-effort JSON object from an LLM response."""
if not raw_response:
return raw_response
cleaned = raw_response.strip()
# Check if this looks like reasoning text first
if _is_reasoning_text_response(cleaned):
# Try intelligent extraction from reasoning text with context
extracted = extract_json_from_reasoning(cleaned, context_code, language)
if extracted:
return extracted
# extract_json_from_reasoning may return nothing; synthesize a fallback JSON if so\
fallback_json = synthesize_json_from_reasoning("", context_code, language)
if fallback_json:
return fallback_json
# If all else fails, return a minimal valid JSON object to avoid downstream errors
return '{"example_name": "Code Example", "summary": "Code example extracted from context."}'
if cleaned.startswith("```"):
lines = cleaned.splitlines()
# Drop opening fence
lines = lines[1:]
# Drop closing fence if present
if lines and lines[-1].strip().startswith("```"):
lines = lines[:-1]
cleaned = "\n".join(lines).strip()
# Trim any leading/trailing text outside the outermost JSON braces
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end >= start:
cleaned = cleaned[start : end + 1]
return cleaned.strip()
REASONING_STARTERS = [
"okay, let's see", "okay, let me", "let me think", "first, i need to", "looking at this",
"i need to", "analyzing", "let me work through", "thinking about", "let me see"
]
def _is_reasoning_text_response(text: str) -> bool:
"""Detect if response is reasoning text rather than direct JSON."""
if not text or len(text) < 20:
return False
text_lower = text.lower().strip()
# Check if it's clearly not JSON (starts with reasoning text)
starts_with_reasoning = any(text_lower.startswith(starter) for starter in REASONING_STARTERS)
# Check if it lacks immediate JSON structure
lacks_immediate_json = not text_lower.lstrip().startswith('{')
return starts_with_reasoning or (lacks_immediate_json and any(pattern in text_lower for pattern in REASONING_STARTERS))
async def _get_model_choice() -> str:
"""Get MODEL_CHOICE with provider-aware defaults from centralized service."""
try:
# Direct cache/env fallback
from ..credential_service import credential_service
# Get the active provider configuration
provider_config = await credential_service.get_active_provider("llm")
active_provider = provider_config.get("provider", "openai")
model = provider_config.get("chat_model")
if credential_service._cache_initialized and "MODEL_CHOICE" in credential_service._cache:
model = credential_service._cache["MODEL_CHOICE"]
else:
model = os.getenv("MODEL_CHOICE", "gpt-4.1-nano")
search_logger.debug(f"Using model choice: {model}")
# If no custom model is set, use provider-specific defaults
if not model or model.strip() == "":
# Provider-specific defaults
provider_defaults = {
"openai": "gpt-4o-mini",
"openrouter": "anthropic/claude-3.5-sonnet",
"google": "gemini-1.5-flash",
"ollama": "llama3.2:latest",
"anthropic": "claude-3-5-haiku-20241022",
"grok": "grok-3-mini"
}
model = provider_defaults.get(active_provider, "gpt-4o-mini")
search_logger.debug(f"Using default model for provider {active_provider}: {model}")
search_logger.debug(f"Using model for provider {active_provider}: {model}")
return model
except Exception as e:
search_logger.warning(f"Error getting model choice: {e}, using default")
return "gpt-4.1-nano"
return "gpt-4o-mini"
def _get_max_workers() -> int:
@@ -155,6 +235,7 @@ def _select_best_code_variant(similar_blocks: list[dict[str, Any]]) -> dict[str,
return best_block
def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[dict[str, Any]]:
"""
Extract code blocks from markdown content along with context.
@@ -168,8 +249,6 @@ def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[d
"""
# Load all code extraction settings with direct fallback
try:
from ...services.credential_service import credential_service
def _get_setting_fallback(key: str, default: str) -> str:
if credential_service._cache_initialized and key in credential_service._cache:
return credential_service._cache[key]
@@ -507,7 +586,7 @@ def generate_code_example_summary(
A dictionary with 'summary' and 'example_name'
"""
import asyncio
# Run the async version in the current thread
return asyncio.run(_generate_code_example_summary_async(code, context_before, context_after, language, provider))
@@ -518,13 +597,22 @@ async def _generate_code_example_summary_async(
"""
Async version of generate_code_example_summary using unified LLM provider service.
"""
from ..llm_provider_service import get_llm_client
# Get model choice from credential service (RAG setting)
model_choice = _get_model_choice()
# Create the prompt
prompt = f"""<context_before>
# Get model choice from credential service (RAG setting)
model_choice = await _get_model_choice()
# If provider is not specified, get it from credential service
if provider is None:
try:
provider_config = await credential_service.get_active_provider("llm")
provider = provider_config.get("provider", "openai")
search_logger.debug(f"Auto-detected provider from credential service: {provider}")
except Exception as e:
search_logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
provider = "openai"
# Create the prompt variants: base prompt, guarded prompt (JSON reminder), and strict prompt for retries
base_prompt = f"""<context_before>
{context_before[-500:] if len(context_before) > 500 else context_before}
</context_before>
@@ -548,6 +636,16 @@ Format your response as JSON:
"summary": "2-3 sentence description of what the code demonstrates"
}}
"""
guard_prompt = (
base_prompt
+ "\n\nImportant: Respond with a valid JSON object that exactly matches the keys "
'{"example_name": string, "summary": string}. Do not include commentary, '
"markdown fences, or reasoning notes."
)
strict_prompt = (
guard_prompt
+ "\n\nSecond attempt enforcement: Return JSON only with the exact schema. No additional text or reasoning content."
)
try:
# Use unified LLM provider service
@@ -555,25 +653,261 @@ Format your response as JSON:
search_logger.info(
f"Generating summary for {hash(code) & 0xffffff:06x} using model: {model_choice}"
)
response = await client.chat.completions.create(
model=model_choice,
messages=[
{
"role": "system",
"content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.",
},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
max_tokens=500,
temperature=0.3,
provider_lower = provider.lower()
is_grok_model = (provider_lower == "grok") or ("grok" in model_choice.lower())
supports_response_format_base = (
provider_lower in {"openai", "google", "anthropic"}
or (provider_lower == "openrouter" and model_choice.startswith("openai/"))
)
response_content = response.choices[0].message.content.strip()
last_response_obj = None
last_elapsed_time = None
last_response_content = ""
last_json_error: json.JSONDecodeError | None = None
for enforce_json, current_prompt in ((False, guard_prompt), (True, strict_prompt)):
request_params = {
"model": model_choice,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.",
},
{"role": "user", "content": current_prompt},
],
"max_tokens": 2000,
"temperature": 0.3,
}
should_use_response_format = False
if enforce_json:
if not is_grok_model and (supports_response_format_base or provider_lower == "openrouter"):
should_use_response_format = True
else:
if supports_response_format_base:
should_use_response_format = True
if should_use_response_format:
request_params["response_format"] = {"type": "json_object"}
if is_grok_model:
unsupported_params = ["presence_penalty", "frequency_penalty", "stop", "reasoning_effort"]
for param in unsupported_params:
if param in request_params:
removed_value = request_params.pop(param)
search_logger.warning(f"Removed unsupported Grok parameter '{param}': {removed_value}")
supported_params = ["model", "messages", "max_tokens", "temperature", "response_format", "stream", "tools", "tool_choice"]
for param in list(request_params.keys()):
if param not in supported_params:
search_logger.warning(f"Parameter '{param}' may not be supported by Grok reasoning models")
start_time = time.time()
max_retries = 3 if is_grok_model else 1
retry_delay = 1.0
response_content_local = ""
reasoning_text_local = ""
json_error_occurred = False
for attempt in range(max_retries):
try:
if is_grok_model and attempt > 0:
search_logger.info(f"Grok retry attempt {attempt + 1}/{max_retries} after {retry_delay:.1f}s delay")
await asyncio.sleep(retry_delay)
final_params = prepare_chat_completion_params(model_choice, request_params)
response = await client.chat.completions.create(**final_params)
last_response_obj = response
choice = response.choices[0] if response.choices else None
message = choice.message if choice and hasattr(choice, "message") else None
response_content_local = ""
reasoning_text_local = ""
if choice:
response_content_local, reasoning_text_local, _ = extract_message_text(choice)
# Enhanced logging for response analysis
if message and reasoning_text_local:
content_preview = response_content_local[:100] if response_content_local else "None"
reasoning_preview = reasoning_text_local[:100] if reasoning_text_local else "None"
search_logger.debug(
f"Response has reasoning content - content: '{content_preview}', reasoning: '{reasoning_preview}'"
)
if response_content_local:
last_response_content = response_content_local.strip()
# Pre-validate response before processing
if len(last_response_content) < 20 or (len(last_response_content) < 50 and not last_response_content.strip().startswith('{')):
# Very minimal response - likely "Okay\nOkay" type
search_logger.debug(f"Minimal response detected: {repr(last_response_content)}")
# Generate fallback directly from context
fallback_json = synthesize_json_from_reasoning("", code, language)
if fallback_json:
try:
result = json.loads(fallback_json)
final_result = {
"example_name": result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
"summary": result.get("summary", "Code example for demonstration purposes."),
}
search_logger.info(f"Generated fallback summary from context - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}")
return final_result
except json.JSONDecodeError:
pass # Continue to normal error handling
else:
# Even synthesis failed - provide hardcoded fallback for minimal responses
final_result = {
"example_name": f"Code Example{f' ({language})' if language else ''}",
"summary": "Code example extracted from development context.",
}
search_logger.info(f"Used hardcoded fallback for minimal response - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}")
return final_result
payload = _extract_json_payload(last_response_content, code, language)
if payload != last_response_content:
search_logger.debug(
f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..."
)
try:
result = json.loads(payload)
if not result.get("example_name") or not result.get("summary"):
search_logger.warning(f"Incomplete response from LLM: {result}")
final_result = {
"example_name": result.get(
"example_name", f"Code Example{f' ({language})' if language else ''}"
),
"summary": result.get("summary", "Code example for demonstration purposes."),
}
search_logger.info(
f"Generated code example summary - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}"
)
return final_result
except json.JSONDecodeError as json_error:
last_json_error = json_error
json_error_occurred = True
snippet = last_response_content[:200]
if not enforce_json:
# Check if this was reasoning text that couldn't be parsed
if _is_reasoning_text_response(last_response_content):
search_logger.debug(
f"Reasoning text detected but no JSON extracted. Response snippet: {repr(snippet)}"
)
else:
search_logger.warning(
f"Failed to parse JSON response from LLM (non-strict attempt). Error: {json_error}. Response snippet: {repr(snippet)}"
)
break
else:
search_logger.error(
f"Strict JSON enforcement still failed to produce valid JSON: {json_error}. Response snippet: {repr(snippet)}"
)
break
elif is_grok_model and attempt < max_retries - 1:
search_logger.warning(f"Grok empty response on attempt {attempt + 1}, retrying...")
retry_delay *= 2
continue
else:
break
except Exception as e:
if is_grok_model and attempt < max_retries - 1:
search_logger.error(f"Grok request failed on attempt {attempt + 1}: {e}, retrying...")
retry_delay *= 2
continue
else:
raise
if is_grok_model:
elapsed_time = time.time() - start_time
last_elapsed_time = elapsed_time
search_logger.debug(f"Grok total response time: {elapsed_time:.2f}s")
if json_error_occurred:
if not enforce_json:
continue
else:
break
if response_content_local:
# We would have returned already on success; if we reach here, parsing failed but we are not retrying
continue
response_content = last_response_content
response = last_response_obj
elapsed_time = last_elapsed_time if last_elapsed_time is not None else 0.0
if last_json_error is not None and response_content:
search_logger.error(
f"LLM response after strict enforcement was still not valid JSON: {last_json_error}. Clearing response to trigger error handling."
)
response_content = ""
if not response_content:
search_logger.error(f"Empty response from LLM for model: {model_choice} (provider: {provider})")
if is_grok_model:
search_logger.error("Grok empty response debugging:")
search_logger.error(f" - Request took: {elapsed_time:.2f}s")
search_logger.error(f" - Response status: {getattr(response, 'status_code', 'N/A')}")
search_logger.error(f" - Response headers: {getattr(response, 'headers', 'N/A')}")
search_logger.error(f" - Full response: {response}")
search_logger.error(f" - Response choices length: {len(response.choices) if response.choices else 0}")
if response.choices:
search_logger.error(f" - First choice: {response.choices[0]}")
search_logger.error(f" - Message content: '{response.choices[0].message.content}'")
search_logger.error(f" - Message role: {response.choices[0].message.role}")
search_logger.error("Check: 1) API key validity, 2) rate limits, 3) model availability")
# Implement fallback for Grok failures
search_logger.warning("Attempting fallback to OpenAI due to Grok failure...")
try:
# Use OpenAI as fallback with similar parameters
fallback_params = {
"model": "gpt-4o-mini",
"messages": request_params["messages"],
"temperature": request_params.get("temperature", 0.1),
"max_tokens": request_params.get("max_tokens", 500),
}
async with get_llm_client(provider="openai") as fallback_client:
fallback_response = await fallback_client.chat.completions.create(**fallback_params)
fallback_content = fallback_response.choices[0].message.content
if fallback_content and fallback_content.strip():
search_logger.info("gpt-4o-mini fallback succeeded")
response_content = fallback_content.strip()
else:
search_logger.error("gpt-4o-mini fallback also returned empty response")
raise ValueError(f"Both {model_choice} and gpt-4o-mini fallback failed")
except Exception as fallback_error:
search_logger.error(f"gpt-4o-mini fallback failed: {fallback_error}")
raise ValueError(f"{model_choice} failed and fallback to gpt-4o-mini also failed: {fallback_error}") from fallback_error
else:
search_logger.debug(f"Full response object: {response}")
raise ValueError("Empty response from LLM")
if not response_content:
# This should not happen after fallback logic, but safety check
raise ValueError("No valid response content after all attempts")
response_content = response_content.strip()
search_logger.debug(f"LLM API response: {repr(response_content[:200])}...")
result = json.loads(response_content)
payload = _extract_json_payload(response_content, code, language)
if payload != response_content:
search_logger.debug(
f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..."
)
result = json.loads(payload)
# Validate the response has the required fields
if not result.get("example_name") or not result.get("summary"):
@@ -595,12 +929,38 @@ Format your response as JSON:
search_logger.error(
f"Failed to parse JSON response from LLM: {e}, Response: {repr(response_content) if 'response_content' in locals() else 'No response'}"
)
# Try to generate context-aware fallback
try:
fallback_json = synthesize_json_from_reasoning("", code, language)
if fallback_json:
fallback_result = json.loads(fallback_json)
search_logger.info(f"Generated context-aware fallback summary")
return {
"example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
"summary": fallback_result.get("summary", "Code example for demonstration purposes."),
}
except Exception:
pass # Fall through to generic fallback
return {
"example_name": f"Code Example{f' ({language})' if language else ''}",
"summary": "Code example for demonstration purposes.",
}
except Exception as e:
search_logger.error(f"Error generating code summary using unified LLM provider: {e}")
# Try to generate context-aware fallback
try:
fallback_json = synthesize_json_from_reasoning("", code, language)
if fallback_json:
fallback_result = json.loads(fallback_json)
search_logger.info(f"Generated context-aware fallback summary after error")
return {
"example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
"summary": fallback_result.get("summary", "Code example for demonstration purposes."),
}
except Exception:
pass # Fall through to generic fallback
return {
"example_name": f"Code Example{f' ({language})' if language else ''}",
"summary": "Code example for demonstration purposes.",
@@ -608,7 +968,7 @@ Format your response as JSON:
async def generate_code_summaries_batch(
code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None
code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None, provider: str = None
) -> list[dict[str, str]]:
"""
Generate summaries for multiple code blocks with rate limiting and proper worker management.
@@ -617,6 +977,7 @@ async def generate_code_summaries_batch(
code_blocks: List of code block dictionaries
max_workers: Maximum number of concurrent API requests
progress_callback: Optional callback for progress updates (async function)
provider: LLM provider to use for generation (e.g., 'grok', 'openai', 'anthropic')
Returns:
List of summary dictionaries
@@ -627,8 +988,6 @@ async def generate_code_summaries_batch(
# Get max_workers from settings if not provided
if max_workers is None:
try:
from ...services.credential_service import credential_service
if (
credential_service._cache_initialized
and "CODE_SUMMARY_MAX_WORKERS" in credential_service._cache
@@ -663,6 +1022,7 @@ async def generate_code_summaries_batch(
block["context_before"],
block["context_after"],
block.get("language", ""),
provider,
)
# Update progress
@@ -757,29 +1117,17 @@ async def add_code_examples_to_supabase(
except Exception as e:
search_logger.error(f"Error deleting existing code examples for {url}: {e}")
# Check if contextual embeddings are enabled
# Check if contextual embeddings are enabled (use proper async method like document storage)
try:
from ..credential_service import credential_service
use_contextual_embeddings = credential_service._cache.get("USE_CONTEXTUAL_EMBEDDINGS")
if isinstance(use_contextual_embeddings, str):
use_contextual_embeddings = use_contextual_embeddings.lower() == "true"
elif isinstance(use_contextual_embeddings, dict) and use_contextual_embeddings.get(
"is_encrypted"
):
# Handle encrypted value
encrypted_value = use_contextual_embeddings.get("encrypted_value")
if encrypted_value:
try:
decrypted = credential_service._decrypt_value(encrypted_value)
use_contextual_embeddings = decrypted.lower() == "true"
except:
use_contextual_embeddings = False
else:
use_contextual_embeddings = False
raw_value = await credential_service.get_credential(
"USE_CONTEXTUAL_EMBEDDINGS", "false", decrypt=True
)
if isinstance(raw_value, str):
use_contextual_embeddings = raw_value.lower() == "true"
else:
use_contextual_embeddings = bool(use_contextual_embeddings)
except:
use_contextual_embeddings = bool(raw_value)
except Exception as e:
search_logger.error(f"DEBUG: Error reading contextual embeddings: {e}")
# Fallback to environment variable
use_contextual_embeddings = (
os.getenv("USE_CONTEXTUAL_EMBEDDINGS", "false").lower() == "true"
@@ -848,14 +1196,13 @@ async def add_code_examples_to_supabase(
# Use only successful embeddings
valid_embeddings = result.embeddings
successful_texts = result.texts_processed
# Get model information for tracking
from ..llm_provider_service import get_embedding_model
from ..credential_service import credential_service
# Get embedding model name
embedding_model_name = await get_embedding_model(provider=provider)
# Get LLM chat model (used for code summaries and contextual embeddings if enabled)
llm_chat_model = None
try:
@@ -868,7 +1215,7 @@ async def add_code_examples_to_supabase(
llm_chat_model = await credential_service.get_credential("MODEL_CHOICE", "gpt-4o-mini")
else:
# For code summaries, we use MODEL_CHOICE
llm_chat_model = _get_model_choice()
llm_chat_model = await _get_model_choice()
except Exception as e:
search_logger.warning(f"Failed to get LLM chat model: {e}")
llm_chat_model = "gpt-4o-mini" # Default fallback
@@ -888,7 +1235,7 @@ async def add_code_examples_to_supabase(
positions_by_text[text].append(original_indices[k])
# Map successful texts back to their original indices
for embedding, text in zip(valid_embeddings, successful_texts, strict=False):
for embedding, text in zip(valid_embeddings, successful_texts, strict=True):
# Get the next available index for this text (handles duplicates)
if positions_by_text[text]:
orig_idx = positions_by_text[text].popleft() # Original j index in [i, batch_end)
@@ -908,7 +1255,7 @@ async def add_code_examples_to_supabase(
# Determine the correct embedding column based on dimension
embedding_dim = len(embedding) if isinstance(embedding, list) else len(embedding.tolist())
embedding_column = None
if embedding_dim == 768:
embedding_column = "embedding_768"
elif embedding_dim == 1024:
@@ -918,10 +1265,12 @@ async def add_code_examples_to_supabase(
elif embedding_dim == 3072:
embedding_column = "embedding_3072"
else:
# Default to closest supported dimension
search_logger.warning(f"Unsupported embedding dimension {embedding_dim}, using embedding_1536")
embedding_column = "embedding_1536"
# Skip unsupported dimensions to avoid corrupting the schema
search_logger.error(
f"Unsupported embedding dimension {embedding_dim}; skipping record to prevent column mismatch"
)
continue
batch_data.append({
"url": urls[idx],
"chunk_number": chunk_numbers[idx],
@@ -954,9 +1303,7 @@ async def add_code_examples_to_supabase(
f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}"
)
search_logger.info(f"Retrying in {retry_delay} seconds...")
import time
time.sleep(retry_delay)
await asyncio.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Final attempt failed

View File

@@ -0,0 +1,162 @@
"""
Version checking service with GitHub API integration.
"""
from datetime import datetime, timedelta
from typing import Any
import httpx
import logfire
from ..config.version import ARCHON_VERSION, GITHUB_REPO_NAME, GITHUB_REPO_OWNER
from ..utils.semantic_version import is_newer_version
class VersionService:
"""Service for checking Archon version against GitHub releases."""
def __init__(self):
self._cache: dict[str, Any] | None = None
self._cache_time: datetime | None = None
self._cache_ttl = 3600 # 1 hour cache TTL
def _is_cache_valid(self) -> bool:
"""Check if cached data is still valid."""
if not self._cache or not self._cache_time:
return False
age = datetime.now() - self._cache_time
return age < timedelta(seconds=self._cache_ttl)
async def get_latest_release(self) -> dict[str, Any] | None:
"""
Fetch latest release information from GitHub API.
Returns:
Release data dictionary or None if no releases
"""
# Check cache first
if self._is_cache_valid():
logfire.debug("Using cached version data")
return self._cache
# GitHub API endpoint
url = f"https://api.github.com/repos/{GITHUB_REPO_OWNER}/{GITHUB_REPO_NAME}/releases/latest"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
url,
headers={
"Accept": "application/vnd.github.v3+json",
"User-Agent": f"Archon/{ARCHON_VERSION}",
},
)
# Handle 404 - no releases yet
if response.status_code == 404:
logfire.info("No releases found on GitHub")
return None
response.raise_for_status()
data = response.json()
# Cache the successful response
self._cache = data
self._cache_time = datetime.now()
return data
except httpx.TimeoutException:
logfire.warning("GitHub API request timed out")
# Return cached data if available
if self._cache:
return self._cache
return None
except httpx.HTTPError as e:
logfire.error(f"HTTP error fetching latest release: {e}")
# Return cached data if available
if self._cache:
return self._cache
return None
except Exception as e:
logfire.error(f"Unexpected error fetching latest release: {e}")
# Return cached data if available
if self._cache:
return self._cache
return None
async def check_for_updates(self) -> dict[str, Any]:
"""
Check if a newer version of Archon is available.
Returns:
Dictionary with version check results
"""
try:
# Get latest release from GitHub
release = await self.get_latest_release()
if not release:
# No releases found or error occurred
return {
"current": ARCHON_VERSION,
"latest": None,
"update_available": False,
"release_url": None,
"release_notes": None,
"published_at": None,
"check_error": None,
}
# Extract version from tag_name (e.g., "v1.0.0" -> "1.0.0")
latest_version = release.get("tag_name", "")
if latest_version.startswith("v"):
latest_version = latest_version[1:]
# Check if update is available
update_available = is_newer_version(ARCHON_VERSION, latest_version)
# Parse published date
published_at = None
if release.get("published_at"):
try:
published_at = datetime.fromisoformat(
release["published_at"].replace("Z", "+00:00")
)
except Exception:
pass
return {
"current": ARCHON_VERSION,
"latest": latest_version,
"update_available": update_available,
"release_url": release.get("html_url"),
"release_notes": release.get("body"),
"published_at": published_at,
"check_error": None,
"assets": release.get("assets", []),
"author": release.get("author", {}).get("login"),
}
except Exception as e:
logfire.error(f"Error checking for updates: {e}")
# Return safe default with error
return {
"current": ARCHON_VERSION,
"latest": None,
"update_available": False,
"release_url": None,
"release_notes": None,
"published_at": None,
"check_error": str(e),
}
def clear_cache(self):
"""Clear the cached version data."""
self._cache = None
self._cache_time = None
# Export singleton instance
version_service = VersionService()

View File

@@ -0,0 +1,107 @@
"""
Semantic version parsing and comparison utilities.
"""
import re
def parse_version(version_string: str) -> tuple[int, int, int, str | None]:
"""
Parse a semantic version string into major, minor, patch, and optional prerelease.
Supports formats like:
- "1.0.0"
- "v1.0.0"
- "1.0.0-beta"
- "v1.0.0-rc.1"
Args:
version_string: Version string to parse
Returns:
Tuple of (major, minor, patch, prerelease)
"""
# Remove 'v' prefix if present
version = version_string.strip()
if version.lower().startswith('v'):
version = version[1:]
# Parse version with optional prerelease
pattern = r'^(\d+)\.(\d+)\.(\d+)(?:-(.+))?$'
match = re.match(pattern, version)
if not match:
# Try to handle incomplete versions like "1.0"
simple_pattern = r'^(\d+)(?:\.(\d+))?(?:\.(\d+))?$'
simple_match = re.match(simple_pattern, version)
if simple_match:
major = int(simple_match.group(1))
minor = int(simple_match.group(2) or 0)
patch = int(simple_match.group(3) or 0)
return (major, minor, patch, None)
raise ValueError(f"Invalid version string: {version_string}")
major = int(match.group(1))
minor = int(match.group(2))
patch = int(match.group(3))
prerelease = match.group(4)
return (major, minor, patch, prerelease)
def compare_versions(version1: str, version2: str) -> int:
"""
Compare two semantic version strings.
Args:
version1: First version string
version2: Second version string
Returns:
-1 if version1 < version2
0 if version1 == version2
1 if version1 > version2
"""
v1 = parse_version(version1)
v2 = parse_version(version2)
# Compare major, minor, patch
for i in range(3):
if v1[i] < v2[i]:
return -1
elif v1[i] > v2[i]:
return 1
# If main versions are equal, check prerelease
# No prerelease is considered newer than any prerelease
if v1[3] is None and v2[3] is None:
return 0
elif v1[3] is None:
return 1 # v1 is release, v2 is prerelease
elif v2[3] is None:
return -1 # v1 is prerelease, v2 is release
else:
# Both have prereleases, compare lexicographically
if v1[3] < v2[3]:
return -1
elif v1[3] > v2[3]:
return 1
return 0
def is_newer_version(current: str, latest: str) -> bool:
"""
Check if latest version is newer than current version.
Args:
current: Current version string
latest: Latest version string to compare
Returns:
True if latest > current, False otherwise
"""
try:
return compare_versions(latest, current) > 0
except ValueError:
# If we can't parse versions, assume no update
return False

View File

@@ -0,0 +1,206 @@
"""
Unit tests for migration_api.py
"""
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from src.server.config.version import ARCHON_VERSION
from src.server.main import app
from src.server.services.migration_service import MigrationRecord, PendingMigration
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_applied_migrations():
"""Mock applied migration data."""
return [
MigrationRecord({
"version": "0.1.0",
"migration_name": "001_initial",
"applied_at": datetime(2025, 1, 1, 0, 0, 0),
"checksum": "abc123",
}),
MigrationRecord({
"version": "0.1.0",
"migration_name": "002_add_column",
"applied_at": datetime(2025, 1, 2, 0, 0, 0),
"checksum": "def456",
}),
]
@pytest.fixture
def mock_pending_migrations():
"""Mock pending migration data."""
return [
PendingMigration(
version="0.1.0",
name="003_add_index",
sql_content="CREATE INDEX idx_test ON test_table(name);",
file_path="migration/0.1.0/003_add_index.sql"
),
PendingMigration(
version="0.1.0",
name="004_add_table",
sql_content="CREATE TABLE new_table (id INT);",
file_path="migration/0.1.0/004_add_table.sql"
),
]
@pytest.fixture
def mock_migration_status(mock_applied_migrations, mock_pending_migrations):
"""Mock complete migration status."""
return {
"pending_migrations": [
{"version": m.version, "name": m.name, "sql_content": m.sql_content, "file_path": m.file_path, "checksum": m.checksum}
for m in mock_pending_migrations
],
"applied_migrations": [
{"version": m.version, "migration_name": m.migration_name, "applied_at": m.applied_at, "checksum": m.checksum}
for m in mock_applied_migrations
],
"has_pending": True,
"bootstrap_required": False,
"current_version": ARCHON_VERSION,
"pending_count": 2,
"applied_count": 2,
}
def test_get_migration_status_success(client, mock_migration_status):
"""Test successful migration status retrieval."""
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_migration_status = AsyncMock(return_value=mock_migration_status)
response = client.get("/api/migrations/status")
assert response.status_code == 200
data = response.json()
assert data["current_version"] == ARCHON_VERSION
assert data["has_pending"] is True
assert data["bootstrap_required"] is False
assert data["pending_count"] == 2
assert data["applied_count"] == 2
assert len(data["pending_migrations"]) == 2
assert len(data["applied_migrations"]) == 2
def test_get_migration_status_bootstrap_required(client):
"""Test migration status when bootstrap is required."""
mock_status = {
"pending_migrations": [],
"applied_migrations": [],
"has_pending": True,
"bootstrap_required": True,
"current_version": ARCHON_VERSION,
"pending_count": 5,
"applied_count": 0,
}
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_migration_status = AsyncMock(return_value=mock_status)
response = client.get("/api/migrations/status")
assert response.status_code == 200
data = response.json()
assert data["bootstrap_required"] is True
assert data["applied_count"] == 0
def test_get_migration_status_error(client):
"""Test error handling in migration status."""
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_migration_status = AsyncMock(side_effect=Exception("Database error"))
response = client.get("/api/migrations/status")
assert response.status_code == 500
assert "Failed to get migration status" in response.json()["detail"]
def test_get_migration_history_success(client, mock_applied_migrations):
"""Test successful migration history retrieval."""
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_applied_migrations = AsyncMock(return_value=mock_applied_migrations)
response = client.get("/api/migrations/history")
assert response.status_code == 200
data = response.json()
assert data["total_count"] == 2
assert data["current_version"] == ARCHON_VERSION
assert len(data["migrations"]) == 2
assert data["migrations"][0]["migration_name"] == "001_initial"
def test_get_migration_history_empty(client):
"""Test migration history when no migrations applied."""
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_applied_migrations = AsyncMock(return_value=[])
response = client.get("/api/migrations/history")
assert response.status_code == 200
data = response.json()
assert data["total_count"] == 0
assert len(data["migrations"]) == 0
def test_get_migration_history_error(client):
"""Test error handling in migration history."""
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_applied_migrations = AsyncMock(side_effect=Exception("Database error"))
response = client.get("/api/migrations/history")
assert response.status_code == 500
assert "Failed to get migration history" in response.json()["detail"]
def test_get_pending_migrations_success(client, mock_pending_migrations):
"""Test successful pending migrations retrieval."""
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_pending_migrations = AsyncMock(return_value=mock_pending_migrations)
response = client.get("/api/migrations/pending")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["name"] == "003_add_index"
assert data[0]["sql_content"] == "CREATE INDEX idx_test ON test_table(name);"
assert data[1]["name"] == "004_add_table"
def test_get_pending_migrations_none(client):
"""Test when no pending migrations exist."""
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_pending_migrations = AsyncMock(return_value=[])
response = client.get("/api/migrations/pending")
assert response.status_code == 200
data = response.json()
assert len(data) == 0
def test_get_pending_migrations_error(client):
"""Test error handling in pending migrations."""
with patch("src.server.api_routes.migration_api.migration_service") as mock_service:
mock_service.get_pending_migrations = AsyncMock(side_effect=Exception("File error"))
response = client.get("/api/migrations/pending")
assert response.status_code == 500
assert "Failed to get pending migrations" in response.json()["detail"]

View File

@@ -0,0 +1,147 @@
"""
Unit tests for version_api.py
"""
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from src.server.config.version import ARCHON_VERSION
from src.server.main import app
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_version_data():
"""Mock version check data."""
return {
"current": ARCHON_VERSION,
"latest": "0.2.0",
"update_available": True,
"release_url": "https://github.com/coleam00/Archon/releases/tag/v0.2.0",
"release_notes": "New features and bug fixes",
"published_at": datetime(2025, 1, 1, 0, 0, 0),
"check_error": None,
"author": "coleam00",
"assets": [{"name": "archon.zip", "size": 1024000}],
}
def test_check_for_updates_success(client, mock_version_data):
"""Test successful version check."""
with patch("src.server.api_routes.version_api.version_service") as mock_service:
mock_service.check_for_updates = AsyncMock(return_value=mock_version_data)
response = client.get("/api/version/check")
assert response.status_code == 200
data = response.json()
assert data["current"] == ARCHON_VERSION
assert data["latest"] == "0.2.0"
assert data["update_available"] is True
assert data["release_url"] == mock_version_data["release_url"]
def test_check_for_updates_no_update(client):
"""Test when no update is available."""
mock_data = {
"current": ARCHON_VERSION,
"latest": ARCHON_VERSION,
"update_available": False,
"release_url": None,
"release_notes": None,
"published_at": None,
"check_error": None,
}
with patch("src.server.api_routes.version_api.version_service") as mock_service:
mock_service.check_for_updates = AsyncMock(return_value=mock_data)
response = client.get("/api/version/check")
assert response.status_code == 200
data = response.json()
assert data["current"] == ARCHON_VERSION
assert data["latest"] == ARCHON_VERSION
assert data["update_available"] is False
def test_check_for_updates_with_etag_modified(client, mock_version_data):
"""Test ETag handling when data has changed."""
with patch("src.server.api_routes.version_api.version_service") as mock_service:
mock_service.check_for_updates = AsyncMock(return_value=mock_version_data)
# First request
response1 = client.get("/api/version/check")
assert response1.status_code == 200
old_etag = response1.headers.get("etag")
# Modify data
modified_data = mock_version_data.copy()
modified_data["latest"] = "0.3.0"
mock_service.check_for_updates = AsyncMock(return_value=modified_data)
# Second request with old ETag
response2 = client.get("/api/version/check", headers={"If-None-Match": old_etag})
assert response2.status_code == 200 # Data changed, return new data
data = response2.json()
assert data["latest"] == "0.3.0"
def test_check_for_updates_error_handling(client):
"""Test error handling in version check."""
with patch("src.server.api_routes.version_api.version_service") as mock_service:
mock_service.check_for_updates = AsyncMock(side_effect=Exception("API error"))
response = client.get("/api/version/check")
assert response.status_code == 200 # Should still return 200
data = response.json()
assert data["current"] == ARCHON_VERSION
assert data["latest"] is None
assert data["update_available"] is False
assert data["check_error"] == "API error"
def test_get_current_version(client):
"""Test getting current version."""
response = client.get("/api/version/current")
assert response.status_code == 200
data = response.json()
assert data["version"] == ARCHON_VERSION
assert "timestamp" in data
def test_clear_version_cache_success(client):
"""Test clearing version cache."""
with patch("src.server.api_routes.version_api.version_service") as mock_service:
mock_service.clear_cache.return_value = None
response = client.post("/api/version/clear-cache")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "Version cache cleared successfully"
mock_service.clear_cache.assert_called_once()
def test_clear_version_cache_error(client):
"""Test error handling when clearing cache fails."""
with patch("src.server.api_routes.version_api.version_service") as mock_service:
mock_service.clear_cache.side_effect = Exception("Cache error")
response = client.post("/api/version/clear-cache")
assert response.status_code == 500
assert "Failed to clear cache" in response.json()["detail"]

View File

@@ -0,0 +1,271 @@
"""
Fixed unit tests for migration_service.py
"""
import hashlib
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from src.server.config.version import ARCHON_VERSION
from src.server.services.migration_service import (
MigrationRecord,
MigrationService,
PendingMigration,
)
@pytest.fixture
def migration_service():
"""Create a migration service instance."""
with patch("src.server.services.migration_service.Path.exists") as mock_exists:
# Mock that migration directory exists locally
mock_exists.return_value = False # Docker path doesn't exist
service = MigrationService()
return service
@pytest.fixture
def mock_supabase_client():
"""Mock Supabase client."""
client = MagicMock()
return client
def test_pending_migration_init():
"""Test PendingMigration initialization and checksum calculation."""
migration = PendingMigration(
version="0.1.0",
name="001_initial",
sql_content="CREATE TABLE test (id INT);",
file_path="migration/0.1.0/001_initial.sql"
)
assert migration.version == "0.1.0"
assert migration.name == "001_initial"
assert migration.sql_content == "CREATE TABLE test (id INT);"
assert migration.file_path == "migration/0.1.0/001_initial.sql"
assert migration.checksum == hashlib.md5("CREATE TABLE test (id INT);".encode()).hexdigest()
def test_migration_record_init():
"""Test MigrationRecord initialization from database data."""
data = {
"id": "123-456",
"version": "0.1.0",
"migration_name": "001_initial",
"applied_at": "2025-01-01T00:00:00Z",
"checksum": "abc123"
}
record = MigrationRecord(data)
assert record.id == "123-456"
assert record.version == "0.1.0"
assert record.migration_name == "001_initial"
assert record.applied_at == "2025-01-01T00:00:00Z"
assert record.checksum == "abc123"
def test_migration_service_init_local():
"""Test MigrationService initialization with local path."""
with patch("src.server.services.migration_service.Path.exists") as mock_exists:
# Mock that Docker path doesn't exist
mock_exists.return_value = False
service = MigrationService()
assert service._migrations_dir == Path("migration")
def test_migration_service_init_docker():
"""Test MigrationService initialization with Docker path."""
with patch("src.server.services.migration_service.Path.exists") as mock_exists:
# Mock that Docker path exists
mock_exists.return_value = True
service = MigrationService()
assert service._migrations_dir == Path("/app/migration")
@pytest.mark.asyncio
async def test_get_applied_migrations_success(migration_service, mock_supabase_client):
"""Test successful retrieval of applied migrations."""
mock_response = MagicMock()
mock_response.data = [
{
"id": "123",
"version": "0.1.0",
"migration_name": "001_initial",
"applied_at": "2025-01-01T00:00:00Z",
"checksum": "abc123",
},
]
mock_supabase_client.table.return_value.select.return_value.order.return_value.execute.return_value = mock_response
with patch.object(migration_service, '_get_supabase_client', return_value=mock_supabase_client):
with patch.object(migration_service, 'check_migrations_table_exists', return_value=True):
result = await migration_service.get_applied_migrations()
assert len(result) == 1
assert isinstance(result[0], MigrationRecord)
assert result[0].version == "0.1.0"
assert result[0].migration_name == "001_initial"
@pytest.mark.asyncio
async def test_get_applied_migrations_table_not_exists(migration_service, mock_supabase_client):
"""Test handling when migrations table doesn't exist."""
with patch.object(migration_service, '_get_supabase_client', return_value=mock_supabase_client):
with patch.object(migration_service, 'check_migrations_table_exists', return_value=False):
result = await migration_service.get_applied_migrations()
assert result == []
@pytest.mark.asyncio
async def test_get_pending_migrations_with_files(migration_service, mock_supabase_client):
"""Test getting pending migrations from filesystem."""
# Mock scan_migration_directory to return test migrations
mock_migrations = [
PendingMigration(
version="0.1.0",
name="001_initial",
sql_content="CREATE TABLE test;",
file_path="migration/0.1.0/001_initial.sql"
),
PendingMigration(
version="0.1.0",
name="002_update",
sql_content="ALTER TABLE test ADD col TEXT;",
file_path="migration/0.1.0/002_update.sql"
)
]
# Mock no applied migrations
with patch.object(migration_service, 'scan_migration_directory', return_value=mock_migrations):
with patch.object(migration_service, 'get_applied_migrations', return_value=[]):
result = await migration_service.get_pending_migrations()
assert len(result) == 2
assert all(isinstance(m, PendingMigration) for m in result)
assert result[0].name == "001_initial"
assert result[1].name == "002_update"
@pytest.mark.asyncio
async def test_get_pending_migrations_some_applied(migration_service, mock_supabase_client):
"""Test getting pending migrations when some are already applied."""
# Mock all migrations
mock_all_migrations = [
PendingMigration(
version="0.1.0",
name="001_initial",
sql_content="CREATE TABLE test;",
file_path="migration/0.1.0/001_initial.sql"
),
PendingMigration(
version="0.1.0",
name="002_update",
sql_content="ALTER TABLE test ADD col TEXT;",
file_path="migration/0.1.0/002_update.sql"
)
]
# Mock first migration as applied
mock_applied = [
MigrationRecord({
"version": "0.1.0",
"migration_name": "001_initial",
"applied_at": "2025-01-01T00:00:00Z",
"checksum": None
})
]
with patch.object(migration_service, 'scan_migration_directory', return_value=mock_all_migrations):
with patch.object(migration_service, 'get_applied_migrations', return_value=mock_applied):
with patch.object(migration_service, 'check_migrations_table_exists', return_value=True):
result = await migration_service.get_pending_migrations()
assert len(result) == 1
assert result[0].name == "002_update"
@pytest.mark.asyncio
async def test_get_migration_status_all_applied(migration_service, mock_supabase_client):
"""Test migration status when all migrations are applied."""
# Mock one migration file
mock_all_migrations = [
PendingMigration(
version="0.1.0",
name="001_initial",
sql_content="CREATE TABLE test;",
file_path="migration/0.1.0/001_initial.sql"
)
]
# Mock migration as applied
mock_applied = [
MigrationRecord({
"version": "0.1.0",
"migration_name": "001_initial",
"applied_at": "2025-01-01T00:00:00Z",
"checksum": None
})
]
with patch.object(migration_service, 'scan_migration_directory', return_value=mock_all_migrations):
with patch.object(migration_service, 'get_applied_migrations', return_value=mock_applied):
with patch.object(migration_service, 'check_migrations_table_exists', return_value=True):
result = await migration_service.get_migration_status()
assert result["current_version"] == ARCHON_VERSION
assert result["has_pending"] is False
assert result["bootstrap_required"] is False
assert result["pending_count"] == 0
assert result["applied_count"] == 1
@pytest.mark.asyncio
async def test_get_migration_status_bootstrap_required(migration_service, mock_supabase_client):
"""Test migration status when bootstrap is required (table doesn't exist)."""
# Mock migration files
mock_all_migrations = [
PendingMigration(
version="0.1.0",
name="001_initial",
sql_content="CREATE TABLE test;",
file_path="migration/0.1.0/001_initial.sql"
),
PendingMigration(
version="0.1.0",
name="002_update",
sql_content="ALTER TABLE test ADD col TEXT;",
file_path="migration/0.1.0/002_update.sql"
)
]
with patch.object(migration_service, 'scan_migration_directory', return_value=mock_all_migrations):
with patch.object(migration_service, 'get_applied_migrations', return_value=[]):
with patch.object(migration_service, 'check_migrations_table_exists', return_value=False):
result = await migration_service.get_migration_status()
assert result["bootstrap_required"] is True
assert result["has_pending"] is True
assert result["pending_count"] == 2
assert result["applied_count"] == 0
assert len(result["pending_migrations"]) == 2
@pytest.mark.asyncio
async def test_get_migration_status_no_files(migration_service, mock_supabase_client):
"""Test migration status when no migration files exist."""
with patch.object(migration_service, 'scan_migration_directory', return_value=[]):
with patch.object(migration_service, 'get_applied_migrations', return_value=[]):
with patch.object(migration_service, 'check_migrations_table_exists', return_value=True):
result = await migration_service.get_migration_status()
assert result["has_pending"] is False
assert result["pending_count"] == 0
assert len(result["pending_migrations"]) == 0

View File

@@ -0,0 +1,234 @@
"""
Unit tests for version_service.py
"""
import json
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from src.server.config.version import ARCHON_VERSION
from src.server.services.version_service import VersionService
@pytest.fixture
def version_service():
"""Create a fresh version service instance for each test."""
service = VersionService()
# Clear any cache from previous tests
service._cache = None
service._cache_time = None
return service
@pytest.fixture
def mock_release_data():
"""Mock GitHub release data."""
return {
"tag_name": "v0.2.0",
"name": "Archon v0.2.0",
"html_url": "https://github.com/coleam00/Archon/releases/tag/v0.2.0",
"body": "## Release Notes\n\nNew features and bug fixes",
"published_at": "2025-01-01T00:00:00Z",
"author": {"login": "coleam00"},
"assets": [
{
"name": "archon-v0.2.0.zip",
"size": 1024000,
"download_count": 100,
"browser_download_url": "https://github.com/coleam00/Archon/releases/download/v0.2.0/archon-v0.2.0.zip",
"content_type": "application/zip",
}
],
}
@pytest.mark.asyncio
async def test_get_latest_release_success(version_service, mock_release_data):
"""Test successful fetching of latest release from GitHub."""
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_release_data
mock_client.get.return_value = mock_response
mock_client_class.return_value.__aenter__.return_value = mock_client
result = await version_service.get_latest_release()
assert result == mock_release_data
assert version_service._cache == mock_release_data
assert version_service._cache_time is not None
@pytest.mark.asyncio
async def test_get_latest_release_uses_cache(version_service, mock_release_data):
"""Test that cache is used when available and not expired."""
# Set up cache
version_service._cache = mock_release_data
version_service._cache_time = datetime.now()
with patch("httpx.AsyncClient") as mock_client_class:
result = await version_service.get_latest_release()
# Should not make HTTP request
mock_client_class.assert_not_called()
assert result == mock_release_data
@pytest.mark.asyncio
async def test_get_latest_release_cache_expired(version_service, mock_release_data):
"""Test that cache is refreshed when expired."""
# Set up expired cache
old_data = {"tag_name": "v0.1.0"}
version_service._cache = old_data
version_service._cache_time = datetime.now() - timedelta(hours=2)
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_release_data
mock_client.get.return_value = mock_response
mock_client_class.return_value.__aenter__.return_value = mock_client
result = await version_service.get_latest_release()
# Should make new HTTP request
mock_client.get.assert_called_once()
assert result == mock_release_data
assert version_service._cache == mock_release_data
@pytest.mark.asyncio
async def test_get_latest_release_404(version_service):
"""Test handling of 404 (no releases)."""
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 404
mock_client.get.return_value = mock_response
mock_client_class.return_value.__aenter__.return_value = mock_client
result = await version_service.get_latest_release()
assert result is None
@pytest.mark.asyncio
async def test_get_latest_release_timeout(version_service, mock_release_data):
"""Test handling of timeout with cache fallback."""
# Set up cache
version_service._cache = mock_release_data
version_service._cache_time = datetime.now() - timedelta(hours=2) # Expired
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.get.side_effect = httpx.TimeoutException("Timeout")
mock_client_class.return_value.__aenter__.return_value = mock_client
result = await version_service.get_latest_release()
# Should return cached data
assert result == mock_release_data
@pytest.mark.asyncio
async def test_check_for_updates_new_version_available(version_service, mock_release_data):
"""Test when a new version is available."""
with patch.object(version_service, "get_latest_release", return_value=mock_release_data):
result = await version_service.check_for_updates()
assert result["current"] == ARCHON_VERSION
assert result["latest"] == "0.2.0"
assert result["update_available"] is True
assert result["release_url"] == mock_release_data["html_url"]
assert result["release_notes"] == mock_release_data["body"]
assert result["published_at"] == datetime.fromisoformat("2025-01-01T00:00:00+00:00")
assert result["author"] == "coleam00"
assert len(result["assets"]) == 1
@pytest.mark.asyncio
async def test_check_for_updates_same_version(version_service):
"""Test when current version is up to date."""
mock_data = {"tag_name": f"v{ARCHON_VERSION}", "html_url": "test_url", "body": "notes"}
with patch.object(version_service, "get_latest_release", return_value=mock_data):
result = await version_service.check_for_updates()
assert result["current"] == ARCHON_VERSION
assert result["latest"] == ARCHON_VERSION
assert result["update_available"] is False
@pytest.mark.asyncio
async def test_check_for_updates_no_release(version_service):
"""Test when no releases are found."""
with patch.object(version_service, "get_latest_release", return_value=None):
result = await version_service.check_for_updates()
assert result["current"] == ARCHON_VERSION
assert result["latest"] is None
assert result["update_available"] is False
assert result["release_url"] is None
@pytest.mark.asyncio
async def test_check_for_updates_parse_version(version_service, mock_release_data):
"""Test version parsing with and without 'v' prefix."""
# Test with 'v' prefix
mock_release_data["tag_name"] = "v1.2.3"
with patch.object(version_service, "get_latest_release", return_value=mock_release_data):
result = await version_service.check_for_updates()
assert result["latest"] == "1.2.3"
# Test without 'v' prefix
mock_release_data["tag_name"] = "2.0.0"
with patch.object(version_service, "get_latest_release", return_value=mock_release_data):
result = await version_service.check_for_updates()
assert result["latest"] == "2.0.0"
@pytest.mark.asyncio
async def test_check_for_updates_missing_fields(version_service):
"""Test handling of incomplete release data."""
mock_data = {"tag_name": "v0.2.0"} # Minimal data
with patch.object(version_service, "get_latest_release", return_value=mock_data):
result = await version_service.check_for_updates()
assert result["latest"] == "0.2.0"
assert result["release_url"] is None
assert result["release_notes"] is None
assert result["published_at"] is None
assert result["author"] is None
assert result["assets"] == [] # Empty list, not None
def test_clear_cache(version_service, mock_release_data):
"""Test cache clearing."""
# Set up cache
version_service._cache = mock_release_data
version_service._cache_time = datetime.now()
# Clear cache
version_service.clear_cache()
assert version_service._cache is None
assert version_service._cache_time is None
def test_is_newer_version():
"""Test version comparison logic using the utility function."""
from src.server.utils.semantic_version import is_newer_version
# Test various version comparisons
assert is_newer_version("1.0.0", "2.0.0") is True
assert is_newer_version("2.0.0", "1.0.0") is False
assert is_newer_version("1.0.0", "1.0.0") is False
assert is_newer_version("1.0.0", "1.1.0") is True
assert is_newer_version("1.0.0", "1.0.1") is True
assert is_newer_version("1.2.3", "1.2.3") is False

View File

@@ -33,6 +33,12 @@ class AsyncContextManager:
class TestAsyncLLMProviderService:
"""Test suite for async LLM provider service functions"""
@staticmethod
def _make_mock_client():
client = MagicMock()
client.aclose = AsyncMock()
return client
@pytest.fixture(autouse=True)
def clear_cache(self):
"""Clear cache before each test"""
@@ -69,7 +75,7 @@ class TestAsyncLLMProviderService:
return {
"provider": "ollama",
"api_key": "ollama",
"base_url": "http://localhost:11434/v1",
"base_url": "http://host.docker.internal:11434/v1",
"chat_model": "llama2",
"embedding_model": "nomic-embed-text",
}
@@ -98,7 +104,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
@@ -121,13 +127,13 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
assert client == mock_client
mock_openai.assert_called_once_with(
api_key="ollama", base_url="http://localhost:11434/v1"
api_key="ollama", base_url="http://host.docker.internal:11434/v1"
)
@pytest.mark.asyncio
@@ -143,7 +149,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client() as client:
@@ -166,7 +172,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client(provider="openai") as client:
@@ -194,7 +200,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
async with get_llm_client(use_embedding_provider=True) as client:
@@ -216,7 +222,7 @@ class TestAsyncLLMProviderService:
}
mock_credential_service.get_active_provider.return_value = config_without_key
mock_credential_service.get_credentials_by_category = AsyncMock(return_value={
"LLM_BASE_URL": "http://localhost:11434"
"LLM_BASE_URL": "http://host.docker.internal:11434"
})
with patch(
@@ -225,7 +231,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
# Should fallback to Ollama instead of raising an error
@@ -234,7 +240,7 @@ class TestAsyncLLMProviderService:
# Verify it created an Ollama client with correct params
mock_openai.assert_called_once_with(
api_key="ollama",
base_url="http://localhost:11434/v1"
base_url="http://host.docker.internal:11434/v1"
)
@pytest.mark.asyncio
@@ -426,7 +432,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
# First call should hit the credential service
@@ -464,7 +470,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
client_ref = None
@@ -474,13 +480,14 @@ class TestAsyncLLMProviderService:
# After context manager exits, should still have reference to client
assert client_ref == mock_client
mock_client.aclose.assert_awaited_once()
@pytest.mark.asyncio
async def test_multiple_providers_in_sequence(self, mock_credential_service):
"""Test creating clients for different providers in sequence"""
configs = [
{"provider": "openai", "api_key": "openai-key", "base_url": None},
{"provider": "ollama", "api_key": "ollama", "base_url": "http://localhost:11434/v1"},
{"provider": "ollama", "api_key": "ollama", "base_url": "http://host.docker.internal:11434/v1"},
{
"provider": "google",
"api_key": "google-key",
@@ -494,7 +501,7 @@ class TestAsyncLLMProviderService:
with patch(
"src.server.services.llm_provider_service.openai.AsyncOpenAI"
) as mock_openai:
mock_client = MagicMock()
mock_client = self._make_mock_client()
mock_openai.return_value = mock_client
for config in configs:

View File

@@ -104,13 +104,15 @@ class TestCodeExtractionSourceId:
)
# Verify the correct source_id was passed (now with cancellation_check parameter)
mock_extract.assert_called_once_with(
crawl_results,
url_to_full_document,
source_id, # This should be the third argument
None,
None # cancellation_check parameter
)
mock_extract.assert_called_once()
args, kwargs = mock_extract.call_args
assert args[0] == crawl_results
assert args[1] == url_to_full_document
assert args[2] == source_id
assert args[3] is None
assert args[4] is None
if len(args) > 5:
assert args[5] is None
assert result == 5
@pytest.mark.asyncio
@@ -174,4 +176,4 @@ class TestCodeExtractionSourceId:
import inspect
source = inspect.getsource(module)
assert "from urllib.parse import urlparse" not in source, \
"Should not import urlparse since we don't extract domain from URL anymore"
"Should not import urlparse since we don't extract domain from URL anymore"