From 987d626489ad8d038108cbcf4a3ff3c68fb0a34f Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Tue, 14 Nov 2023 13:16:56 +0100 Subject: [PATCH] Enforce free plan at generation level (#2531) * Enforce free plan at generation level * fix lint --- front/lib/api/assistant/generation.ts | 19 ++++++++++++++++++- front/lib/assistant.ts | 21 +++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/front/lib/api/assistant/generation.ts b/front/lib/api/assistant/generation.ts index 89daddb0ad33..92f150e8c93c 100644 --- a/front/lib/api/assistant/generation.ts +++ b/front/lib/api/assistant/generation.ts @@ -13,9 +13,11 @@ import { getSupportedModelConfig, GPT_4_32K_MODEL_ID, GPT_4_MODEL_CONFIG, + isLargeModel, } from "@app/lib/assistant"; import { Authenticator } from "@app/lib/auth"; import { CoreAPI } from "@app/lib/core_api"; +import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes"; import { redisClient } from "@app/lib/redis"; import { Err, Ok, Result } from "@app/lib/result"; import logger from "@app/logger/logger"; @@ -311,7 +313,8 @@ export async function* runGeneration( void > { const owner = auth.workspace(); - if (!owner) { + const plan = auth.plan(); + if (!owner || !plan) { throw new Error("Unexpected unauthenticated call to `runGeneration`"); } @@ -334,6 +337,20 @@ export async function* runGeneration( let model = c.model; + if (isLargeModel(model) && plan.code === FREE_TEST_PLAN_CODE) { + yield { + type: "generation_error", + created: Date.now(), + configurationId: configuration.sId, + messageId: agentMessage.sId, + error: { + code: "free_plan_error", + message: `Free plan does not support large models. Please upgrade to a paid plan to use this model.`, + }, + }; + return; + } + const contextSize = getSupportedModelConfig(c.model).contextSize; const MIN_GENERATION_TOKENS = 2048; diff --git a/front/lib/assistant.ts b/front/lib/assistant.ts index b0f16e5becdb..dfcbafe828d1 100644 --- a/front/lib/assistant.ts +++ b/front/lib/assistant.ts @@ -15,6 +15,7 @@ export const GPT_4_32K_MODEL_CONFIG = { displayName: "GPT 4", contextSize: 32768, recommendedTopK: 32, + largeModel: true, } as const; export const GPT_4_MODEL_CONFIG = { @@ -23,6 +24,7 @@ export const GPT_4_MODEL_CONFIG = { displayName: "GPT 4", contextSize: 8192, recommendedTopK: 16, + largeModel: true, }; export const GPT_4_TURBO_MODEL_CONFIG = { @@ -31,6 +33,7 @@ export const GPT_4_TURBO_MODEL_CONFIG = { displayName: "GPT 4", contextSize: 128000, recommendedTopK: 32, + largeModel: true, } as const; export const GPT_3_5_TURBO_16K_MODEL_CONFIG = { @@ -39,6 +42,7 @@ export const GPT_3_5_TURBO_16K_MODEL_CONFIG = { displayName: "GPT 3.5 Turbo", contextSize: 16384, recommendedTopK: 16, + largeModel: false, } as const; export const GPT_3_5_TURBO_MODEL_CONFIG = { @@ -47,6 +51,7 @@ export const GPT_3_5_TURBO_MODEL_CONFIG = { displayName: "GPT 3.5 Turbo", contextSize: 4096, recommendedTopK: 16, + largeModel: false, } as const; export const CLAUDE_DEFAULT_MODEL_CONFIG = { @@ -55,6 +60,7 @@ export const CLAUDE_DEFAULT_MODEL_CONFIG = { displayName: "Claude 2", contextSize: 100000, recommendedTopK: 32, + largeModel: true, } as const; export const CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG = { @@ -63,6 +69,7 @@ export const CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG = { displayName: "Claude Instant 1.2", contextSize: 100000, recommendedTopK: 32, + largeModel: false, } as const; export const MISTRAL_7B_DEFAULT_MODEL_CONFIG = { @@ -71,6 +78,7 @@ export const MISTRAL_7B_DEFAULT_MODEL_CONFIG = { displayName: "Mistral 7B", contextSize: 8192, recommendedTopK: 16, + largeModel: false, } as const; export const SUPPORTED_MODEL_CONFIGS = [ @@ -100,6 +108,19 @@ export function isSupportedModel(model: unknown): model is SupportedModel { ); } +export function isLargeModel(model: unknown): model is SupportedModel { + const maybeSupportedModel = model as SupportedModel; + const m = SUPPORTED_MODEL_CONFIGS.find( + (m) => + m.modelId === maybeSupportedModel.modelId && + m.providerId === maybeSupportedModel.providerId + ); + if (m) { + return m.largeModel; + } + return false; +} + export function getSupportedModelConfig(supportedModel: SupportedModel) { // here it is safe to cast the result to non-nullable because SupportedModel // is derived from the const array of configs above