Skip to content

Commit

Permalink
Merge branch 'main' into flav/eslint-consistent-type-import
Browse files Browse the repository at this point in the history
  • Loading branch information
flvndvd committed Jan 17, 2024
2 parents a8da464 + e658c1c commit 3eb8a27
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { useContext, useEffect, useState } from "react";
import type { NotificationType } from "@app/components/sparkle/Notification";
import { SendNotificationsContext } from "@app/components/sparkle/Notification";
import { isLargeModel } from "@app/lib/assistant";
import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes";
import { isUpgraded } from "@app/lib/plans/plan_codes";
import type { PostAgentListStatusRequestBody } from "@app/pages/api/w/[wId]/members/me/agent_list_status";

type AssistantPreviewFlow = "personal" | "workspace";
Expand Down Expand Up @@ -174,7 +174,7 @@ export function GalleryAssistantPreviewContainer({

const isGlobal = scope === "global";
const isAddedToWorkspace = flow === "workspace" && isAdded;
const hasAccessToLargeModels = plan?.code !== FREE_TEST_PLAN_CODE;
const hasAccessToLargeModels = isUpgraded(plan);
const eligibleForTesting =
hasAccessToLargeModels || !isLargeModel(generation?.model);
const isTestable = !isGlobal && !isAdded && eligibleForTesting;
Expand Down
13 changes: 5 additions & 8 deletions front/components/assistant_builder/AssistantBuilder.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ import { SendNotificationsContext } from "@app/components/sparkle/Notification";
import { getSupportedModelConfig } from "@app/lib/assistant";
import { CONNECTOR_CONFIGURATIONS } from "@app/lib/connector_providers";
import { isActivatedStructuredDB } from "@app/lib/development";
import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes";
import { isUpgraded } from "@app/lib/plans/plan_codes";
import { useSlackChannelsLinkedWithAgent } from "@app/lib/swr";
import { classNames } from "@app/lib/utils";

Expand Down Expand Up @@ -308,10 +308,9 @@ export default function AssistantBuilder({
scope: defaultScope,
generationSettings: {
...DEFAULT_ASSISTANT_STATE.generationSettings,
modelSettings:
plan.code === FREE_TEST_PLAN_CODE
? GPT_3_5_TURBO_MODEL_CONFIG
: GPT_4_TURBO_MODEL_CONFIG,
modelSettings: !isUpgraded(plan)
? GPT_3_5_TURBO_MODEL_CONFIG
: GPT_4_TURBO_MODEL_CONFIG,
},
}
);
Expand Down Expand Up @@ -1532,9 +1531,7 @@ function AdvancedSettings({
</DropdownMenu.Button>
<DropdownMenu.Items origin="bottomRight">
{usedModelConfigs
.filter(
(m) => !(m.largeModel && plan.code === FREE_TEST_PLAN_CODE)
)
.filter((m) => !(m.largeModel && !isUpgraded(plan)))
.map((modelConfig) => (
<DropdownMenu.Item
key={modelConfig.modelId}
Expand Down
8 changes: 2 additions & 6 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import {
runGeneration,
} from "@app/lib/api/assistant/generation";
import type { Authenticator } from "@app/lib/auth";
import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes";
import logger from "@app/logger/logger";

/**
Expand All @@ -68,10 +67,7 @@ export async function generateActionInputs(

const MIN_GENERATION_TOKENS = 2048;

const plan = auth.plan();
const isFree = !plan || plan.code === FREE_TEST_PLAN_CODE;

let model: { providerId: string; modelId: string } = isFree
let model: { providerId: string; modelId: string } = !auth.isUpgraded()
? {
providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId,
Expand All @@ -81,7 +77,7 @@ export async function generateActionInputs(
modelId: GPT_4_32K_MODEL_CONFIG.modelId,
};

const contextSize = isFree
const contextSize = !auth.isUpgraded()
? GPT_3_5_TURBO_MODEL_CONFIG.contextSize
: GPT_4_32K_MODEL_CONFIG.contextSize;

Expand Down
6 changes: 2 additions & 4 deletions front/lib/api/assistant/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import {
import { getAgentConfigurations } from "@app/lib/api/assistant/configuration";
import { getSupportedModelConfig, isLargeModel } from "@app/lib/assistant";
import type { Authenticator } from "@app/lib/auth";
import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes";
import { redisClient } from "@app/lib/redis";
import logger from "@app/logger/logger";
const CANCELLATION_CHECK_INTERVAL = 500;
Expand Down Expand Up @@ -285,8 +284,7 @@ export async function* runGeneration(
void
> {
const owner = auth.workspace();
const plan = auth.plan();
if (!owner || !plan) {
if (!owner) {
throw new Error("Unexpected unauthenticated call to `runGeneration`");
}

Expand All @@ -309,7 +307,7 @@ export async function* runGeneration(

let model = c.model;

if (isLargeModel(model) && plan.code === FREE_TEST_PLAN_CODE) {
if (isLargeModel(model) && !auth.isUpgraded()) {
yield {
type: "generation_error",
created: Date.now(),
Expand Down
104 changes: 41 additions & 63 deletions front/lib/api/assistant/global_agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import type {
ConnectorProvider,
DataSourceType,
} from "@dust-tt/types";
import type { GlobalAgentStatus, PlanType } from "@dust-tt/types";
import type { GlobalAgentStatus } from "@dust-tt/types";
import { GEMINI_PRO_DEFAULT_MODEL_CONFIG } from "@dust-tt/types";
import {
CLAUDE_DEFAULT_MODEL_CONFIG,
Expand All @@ -25,7 +25,6 @@ import { GLOBAL_AGENTS_SID } from "@app/lib/assistant";
import type { Authenticator } from "@app/lib/auth";
import { prodAPICredentialsForOwner } from "@app/lib/auth";
import { GlobalAgentSettings } from "@app/lib/models/assistant/agent";
import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes";
import logger from "@app/logger/logger";

class HelperAssistantPrompt {
Expand Down Expand Up @@ -84,20 +83,15 @@ async function _getHelperGlobalAgent(
if (!owner) {
throw new Error("Unexpected `auth` without `workspace`.");
}
const plan = auth.plan();
if (!plan) {
throw new Error("Unexpected `auth` without `plan`.");
}
const model =
plan.code === FREE_TEST_PLAN_CODE
? {
providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId,
}
: {
providerId: GPT_4_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_4_TURBO_MODEL_CONFIG.modelId,
};
const model = !auth.isUpgraded()
? {
providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId,
}
: {
providerId: GPT_4_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_4_TURBO_MODEL_CONFIG.modelId,
};
return {
id: -1,
sId: GLOBAL_AGENTS_SID.HELPER,
Expand Down Expand Up @@ -153,12 +147,11 @@ async function _getGPT35TurboGlobalAgent({
}

async function _getGPT4GlobalAgent({
plan,
auth,
}: {
plan: PlanType;
auth: Authenticator;
}): Promise<AgentConfigurationType> {
const status =
plan.code === FREE_TEST_PLAN_CODE ? "disabled_free_workspace" : "active";
const status = !auth.isUpgraded() ? "disabled_free_workspace" : "active";
return {
id: -1,
sId: GLOBAL_AGENTS_SID.GPT4,
Expand Down Expand Up @@ -218,14 +211,13 @@ async function _getClaudeInstantGlobalAgent({
}

async function _getClaudeGlobalAgent({
auth,
settings,
plan,
}: {
auth: Authenticator;
settings: GlobalAgentSettings | null;
plan: PlanType;
}): Promise<AgentConfigurationType> {
const status =
plan.code === FREE_TEST_PLAN_CODE ? "disabled_free_workspace" : "active";
const status = !auth.isUpgraded() ? "disabled_free_workspace" : "active";
return {
id: -1,
sId: GLOBAL_AGENTS_SID.CLAUDE,
Expand All @@ -252,14 +244,14 @@ async function _getClaudeGlobalAgent({
}

async function _getMistralMediumGlobalAgent({
plan,
auth,
settings,
}: {
plan: PlanType;
auth: Authenticator;
settings: GlobalAgentSettings | null;
}): Promise<AgentConfigurationType> {
let status = settings?.status ?? "disabled_by_admin";
if (plan.code === FREE_TEST_PLAN_CODE) {
if (!auth.isUpgraded()) {
status = "disabled_free_workspace";
}

Expand Down Expand Up @@ -378,11 +370,6 @@ async function _getManagedDataSourceAgent(
throw new Error("Unexpected `auth` without `workspace`.");
}

const plan = auth.plan();
if (!plan) {
throw new Error("Unexpected `auth` without `plan`.");
}

const prodCredentials = await prodAPICredentialsForOwner(owner);

// Check if deactivated by an admin
Expand Down Expand Up @@ -441,16 +428,15 @@ async function _getManagedDataSourceAgent(
generation: {
id: -1,
prompt,
model:
plan.code === FREE_TEST_PLAN_CODE
? {
providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId,
}
: {
providerId: GPT_4_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_4_TURBO_MODEL_CONFIG.modelId,
},
model: !auth.isUpgraded()
? {
providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId,
}
: {
providerId: GPT_4_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_4_TURBO_MODEL_CONFIG.modelId,
},
temperature: 0.4,
},
action: {
Expand Down Expand Up @@ -567,10 +553,8 @@ async function _getNotionGlobalAgent(
async function _getDustGlobalAgent(
auth: Authenticator,
{
plan,
settings,
}: {
plan: PlanType;
settings: GlobalAgentSettings | null;
}
): Promise<AgentConfigurationType | null> {
Expand Down Expand Up @@ -647,16 +631,15 @@ async function _getDustGlobalAgent(
id: -1,
prompt:
"Assist the user based on the retrieved data from their workspace. Unlesss the user explicitely asks for a detailed answer, you goal is to provide a quick answer to their question.",
model:
plan.code === FREE_TEST_PLAN_CODE
? {
providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId,
}
: {
providerId: GPT_4_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_4_TURBO_MODEL_CONFIG.modelId,
},
model: !auth.isUpgraded()
? {
providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId,
}
: {
providerId: GPT_4_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_4_TURBO_MODEL_CONFIG.modelId,
},
temperature: 0.4,
},
action: {
Expand Down Expand Up @@ -693,11 +676,6 @@ export async function getGlobalAgent(
throw new Error("Cannot find Global Agent Configuration: no workspace.");
}

const plan = auth.plan();
if (!plan) {
throw new Error("Unexpected `auth` without `plan`.");
}

if (preFetchedDataSources === null) {
const prodCredentials = await prodAPICredentialsForOwner(owner);
const api = new DustAPI(prodCredentials, logger);
Expand All @@ -721,18 +699,18 @@ export async function getGlobalAgent(
agentConfiguration = await _getGPT35TurboGlobalAgent({ settings });
break;
case GLOBAL_AGENTS_SID.GPT4:
agentConfiguration = await _getGPT4GlobalAgent({ plan });
agentConfiguration = await _getGPT4GlobalAgent({ auth });
break;
case GLOBAL_AGENTS_SID.CLAUDE_INSTANT:
agentConfiguration = await _getClaudeInstantGlobalAgent({ settings });
break;
case GLOBAL_AGENTS_SID.CLAUDE:
agentConfiguration = await _getClaudeGlobalAgent({ settings, plan });
agentConfiguration = await _getClaudeGlobalAgent({ auth, settings });
break;
case GLOBAL_AGENTS_SID.MISTRAL_MEDIUM:
agentConfiguration = await _getMistralMediumGlobalAgent({
plan,
settings,
auth,
});
break;
case GLOBAL_AGENTS_SID.MISTRAL_SMALL:
Expand Down Expand Up @@ -766,7 +744,7 @@ export async function getGlobalAgent(
});
break;
case GLOBAL_AGENTS_SID.DUST:
agentConfiguration = await _getDustGlobalAgent(auth, { plan, settings });
agentConfiguration = await _getDustGlobalAgent(auth, { settings });
break;
default:
return null;
Expand Down
3 changes: 3 additions & 0 deletions front/lib/api/assistant/pubsub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ export async function postUserMessageWithPubSub(
let rateLimitKey: string | undefined = "";
if (auth.user()?.id) {
maxPerTimeframe = 50;
if (auth.isUpgraded()) {
maxPerTimeframe = 200;
}
timeframeSeconds = 60 * 60 * 3;
rateLimitKey = `postUserMessageUser:${auth.user()?.id}`;
} else {
Expand Down
5 changes: 5 additions & 0 deletions front/lib/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
} from "@app/lib/models";
import type { PlanAttributes } from "@app/lib/plans/free_plans";
import { FREE_TEST_PLAN_DATA } from "@app/lib/plans/free_plans";
import { isUpgraded } from "@app/lib/plans/plan_codes";
import { new_id } from "@app/lib/utils";
import logger from "@app/logger/logger";
import { authOptions } from "@app/pages/api/auth/[...nextauth]";
Expand Down Expand Up @@ -328,6 +329,10 @@ export class Authenticator {
return this._subscription ? this._subscription.plan : null;
}

isUpgraded(): boolean {
return isUpgraded(this.plan());
}

/**
* This is a convenience method to get the user from the Authenticator. The returned UserType
* object won't have the user's workspaces set.
Expand Down
14 changes: 14 additions & 0 deletions front/lib/plans/plan_codes.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { PlanType } from "@dust-tt/types";

// Current free plans:
export const FREE_UPGRADED_PLAN_CODE = "FREE_UPGRADED_PLAN";
export const FREE_TEST_PLAN_CODE = "FREE_TEST_PLAN";
Expand All @@ -9,3 +11,15 @@ export const PRO_PLAN_SEAT_29_CODE = "PRO_PLAN_SEAT_29";
* ENT_PLAN_FAKE is not subscribable and is only used to display the Enterprise plan in the UI (hence it's not stored on the db).
*/
export const ENT_PLAN_FAKE_CODE = "ENT_PLAN_FAKE_CODE";

/**
* `isUpgraded` returns true if the plan has access to all features of Dust, including large
* language models (meaning it's either a paid plan or free plan with (eg friends and family, or
* free trial plan)).
*
* Note: We didn't go for isFree or isPayingWorkspace as we have "upgraded" plans that are free.
*/
export const isUpgraded = (plan: PlanType | null): boolean => {
if (!plan) return false;
return plan.code !== FREE_TEST_PLAN_CODE;
};
3 changes: 1 addition & 2 deletions front/pages/api/w/[wId]/data_sources/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import type { NextApiRequest, NextApiResponse } from "next";
import { getDataSources } from "@app/lib/api/data_sources";
import { Authenticator, getSession } from "@app/lib/auth";
import { DataSource } from "@app/lib/models";
import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes";
import logger from "@app/logger/logger";
import { apiError, withLogging } from "@app/logger/withlogging";

Expand Down Expand Up @@ -140,7 +139,7 @@ async function handler(
splitter_id: "base_v0",
max_chunk_size: dataSourceMaxChunkSize,
qdrant_config:
plan.code !== FREE_TEST_PLAN_CODE && NODE_ENV === "production"
auth.isUpgraded() && NODE_ENV === "production"
? {
cluster: "dedicated-1",
shadow_write_cluster: null,
Expand Down
Loading

0 comments on commit 3eb8a27

Please sign in to comment.