mirror of
https://github.com/coleam00/Archon.git
synced 2025-12-30 13:39:44 -05:00
Merge branch 'main' into feature/automatic-discovery-llms-sitemap-430
This commit is contained in:
@@ -16,7 +16,6 @@ import os
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
# Import service discovery for HTTP communication
|
||||
@@ -78,15 +77,18 @@ def register_rag_tools(mcp: FastMCP):
|
||||
|
||||
@mcp.tool()
|
||||
async def rag_search_knowledge_base(
|
||||
ctx: Context, query: str, source_domain: str | None = None, match_count: int = 5
|
||||
ctx: Context, query: str, source_id: str | None = None, match_count: int = 5
|
||||
) -> str:
|
||||
"""
|
||||
Search knowledge base for relevant content using RAG.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
source_domain: Optional domain filter (e.g., 'docs.anthropic.com').
|
||||
Note: This is a domain name, not the source_id from get_available_sources.
|
||||
query: Search query - Keep it SHORT and FOCUSED (2-5 keywords).
|
||||
Good: "vector search", "authentication JWT", "React hooks"
|
||||
Bad: "how to implement user authentication with JWT tokens in React with TypeScript and handle refresh tokens"
|
||||
source_id: Optional source ID filter from rag_get_available_sources().
|
||||
This is the 'id' field from available sources, NOT a URL or domain name.
|
||||
Example: "src_1234abcd" not "docs.anthropic.com"
|
||||
match_count: Max results (default: 5)
|
||||
|
||||
Returns:
|
||||
@@ -102,8 +104,8 @@ def register_rag_tools(mcp: FastMCP):
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
request_data = {"query": query, "match_count": match_count}
|
||||
if source_domain:
|
||||
request_data["source"] = source_domain
|
||||
if source_id:
|
||||
request_data["source"] = source_id
|
||||
|
||||
response = await client.post(urljoin(api_url, "/api/rag/query"), json=request_data)
|
||||
|
||||
@@ -135,15 +137,18 @@ def register_rag_tools(mcp: FastMCP):
|
||||
|
||||
@mcp.tool()
|
||||
async def rag_search_code_examples(
|
||||
ctx: Context, query: str, source_domain: str | None = None, match_count: int = 5
|
||||
ctx: Context, query: str, source_id: str | None = None, match_count: int = 5
|
||||
) -> str:
|
||||
"""
|
||||
Search for relevant code examples in the knowledge base.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
source_domain: Optional domain filter (e.g., 'docs.anthropic.com').
|
||||
Note: This is a domain name, not the source_id from get_available_sources.
|
||||
query: Search query - Keep it SHORT and FOCUSED (2-5 keywords).
|
||||
Good: "React useState", "FastAPI middleware", "vector pgvector"
|
||||
Bad: "React hooks useState useEffect useContext useReducer useMemo useCallback"
|
||||
source_id: Optional source ID filter from rag_get_available_sources().
|
||||
This is the 'id' field from available sources, NOT a URL or domain name.
|
||||
Example: "src_1234abcd" not "docs.anthropic.com"
|
||||
match_count: Max results (default: 5)
|
||||
|
||||
Returns:
|
||||
@@ -159,8 +164,8 @@ def register_rag_tools(mcp: FastMCP):
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
request_data = {"query": query, "match_count": match_count}
|
||||
if source_domain:
|
||||
request_data["source"] = source_domain
|
||||
if source_id:
|
||||
request_data["source"] = source_id
|
||||
|
||||
# Call the dedicated code examples endpoint
|
||||
response = await client.post(
|
||||
|
||||
@@ -10,8 +10,8 @@ from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
from src.mcp_server.utils.error_handling import MCPErrorFormatter
|
||||
from src.mcp_server.utils.timeout_config import get_default_timeout
|
||||
from src.server.config.service_discovery import get_api_url
|
||||
@@ -31,20 +31,20 @@ def truncate_text(text: str, max_length: int = MAX_DESCRIPTION_LENGTH) -> str:
|
||||
def optimize_task_response(task: dict) -> dict:
|
||||
"""Optimize task object for MCP response."""
|
||||
task = task.copy() # Don't modify original
|
||||
|
||||
|
||||
# Truncate description if present
|
||||
if "description" in task and task["description"]:
|
||||
task["description"] = truncate_text(task["description"])
|
||||
|
||||
|
||||
# Replace arrays with counts
|
||||
if "sources" in task and isinstance(task["sources"], list):
|
||||
task["sources_count"] = len(task["sources"])
|
||||
del task["sources"]
|
||||
|
||||
|
||||
if "code_examples" in task and isinstance(task["code_examples"], list):
|
||||
task["code_examples_count"] = len(task["code_examples"])
|
||||
del task["code_examples"]
|
||||
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@@ -88,12 +88,12 @@ def register_task_tools(mcp: FastMCP):
|
||||
try:
|
||||
api_url = get_api_url()
|
||||
timeout = get_default_timeout()
|
||||
|
||||
|
||||
# Single task get mode
|
||||
if task_id:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(urljoin(api_url, f"/api/tasks/{task_id}"))
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
task = response.json()
|
||||
# Don't optimize single task get - return full details
|
||||
@@ -107,18 +107,18 @@ def register_task_tools(mcp: FastMCP):
|
||||
)
|
||||
else:
|
||||
return MCPErrorFormatter.from_http_error(response, "get task")
|
||||
|
||||
|
||||
# List mode with search and filters
|
||||
params: dict[str, Any] = {
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"exclude_large_fields": True, # Always exclude large fields in MCP responses
|
||||
}
|
||||
|
||||
|
||||
# Add search query if provided
|
||||
if query:
|
||||
params["q"] = query
|
||||
|
||||
|
||||
if filter_by == "project" and filter_value:
|
||||
# Use project-specific endpoint for project filtering
|
||||
url = urljoin(api_url, f"/api/projects/{filter_value}/tasks")
|
||||
@@ -146,13 +146,13 @@ def register_task_tools(mcp: FastMCP):
|
||||
# No specific filters - get all tasks
|
||||
url = urljoin(api_url, "/api/tasks")
|
||||
params["include_closed"] = include_closed
|
||||
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
result = response.json()
|
||||
|
||||
|
||||
# Normalize response format
|
||||
if isinstance(result, list):
|
||||
tasks = result
|
||||
@@ -176,10 +176,10 @@ def register_task_tools(mcp: FastMCP):
|
||||
message="Invalid response type from API",
|
||||
details={"response_type": type(result).__name__},
|
||||
)
|
||||
|
||||
|
||||
# Optimize task responses
|
||||
optimized_tasks = [optimize_task_response(task) for task in tasks]
|
||||
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"tasks": optimized_tasks,
|
||||
@@ -187,7 +187,7 @@ def register_task_tools(mcp: FastMCP):
|
||||
"count": len(optimized_tasks),
|
||||
"query": query, # Include search query in response
|
||||
})
|
||||
|
||||
|
||||
except httpx.RequestError as e:
|
||||
return MCPErrorFormatter.from_exception(
|
||||
e, "list tasks", {"filter_by": filter_by, "filter_value": filter_value}
|
||||
@@ -211,13 +211,19 @@ def register_task_tools(mcp: FastMCP):
|
||||
) -> str:
|
||||
"""
|
||||
Manage tasks (consolidated: create/update/delete).
|
||||
|
||||
|
||||
TASK GRANULARITY GUIDANCE:
|
||||
- For feature-specific projects: Create detailed implementation tasks (setup, implement, test, document)
|
||||
- For codebase-wide projects: Create feature-level tasks
|
||||
- Default to more granular tasks when project scope is unclear
|
||||
- Each task should represent 30 minutes to 4 hours of work
|
||||
|
||||
Args:
|
||||
action: "create" | "update" | "delete"
|
||||
task_id: Task UUID for update/delete
|
||||
project_id: Project UUID for create
|
||||
title: Task title text
|
||||
description: Detailed task description
|
||||
description: Detailed task description with clear completion criteria
|
||||
status: "todo" | "doing" | "review" | "done"
|
||||
assignee: String name of the assignee. Can be any agent name,
|
||||
"User" for human assignment, or custom agent identifiers
|
||||
@@ -228,16 +234,17 @@ def register_task_tools(mcp: FastMCP):
|
||||
feature: Feature label for grouping
|
||||
|
||||
Examples:
|
||||
manage_task("create", project_id="p-1", title="Fix auth bug", assignee="CodeAnalyzer-v2")
|
||||
manage_task("create", project_id="p-1", title="Research existing patterns", description="Study codebase for similar implementations")
|
||||
manage_task("create", project_id="p-1", title="Write unit tests", description="Cover all edge cases with 80% coverage")
|
||||
manage_task("update", task_id="t-1", status="doing", assignee="User")
|
||||
manage_task("delete", task_id="t-1")
|
||||
|
||||
|
||||
Returns: {success: bool, task?: object, message: string}
|
||||
"""
|
||||
try:
|
||||
api_url = get_api_url()
|
||||
timeout = get_default_timeout()
|
||||
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
if action == "create":
|
||||
if not project_id or not title:
|
||||
@@ -246,7 +253,7 @@ def register_task_tools(mcp: FastMCP):
|
||||
"project_id and title required for create",
|
||||
suggestion="Provide both project_id and title"
|
||||
)
|
||||
|
||||
|
||||
response = await client.post(
|
||||
urljoin(api_url, "/api/tasks"),
|
||||
json={
|
||||
@@ -260,15 +267,15 @@ def register_task_tools(mcp: FastMCP):
|
||||
"code_examples": [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
task = result.get("task")
|
||||
|
||||
|
||||
# Optimize task response
|
||||
if task:
|
||||
task = optimize_task_response(task)
|
||||
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"task": task,
|
||||
@@ -277,7 +284,7 @@ def register_task_tools(mcp: FastMCP):
|
||||
})
|
||||
else:
|
||||
return MCPErrorFormatter.from_http_error(response, "create task")
|
||||
|
||||
|
||||
elif action == "update":
|
||||
if not task_id:
|
||||
return MCPErrorFormatter.format_error(
|
||||
@@ -285,7 +292,7 @@ def register_task_tools(mcp: FastMCP):
|
||||
"task_id required for update",
|
||||
suggestion="Provide task_id to update"
|
||||
)
|
||||
|
||||
|
||||
# Build update fields
|
||||
update_fields = {}
|
||||
if title is not None:
|
||||
@@ -300,27 +307,27 @@ def register_task_tools(mcp: FastMCP):
|
||||
update_fields["task_order"] = task_order
|
||||
if feature is not None:
|
||||
update_fields["feature"] = feature
|
||||
|
||||
|
||||
if not update_fields:
|
||||
return MCPErrorFormatter.format_error(
|
||||
error_type="validation_error",
|
||||
message="No fields to update",
|
||||
suggestion="Provide at least one field to update",
|
||||
)
|
||||
|
||||
|
||||
response = await client.put(
|
||||
urljoin(api_url, f"/api/tasks/{task_id}"),
|
||||
json=update_fields
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
task = result.get("task")
|
||||
|
||||
|
||||
# Optimize task response
|
||||
if task:
|
||||
task = optimize_task_response(task)
|
||||
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"task": task,
|
||||
@@ -328,7 +335,7 @@ def register_task_tools(mcp: FastMCP):
|
||||
})
|
||||
else:
|
||||
return MCPErrorFormatter.from_http_error(response, "update task")
|
||||
|
||||
|
||||
elif action == "delete":
|
||||
if not task_id:
|
||||
return MCPErrorFormatter.format_error(
|
||||
@@ -336,11 +343,11 @@ def register_task_tools(mcp: FastMCP):
|
||||
"task_id required for delete",
|
||||
suggestion="Provide task_id to delete"
|
||||
)
|
||||
|
||||
|
||||
response = await client.delete(
|
||||
urljoin(api_url, f"/api/tasks/{task_id}")
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return json.dumps({
|
||||
@@ -349,14 +356,14 @@ def register_task_tools(mcp: FastMCP):
|
||||
})
|
||||
else:
|
||||
return MCPErrorFormatter.from_http_error(response, "delete task")
|
||||
|
||||
|
||||
else:
|
||||
return MCPErrorFormatter.format_error(
|
||||
"invalid_action",
|
||||
f"Unknown action: {action}",
|
||||
suggestion="Use 'create', 'update', or 'delete'"
|
||||
)
|
||||
|
||||
|
||||
except httpx.RequestError as e:
|
||||
return MCPErrorFormatter.from_exception(
|
||||
e, f"{action} task", {"task_id": task_id, "project_id": project_id}
|
||||
|
||||
@@ -194,12 +194,30 @@ MCP_INSTRUCTIONS = """
|
||||
## 🚨 CRITICAL RULES (ALWAYS FOLLOW)
|
||||
1. **Task Management**: ALWAYS use Archon MCP tools for task management.
|
||||
- Combine with your local TODO tools for granular tracking
|
||||
- First TODO: Update Archon task status
|
||||
- Last TODO: Update Archon with findings/completion
|
||||
|
||||
2. **Research First**: Before implementing, use rag_search_knowledge_base and rag_search_code_examples
|
||||
3. **Task-Driven Development**: Never code without checking current tasks first
|
||||
|
||||
## 🎯 Targeted Documentation Search
|
||||
|
||||
When searching specific documentation (very common!):
|
||||
1. **Get available sources**: `rag_get_available_sources()` - Returns list with id, title, url
|
||||
2. **Find source ID**: Match user's request to source title (e.g., "PydanticAI docs" -> find ID)
|
||||
3. **Filter search**: `rag_search_knowledge_base(query="...", source_id="src_xxx", match_count=5)`
|
||||
|
||||
Examples:
|
||||
- User: "Search the Supabase docs for vector functions"
|
||||
1. Call `rag_get_available_sources()`
|
||||
2. Find Supabase source ID from results (e.g., "src_abc123")
|
||||
3. Call `rag_search_knowledge_base(query="vector functions", source_id="src_abc123")`
|
||||
|
||||
- User: "Find authentication examples in the MCP documentation"
|
||||
1. Call `rag_get_available_sources()`
|
||||
2. Find MCP docs source ID
|
||||
3. Call `rag_search_code_examples(query="authentication", source_id="src_def456")`
|
||||
|
||||
IMPORTANT: Always use source_id (not URLs or domain names) for filtering!
|
||||
|
||||
## 📋 Core Workflow
|
||||
|
||||
### Task Management Cycle
|
||||
@@ -215,9 +233,9 @@ MCP_INSTRUCTIONS = """
|
||||
|
||||
### Consolidated Task Tools (Optimized ~2 tools from 5)
|
||||
- `list_tasks(query=None, task_id=None, filter_by=None, filter_value=None, per_page=10)`
|
||||
- **Consolidated**: list + search + get in one tool
|
||||
- **NEW**: Search with keyword query parameter
|
||||
- **NEW**: task_id parameter for getting single task (full details)
|
||||
- list + search + get in one tool
|
||||
- Search with keyword query parameter (optional)
|
||||
- task_id parameter for getting single task (full details)
|
||||
- Filter by status, project, or assignee
|
||||
- **Optimized**: Returns truncated descriptions and array counts (lists only)
|
||||
- **Default**: 10 items per page (was 50)
|
||||
@@ -231,23 +249,38 @@ MCP_INSTRUCTIONS = """
|
||||
|
||||
## 🏗️ Project Management
|
||||
|
||||
### Project Tools (Consolidated)
|
||||
### Project Tools
|
||||
- `list_projects(project_id=None, query=None, page=1, per_page=10)`
|
||||
- List all projects, search by query, or get specific project by ID
|
||||
- `manage_project(action, project_id=None, title=None, description=None, github_repo=None)`
|
||||
- Actions: "create", "update", "delete"
|
||||
|
||||
### Document Tools (Consolidated)
|
||||
### Document Tools
|
||||
- `list_documents(project_id, document_id=None, query=None, document_type=None, page=1, per_page=10)`
|
||||
- List project documents, search, filter by type, or get specific document
|
||||
- `manage_document(action, project_id, document_id=None, title=None, document_type=None, content=None, ...)`
|
||||
- Actions: "create", "update", "delete"
|
||||
|
||||
## 🔍 Research Patterns
|
||||
- **Architecture patterns**: `rag_search_knowledge_base(query="[tech] architecture patterns", match_count=5)`
|
||||
- **Code examples**: `rag_search_code_examples(query="[feature] implementation", match_count=3)`
|
||||
- **Source discovery**: `rag_get_available_sources()`
|
||||
- Keep match_count around 3-5 for focused results
|
||||
|
||||
### CRITICAL: Keep Queries Short and Focused!
|
||||
Vector search works best with 2-5 keywords, NOT long sentences or keyword dumps.
|
||||
|
||||
✅ GOOD Queries (concise, focused):
|
||||
- `rag_search_knowledge_base(query="vector search pgvector")`
|
||||
- `rag_search_code_examples(query="React useState")`
|
||||
- `rag_search_knowledge_base(query="authentication JWT")`
|
||||
- `rag_search_code_examples(query="FastAPI middleware")`
|
||||
|
||||
❌ BAD Queries (too long, unfocused):
|
||||
- `rag_search_knowledge_base(query="how to implement vector search with pgvector in PostgreSQL for semantic similarity matching with OpenAI embeddings")`
|
||||
- `rag_search_code_examples(query="React hooks useState useEffect useContext useReducer useMemo useCallback")`
|
||||
|
||||
### Query Construction Tips:
|
||||
- Extract 2-5 most important keywords from the user's request
|
||||
- Focus on technical terms and specific technologies
|
||||
- Omit filler words like "how to", "implement", "create", "example"
|
||||
- For multi-concept searches, do multiple focused queries instead of one broad query
|
||||
|
||||
## 📊 Task Status Flow
|
||||
`todo` → `doing` → `review` → `done`
|
||||
@@ -255,25 +288,26 @@ MCP_INSTRUCTIONS = """
|
||||
- Use 'review' for completed work awaiting validation
|
||||
- Mark tasks 'done' only after verification
|
||||
|
||||
## 💾 Version Management (Consolidated)
|
||||
- `list_versions(project_id, field_name=None, version_number=None, page=1, per_page=10)`
|
||||
- List all versions, filter by field, or get specific version
|
||||
- `manage_version(action, project_id, field_name, version_number=None, content=None, change_summary=None, ...)`
|
||||
- Actions: "create", "restore"
|
||||
- Field names: "docs", "features", "data", "prd"
|
||||
## 📝 Task Granularity Guidelines
|
||||
|
||||
## 🎯 Best Practices
|
||||
1. **Atomic Tasks**: Create tasks that take 1-4 hours
|
||||
2. **Clear Descriptions**: Include acceptance criteria in task descriptions
|
||||
3. **Use Features**: Group related tasks with feature labels
|
||||
4. **Add Sources**: Link relevant documentation to tasks
|
||||
5. **Track Progress**: Update task status as you work
|
||||
### Project Scope Determines Task Granularity
|
||||
|
||||
## 📊 Optimization Updates
|
||||
- **Payload Optimization**: Tasks in lists return truncated descriptions (200 chars)
|
||||
- **Array Counts**: Source/example arrays replaced with counts in list responses
|
||||
- **Smart Defaults**: Default page size reduced from 50 to 10 items
|
||||
- **Search Support**: New `query` parameter in list_tasks for keyword search
|
||||
**For Feature-Specific Projects** (project = single feature):
|
||||
Create granular implementation tasks:
|
||||
- "Set up development environment"
|
||||
- "Install required dependencies"
|
||||
- "Create database schema"
|
||||
- "Implement API endpoints"
|
||||
- "Add frontend components"
|
||||
- "Write unit tests"
|
||||
- "Add integration tests"
|
||||
- "Update documentation"
|
||||
|
||||
**For Codebase-Wide Projects** (project = entire application):
|
||||
Create feature-level tasks:
|
||||
- "Implement user authentication feature"
|
||||
- "Add payment processing system"
|
||||
- "Create admin dashboard"
|
||||
"""
|
||||
|
||||
# Initialize the main FastMCP server with fixed configuration
|
||||
|
||||
@@ -14,6 +14,7 @@ from .internal_api import router as internal_router
|
||||
from .knowledge_api import router as knowledge_router
|
||||
from .mcp_api import router as mcp_router
|
||||
from .projects_api import router as projects_router
|
||||
from .providers_api import router as providers_router
|
||||
from .settings_api import router as settings_router
|
||||
|
||||
__all__ = [
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"projects_router",
|
||||
"agent_chat_router",
|
||||
"internal_router",
|
||||
"providers_router",
|
||||
]
|
||||
|
||||
@@ -18,6 +18,8 @@ from urllib.parse import urlparse
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Basic validation - simplified inline version
|
||||
|
||||
# Import unified logging
|
||||
from ..config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info
|
||||
from ..services.crawler_manager import get_crawler
|
||||
@@ -62,26 +64,59 @@ async def _validate_provider_api_key(provider: str = None) -> None:
|
||||
logger.info("🔑 Starting API key validation...")
|
||||
|
||||
try:
|
||||
# Basic provider validation
|
||||
if not provider:
|
||||
provider = "openai"
|
||||
else:
|
||||
# Simple provider validation
|
||||
allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
|
||||
if provider not in allowed_providers:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid provider name",
|
||||
"message": f"Provider '{provider}' not supported",
|
||||
"error_type": "validation_error"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"🔑 Testing {provider.title()} API key with minimal embedding request...")
|
||||
|
||||
# Test API key with minimal embedding request - this will fail if key is invalid
|
||||
from ..services.embeddings.embedding_service import create_embedding
|
||||
test_result = await create_embedding(text="test")
|
||||
|
||||
if not test_result:
|
||||
logger.error(f"❌ {provider.title()} API key validation failed - no embedding returned")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"Invalid {provider.title()} API key",
|
||||
"message": f"Please verify your {provider.title()} API key in Settings.",
|
||||
"error_type": "authentication_failed",
|
||||
"provider": provider
|
||||
}
|
||||
)
|
||||
# Basic sanitization for logging
|
||||
safe_provider = provider[:20] # Limit length
|
||||
logger.info(f"🔑 Testing {safe_provider.title()} API key with minimal embedding request...")
|
||||
|
||||
try:
|
||||
# Test API key with minimal embedding request using provider-scoped configuration
|
||||
from ..services.embeddings.embedding_service import create_embedding
|
||||
|
||||
test_result = await create_embedding(text="test", provider=provider)
|
||||
|
||||
if not test_result:
|
||||
logger.error(
|
||||
f"❌ {provider.title()} API key validation failed - no embedding returned"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"Invalid {provider.title()} API key",
|
||||
"message": f"Please verify your {provider.title()} API key in Settings.",
|
||||
"error_type": "authentication_failed",
|
||||
"provider": provider,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"❌ {provider.title()} API key validation failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"Invalid {provider.title()} API key",
|
||||
"message": f"Please verify your {provider.title()} API key in Settings. Error: {str(e)[:100]}",
|
||||
"error_type": "authentication_failed",
|
||||
"provider": provider,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"✅ {provider.title()} API key validation successful")
|
||||
|
||||
|
||||
170
python/src/server/api_routes/migration_api.py
Normal file
170
python/src/server/api_routes/migration_api.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
API routes for database migration tracking and management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import logfire
|
||||
from fastapi import APIRouter, Header, HTTPException, Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..config.version import ARCHON_VERSION
|
||||
from ..services.migration_service import migration_service
|
||||
from ..utils.etag_utils import check_etag, generate_etag
|
||||
|
||||
|
||||
# Response models
|
||||
class MigrationRecord(BaseModel):
|
||||
"""Represents an applied migration."""
|
||||
|
||||
version: str
|
||||
migration_name: str
|
||||
applied_at: datetime
|
||||
checksum: str | None = None
|
||||
|
||||
|
||||
class PendingMigration(BaseModel):
|
||||
"""Represents a pending migration."""
|
||||
|
||||
version: str
|
||||
name: str
|
||||
sql_content: str
|
||||
file_path: str
|
||||
checksum: str | None = None
|
||||
|
||||
|
||||
class MigrationStatusResponse(BaseModel):
|
||||
"""Complete migration status response."""
|
||||
|
||||
pending_migrations: list[PendingMigration]
|
||||
applied_migrations: list[MigrationRecord]
|
||||
has_pending: bool
|
||||
bootstrap_required: bool
|
||||
current_version: str
|
||||
pending_count: int
|
||||
applied_count: int
|
||||
|
||||
|
||||
class MigrationHistoryResponse(BaseModel):
|
||||
"""Migration history response."""
|
||||
|
||||
migrations: list[MigrationRecord]
|
||||
total_count: int
|
||||
current_version: str
|
||||
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/api/migrations", tags=["migrations"])
|
||||
|
||||
|
||||
@router.get("/status", response_model=MigrationStatusResponse)
|
||||
async def get_migration_status(
|
||||
response: Response, if_none_match: str | None = Header(None)
|
||||
):
|
||||
"""
|
||||
Get current migration status including pending and applied migrations.
|
||||
|
||||
Returns comprehensive migration status with:
|
||||
- List of pending migrations with SQL content
|
||||
- List of applied migrations
|
||||
- Bootstrap flag if migrations table doesn't exist
|
||||
- Current version information
|
||||
"""
|
||||
try:
|
||||
# Get migration status from service
|
||||
status = await migration_service.get_migration_status()
|
||||
|
||||
# Generate ETag for response
|
||||
etag = generate_etag(status)
|
||||
|
||||
# Check if client has current data
|
||||
if check_etag(if_none_match, etag):
|
||||
# Client has current data, return 304
|
||||
response.status_code = 304
|
||||
response.headers["ETag"] = f'"{etag}"'
|
||||
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||
return Response(status_code=304)
|
||||
else:
|
||||
# Client needs new data
|
||||
response.headers["ETag"] = f'"{etag}"'
|
||||
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||
return MigrationStatusResponse(**status)
|
||||
|
||||
except Exception as e:
|
||||
logfire.error(f"Error getting migration status: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get migration status: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/history", response_model=MigrationHistoryResponse)
|
||||
async def get_migration_history(response: Response, if_none_match: str | None = Header(None)):
|
||||
"""
|
||||
Get history of applied migrations.
|
||||
|
||||
Returns list of all applied migrations sorted by date.
|
||||
"""
|
||||
try:
|
||||
# Get applied migrations from service
|
||||
applied = await migration_service.get_applied_migrations()
|
||||
|
||||
# Format response
|
||||
history = {
|
||||
"migrations": [
|
||||
MigrationRecord(
|
||||
version=m.version,
|
||||
migration_name=m.migration_name,
|
||||
applied_at=m.applied_at,
|
||||
checksum=m.checksum,
|
||||
)
|
||||
for m in applied
|
||||
],
|
||||
"total_count": len(applied),
|
||||
"current_version": ARCHON_VERSION,
|
||||
}
|
||||
|
||||
# Generate ETag for response
|
||||
etag = generate_etag(history)
|
||||
|
||||
# Check if client has current data
|
||||
if check_etag(if_none_match, etag):
|
||||
# Client has current data, return 304
|
||||
response.status_code = 304
|
||||
response.headers["ETag"] = f'"{etag}"'
|
||||
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||
return Response(status_code=304)
|
||||
else:
|
||||
# Client needs new data
|
||||
response.headers["ETag"] = f'"{etag}"'
|
||||
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||
return MigrationHistoryResponse(**history)
|
||||
|
||||
except Exception as e:
|
||||
logfire.error(f"Error getting migration history: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get migration history: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/pending", response_model=list[PendingMigration])
|
||||
async def get_pending_migrations():
|
||||
"""
|
||||
Get list of pending migrations only.
|
||||
|
||||
Returns simplified list of migrations that need to be applied.
|
||||
"""
|
||||
try:
|
||||
# Get pending migrations from service
|
||||
pending = await migration_service.get_pending_migrations()
|
||||
|
||||
# Format response
|
||||
return [
|
||||
PendingMigration(
|
||||
version=m.version,
|
||||
name=m.name,
|
||||
sql_content=m.sql_content,
|
||||
file_path=m.file_path,
|
||||
checksum=m.checksum,
|
||||
)
|
||||
for m in pending
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logfire.error(f"Error getting pending migrations: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get pending migrations: {str(e)}") from e
|
||||
154
python/src/server/api_routes/providers_api.py
Normal file
154
python/src/server/api_routes/providers_api.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Provider status API endpoints for testing connectivity
|
||||
|
||||
Handles server-side provider connectivity testing without exposing API keys to frontend.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Path
|
||||
|
||||
from ..config.logfire_config import logfire
|
||||
from ..services.credential_service import credential_service
|
||||
# Provider validation - simplified inline version
|
||||
|
||||
router = APIRouter(prefix="/api/providers", tags=["providers"])
|
||||
|
||||
|
||||
async def test_openai_connection(api_key: str) -> bool:
|
||||
"""Test OpenAI API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logfire.warning(f"OpenAI connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_google_connection(api_key: str) -> bool:
|
||||
"""Test Google AI API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://generativelanguage.googleapis.com/v1/models",
|
||||
headers={"x-goog-api-key": api_key}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
logfire.warning("Google AI connectivity test failed")
|
||||
return False
|
||||
|
||||
|
||||
async def test_anthropic_connection(api_key: str) -> bool:
|
||||
"""Test Anthropic API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logfire.warning(f"Anthropic connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_openrouter_connection(api_key: str) -> bool:
|
||||
"""Test OpenRouter API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://openrouter.ai/api/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logfire.warning(f"OpenRouter connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_grok_connection(api_key: str) -> bool:
|
||||
"""Test Grok API connectivity"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://api.x.ai/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logfire.warning(f"Grok connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
PROVIDER_TESTERS = {
|
||||
"openai": test_openai_connection,
|
||||
"google": test_google_connection,
|
||||
"anthropic": test_anthropic_connection,
|
||||
"openrouter": test_openrouter_connection,
|
||||
"grok": test_grok_connection,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{provider}/status")
|
||||
async def get_provider_status(
|
||||
provider: str = Path(
|
||||
...,
|
||||
description="Provider name to test connectivity for",
|
||||
regex="^[a-z0-9_]+$",
|
||||
max_length=20
|
||||
)
|
||||
):
|
||||
"""Test provider connectivity using server-side API key (secure)"""
|
||||
try:
|
||||
# Basic provider validation
|
||||
allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
|
||||
if provider not in allowed_providers:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid provider '{provider}'. Allowed providers: {sorted(allowed_providers)}"
|
||||
)
|
||||
|
||||
# Basic sanitization for logging
|
||||
safe_provider = provider[:20] # Limit length
|
||||
logfire.info(f"Testing {safe_provider} connectivity server-side")
|
||||
|
||||
if provider not in PROVIDER_TESTERS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider '{provider}' not supported for connectivity testing"
|
||||
)
|
||||
|
||||
# Get API key server-side (never expose to client)
|
||||
key_name = f"{provider.upper()}_API_KEY"
|
||||
api_key = await credential_service.get_credential(key_name, decrypt=True)
|
||||
|
||||
if not api_key or not isinstance(api_key, str) or not api_key.strip():
|
||||
logfire.info(f"No API key configured for {safe_provider}")
|
||||
return {"ok": False, "reason": "no_key"}
|
||||
|
||||
# Test connectivity using server-side key
|
||||
tester = PROVIDER_TESTERS[provider]
|
||||
is_connected = await tester(api_key)
|
||||
|
||||
logfire.info(f"{safe_provider} connectivity test result: {is_connected}")
|
||||
return {
|
||||
"ok": is_connected,
|
||||
"reason": "connected" if is_connected else "connection_failed",
|
||||
"provider": provider # Echo back validated provider name
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (they're already properly formatted)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Basic error sanitization for logging
|
||||
safe_error = str(e)[:100] # Limit length
|
||||
logfire.error(f"Error testing {provider[:20]} connectivity: {safe_error}")
|
||||
raise HTTPException(status_code=500, detail={"error": "Internal server error during connectivity test"})
|
||||
121
python/src/server/api_routes/version_api.py
Normal file
121
python/src/server/api_routes/version_api.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
API routes for version checking and update management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import logfire
|
||||
from fastapi import APIRouter, Header, HTTPException, Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..config.version import ARCHON_VERSION
|
||||
from ..services.version_service import version_service
|
||||
from ..utils.etag_utils import check_etag, generate_etag
|
||||
|
||||
|
||||
# Response models
|
||||
class ReleaseAsset(BaseModel):
|
||||
"""Represents a downloadable asset from a release."""
|
||||
|
||||
name: str
|
||||
size: int
|
||||
download_count: int
|
||||
browser_download_url: str
|
||||
content_type: str
|
||||
|
||||
|
||||
class VersionCheckResponse(BaseModel):
|
||||
"""Version check response with update information."""
|
||||
|
||||
current: str
|
||||
latest: str | None
|
||||
update_available: bool
|
||||
release_url: str | None
|
||||
release_notes: str | None
|
||||
published_at: datetime | None
|
||||
check_error: str | None = None
|
||||
assets: list[dict[str, Any]] | None = None
|
||||
author: str | None = None
|
||||
|
||||
|
||||
class CurrentVersionResponse(BaseModel):
|
||||
"""Simple current version response."""
|
||||
|
||||
version: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/api/version", tags=["version"])
|
||||
|
||||
|
||||
@router.get("/check", response_model=VersionCheckResponse)
|
||||
async def check_for_updates(response: Response, if_none_match: str | None = Header(None)):
|
||||
"""
|
||||
Check for available Archon updates.
|
||||
|
||||
Queries GitHub releases API to determine if a newer version is available.
|
||||
Results are cached for 1 hour to avoid rate limiting.
|
||||
|
||||
Returns:
|
||||
Version information including current, latest, and update availability
|
||||
"""
|
||||
try:
|
||||
# Get version check results from service
|
||||
result = await version_service.check_for_updates()
|
||||
|
||||
# Generate ETag for response
|
||||
etag = generate_etag(result)
|
||||
|
||||
# Check if client has current data
|
||||
if check_etag(if_none_match, etag):
|
||||
# Client has current data, return 304
|
||||
response.status_code = 304
|
||||
response.headers["ETag"] = f'"{etag}"'
|
||||
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||
return Response(status_code=304)
|
||||
else:
|
||||
# Client needs new data
|
||||
response.headers["ETag"] = f'"{etag}"'
|
||||
response.headers["Cache-Control"] = "no-cache, must-revalidate"
|
||||
return VersionCheckResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logfire.error(f"Error checking for updates: {e}")
|
||||
# Return safe response with error
|
||||
return VersionCheckResponse(
|
||||
current=ARCHON_VERSION,
|
||||
latest=None,
|
||||
update_available=False,
|
||||
release_url=None,
|
||||
release_notes=None,
|
||||
published_at=None,
|
||||
check_error=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/current", response_model=CurrentVersionResponse)
|
||||
async def get_current_version():
|
||||
"""
|
||||
Get the current Archon version.
|
||||
|
||||
Simple endpoint that returns the installed version without checking for updates.
|
||||
"""
|
||||
return CurrentVersionResponse(version=ARCHON_VERSION, timestamp=datetime.now())
|
||||
|
||||
|
||||
@router.post("/clear-cache")
|
||||
async def clear_version_cache():
|
||||
"""
|
||||
Clear the version check cache.
|
||||
|
||||
Forces the next version check to query GitHub API instead of using cached data.
|
||||
Useful for testing or forcing an immediate update check.
|
||||
"""
|
||||
try:
|
||||
version_service.clear_cache()
|
||||
return {"message": "Version cache cleared successfully", "success": True}
|
||||
except Exception as e:
|
||||
logfire.error(f"Error clearing version cache: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}") from e
|
||||
11
python/src/server/config/version.py
Normal file
11
python/src/server/config/version.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Version configuration for Archon.
|
||||
"""
|
||||
|
||||
# Current version of Archon
|
||||
# Update this with each release
|
||||
ARCHON_VERSION = "0.1.0"
|
||||
|
||||
# Repository information for GitHub API
|
||||
GITHUB_REPO_OWNER = "coleam00"
|
||||
GITHUB_REPO_NAME = "Archon"
|
||||
@@ -23,9 +23,12 @@ from .api_routes.bug_report_api import router as bug_report_router
|
||||
from .api_routes.internal_api import router as internal_router
|
||||
from .api_routes.knowledge_api import router as knowledge_router
|
||||
from .api_routes.mcp_api import router as mcp_router
|
||||
from .api_routes.migration_api import router as migration_router
|
||||
from .api_routes.ollama_api import router as ollama_router
|
||||
from .api_routes.progress_api import router as progress_router
|
||||
from .api_routes.projects_api import router as projects_router
|
||||
from .api_routes.providers_api import router as providers_router
|
||||
from .api_routes.version_api import router as version_router
|
||||
|
||||
# Import modular API routers
|
||||
from .api_routes.settings_api import router as settings_router
|
||||
@@ -186,6 +189,9 @@ app.include_router(progress_router)
|
||||
app.include_router(agent_chat_router)
|
||||
app.include_router(internal_router)
|
||||
app.include_router(bug_report_router)
|
||||
app.include_router(providers_router)
|
||||
app.include_router(version_router)
|
||||
app.include_router(migration_router)
|
||||
|
||||
|
||||
# Root endpoint
|
||||
|
||||
@@ -139,6 +139,7 @@ class CodeExtractionService:
|
||||
source_id: str,
|
||||
progress_callback: Callable | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Extract code examples from crawled documents and store them.
|
||||
@@ -204,7 +205,7 @@ class CodeExtractionService:
|
||||
|
||||
# Generate summaries for code blocks
|
||||
summary_results = await self._generate_code_summaries(
|
||||
all_code_blocks, summary_callback, cancellation_check
|
||||
all_code_blocks, summary_callback, cancellation_check, provider
|
||||
)
|
||||
|
||||
# Prepare code examples for storage
|
||||
@@ -223,7 +224,7 @@ class CodeExtractionService:
|
||||
|
||||
# Store code examples in database
|
||||
return await self._store_code_examples(
|
||||
storage_data, url_to_full_document, storage_callback
|
||||
storage_data, url_to_full_document, storage_callback, provider
|
||||
)
|
||||
|
||||
async def _extract_code_blocks_from_documents(
|
||||
@@ -1523,6 +1524,7 @@ class CodeExtractionService:
|
||||
all_code_blocks: list[dict[str, Any]],
|
||||
progress_callback: Callable | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
provider: str | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Generate summaries for all code blocks.
|
||||
@@ -1587,7 +1589,7 @@ class CodeExtractionService:
|
||||
|
||||
try:
|
||||
results = await generate_code_summaries_batch(
|
||||
code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback
|
||||
code_blocks_for_summaries, max_workers, progress_callback=summary_progress_callback, provider=provider
|
||||
)
|
||||
|
||||
# Ensure all results are valid dicts
|
||||
@@ -1667,6 +1669,7 @@ class CodeExtractionService:
|
||||
storage_data: dict[str, list[Any]],
|
||||
url_to_full_document: dict[str, str],
|
||||
progress_callback: Callable | None = None,
|
||||
provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Store code examples in the database.
|
||||
@@ -1709,7 +1712,7 @@ class CodeExtractionService:
|
||||
batch_size=20,
|
||||
url_to_full_document=url_to_full_document,
|
||||
progress_callback=storage_progress_callback,
|
||||
provider=None, # Use configured provider
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Report completion of code extraction/storage phase
|
||||
|
||||
@@ -75,10 +75,11 @@ class CrawlingService:
|
||||
self.url_handler = URLHandler()
|
||||
self.site_config = SiteConfig()
|
||||
self.markdown_generator = self.site_config.get_markdown_generator()
|
||||
self.link_pruning_markdown_generator = self.site_config.get_link_pruning_markdown_generator()
|
||||
|
||||
# Initialize strategies
|
||||
self.batch_strategy = BatchCrawlStrategy(crawler, self.markdown_generator)
|
||||
self.recursive_strategy = RecursiveCrawlStrategy(crawler, self.markdown_generator)
|
||||
self.batch_strategy = BatchCrawlStrategy(crawler, self.link_pruning_markdown_generator)
|
||||
self.recursive_strategy = RecursiveCrawlStrategy(crawler, self.link_pruning_markdown_generator)
|
||||
self.single_page_strategy = SinglePageCrawlStrategy(crawler, self.markdown_generator)
|
||||
self.sitemap_strategy = SitemapCrawlStrategy()
|
||||
|
||||
@@ -551,12 +552,24 @@ class CrawlingService:
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract provider from request or use credential service default
|
||||
provider = request.get("provider")
|
||||
if not provider:
|
||||
try:
|
||||
from ..credential_service import credential_service
|
||||
provider_config = await credential_service.get_active_provider("llm")
|
||||
provider = provider_config.get("provider", "openai")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
|
||||
provider = "openai"
|
||||
|
||||
code_examples_count = await self.doc_storage_ops.extract_and_store_code_examples(
|
||||
crawl_results,
|
||||
storage_results["url_to_full_document"],
|
||||
storage_results["source_id"],
|
||||
code_progress_callback,
|
||||
self._check_cancellation,
|
||||
provider,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Code extraction failed, continue crawl with warning
|
||||
|
||||
@@ -351,6 +351,7 @@ class DocumentStorageOperations:
|
||||
source_id: str,
|
||||
progress_callback: Callable | None = None,
|
||||
cancellation_check: Callable[[], None] | None = None,
|
||||
provider: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Extract code examples from crawled documents and store them.
|
||||
@@ -361,12 +362,13 @@ class DocumentStorageOperations:
|
||||
source_id: The unique source_id for all documents
|
||||
progress_callback: Optional callback for progress updates
|
||||
cancellation_check: Optional function to check for cancellation
|
||||
provider: Optional LLM provider to use for code summaries
|
||||
|
||||
Returns:
|
||||
Number of code examples stored
|
||||
"""
|
||||
result = await self.code_extraction_service.extract_and_store_code_examples(
|
||||
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check
|
||||
crawl_results, url_to_full_document, source_id, progress_callback, cancellation_check, provider
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -4,6 +4,7 @@ Site Configuration Helper
|
||||
Handles site-specific configurations and detection.
|
||||
"""
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||
|
||||
from ....config.logfire_config import get_logger
|
||||
|
||||
@@ -96,3 +97,33 @@ class SiteConfig:
|
||||
"code_language_callback": lambda el: el.get('class', '').replace('language-', '') if el else ''
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_link_pruning_markdown_generator():
|
||||
"""
|
||||
Get markdown generator for the recursive crawling strategy that cleans up pages crawled.
|
||||
|
||||
Returns:
|
||||
Configured markdown generator
|
||||
"""
|
||||
prune_filter = PruningContentFilter(
|
||||
threshold=0.2,
|
||||
threshold_type="fixed"
|
||||
)
|
||||
|
||||
return DefaultMarkdownGenerator(
|
||||
content_source="html", # Use raw HTML to preserve code blocks
|
||||
content_filter=prune_filter,
|
||||
options={
|
||||
"mark_code": True, # Mark code blocks properly
|
||||
"handle_code_in_pre": True, # Handle <pre><code> tags
|
||||
"body_width": 0, # No line wrapping
|
||||
"skip_internal_links": True, # Add to reduce noise
|
||||
"include_raw_html": False, # Prevent HTML in markdown
|
||||
"escape": False, # Don't escape special chars in code
|
||||
"decode_unicode": True, # Decode unicode characters
|
||||
"strip_empty_lines": False, # Preserve empty lines in code
|
||||
"preserve_code_formatting": True, # Custom option if supported
|
||||
"code_language_callback": lambda el: el.get('class', '').replace('language-', '') if el else ''
|
||||
}
|
||||
)
|
||||
|
||||
@@ -231,12 +231,12 @@ class BatchCrawlStrategy:
|
||||
raise
|
||||
|
||||
processed += 1
|
||||
if result.success and result.markdown:
|
||||
if result.success and result.markdown and result.markdown.fit_markdown:
|
||||
# Map back to original URL
|
||||
original_url = url_mapping.get(result.url, result.url)
|
||||
successful_results.append({
|
||||
"url": original_url,
|
||||
"markdown": result.markdown,
|
||||
"markdown": result.markdown.fit_markdown,
|
||||
"html": result.html, # Use raw HTML
|
||||
})
|
||||
else:
|
||||
|
||||
@@ -276,10 +276,10 @@ class RecursiveCrawlStrategy:
|
||||
visited.add(norm_url)
|
||||
total_processed += 1
|
||||
|
||||
if result.success and result.markdown:
|
||||
if result.success and result.markdown and result.markdown.fit_markdown:
|
||||
results_all.append({
|
||||
"url": original_url,
|
||||
"markdown": result.markdown,
|
||||
"markdown": result.markdown.fit_markdown,
|
||||
"html": result.html, # Always use raw HTML for code extraction
|
||||
})
|
||||
depth_successful += 1
|
||||
|
||||
@@ -36,6 +36,44 @@ class CredentialItem:
|
||||
description: str | None = None
|
||||
|
||||
|
||||
def _detect_embedding_provider_from_model(embedding_model: str) -> str:
|
||||
"""
|
||||
Detect the appropriate embedding provider based on model name.
|
||||
|
||||
Args:
|
||||
embedding_model: The embedding model name
|
||||
|
||||
Returns:
|
||||
Provider name: 'google', 'openai', or 'openai' (default)
|
||||
"""
|
||||
if not embedding_model:
|
||||
return "openai" # Default
|
||||
|
||||
model_lower = embedding_model.lower()
|
||||
|
||||
# Google embedding models
|
||||
google_patterns = [
|
||||
"text-embedding-004",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding",
|
||||
"gemini-embedding",
|
||||
"multimodalembedding"
|
||||
]
|
||||
|
||||
if any(pattern in model_lower for pattern in google_patterns):
|
||||
return "google"
|
||||
|
||||
# OpenAI embedding models (and default for unknown)
|
||||
openai_patterns = [
|
||||
"text-embedding-ada-002",
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large"
|
||||
]
|
||||
|
||||
# Default to OpenAI for OpenAI models or unknown models
|
||||
return "openai"
|
||||
|
||||
|
||||
class CredentialService:
|
||||
"""Service for managing application credentials and configuration."""
|
||||
|
||||
@@ -239,6 +277,14 @@ class CredentialService:
|
||||
self._rag_cache_timestamp = None
|
||||
logger.debug(f"Invalidated RAG settings cache due to update of {key}")
|
||||
|
||||
# Also invalidate provider service cache to ensure immediate effect
|
||||
try:
|
||||
from .llm_provider_service import clear_provider_cache
|
||||
clear_provider_cache()
|
||||
logger.debug("Also cleared LLM provider service cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear provider service cache: {e}")
|
||||
|
||||
# Also invalidate LLM provider service cache for provider config
|
||||
try:
|
||||
from . import llm_provider_service
|
||||
@@ -281,6 +327,14 @@ class CredentialService:
|
||||
self._rag_cache_timestamp = None
|
||||
logger.debug(f"Invalidated RAG settings cache due to deletion of {key}")
|
||||
|
||||
# Also invalidate provider service cache to ensure immediate effect
|
||||
try:
|
||||
from .llm_provider_service import clear_provider_cache
|
||||
clear_provider_cache()
|
||||
logger.debug("Also cleared LLM provider service cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear provider service cache: {e}")
|
||||
|
||||
# Also invalidate LLM provider service cache for provider config
|
||||
try:
|
||||
from . import llm_provider_service
|
||||
@@ -419,8 +473,33 @@ class CredentialService:
|
||||
# Get RAG strategy settings (where UI saves provider selection)
|
||||
rag_settings = await self.get_credentials_by_category("rag_strategy")
|
||||
|
||||
# Get the selected provider
|
||||
provider = rag_settings.get("LLM_PROVIDER", "openai")
|
||||
# Get the selected provider based on service type
|
||||
if service_type == "embedding":
|
||||
# Get the LLM provider setting to determine embedding provider
|
||||
llm_provider = rag_settings.get("LLM_PROVIDER", "openai")
|
||||
embedding_model = rag_settings.get("EMBEDDING_MODEL", "text-embedding-3-small")
|
||||
|
||||
# Determine embedding provider based on LLM provider
|
||||
if llm_provider == "google":
|
||||
provider = "google"
|
||||
elif llm_provider == "ollama":
|
||||
provider = "ollama"
|
||||
elif llm_provider == "openrouter":
|
||||
# OpenRouter supports both OpenAI and Google embedding models
|
||||
provider = _detect_embedding_provider_from_model(embedding_model)
|
||||
elif llm_provider in ["anthropic", "grok"]:
|
||||
# Anthropic and Grok support both OpenAI and Google embedding models
|
||||
provider = _detect_embedding_provider_from_model(embedding_model)
|
||||
else:
|
||||
# Default case (openai, or unknown providers)
|
||||
provider = "openai"
|
||||
|
||||
logger.debug(f"Determined embedding provider '{provider}' from LLM provider '{llm_provider}' and embedding model '{embedding_model}'")
|
||||
else:
|
||||
provider = rag_settings.get("LLM_PROVIDER", "openai")
|
||||
# Ensure provider is a valid string, not a boolean or other type
|
||||
if not isinstance(provider, str) or provider.lower() in ("true", "false", "none", "null"):
|
||||
provider = "openai"
|
||||
|
||||
# Get API key for this provider
|
||||
api_key = await self._get_provider_api_key(provider)
|
||||
@@ -464,6 +543,9 @@ class CredentialService:
|
||||
key_mapping = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"grok": "GROK_API_KEY",
|
||||
"ollama": None, # No API key needed
|
||||
}
|
||||
|
||||
@@ -475,9 +557,15 @@ class CredentialService:
|
||||
def _get_provider_base_url(self, provider: str, rag_settings: dict) -> str | None:
|
||||
"""Get base URL for provider."""
|
||||
if provider == "ollama":
|
||||
return rag_settings.get("LLM_BASE_URL", "http://localhost:11434/v1")
|
||||
return rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434/v1")
|
||||
elif provider == "google":
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
elif provider == "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
elif provider == "anthropic":
|
||||
return "https://api.anthropic.com/v1"
|
||||
elif provider == "grok":
|
||||
return "https://api.x.ai/v1"
|
||||
return None # Use default for OpenAI
|
||||
|
||||
async def set_active_provider(self, provider: str, service_type: str = "llm") -> bool:
|
||||
@@ -485,7 +573,7 @@ class CredentialService:
|
||||
try:
|
||||
# For now, we'll update the RAG strategy settings
|
||||
return await self.set_credential(
|
||||
"llm_provider",
|
||||
"LLM_PROVIDER",
|
||||
provider,
|
||||
category="rag_strategy",
|
||||
description=f"Active {service_type} provider",
|
||||
|
||||
@@ -10,7 +10,13 @@ import os
|
||||
import openai
|
||||
|
||||
from ...config.logfire_config import search_logger
|
||||
from ..llm_provider_service import get_llm_client
|
||||
from ..credential_service import credential_service
|
||||
from ..llm_provider_service import (
|
||||
extract_message_text,
|
||||
get_llm_client,
|
||||
prepare_chat_completion_params,
|
||||
requires_max_completion_tokens,
|
||||
)
|
||||
from ..threading_service import get_threading_service
|
||||
|
||||
|
||||
@@ -32,8 +38,6 @@ async def generate_contextual_embedding(
|
||||
"""
|
||||
# Model choice is a RAG setting, get from credential service
|
||||
try:
|
||||
from ...services.credential_service import credential_service
|
||||
|
||||
model_choice = await credential_service.get_credential("MODEL_CHOICE", "gpt-4.1-nano")
|
||||
except Exception as e:
|
||||
# Fallback to environment variable or default
|
||||
@@ -65,20 +69,25 @@ Please give a short succinct context to situate this chunk within the overall do
|
||||
# Get model from provider configuration
|
||||
model = await _get_model_choice(provider)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
# Prepare parameters and convert max_tokens for GPT-5/reasoning models
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that provides concise contextual information.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=200,
|
||||
)
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1200 if requires_max_completion_tokens(model) else 200, # Much more tokens for reasoning models (GPT-5 needs extra for reasoning process)
|
||||
}
|
||||
final_params = prepare_chat_completion_params(model, params)
|
||||
response = await client.chat.completions.create(**final_params)
|
||||
|
||||
context = response.choices[0].message.content.strip()
|
||||
choice = response.choices[0] if response.choices else None
|
||||
context, _, _ = extract_message_text(choice)
|
||||
context = context.strip()
|
||||
contextual_text = f"{context}\n---\n{chunk}"
|
||||
|
||||
return contextual_text, True
|
||||
@@ -111,7 +120,7 @@ async def process_chunk_with_context(
|
||||
|
||||
|
||||
async def _get_model_choice(provider: str | None = None) -> str:
|
||||
"""Get model choice from credential service."""
|
||||
"""Get model choice from credential service with centralized defaults."""
|
||||
from ..credential_service import credential_service
|
||||
|
||||
# Get the active provider configuration
|
||||
@@ -119,31 +128,36 @@ async def _get_model_choice(provider: str | None = None) -> str:
|
||||
model = provider_config.get("chat_model", "").strip() # Strip whitespace
|
||||
provider_name = provider_config.get("provider", "openai")
|
||||
|
||||
# Handle empty model case - fallback to provider-specific defaults or explicit config
|
||||
# Handle empty model case - use centralized defaults
|
||||
if not model:
|
||||
search_logger.warning(f"chat_model is empty for provider {provider_name}, using fallback logic")
|
||||
|
||||
search_logger.warning(f"chat_model is empty for provider {provider_name}, using centralized defaults")
|
||||
|
||||
# Special handling for Ollama to check specific credential
|
||||
if provider_name == "ollama":
|
||||
# Try to get OLLAMA_CHAT_MODEL specifically
|
||||
try:
|
||||
ollama_model = await credential_service.get_credential("OLLAMA_CHAT_MODEL")
|
||||
if ollama_model and ollama_model.strip():
|
||||
model = ollama_model.strip()
|
||||
search_logger.info(f"Using OLLAMA_CHAT_MODEL fallback: {model}")
|
||||
else:
|
||||
# Use a sensible Ollama default
|
||||
# Use default for Ollama
|
||||
model = "llama3.2:latest"
|
||||
search_logger.info(f"Using Ollama default model: {model}")
|
||||
search_logger.info(f"Using Ollama default: {model}")
|
||||
except Exception as e:
|
||||
search_logger.error(f"Error getting OLLAMA_CHAT_MODEL: {e}")
|
||||
model = "llama3.2:latest"
|
||||
search_logger.info(f"Using Ollama fallback model: {model}")
|
||||
elif provider_name == "google":
|
||||
model = "gemini-1.5-flash"
|
||||
search_logger.info(f"Using Ollama fallback: {model}")
|
||||
else:
|
||||
# OpenAI or other providers
|
||||
model = "gpt-4o-mini"
|
||||
|
||||
# Use provider-specific defaults
|
||||
provider_defaults = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"openrouter": "anthropic/claude-3.5-sonnet",
|
||||
"google": "gemini-1.5-flash",
|
||||
"anthropic": "claude-3-5-haiku-20241022",
|
||||
"grok": "grok-3-mini"
|
||||
}
|
||||
model = provider_defaults.get(provider_name, "gpt-4o-mini")
|
||||
search_logger.debug(f"Using default model for provider {provider_name}: {model}")
|
||||
search_logger.debug(f"Using model from credential service: {model}")
|
||||
|
||||
return model
|
||||
@@ -174,38 +188,48 @@ async def generate_contextual_embeddings_batch(
|
||||
model_choice = await _get_model_choice(provider)
|
||||
|
||||
# Build batch prompt for ALL chunks at once
|
||||
batch_prompt = (
|
||||
"Process the following chunks and provide contextual information for each:\\n\\n"
|
||||
)
|
||||
batch_prompt = "Process the following chunks and provide contextual information for each:\n\n"
|
||||
|
||||
for i, (doc, chunk) in enumerate(zip(full_documents, chunks, strict=False)):
|
||||
# Use only 2000 chars of document context to save tokens
|
||||
doc_preview = doc[:2000] if len(doc) > 2000 else doc
|
||||
batch_prompt += f"CHUNK {i + 1}:\\n"
|
||||
batch_prompt += f"<document_preview>\\n{doc_preview}\\n</document_preview>\\n"
|
||||
batch_prompt += f"<chunk>\\n{chunk[:500]}\\n</chunk>\\n\\n" # Limit chunk preview
|
||||
batch_prompt += f"CHUNK {i + 1}:\n"
|
||||
batch_prompt += f"<document_preview>\n{doc_preview}\n</document_preview>\n"
|
||||
batch_prompt += f"<chunk>\n{chunk[:500]}\n</chunk>\n\n" # Limit chunk preview
|
||||
|
||||
batch_prompt += "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. Format your response as:\\nCHUNK 1: [context]\\nCHUNK 2: [context]\\netc."
|
||||
batch_prompt += (
|
||||
"For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. "
|
||||
"Format your response as:\nCHUNK 1: [context]\nCHUNK 2: [context]\netc."
|
||||
)
|
||||
|
||||
# Make single API call for ALL chunks
|
||||
response = await client.chat.completions.create(
|
||||
model=model_choice,
|
||||
messages=[
|
||||
# Prepare parameters and convert max_tokens for GPT-5/reasoning models
|
||||
batch_params = {
|
||||
"model": model_choice,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that generates contextual information for document chunks.",
|
||||
},
|
||||
{"role": "user", "content": batch_prompt},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=100 * len(chunks), # Limit response size
|
||||
)
|
||||
"temperature": 0,
|
||||
"max_tokens": (600 if requires_max_completion_tokens(model_choice) else 100) * len(chunks), # Much more tokens for reasoning models (GPT-5 needs extra reasoning space)
|
||||
}
|
||||
final_batch_params = prepare_chat_completion_params(model_choice, batch_params)
|
||||
response = await client.chat.completions.create(**final_batch_params)
|
||||
|
||||
# Parse response
|
||||
response_text = response.choices[0].message.content
|
||||
choice = response.choices[0] if response.choices else None
|
||||
response_text, _, _ = extract_message_text(choice)
|
||||
if not response_text:
|
||||
search_logger.error(
|
||||
"Empty response from LLM when generating contextual embeddings batch"
|
||||
)
|
||||
return [(chunk, False) for chunk in chunks]
|
||||
|
||||
# Extract contexts from response
|
||||
lines = response_text.strip().split("\\n")
|
||||
lines = response_text.strip().split("\n")
|
||||
chunk_contexts = {}
|
||||
|
||||
for line in lines:
|
||||
@@ -245,4 +269,4 @@ async def generate_contextual_embeddings_batch(
|
||||
except Exception as e:
|
||||
search_logger.error(f"Error in contextual embedding batch: {e}")
|
||||
# Return non-contextual for all chunks
|
||||
return [(chunk, False) for chunk in chunks]
|
||||
return [(chunk, False) for chunk in chunks]
|
||||
|
||||
@@ -13,7 +13,7 @@ import openai
|
||||
|
||||
from ...config.logfire_config import safe_span, search_logger
|
||||
from ..credential_service import credential_service
|
||||
from ..llm_provider_service import get_embedding_model, get_llm_client
|
||||
from ..llm_provider_service import get_embedding_model, get_llm_client, is_google_embedding_model, is_openai_embedding_model
|
||||
from ..threading_service import get_threading_service
|
||||
from .embedding_exceptions import (
|
||||
EmbeddingAPIError,
|
||||
@@ -152,34 +152,56 @@ async def create_embeddings_batch(
|
||||
if not texts:
|
||||
return EmbeddingBatchResult()
|
||||
|
||||
result = EmbeddingBatchResult()
|
||||
|
||||
# Validate that all items in texts are strings
|
||||
validated_texts = []
|
||||
for i, text in enumerate(texts):
|
||||
if not isinstance(text, str):
|
||||
search_logger.error(
|
||||
f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
|
||||
)
|
||||
# Try to convert to string
|
||||
try:
|
||||
validated_texts.append(str(text))
|
||||
except Exception as e:
|
||||
search_logger.error(
|
||||
f"Failed to convert text at index {i} to string: {e}", exc_info=True
|
||||
)
|
||||
validated_texts.append("") # Use empty string as fallback
|
||||
else:
|
||||
if isinstance(text, str):
|
||||
validated_texts.append(text)
|
||||
continue
|
||||
|
||||
search_logger.error(
|
||||
f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
|
||||
)
|
||||
try:
|
||||
converted = str(text)
|
||||
validated_texts.append(converted)
|
||||
except Exception as conversion_error:
|
||||
search_logger.error(
|
||||
f"Failed to convert text at index {i} to string: {conversion_error}",
|
||||
exc_info=True,
|
||||
)
|
||||
result.add_failure(
|
||||
repr(text),
|
||||
EmbeddingAPIError("Invalid text type", original_error=conversion_error),
|
||||
batch_index=None,
|
||||
)
|
||||
|
||||
texts = validated_texts
|
||||
|
||||
result = EmbeddingBatchResult()
|
||||
threading_service = get_threading_service()
|
||||
|
||||
with safe_span(
|
||||
"create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts)
|
||||
) as span:
|
||||
try:
|
||||
async with get_llm_client(provider=provider, use_embedding_provider=True) as client:
|
||||
# Intelligent embedding provider routing based on model type
|
||||
# Get the embedding model first to determine the correct provider
|
||||
embedding_model = await get_embedding_model(provider=provider)
|
||||
|
||||
# Route to correct provider based on model type
|
||||
if is_google_embedding_model(embedding_model):
|
||||
embedding_provider = "google"
|
||||
search_logger.info(f"Routing to Google for embedding model: {embedding_model}")
|
||||
elif is_openai_embedding_model(embedding_model) or "openai/" in embedding_model.lower():
|
||||
embedding_provider = "openai"
|
||||
search_logger.info(f"Routing to OpenAI for embedding model: {embedding_model}")
|
||||
else:
|
||||
# Keep original provider for ollama and other providers
|
||||
embedding_provider = provider
|
||||
search_logger.info(f"Using original provider '{provider}' for embedding model: {embedding_model}")
|
||||
|
||||
async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client:
|
||||
# Load batch size and dimensions from settings
|
||||
try:
|
||||
rag_settings = await credential_service.get_credentials_by_category(
|
||||
@@ -220,7 +242,8 @@ async def create_embeddings_batch(
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Create embeddings for this batch
|
||||
embedding_model = await get_embedding_model(provider=provider)
|
||||
embedding_model = await get_embedding_model(provider=embedding_provider)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
233
python/src/server/services/migration_service.py
Normal file
233
python/src/server/services/migration_service.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Database migration tracking and management service.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import logfire
|
||||
from supabase import Client
|
||||
|
||||
from .client_manager import get_supabase_client
|
||||
from ..config.version import ARCHON_VERSION
|
||||
|
||||
|
||||
class MigrationRecord:
|
||||
"""Represents a migration record from the database."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self.id = data.get("id")
|
||||
self.version = data.get("version")
|
||||
self.migration_name = data.get("migration_name")
|
||||
self.applied_at = data.get("applied_at")
|
||||
self.checksum = data.get("checksum")
|
||||
|
||||
|
||||
class PendingMigration:
|
||||
"""Represents a pending migration from the filesystem."""
|
||||
|
||||
def __init__(self, version: str, name: str, sql_content: str, file_path: str):
|
||||
self.version = version
|
||||
self.name = name
|
||||
self.sql_content = sql_content
|
||||
self.file_path = file_path
|
||||
self.checksum = self._calculate_checksum(sql_content)
|
||||
|
||||
def _calculate_checksum(self, content: str) -> str:
|
||||
"""Calculate MD5 checksum of migration content."""
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
class MigrationService:
|
||||
"""Service for managing database migrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._supabase: Client | None = None
|
||||
# Handle both Docker (/app/migration) and local (./migration) environments
|
||||
if Path("/app/migration").exists():
|
||||
self._migrations_dir = Path("/app/migration")
|
||||
else:
|
||||
self._migrations_dir = Path("migration")
|
||||
|
||||
def _get_supabase_client(self) -> Client:
|
||||
"""Get or create Supabase client."""
|
||||
if not self._supabase:
|
||||
self._supabase = get_supabase_client()
|
||||
return self._supabase
|
||||
|
||||
async def check_migrations_table_exists(self) -> bool:
|
||||
"""
|
||||
Check if the archon_migrations table exists in the database.
|
||||
|
||||
Returns:
|
||||
True if table exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
supabase = self._get_supabase_client()
|
||||
|
||||
# Query to check if table exists
|
||||
result = supabase.rpc(
|
||||
"sql",
|
||||
{
|
||||
"query": """
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'archon_migrations'
|
||||
) as exists
|
||||
"""
|
||||
}
|
||||
).execute()
|
||||
|
||||
# Check if result indicates table exists
|
||||
if result.data and len(result.data) > 0:
|
||||
return result.data[0].get("exists", False)
|
||||
return False
|
||||
except Exception:
|
||||
# If the SQL function doesn't exist or query fails, try direct query
|
||||
try:
|
||||
supabase = self._get_supabase_client()
|
||||
# Try to select from the table with limit 0
|
||||
supabase.table("archon_migrations").select("id").limit(0).execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
logfire.info(f"Migrations table does not exist: {e}")
|
||||
return False
|
||||
|
||||
async def get_applied_migrations(self) -> list[MigrationRecord]:
|
||||
"""
|
||||
Get list of applied migrations from the database.
|
||||
|
||||
Returns:
|
||||
List of MigrationRecord objects
|
||||
"""
|
||||
try:
|
||||
# Check if table exists first
|
||||
if not await self.check_migrations_table_exists():
|
||||
logfire.info("Migrations table does not exist, returning empty list")
|
||||
return []
|
||||
|
||||
supabase = self._get_supabase_client()
|
||||
result = supabase.table("archon_migrations").select("*").order("applied_at", desc=True).execute()
|
||||
|
||||
return [MigrationRecord(row) for row in result.data]
|
||||
except Exception as e:
|
||||
logfire.error(f"Error fetching applied migrations: {e}")
|
||||
# Return empty list if we can't fetch migrations
|
||||
return []
|
||||
|
||||
async def scan_migration_directory(self) -> list[PendingMigration]:
|
||||
"""
|
||||
Scan the migration directory for all SQL files.
|
||||
|
||||
Returns:
|
||||
List of PendingMigration objects
|
||||
"""
|
||||
migrations = []
|
||||
|
||||
if not self._migrations_dir.exists():
|
||||
logfire.warning(f"Migration directory does not exist: {self._migrations_dir}")
|
||||
return migrations
|
||||
|
||||
# Scan all version directories
|
||||
for version_dir in sorted(self._migrations_dir.iterdir()):
|
||||
if not version_dir.is_dir():
|
||||
continue
|
||||
|
||||
version = version_dir.name
|
||||
|
||||
# Scan all SQL files in version directory
|
||||
for sql_file in sorted(version_dir.glob("*.sql")):
|
||||
try:
|
||||
# Read SQL content
|
||||
with open(sql_file, encoding="utf-8") as f:
|
||||
sql_content = f.read()
|
||||
|
||||
# Extract migration name (filename without extension)
|
||||
migration_name = sql_file.stem
|
||||
|
||||
# Create pending migration object
|
||||
migration = PendingMigration(
|
||||
version=version,
|
||||
name=migration_name,
|
||||
sql_content=sql_content,
|
||||
file_path=str(sql_file.relative_to(Path.cwd())),
|
||||
)
|
||||
migrations.append(migration)
|
||||
except Exception as e:
|
||||
logfire.error(f"Error reading migration file {sql_file}: {e}")
|
||||
|
||||
return migrations
|
||||
|
||||
async def get_pending_migrations(self) -> list[PendingMigration]:
|
||||
"""
|
||||
Get list of pending migrations by comparing filesystem with database.
|
||||
|
||||
Returns:
|
||||
List of PendingMigration objects that haven't been applied
|
||||
"""
|
||||
# Get all migrations from filesystem
|
||||
all_migrations = await self.scan_migration_directory()
|
||||
|
||||
# Check if migrations table exists
|
||||
if not await self.check_migrations_table_exists():
|
||||
# Bootstrap case - all migrations are pending
|
||||
logfire.info("Migrations table doesn't exist, all migrations are pending")
|
||||
return all_migrations
|
||||
|
||||
# Get applied migrations from database
|
||||
applied_migrations = await self.get_applied_migrations()
|
||||
|
||||
# Create set of applied migration identifiers
|
||||
applied_set = {(m.version, m.migration_name) for m in applied_migrations}
|
||||
|
||||
# Filter out applied migrations
|
||||
pending = [m for m in all_migrations if (m.version, m.name) not in applied_set]
|
||||
|
||||
return pending
|
||||
|
||||
async def get_migration_status(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive migration status.
|
||||
|
||||
Returns:
|
||||
Dictionary with pending and applied migrations info
|
||||
"""
|
||||
pending = await self.get_pending_migrations()
|
||||
applied = await self.get_applied_migrations()
|
||||
|
||||
# Check if bootstrap is required
|
||||
bootstrap_required = not await self.check_migrations_table_exists()
|
||||
|
||||
return {
|
||||
"pending_migrations": [
|
||||
{
|
||||
"version": m.version,
|
||||
"name": m.name,
|
||||
"sql_content": m.sql_content,
|
||||
"file_path": m.file_path,
|
||||
"checksum": m.checksum,
|
||||
}
|
||||
for m in pending
|
||||
],
|
||||
"applied_migrations": [
|
||||
{
|
||||
"version": m.version,
|
||||
"migration_name": m.migration_name,
|
||||
"applied_at": m.applied_at,
|
||||
"checksum": m.checksum,
|
||||
}
|
||||
for m in applied
|
||||
],
|
||||
"has_pending": len(pending) > 0,
|
||||
"bootstrap_required": bootstrap_required,
|
||||
"current_version": ARCHON_VERSION,
|
||||
"pending_count": len(pending),
|
||||
"applied_count": len(applied),
|
||||
}
|
||||
|
||||
|
||||
# Export singleton instance
|
||||
migration_service = MigrationService()
|
||||
@@ -2,7 +2,7 @@
|
||||
Provider Discovery Service
|
||||
|
||||
Discovers available models, checks provider health, and provides model specifications
|
||||
for OpenAI, Google Gemini, Ollama, and Anthropic providers.
|
||||
for OpenAI, Google Gemini, Ollama, Anthropic, and Grok providers.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -23,7 +23,7 @@ _provider_cache: dict[str, tuple[Any, float]] = {}
|
||||
_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
# Default Ollama instance URL (configurable via environment/settings)
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
DEFAULT_OLLAMA_URL = "http://host.docker.internal:11434"
|
||||
|
||||
# Model pattern detection for dynamic capabilities (no hardcoded model names)
|
||||
CHAT_MODEL_PATTERNS = ["llama", "qwen", "mistral", "codellama", "phi", "gemma", "vicuna", "orca"]
|
||||
@@ -359,6 +359,36 @@ class ProviderDiscoveryService:
|
||||
|
||||
return models
|
||||
|
||||
async def discover_grok_models(self, api_key: str) -> list[ModelSpec]:
|
||||
"""Discover available Grok models."""
|
||||
cache_key = f"grok_models_{hash(api_key)}"
|
||||
cached = self._get_cached_result(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
models = []
|
||||
try:
|
||||
# Grok model specifications
|
||||
model_specs = [
|
||||
ModelSpec("grok-3-mini", "grok", 32768, True, True, False, None, 0.15, 0.60, "Fast and efficient Grok model"),
|
||||
ModelSpec("grok-3", "grok", 32768, True, True, False, None, 2.00, 10.00, "Standard Grok model"),
|
||||
ModelSpec("grok-4", "grok", 32768, True, True, False, None, 5.00, 25.00, "Advanced Grok model"),
|
||||
ModelSpec("grok-2-vision", "grok", 8192, True, True, True, None, 3.00, 15.00, "Grok model with vision capabilities"),
|
||||
ModelSpec("grok-2-latest", "grok", 8192, True, True, False, None, 2.00, 10.00, "Latest Grok 2 model"),
|
||||
]
|
||||
|
||||
# Test connectivity - Grok doesn't have a models list endpoint,
|
||||
# so we'll just return the known models if API key is provided
|
||||
if api_key:
|
||||
models = model_specs
|
||||
self._cache_result(cache_key, models)
|
||||
logger.info(f"Discovered {len(models)} Grok models")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error discovering Grok models: {e}")
|
||||
|
||||
return models
|
||||
|
||||
async def check_provider_health(self, provider: str, config: dict[str, Any]) -> ProviderStatus:
|
||||
"""Check health and connectivity status of a provider."""
|
||||
start_time = time.time()
|
||||
@@ -456,6 +486,23 @@ class ProviderDiscoveryService:
|
||||
last_checked=time.time()
|
||||
)
|
||||
|
||||
elif provider == "grok":
|
||||
api_key = config.get("api_key")
|
||||
if not api_key:
|
||||
return ProviderStatus(provider, False, None, "API key not configured")
|
||||
|
||||
# Grok doesn't have a health check endpoint, so we'll assume it's available
|
||||
# if API key is provided. In a real implementation, you might want to make a
|
||||
# small test request to verify the key is valid.
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
return ProviderStatus(
|
||||
provider="grok",
|
||||
is_available=True,
|
||||
response_time_ms=response_time,
|
||||
models_available=5, # Known model count
|
||||
last_checked=time.time()
|
||||
)
|
||||
|
||||
else:
|
||||
return ProviderStatus(provider, False, None, f"Unknown provider: {provider}")
|
||||
|
||||
@@ -496,6 +543,11 @@ class ProviderDiscoveryService:
|
||||
if anthropic_key:
|
||||
providers["anthropic"] = await self.discover_anthropic_models(anthropic_key)
|
||||
|
||||
# Grok
|
||||
grok_key = await credential_service.get_credential("GROK_API_KEY")
|
||||
if grok_key:
|
||||
providers["grok"] = await self.discover_grok_models(grok_key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all available models: {e}")
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from ...config.logfire_config import get_logger, safe_span
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Fixed similarity threshold for vector results
|
||||
SIMILARITY_THRESHOLD = 0.15
|
||||
SIMILARITY_THRESHOLD = 0.05
|
||||
|
||||
|
||||
class BaseSearchStrategy:
|
||||
|
||||
@@ -11,7 +11,7 @@ from supabase import Client
|
||||
|
||||
from ..config.logfire_config import get_logger, search_logger
|
||||
from .client_manager import get_supabase_client
|
||||
from .llm_provider_service import get_llm_client
|
||||
from .llm_provider_service import extract_message_text, get_llm_client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -72,20 +72,21 @@ The above content is from the documentation for '{source_id}'. Please provide a
|
||||
)
|
||||
|
||||
# Extract the generated summary with proper error handling
|
||||
if not response or not response.choices or len(response.choices) == 0:
|
||||
search_logger.error(f"Empty or invalid response from LLM for {source_id}")
|
||||
return default_summary
|
||||
|
||||
message_content = response.choices[0].message.content
|
||||
if message_content is None:
|
||||
search_logger.error(f"LLM returned None content for {source_id}")
|
||||
return default_summary
|
||||
|
||||
summary = message_content.strip()
|
||||
|
||||
# Ensure the summary is not too long
|
||||
if len(summary) > max_length:
|
||||
summary = summary[:max_length] + "..."
|
||||
if not response or not response.choices or len(response.choices) == 0:
|
||||
search_logger.error(f"Empty or invalid response from LLM for {source_id}")
|
||||
return default_summary
|
||||
|
||||
choice = response.choices[0]
|
||||
summary_text, _, _ = extract_message_text(choice)
|
||||
if not summary_text:
|
||||
search_logger.error(f"LLM returned None content for {source_id}")
|
||||
return default_summary
|
||||
|
||||
summary = summary_text.strip()
|
||||
|
||||
# Ensure the summary is not too long
|
||||
if len(summary) > max_length:
|
||||
summary = summary[:max_length] + "..."
|
||||
|
||||
return summary
|
||||
|
||||
@@ -187,7 +188,9 @@ Generate only the title, nothing else."""
|
||||
],
|
||||
)
|
||||
|
||||
generated_title = response.choices[0].message.content.strip()
|
||||
choice = response.choices[0]
|
||||
generated_title, _, _ = extract_message_text(choice)
|
||||
generated_title = generated_title.strip()
|
||||
# Clean up the title
|
||||
generated_title = generated_title.strip("\"'")
|
||||
if len(generated_title) < 50: # Sanity check
|
||||
|
||||
@@ -8,6 +8,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from difflib import SequenceMatcher
|
||||
@@ -17,25 +18,104 @@ from urllib.parse import urlparse
|
||||
from supabase import Client
|
||||
|
||||
from ...config.logfire_config import search_logger
|
||||
from ..credential_service import credential_service
|
||||
from ..embeddings.contextual_embedding_service import generate_contextual_embeddings_batch
|
||||
from ..embeddings.embedding_service import create_embeddings_batch
|
||||
from ..llm_provider_service import (
|
||||
extract_json_from_reasoning,
|
||||
extract_message_text,
|
||||
get_llm_client,
|
||||
prepare_chat_completion_params,
|
||||
synthesize_json_from_reasoning,
|
||||
)
|
||||
|
||||
|
||||
def _get_model_choice() -> str:
|
||||
"""Get MODEL_CHOICE with direct fallback."""
|
||||
def _extract_json_payload(raw_response: str, context_code: str = "", language: str = "") -> str:
|
||||
"""Return the best-effort JSON object from an LLM response."""
|
||||
|
||||
if not raw_response:
|
||||
return raw_response
|
||||
|
||||
cleaned = raw_response.strip()
|
||||
|
||||
# Check if this looks like reasoning text first
|
||||
if _is_reasoning_text_response(cleaned):
|
||||
# Try intelligent extraction from reasoning text with context
|
||||
extracted = extract_json_from_reasoning(cleaned, context_code, language)
|
||||
if extracted:
|
||||
return extracted
|
||||
# extract_json_from_reasoning may return nothing; synthesize a fallback JSON if so\
|
||||
fallback_json = synthesize_json_from_reasoning("", context_code, language)
|
||||
if fallback_json:
|
||||
return fallback_json
|
||||
# If all else fails, return a minimal valid JSON object to avoid downstream errors
|
||||
return '{"example_name": "Code Example", "summary": "Code example extracted from context."}'
|
||||
|
||||
|
||||
if cleaned.startswith("```"):
|
||||
lines = cleaned.splitlines()
|
||||
# Drop opening fence
|
||||
lines = lines[1:]
|
||||
# Drop closing fence if present
|
||||
if lines and lines[-1].strip().startswith("```"):
|
||||
lines = lines[:-1]
|
||||
cleaned = "\n".join(lines).strip()
|
||||
|
||||
# Trim any leading/trailing text outside the outermost JSON braces
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start != -1 and end != -1 and end >= start:
|
||||
cleaned = cleaned[start : end + 1]
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
REASONING_STARTERS = [
|
||||
"okay, let's see", "okay, let me", "let me think", "first, i need to", "looking at this",
|
||||
"i need to", "analyzing", "let me work through", "thinking about", "let me see"
|
||||
]
|
||||
|
||||
def _is_reasoning_text_response(text: str) -> bool:
|
||||
"""Detect if response is reasoning text rather than direct JSON."""
|
||||
if not text or len(text) < 20:
|
||||
return False
|
||||
|
||||
text_lower = text.lower().strip()
|
||||
|
||||
# Check if it's clearly not JSON (starts with reasoning text)
|
||||
starts_with_reasoning = any(text_lower.startswith(starter) for starter in REASONING_STARTERS)
|
||||
|
||||
# Check if it lacks immediate JSON structure
|
||||
lacks_immediate_json = not text_lower.lstrip().startswith('{')
|
||||
|
||||
return starts_with_reasoning or (lacks_immediate_json and any(pattern in text_lower for pattern in REASONING_STARTERS))
|
||||
async def _get_model_choice() -> str:
|
||||
"""Get MODEL_CHOICE with provider-aware defaults from centralized service."""
|
||||
try:
|
||||
# Direct cache/env fallback
|
||||
from ..credential_service import credential_service
|
||||
# Get the active provider configuration
|
||||
provider_config = await credential_service.get_active_provider("llm")
|
||||
active_provider = provider_config.get("provider", "openai")
|
||||
model = provider_config.get("chat_model")
|
||||
|
||||
if credential_service._cache_initialized and "MODEL_CHOICE" in credential_service._cache:
|
||||
model = credential_service._cache["MODEL_CHOICE"]
|
||||
else:
|
||||
model = os.getenv("MODEL_CHOICE", "gpt-4.1-nano")
|
||||
search_logger.debug(f"Using model choice: {model}")
|
||||
# If no custom model is set, use provider-specific defaults
|
||||
if not model or model.strip() == "":
|
||||
# Provider-specific defaults
|
||||
provider_defaults = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"openrouter": "anthropic/claude-3.5-sonnet",
|
||||
"google": "gemini-1.5-flash",
|
||||
"ollama": "llama3.2:latest",
|
||||
"anthropic": "claude-3-5-haiku-20241022",
|
||||
"grok": "grok-3-mini"
|
||||
}
|
||||
model = provider_defaults.get(active_provider, "gpt-4o-mini")
|
||||
search_logger.debug(f"Using default model for provider {active_provider}: {model}")
|
||||
|
||||
search_logger.debug(f"Using model for provider {active_provider}: {model}")
|
||||
return model
|
||||
except Exception as e:
|
||||
search_logger.warning(f"Error getting model choice: {e}, using default")
|
||||
return "gpt-4.1-nano"
|
||||
return "gpt-4o-mini"
|
||||
|
||||
|
||||
def _get_max_workers() -> int:
|
||||
@@ -155,6 +235,7 @@ def _select_best_code_variant(similar_blocks: list[dict[str, Any]]) -> dict[str,
|
||||
return best_block
|
||||
|
||||
|
||||
|
||||
def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract code blocks from markdown content along with context.
|
||||
@@ -168,8 +249,6 @@ def extract_code_blocks(markdown_content: str, min_length: int = None) -> list[d
|
||||
"""
|
||||
# Load all code extraction settings with direct fallback
|
||||
try:
|
||||
from ...services.credential_service import credential_service
|
||||
|
||||
def _get_setting_fallback(key: str, default: str) -> str:
|
||||
if credential_service._cache_initialized and key in credential_service._cache:
|
||||
return credential_service._cache[key]
|
||||
@@ -507,7 +586,7 @@ def generate_code_example_summary(
|
||||
A dictionary with 'summary' and 'example_name'
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
# Run the async version in the current thread
|
||||
return asyncio.run(_generate_code_example_summary_async(code, context_before, context_after, language, provider))
|
||||
|
||||
@@ -518,13 +597,22 @@ async def _generate_code_example_summary_async(
|
||||
"""
|
||||
Async version of generate_code_example_summary using unified LLM provider service.
|
||||
"""
|
||||
from ..llm_provider_service import get_llm_client
|
||||
|
||||
# Get model choice from credential service (RAG setting)
|
||||
model_choice = _get_model_choice()
|
||||
|
||||
# Create the prompt
|
||||
prompt = f"""<context_before>
|
||||
# Get model choice from credential service (RAG setting)
|
||||
model_choice = await _get_model_choice()
|
||||
|
||||
# If provider is not specified, get it from credential service
|
||||
if provider is None:
|
||||
try:
|
||||
provider_config = await credential_service.get_active_provider("llm")
|
||||
provider = provider_config.get("provider", "openai")
|
||||
search_logger.debug(f"Auto-detected provider from credential service: {provider}")
|
||||
except Exception as e:
|
||||
search_logger.warning(f"Failed to get provider from credential service: {e}, defaulting to openai")
|
||||
provider = "openai"
|
||||
|
||||
# Create the prompt variants: base prompt, guarded prompt (JSON reminder), and strict prompt for retries
|
||||
base_prompt = f"""<context_before>
|
||||
{context_before[-500:] if len(context_before) > 500 else context_before}
|
||||
</context_before>
|
||||
|
||||
@@ -548,6 +636,16 @@ Format your response as JSON:
|
||||
"summary": "2-3 sentence description of what the code demonstrates"
|
||||
}}
|
||||
"""
|
||||
guard_prompt = (
|
||||
base_prompt
|
||||
+ "\n\nImportant: Respond with a valid JSON object that exactly matches the keys "
|
||||
'{"example_name": string, "summary": string}. Do not include commentary, '
|
||||
"markdown fences, or reasoning notes."
|
||||
)
|
||||
strict_prompt = (
|
||||
guard_prompt
|
||||
+ "\n\nSecond attempt enforcement: Return JSON only with the exact schema. No additional text or reasoning content."
|
||||
)
|
||||
|
||||
try:
|
||||
# Use unified LLM provider service
|
||||
@@ -555,25 +653,261 @@ Format your response as JSON:
|
||||
search_logger.info(
|
||||
f"Generating summary for {hash(code) & 0xffffff:06x} using model: {model_choice}"
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model_choice,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
max_tokens=500,
|
||||
temperature=0.3,
|
||||
|
||||
provider_lower = provider.lower()
|
||||
is_grok_model = (provider_lower == "grok") or ("grok" in model_choice.lower())
|
||||
|
||||
supports_response_format_base = (
|
||||
provider_lower in {"openai", "google", "anthropic"}
|
||||
or (provider_lower == "openrouter" and model_choice.startswith("openai/"))
|
||||
)
|
||||
|
||||
response_content = response.choices[0].message.content.strip()
|
||||
last_response_obj = None
|
||||
last_elapsed_time = None
|
||||
last_response_content = ""
|
||||
last_json_error: json.JSONDecodeError | None = None
|
||||
|
||||
for enforce_json, current_prompt in ((False, guard_prompt), (True, strict_prompt)):
|
||||
request_params = {
|
||||
"model": model_choice,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that analyzes code examples and provides JSON responses with example names and summaries.",
|
||||
},
|
||||
{"role": "user", "content": current_prompt},
|
||||
],
|
||||
"max_tokens": 2000,
|
||||
"temperature": 0.3,
|
||||
}
|
||||
|
||||
should_use_response_format = False
|
||||
if enforce_json:
|
||||
if not is_grok_model and (supports_response_format_base or provider_lower == "openrouter"):
|
||||
should_use_response_format = True
|
||||
else:
|
||||
if supports_response_format_base:
|
||||
should_use_response_format = True
|
||||
|
||||
if should_use_response_format:
|
||||
request_params["response_format"] = {"type": "json_object"}
|
||||
|
||||
if is_grok_model:
|
||||
unsupported_params = ["presence_penalty", "frequency_penalty", "stop", "reasoning_effort"]
|
||||
for param in unsupported_params:
|
||||
if param in request_params:
|
||||
removed_value = request_params.pop(param)
|
||||
search_logger.warning(f"Removed unsupported Grok parameter '{param}': {removed_value}")
|
||||
|
||||
supported_params = ["model", "messages", "max_tokens", "temperature", "response_format", "stream", "tools", "tool_choice"]
|
||||
for param in list(request_params.keys()):
|
||||
if param not in supported_params:
|
||||
search_logger.warning(f"Parameter '{param}' may not be supported by Grok reasoning models")
|
||||
|
||||
start_time = time.time()
|
||||
max_retries = 3 if is_grok_model else 1
|
||||
retry_delay = 1.0
|
||||
response_content_local = ""
|
||||
reasoning_text_local = ""
|
||||
json_error_occurred = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if is_grok_model and attempt > 0:
|
||||
search_logger.info(f"Grok retry attempt {attempt + 1}/{max_retries} after {retry_delay:.1f}s delay")
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
final_params = prepare_chat_completion_params(model_choice, request_params)
|
||||
response = await client.chat.completions.create(**final_params)
|
||||
last_response_obj = response
|
||||
|
||||
choice = response.choices[0] if response.choices else None
|
||||
message = choice.message if choice and hasattr(choice, "message") else None
|
||||
response_content_local = ""
|
||||
reasoning_text_local = ""
|
||||
|
||||
if choice:
|
||||
response_content_local, reasoning_text_local, _ = extract_message_text(choice)
|
||||
|
||||
# Enhanced logging for response analysis
|
||||
if message and reasoning_text_local:
|
||||
content_preview = response_content_local[:100] if response_content_local else "None"
|
||||
reasoning_preview = reasoning_text_local[:100] if reasoning_text_local else "None"
|
||||
search_logger.debug(
|
||||
f"Response has reasoning content - content: '{content_preview}', reasoning: '{reasoning_preview}'"
|
||||
)
|
||||
|
||||
if response_content_local:
|
||||
last_response_content = response_content_local.strip()
|
||||
|
||||
# Pre-validate response before processing
|
||||
if len(last_response_content) < 20 or (len(last_response_content) < 50 and not last_response_content.strip().startswith('{')):
|
||||
# Very minimal response - likely "Okay\nOkay" type
|
||||
search_logger.debug(f"Minimal response detected: {repr(last_response_content)}")
|
||||
# Generate fallback directly from context
|
||||
fallback_json = synthesize_json_from_reasoning("", code, language)
|
||||
if fallback_json:
|
||||
try:
|
||||
result = json.loads(fallback_json)
|
||||
final_result = {
|
||||
"example_name": result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
|
||||
"summary": result.get("summary", "Code example for demonstration purposes."),
|
||||
}
|
||||
search_logger.info(f"Generated fallback summary from context - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}")
|
||||
return final_result
|
||||
except json.JSONDecodeError:
|
||||
pass # Continue to normal error handling
|
||||
else:
|
||||
# Even synthesis failed - provide hardcoded fallback for minimal responses
|
||||
final_result = {
|
||||
"example_name": f"Code Example{f' ({language})' if language else ''}",
|
||||
"summary": "Code example extracted from development context.",
|
||||
}
|
||||
search_logger.info(f"Used hardcoded fallback for minimal response - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}")
|
||||
return final_result
|
||||
|
||||
payload = _extract_json_payload(last_response_content, code, language)
|
||||
if payload != last_response_content:
|
||||
search_logger.debug(
|
||||
f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..."
|
||||
)
|
||||
|
||||
try:
|
||||
result = json.loads(payload)
|
||||
|
||||
if not result.get("example_name") or not result.get("summary"):
|
||||
search_logger.warning(f"Incomplete response from LLM: {result}")
|
||||
|
||||
final_result = {
|
||||
"example_name": result.get(
|
||||
"example_name", f"Code Example{f' ({language})' if language else ''}"
|
||||
),
|
||||
"summary": result.get("summary", "Code example for demonstration purposes."),
|
||||
}
|
||||
|
||||
search_logger.info(
|
||||
f"Generated code example summary - Name: '{final_result['example_name']}', Summary length: {len(final_result['summary'])}"
|
||||
)
|
||||
return final_result
|
||||
|
||||
except json.JSONDecodeError as json_error:
|
||||
last_json_error = json_error
|
||||
json_error_occurred = True
|
||||
snippet = last_response_content[:200]
|
||||
if not enforce_json:
|
||||
# Check if this was reasoning text that couldn't be parsed
|
||||
if _is_reasoning_text_response(last_response_content):
|
||||
search_logger.debug(
|
||||
f"Reasoning text detected but no JSON extracted. Response snippet: {repr(snippet)}"
|
||||
)
|
||||
else:
|
||||
search_logger.warning(
|
||||
f"Failed to parse JSON response from LLM (non-strict attempt). Error: {json_error}. Response snippet: {repr(snippet)}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
search_logger.error(
|
||||
f"Strict JSON enforcement still failed to produce valid JSON: {json_error}. Response snippet: {repr(snippet)}"
|
||||
)
|
||||
break
|
||||
|
||||
elif is_grok_model and attempt < max_retries - 1:
|
||||
search_logger.warning(f"Grok empty response on attempt {attempt + 1}, retrying...")
|
||||
retry_delay *= 2
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if is_grok_model and attempt < max_retries - 1:
|
||||
search_logger.error(f"Grok request failed on attempt {attempt + 1}: {e}, retrying...")
|
||||
retry_delay *= 2
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
if is_grok_model:
|
||||
elapsed_time = time.time() - start_time
|
||||
last_elapsed_time = elapsed_time
|
||||
search_logger.debug(f"Grok total response time: {elapsed_time:.2f}s")
|
||||
|
||||
if json_error_occurred:
|
||||
if not enforce_json:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
if response_content_local:
|
||||
# We would have returned already on success; if we reach here, parsing failed but we are not retrying
|
||||
continue
|
||||
|
||||
response_content = last_response_content
|
||||
response = last_response_obj
|
||||
elapsed_time = last_elapsed_time if last_elapsed_time is not None else 0.0
|
||||
|
||||
if last_json_error is not None and response_content:
|
||||
search_logger.error(
|
||||
f"LLM response after strict enforcement was still not valid JSON: {last_json_error}. Clearing response to trigger error handling."
|
||||
)
|
||||
response_content = ""
|
||||
|
||||
if not response_content:
|
||||
search_logger.error(f"Empty response from LLM for model: {model_choice} (provider: {provider})")
|
||||
if is_grok_model:
|
||||
search_logger.error("Grok empty response debugging:")
|
||||
search_logger.error(f" - Request took: {elapsed_time:.2f}s")
|
||||
search_logger.error(f" - Response status: {getattr(response, 'status_code', 'N/A')}")
|
||||
search_logger.error(f" - Response headers: {getattr(response, 'headers', 'N/A')}")
|
||||
search_logger.error(f" - Full response: {response}")
|
||||
search_logger.error(f" - Response choices length: {len(response.choices) if response.choices else 0}")
|
||||
if response.choices:
|
||||
search_logger.error(f" - First choice: {response.choices[0]}")
|
||||
search_logger.error(f" - Message content: '{response.choices[0].message.content}'")
|
||||
search_logger.error(f" - Message role: {response.choices[0].message.role}")
|
||||
search_logger.error("Check: 1) API key validity, 2) rate limits, 3) model availability")
|
||||
|
||||
# Implement fallback for Grok failures
|
||||
search_logger.warning("Attempting fallback to OpenAI due to Grok failure...")
|
||||
try:
|
||||
# Use OpenAI as fallback with similar parameters
|
||||
fallback_params = {
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": request_params["messages"],
|
||||
"temperature": request_params.get("temperature", 0.1),
|
||||
"max_tokens": request_params.get("max_tokens", 500),
|
||||
}
|
||||
|
||||
async with get_llm_client(provider="openai") as fallback_client:
|
||||
fallback_response = await fallback_client.chat.completions.create(**fallback_params)
|
||||
fallback_content = fallback_response.choices[0].message.content
|
||||
if fallback_content and fallback_content.strip():
|
||||
search_logger.info("gpt-4o-mini fallback succeeded")
|
||||
response_content = fallback_content.strip()
|
||||
else:
|
||||
search_logger.error("gpt-4o-mini fallback also returned empty response")
|
||||
raise ValueError(f"Both {model_choice} and gpt-4o-mini fallback failed")
|
||||
|
||||
except Exception as fallback_error:
|
||||
search_logger.error(f"gpt-4o-mini fallback failed: {fallback_error}")
|
||||
raise ValueError(f"{model_choice} failed and fallback to gpt-4o-mini also failed: {fallback_error}") from fallback_error
|
||||
else:
|
||||
search_logger.debug(f"Full response object: {response}")
|
||||
raise ValueError("Empty response from LLM")
|
||||
|
||||
if not response_content:
|
||||
# This should not happen after fallback logic, but safety check
|
||||
raise ValueError("No valid response content after all attempts")
|
||||
|
||||
response_content = response_content.strip()
|
||||
search_logger.debug(f"LLM API response: {repr(response_content[:200])}...")
|
||||
|
||||
result = json.loads(response_content)
|
||||
payload = _extract_json_payload(response_content, code, language)
|
||||
if payload != response_content:
|
||||
search_logger.debug(
|
||||
f"Sanitized LLM response payload before parsing: {repr(payload[:200])}..."
|
||||
)
|
||||
|
||||
result = json.loads(payload)
|
||||
|
||||
# Validate the response has the required fields
|
||||
if not result.get("example_name") or not result.get("summary"):
|
||||
@@ -595,12 +929,38 @@ Format your response as JSON:
|
||||
search_logger.error(
|
||||
f"Failed to parse JSON response from LLM: {e}, Response: {repr(response_content) if 'response_content' in locals() else 'No response'}"
|
||||
)
|
||||
# Try to generate context-aware fallback
|
||||
try:
|
||||
fallback_json = synthesize_json_from_reasoning("", code, language)
|
||||
if fallback_json:
|
||||
fallback_result = json.loads(fallback_json)
|
||||
search_logger.info(f"Generated context-aware fallback summary")
|
||||
return {
|
||||
"example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
|
||||
"summary": fallback_result.get("summary", "Code example for demonstration purposes."),
|
||||
}
|
||||
except Exception:
|
||||
pass # Fall through to generic fallback
|
||||
|
||||
return {
|
||||
"example_name": f"Code Example{f' ({language})' if language else ''}",
|
||||
"summary": "Code example for demonstration purposes.",
|
||||
}
|
||||
except Exception as e:
|
||||
search_logger.error(f"Error generating code summary using unified LLM provider: {e}")
|
||||
# Try to generate context-aware fallback
|
||||
try:
|
||||
fallback_json = synthesize_json_from_reasoning("", code, language)
|
||||
if fallback_json:
|
||||
fallback_result = json.loads(fallback_json)
|
||||
search_logger.info(f"Generated context-aware fallback summary after error")
|
||||
return {
|
||||
"example_name": fallback_result.get("example_name", f"Code Example{f' ({language})' if language else ''}"),
|
||||
"summary": fallback_result.get("summary", "Code example for demonstration purposes."),
|
||||
}
|
||||
except Exception:
|
||||
pass # Fall through to generic fallback
|
||||
|
||||
return {
|
||||
"example_name": f"Code Example{f' ({language})' if language else ''}",
|
||||
"summary": "Code example for demonstration purposes.",
|
||||
@@ -608,7 +968,7 @@ Format your response as JSON:
|
||||
|
||||
|
||||
async def generate_code_summaries_batch(
|
||||
code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None
|
||||
code_blocks: list[dict[str, Any]], max_workers: int = None, progress_callback=None, provider: str = None
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Generate summaries for multiple code blocks with rate limiting and proper worker management.
|
||||
@@ -617,6 +977,7 @@ async def generate_code_summaries_batch(
|
||||
code_blocks: List of code block dictionaries
|
||||
max_workers: Maximum number of concurrent API requests
|
||||
progress_callback: Optional callback for progress updates (async function)
|
||||
provider: LLM provider to use for generation (e.g., 'grok', 'openai', 'anthropic')
|
||||
|
||||
Returns:
|
||||
List of summary dictionaries
|
||||
@@ -627,8 +988,6 @@ async def generate_code_summaries_batch(
|
||||
# Get max_workers from settings if not provided
|
||||
if max_workers is None:
|
||||
try:
|
||||
from ...services.credential_service import credential_service
|
||||
|
||||
if (
|
||||
credential_service._cache_initialized
|
||||
and "CODE_SUMMARY_MAX_WORKERS" in credential_service._cache
|
||||
@@ -663,6 +1022,7 @@ async def generate_code_summaries_batch(
|
||||
block["context_before"],
|
||||
block["context_after"],
|
||||
block.get("language", ""),
|
||||
provider,
|
||||
)
|
||||
|
||||
# Update progress
|
||||
@@ -757,29 +1117,17 @@ async def add_code_examples_to_supabase(
|
||||
except Exception as e:
|
||||
search_logger.error(f"Error deleting existing code examples for {url}: {e}")
|
||||
|
||||
# Check if contextual embeddings are enabled
|
||||
# Check if contextual embeddings are enabled (use proper async method like document storage)
|
||||
try:
|
||||
from ..credential_service import credential_service
|
||||
|
||||
use_contextual_embeddings = credential_service._cache.get("USE_CONTEXTUAL_EMBEDDINGS")
|
||||
if isinstance(use_contextual_embeddings, str):
|
||||
use_contextual_embeddings = use_contextual_embeddings.lower() == "true"
|
||||
elif isinstance(use_contextual_embeddings, dict) and use_contextual_embeddings.get(
|
||||
"is_encrypted"
|
||||
):
|
||||
# Handle encrypted value
|
||||
encrypted_value = use_contextual_embeddings.get("encrypted_value")
|
||||
if encrypted_value:
|
||||
try:
|
||||
decrypted = credential_service._decrypt_value(encrypted_value)
|
||||
use_contextual_embeddings = decrypted.lower() == "true"
|
||||
except:
|
||||
use_contextual_embeddings = False
|
||||
else:
|
||||
use_contextual_embeddings = False
|
||||
raw_value = await credential_service.get_credential(
|
||||
"USE_CONTEXTUAL_EMBEDDINGS", "false", decrypt=True
|
||||
)
|
||||
if isinstance(raw_value, str):
|
||||
use_contextual_embeddings = raw_value.lower() == "true"
|
||||
else:
|
||||
use_contextual_embeddings = bool(use_contextual_embeddings)
|
||||
except:
|
||||
use_contextual_embeddings = bool(raw_value)
|
||||
except Exception as e:
|
||||
search_logger.error(f"DEBUG: Error reading contextual embeddings: {e}")
|
||||
# Fallback to environment variable
|
||||
use_contextual_embeddings = (
|
||||
os.getenv("USE_CONTEXTUAL_EMBEDDINGS", "false").lower() == "true"
|
||||
@@ -848,14 +1196,13 @@ async def add_code_examples_to_supabase(
|
||||
# Use only successful embeddings
|
||||
valid_embeddings = result.embeddings
|
||||
successful_texts = result.texts_processed
|
||||
|
||||
|
||||
# Get model information for tracking
|
||||
from ..llm_provider_service import get_embedding_model
|
||||
from ..credential_service import credential_service
|
||||
|
||||
|
||||
# Get embedding model name
|
||||
embedding_model_name = await get_embedding_model(provider=provider)
|
||||
|
||||
|
||||
# Get LLM chat model (used for code summaries and contextual embeddings if enabled)
|
||||
llm_chat_model = None
|
||||
try:
|
||||
@@ -868,7 +1215,7 @@ async def add_code_examples_to_supabase(
|
||||
llm_chat_model = await credential_service.get_credential("MODEL_CHOICE", "gpt-4o-mini")
|
||||
else:
|
||||
# For code summaries, we use MODEL_CHOICE
|
||||
llm_chat_model = _get_model_choice()
|
||||
llm_chat_model = await _get_model_choice()
|
||||
except Exception as e:
|
||||
search_logger.warning(f"Failed to get LLM chat model: {e}")
|
||||
llm_chat_model = "gpt-4o-mini" # Default fallback
|
||||
@@ -888,7 +1235,7 @@ async def add_code_examples_to_supabase(
|
||||
positions_by_text[text].append(original_indices[k])
|
||||
|
||||
# Map successful texts back to their original indices
|
||||
for embedding, text in zip(valid_embeddings, successful_texts, strict=False):
|
||||
for embedding, text in zip(valid_embeddings, successful_texts, strict=True):
|
||||
# Get the next available index for this text (handles duplicates)
|
||||
if positions_by_text[text]:
|
||||
orig_idx = positions_by_text[text].popleft() # Original j index in [i, batch_end)
|
||||
@@ -908,7 +1255,7 @@ async def add_code_examples_to_supabase(
|
||||
# Determine the correct embedding column based on dimension
|
||||
embedding_dim = len(embedding) if isinstance(embedding, list) else len(embedding.tolist())
|
||||
embedding_column = None
|
||||
|
||||
|
||||
if embedding_dim == 768:
|
||||
embedding_column = "embedding_768"
|
||||
elif embedding_dim == 1024:
|
||||
@@ -918,10 +1265,12 @@ async def add_code_examples_to_supabase(
|
||||
elif embedding_dim == 3072:
|
||||
embedding_column = "embedding_3072"
|
||||
else:
|
||||
# Default to closest supported dimension
|
||||
search_logger.warning(f"Unsupported embedding dimension {embedding_dim}, using embedding_1536")
|
||||
embedding_column = "embedding_1536"
|
||||
|
||||
# Skip unsupported dimensions to avoid corrupting the schema
|
||||
search_logger.error(
|
||||
f"Unsupported embedding dimension {embedding_dim}; skipping record to prevent column mismatch"
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data.append({
|
||||
"url": urls[idx],
|
||||
"chunk_number": chunk_numbers[idx],
|
||||
@@ -954,9 +1303,7 @@ async def add_code_examples_to_supabase(
|
||||
f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}"
|
||||
)
|
||||
search_logger.info(f"Retrying in {retry_delay} seconds...")
|
||||
import time
|
||||
|
||||
time.sleep(retry_delay)
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
else:
|
||||
# Final attempt failed
|
||||
|
||||
162
python/src/server/services/version_service.py
Normal file
162
python/src/server/services/version_service.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Version checking service with GitHub API integration.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import logfire
|
||||
|
||||
from ..config.version import ARCHON_VERSION, GITHUB_REPO_NAME, GITHUB_REPO_OWNER
|
||||
from ..utils.semantic_version import is_newer_version
|
||||
|
||||
|
||||
class VersionService:
|
||||
"""Service for checking Archon version against GitHub releases."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[str, Any] | None = None
|
||||
self._cache_time: datetime | None = None
|
||||
self._cache_ttl = 3600 # 1 hour cache TTL
|
||||
|
||||
def _is_cache_valid(self) -> bool:
|
||||
"""Check if cached data is still valid."""
|
||||
if not self._cache or not self._cache_time:
|
||||
return False
|
||||
|
||||
age = datetime.now() - self._cache_time
|
||||
return age < timedelta(seconds=self._cache_ttl)
|
||||
|
||||
async def get_latest_release(self) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch latest release information from GitHub API.
|
||||
|
||||
Returns:
|
||||
Release data dictionary or None if no releases
|
||||
"""
|
||||
# Check cache first
|
||||
if self._is_cache_valid():
|
||||
logfire.debug("Using cached version data")
|
||||
return self._cache
|
||||
|
||||
# GitHub API endpoint
|
||||
url = f"https://api.github.com/repos/{GITHUB_REPO_OWNER}/{GITHUB_REPO_NAME}/releases/latest"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": f"Archon/{ARCHON_VERSION}",
|
||||
},
|
||||
)
|
||||
|
||||
# Handle 404 - no releases yet
|
||||
if response.status_code == 404:
|
||||
logfire.info("No releases found on GitHub")
|
||||
return None
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Cache the successful response
|
||||
self._cache = data
|
||||
self._cache_time = datetime.now()
|
||||
|
||||
return data
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logfire.warning("GitHub API request timed out")
|
||||
# Return cached data if available
|
||||
if self._cache:
|
||||
return self._cache
|
||||
return None
|
||||
except httpx.HTTPError as e:
|
||||
logfire.error(f"HTTP error fetching latest release: {e}")
|
||||
# Return cached data if available
|
||||
if self._cache:
|
||||
return self._cache
|
||||
return None
|
||||
except Exception as e:
|
||||
logfire.error(f"Unexpected error fetching latest release: {e}")
|
||||
# Return cached data if available
|
||||
if self._cache:
|
||||
return self._cache
|
||||
return None
|
||||
|
||||
async def check_for_updates(self) -> dict[str, Any]:
|
||||
"""
|
||||
Check if a newer version of Archon is available.
|
||||
|
||||
Returns:
|
||||
Dictionary with version check results
|
||||
"""
|
||||
try:
|
||||
# Get latest release from GitHub
|
||||
release = await self.get_latest_release()
|
||||
|
||||
if not release:
|
||||
# No releases found or error occurred
|
||||
return {
|
||||
"current": ARCHON_VERSION,
|
||||
"latest": None,
|
||||
"update_available": False,
|
||||
"release_url": None,
|
||||
"release_notes": None,
|
||||
"published_at": None,
|
||||
"check_error": None,
|
||||
}
|
||||
|
||||
# Extract version from tag_name (e.g., "v1.0.0" -> "1.0.0")
|
||||
latest_version = release.get("tag_name", "")
|
||||
if latest_version.startswith("v"):
|
||||
latest_version = latest_version[1:]
|
||||
|
||||
# Check if update is available
|
||||
update_available = is_newer_version(ARCHON_VERSION, latest_version)
|
||||
|
||||
# Parse published date
|
||||
published_at = None
|
||||
if release.get("published_at"):
|
||||
try:
|
||||
published_at = datetime.fromisoformat(
|
||||
release["published_at"].replace("Z", "+00:00")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"current": ARCHON_VERSION,
|
||||
"latest": latest_version,
|
||||
"update_available": update_available,
|
||||
"release_url": release.get("html_url"),
|
||||
"release_notes": release.get("body"),
|
||||
"published_at": published_at,
|
||||
"check_error": None,
|
||||
"assets": release.get("assets", []),
|
||||
"author": release.get("author", {}).get("login"),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logfire.error(f"Error checking for updates: {e}")
|
||||
# Return safe default with error
|
||||
return {
|
||||
"current": ARCHON_VERSION,
|
||||
"latest": None,
|
||||
"update_available": False,
|
||||
"release_url": None,
|
||||
"release_notes": None,
|
||||
"published_at": None,
|
||||
"check_error": str(e),
|
||||
}
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the cached version data."""
|
||||
self._cache = None
|
||||
self._cache_time = None
|
||||
|
||||
|
||||
# Export singleton instance
|
||||
version_service = VersionService()
|
||||
107
python/src/server/utils/semantic_version.py
Normal file
107
python/src/server/utils/semantic_version.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Semantic version parsing and comparison utilities.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def parse_version(version_string: str) -> tuple[int, int, int, str | None]:
|
||||
"""
|
||||
Parse a semantic version string into major, minor, patch, and optional prerelease.
|
||||
|
||||
Supports formats like:
|
||||
- "1.0.0"
|
||||
- "v1.0.0"
|
||||
- "1.0.0-beta"
|
||||
- "v1.0.0-rc.1"
|
||||
|
||||
Args:
|
||||
version_string: Version string to parse
|
||||
|
||||
Returns:
|
||||
Tuple of (major, minor, patch, prerelease)
|
||||
"""
|
||||
# Remove 'v' prefix if present
|
||||
version = version_string.strip()
|
||||
if version.lower().startswith('v'):
|
||||
version = version[1:]
|
||||
|
||||
# Parse version with optional prerelease
|
||||
pattern = r'^(\d+)\.(\d+)\.(\d+)(?:-(.+))?$'
|
||||
match = re.match(pattern, version)
|
||||
|
||||
if not match:
|
||||
# Try to handle incomplete versions like "1.0"
|
||||
simple_pattern = r'^(\d+)(?:\.(\d+))?(?:\.(\d+))?$'
|
||||
simple_match = re.match(simple_pattern, version)
|
||||
if simple_match:
|
||||
major = int(simple_match.group(1))
|
||||
minor = int(simple_match.group(2) or 0)
|
||||
patch = int(simple_match.group(3) or 0)
|
||||
return (major, minor, patch, None)
|
||||
raise ValueError(f"Invalid version string: {version_string}")
|
||||
|
||||
major = int(match.group(1))
|
||||
minor = int(match.group(2))
|
||||
patch = int(match.group(3))
|
||||
prerelease = match.group(4)
|
||||
|
||||
return (major, minor, patch, prerelease)
|
||||
|
||||
|
||||
def compare_versions(version1: str, version2: str) -> int:
|
||||
"""
|
||||
Compare two semantic version strings.
|
||||
|
||||
Args:
|
||||
version1: First version string
|
||||
version2: Second version string
|
||||
|
||||
Returns:
|
||||
-1 if version1 < version2
|
||||
0 if version1 == version2
|
||||
1 if version1 > version2
|
||||
"""
|
||||
v1 = parse_version(version1)
|
||||
v2 = parse_version(version2)
|
||||
|
||||
# Compare major, minor, patch
|
||||
for i in range(3):
|
||||
if v1[i] < v2[i]:
|
||||
return -1
|
||||
elif v1[i] > v2[i]:
|
||||
return 1
|
||||
|
||||
# If main versions are equal, check prerelease
|
||||
# No prerelease is considered newer than any prerelease
|
||||
if v1[3] is None and v2[3] is None:
|
||||
return 0
|
||||
elif v1[3] is None:
|
||||
return 1 # v1 is release, v2 is prerelease
|
||||
elif v2[3] is None:
|
||||
return -1 # v1 is prerelease, v2 is release
|
||||
else:
|
||||
# Both have prereleases, compare lexicographically
|
||||
if v1[3] < v2[3]:
|
||||
return -1
|
||||
elif v1[3] > v2[3]:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def is_newer_version(current: str, latest: str) -> bool:
|
||||
"""
|
||||
Check if latest version is newer than current version.
|
||||
|
||||
Args:
|
||||
current: Current version string
|
||||
latest: Latest version string to compare
|
||||
|
||||
Returns:
|
||||
True if latest > current, False otherwise
|
||||
"""
|
||||
try:
|
||||
return compare_versions(latest, current) > 0
|
||||
except ValueError:
|
||||
# If we can't parse versions, assume no update
|
||||
return False
|
||||
Reference in New Issue
Block a user