mirror of
https://github.com/TandoorRecipes/recipes.git
synced 2026-01-01 12:18:45 -05:00
basics of AI provider system
This commit is contained in:
@@ -74,7 +74,7 @@ from cookbook.helper.permission_helper import (CustomIsAdmin, CustomIsOwner, Cus
|
||||
CustomTokenHasScope, CustomUserPermission, IsReadOnlyDRF,
|
||||
above_space_limit,
|
||||
group_required, has_group_permission, is_space_owner,
|
||||
switch_user_active_space
|
||||
switch_user_active_space, CustomAiProviderPermission
|
||||
)
|
||||
from cookbook.helper.recipe_search import RecipeSearch
|
||||
from cookbook.helper.recipe_url_import import clean_dict, get_from_youtube_scraper, get_images_from_soup
|
||||
@@ -85,7 +85,7 @@ from cookbook.models import (Automation, BookmarkletImport, ConnectorConfig, Coo
|
||||
RecipeBookEntry, ShareLink, ShoppingListEntry,
|
||||
ShoppingListRecipe, Space, Step, Storage, Supermarket, SupermarketCategory,
|
||||
SupermarketCategoryRelation, Sync, SyncLog, Unit, UnitConversion,
|
||||
UserFile, UserPreference, UserSpace, ViewLog, RecipeImport, SearchPreference, SearchFields
|
||||
UserFile, UserPreference, UserSpace, ViewLog, RecipeImport, SearchPreference, SearchFields, AiLog, AiProvider
|
||||
)
|
||||
from cookbook.provider.dropbox import Dropbox
|
||||
from cookbook.provider.local import Local
|
||||
@@ -110,7 +110,8 @@ from cookbook.serializer import (AccessTokenSerializer, AutomationSerializer, Au
|
||||
UserSerializer, UserSpaceSerializer, ViewLogSerializer,
|
||||
LocalizationSerializer, ServerSettingsSerializer, RecipeFromSourceResponseSerializer, ShoppingListEntryBulkCreateSerializer, FdcQuerySerializer,
|
||||
AiImportSerializer, ImportOpenDataSerializer, ImportOpenDataMetaDataSerializer, ImportOpenDataResponseSerializer, ExportRequestSerializer,
|
||||
RecipeImportSerializer, ConnectorConfigSerializer, SearchPreferenceSerializer, SearchFieldsSerializer, RecipeBatchUpdateSerializer
|
||||
RecipeImportSerializer, ConnectorConfigSerializer, SearchPreferenceSerializer, SearchFieldsSerializer, RecipeBatchUpdateSerializer,
|
||||
AiProviderSerializer, AiLogSerializer
|
||||
)
|
||||
from cookbook.version_info import TANDOOR_VERSION
|
||||
from cookbook.views.import_export import get_integration
|
||||
@@ -617,6 +618,29 @@ class SearchPreferenceViewSet(LoggingMixin, viewsets.ModelViewSet):
|
||||
return self.queryset.filter(user=self.request.user)
|
||||
|
||||
|
||||
class AiProviderViewSet(LoggingMixin, viewsets.ModelViewSet):
|
||||
queryset = AiProvider.objects
|
||||
serializer_class = AiProviderSerializer
|
||||
permission_classes = [CustomAiProviderPermission & CustomTokenHasReadWriteScope]
|
||||
pagination_class = DefaultPagination
|
||||
|
||||
def get_queryset(self):
|
||||
# read only access to all space and global AiProviders
|
||||
with scopes_disabled():
|
||||
return self.queryset.filter(Q(space=self.request.space) | Q(space__isnull=True))
|
||||
|
||||
|
||||
class AiLogViewSet(LoggingMixin, viewsets.ModelViewSet):
|
||||
queryset = AiLog.objects
|
||||
serializer_class = AiLogSerializer
|
||||
permission_classes = [CustomIsUser & CustomTokenHasReadWriteScope]
|
||||
http_method_names = ['get']
|
||||
pagination_class = DefaultPagination
|
||||
|
||||
def get_queryset(self):
|
||||
return self.queryset.filter(space=self.request.space)
|
||||
|
||||
|
||||
class StorageViewSet(LoggingMixin, viewsets.ModelViewSet):
|
||||
# TODO handle delete protect error and adjust test
|
||||
queryset = Storage.objects
|
||||
@@ -2000,14 +2024,28 @@ class AiImportView(APIView):
|
||||
if serializer.is_valid():
|
||||
# TODO max file size check
|
||||
|
||||
if 'ai_provider_id' not in serializer.validated_data:
|
||||
response = {
|
||||
'error': True,
|
||||
'msg': 'You must select an AI provider to perform your request',
|
||||
}
|
||||
return Response(RecipeFromSourceResponseSerializer(context={'request': request}).to_representation(response), status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
ai_provider = AiProvider.objects.filter(pk=serializer.validated_data['ai_provider_id']).filter(Q(space=request.space) | Q(space__isnull=True)).first()
|
||||
|
||||
def log_ai_request(kwargs, completion_response, start_time, end_time):
|
||||
print(completion_response['usage']['completion_tokens'], completion_response['usage']['prompt_tokens'], start_time, end_time)
|
||||
try:
|
||||
response_cost = kwargs.get("response_cost", 0)
|
||||
print("response_cost", response_cost)
|
||||
except:
|
||||
print('could not get cost')
|
||||
traceback.print_exc()
|
||||
AiLog.objects.create(
|
||||
created_by=request.user,
|
||||
space=request.space,
|
||||
ai_provider=ai_provider,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
input_tokens=completion_response['usage']['prompt_tokens'],
|
||||
output_tokens=completion_response['usage']['completion_tokens'],
|
||||
function=AiLog.F_FILE_IMPORT,
|
||||
credit_cost=kwargs.get("response_cost", 0) * 100,
|
||||
credits_from_balance=False, # TODO implement
|
||||
)
|
||||
|
||||
litellm.success_callback = [log_ai_request]
|
||||
|
||||
@@ -2079,7 +2117,9 @@ class AiImportView(APIView):
|
||||
return Response(RecipeFromSourceResponseSerializer(context={'request': request}).to_representation(response), status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
try:
|
||||
ai_response = completion(api_key=AI_API_KEY, model=AI_MODEL_NAME, response_format={"type": "json_object"}, messages=messages, )
|
||||
ai_response = completion(api_key=ai_provider.api_key,
|
||||
model=ai_provider.model_name,
|
||||
response_format={"type": "json_object"}, messages=messages, )
|
||||
except BadRequestError as err:
|
||||
response = {
|
||||
'error': True,
|
||||
|
||||
Reference in New Issue
Block a user