mirror of
https://github.com/sourcebot-dev/sourcebot.git
synced 2025-12-12 04:15:30 +00:00
fix(ask_sb): Fix long generation times on first message in thread (#447)
This commit is contained in:
parent
0773399392
commit
4f2644daa2
5 changed files with 489 additions and 408 deletions
|
|
@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
### Fixed
|
||||
- Removed prefix from structured log output. [#443](https://github.com/sourcebot-dev/sourcebot/pull/443)
|
||||
- [ask sb] Fixed long generation times for first message in a chat thread. [#447](https://github.com/sourcebot-dev/sourcebot/pull/447)
|
||||
|
||||
### Changed
|
||||
- Bumped AI SDK and associated packages version. [#444](https://github.com/sourcebot-dev/sourcebot/pull/444)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import { sew, withAuth, withOrgMembership } from "@/actions";
|
||||
import { env } from "@/env.mjs";
|
||||
import { _getConfiguredLanguageModelsFull, updateChatMessages, updateChatName } from "@/features/chat/actions";
|
||||
import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions";
|
||||
import { createAgentStream } from "@/features/chat/agent";
|
||||
import { additionalChatRequestParamsSchema, SBChatMessage, SearchScope } from "@/features/chat/types";
|
||||
import { getAnswerPartFromAssistantMessage } from "@/features/chat/utils";
|
||||
|
|
@ -8,33 +7,18 @@ import { ErrorCode } from "@/lib/errorCodes";
|
|||
import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
|
||||
import { isServiceError } from "@/lib/utils";
|
||||
import { prisma } from "@/prisma";
|
||||
import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock';
|
||||
import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic';
|
||||
import { createAzure } from '@ai-sdk/azure';
|
||||
import { createDeepSeek } from '@ai-sdk/deepseek';
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { createVertex } from '@ai-sdk/google-vertex';
|
||||
import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic';
|
||||
import { createMistral } from '@ai-sdk/mistral';
|
||||
import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai";
|
||||
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
||||
import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider";
|
||||
import { createXai } from '@ai-sdk/xai';
|
||||
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { getTokenFromConfig } from "@sourcebot/crypto";
|
||||
import { OrgRole } from "@sourcebot/db";
|
||||
import { createLogger } from "@sourcebot/logger";
|
||||
import { LanguageModel } from "@sourcebot/schemas/v3/index.type";
|
||||
import {
|
||||
createUIMessageStream,
|
||||
createUIMessageStreamResponse,
|
||||
generateText,
|
||||
JSONValue,
|
||||
ModelMessage,
|
||||
StreamTextResult,
|
||||
UIMessageStreamOptions,
|
||||
UIMessageStreamWriter,
|
||||
UIMessageStreamWriter
|
||||
} from "ai";
|
||||
import { randomUUID } from "crypto";
|
||||
import { StatusCodes } from "http-status-codes";
|
||||
|
|
@ -66,40 +50,11 @@ export async function POST(req: Request) {
|
|||
}
|
||||
|
||||
const { messages, id, selectedSearchScopes, languageModelId } = parsed.data;
|
||||
const response = await chatHandler({
|
||||
messages,
|
||||
id,
|
||||
selectedSearchScopes,
|
||||
languageModelId,
|
||||
}, domain);
|
||||
|
||||
if (isServiceError(response)) {
|
||||
return serviceErrorResponse(response);
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const mergeStreamAsync = async (stream: StreamTextResult<any, any>, writer: UIMessageStreamWriter<SBChatMessage>, options: UIMessageStreamOptions<SBChatMessage> = {}) => {
|
||||
await new Promise<void>((resolve) => writer.merge(stream.toUIMessageStream({
|
||||
...options,
|
||||
onFinish: async () => {
|
||||
resolve();
|
||||
}
|
||||
})));
|
||||
}
|
||||
|
||||
interface ChatHandlerProps {
|
||||
messages: SBChatMessage[];
|
||||
id: string;
|
||||
selectedSearchScopes: SearchScope[];
|
||||
languageModelId: string;
|
||||
}
|
||||
|
||||
const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: ChatHandlerProps, domain: string) => sew(async () =>
|
||||
const response = await sew(() =>
|
||||
withAuth((userId) =>
|
||||
withOrgMembership(userId, domain, async ({ org }) => {
|
||||
// Validate that the chat exists and is not readonly.
|
||||
const chat = await prisma.chat.findUnique({
|
||||
where: {
|
||||
orgId: org.id,
|
||||
|
|
@ -119,11 +74,6 @@ const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: Ch
|
|||
});
|
||||
}
|
||||
|
||||
const latestMessage = messages[messages.length - 1];
|
||||
const sources = latestMessage.parts
|
||||
.filter((part) => part.type === 'data-source')
|
||||
.map((part) => part.data);
|
||||
|
||||
// From the language model ID, attempt to find the
|
||||
// corresponding config in `config.json`.
|
||||
const languageModelConfig =
|
||||
|
|
@ -138,23 +88,68 @@ const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: Ch
|
|||
});
|
||||
}
|
||||
|
||||
const { model, providerOptions, headers } = await getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
|
||||
const { model, providerOptions, headers } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
|
||||
|
||||
if (
|
||||
messages.length === 1 &&
|
||||
messages[0].role === "user" &&
|
||||
messages[0].parts.length >= 1 &&
|
||||
messages[0].parts[0].type === 'text'
|
||||
) {
|
||||
const content = messages[0].parts[0].text;
|
||||
return createMessageStreamResponse({
|
||||
messages,
|
||||
id,
|
||||
selectedSearchScopes,
|
||||
model,
|
||||
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
|
||||
modelProviderOptions: providerOptions,
|
||||
modelHeaders: headers,
|
||||
domain,
|
||||
orgId: org.id,
|
||||
});
|
||||
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true
|
||||
)
|
||||
)
|
||||
|
||||
const title = await generateChatTitle(content, model);
|
||||
await updateChatName({
|
||||
chatId: id,
|
||||
name: title,
|
||||
}, domain);
|
||||
if (isServiceError(response)) {
|
||||
return serviceErrorResponse(response);
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const mergeStreamAsync = async (stream: StreamTextResult<any, any>, writer: UIMessageStreamWriter<SBChatMessage>, options: UIMessageStreamOptions<SBChatMessage> = {}) => {
|
||||
await new Promise<void>((resolve) => writer.merge(stream.toUIMessageStream({
|
||||
...options,
|
||||
onFinish: async () => {
|
||||
resolve();
|
||||
}
|
||||
})));
|
||||
}
|
||||
|
||||
interface CreateMessageStreamResponseProps {
|
||||
messages: SBChatMessage[];
|
||||
id: string;
|
||||
selectedSearchScopes: SearchScope[];
|
||||
model: AISDKLanguageModelV2;
|
||||
modelName: string;
|
||||
modelProviderOptions?: Record<string, Record<string, JSONValue>>;
|
||||
modelHeaders?: Record<string, string>;
|
||||
domain: string;
|
||||
orgId: number;
|
||||
}
|
||||
|
||||
const createMessageStreamResponse = async ({
|
||||
messages,
|
||||
id,
|
||||
selectedSearchScopes,
|
||||
model,
|
||||
modelName,
|
||||
modelProviderOptions,
|
||||
modelHeaders,
|
||||
domain,
|
||||
orgId,
|
||||
}: CreateMessageStreamResponseProps) => {
|
||||
const latestMessage = messages[messages.length - 1];
|
||||
const sources = latestMessage.parts
|
||||
.filter((part) => part.type === 'data-source')
|
||||
.map((part) => part.data);
|
||||
|
||||
const traceId = randomUUID();
|
||||
|
||||
// Extract user messages and assistant answers.
|
||||
|
|
@ -195,7 +190,7 @@ const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: Ch
|
|||
if (scope.type === 'reposet') {
|
||||
const reposet = await prisma.searchContext.findFirst({
|
||||
where: {
|
||||
orgId: org.id,
|
||||
orgId,
|
||||
name: scope.value
|
||||
},
|
||||
include: {
|
||||
|
|
@ -214,8 +209,8 @@ const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: Ch
|
|||
|
||||
const researchStream = await createAgentStream({
|
||||
model,
|
||||
providerOptions,
|
||||
headers,
|
||||
providerOptions: modelProviderOptions,
|
||||
headers: modelHeaders,
|
||||
inputMessages: messageHistory,
|
||||
inputSources: sources,
|
||||
searchScopeRepoNames: expandedRepos,
|
||||
|
|
@ -243,12 +238,11 @@ const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: Ch
|
|||
totalInputTokens: totalUsage.inputTokens,
|
||||
totalOutputTokens: totalUsage.outputTokens,
|
||||
totalResponseTimeMs: new Date().getTime() - startTime.getTime(),
|
||||
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
|
||||
modelName,
|
||||
selectedSearchScopes,
|
||||
traceId,
|
||||
}
|
||||
})
|
||||
|
||||
});
|
||||
|
||||
writer.write({
|
||||
type: 'finish',
|
||||
|
|
@ -267,223 +261,7 @@ const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: Ch
|
|||
return createUIMessageStreamResponse({
|
||||
stream,
|
||||
});
|
||||
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true
|
||||
));
|
||||
|
||||
const generateChatTitle = async (message: string, model: AISDKLanguageModelV2) => {
|
||||
const prompt = `Convert this question into a short topic title (max 50 characters).
|
||||
|
||||
Rules:
|
||||
- Do NOT include question words (what, where, how, why, when, which)
|
||||
- Do NOT end with a question mark
|
||||
- Capitalize the first letter of the title
|
||||
- Focus on the subject/topic being discussed
|
||||
- Make it sound like a file name or category
|
||||
|
||||
Examples:
|
||||
"Where is the authentication code?" → "Authentication Code"
|
||||
"How to setup the database?" → "Database Setup"
|
||||
"What are the API endpoints?" → "API Endpoints"
|
||||
|
||||
User question: ${message}`;
|
||||
|
||||
const result = await generateText({
|
||||
model,
|
||||
prompt,
|
||||
});
|
||||
|
||||
return result.text;
|
||||
}
|
||||
|
||||
const getAISDKLanguageModelAndOptions = async (config: LanguageModel, orgId: number): Promise<{
|
||||
model: AISDKLanguageModelV2,
|
||||
providerOptions?: Record<string, Record<string, JSONValue>>,
|
||||
headers?: Record<string, string>,
|
||||
}> => {
|
||||
|
||||
const { provider, model: modelId } = config;
|
||||
|
||||
switch (provider) {
|
||||
case 'amazon-bedrock': {
|
||||
const aws = createAmazonBedrock({
|
||||
baseURL: config.baseUrl,
|
||||
region: config.region ?? env.AWS_REGION,
|
||||
accessKeyId: config.accessKeyId
|
||||
? await getTokenFromConfig(config.accessKeyId, orgId, prisma)
|
||||
: env.AWS_ACCESS_KEY_ID,
|
||||
secretAccessKey: config.accessKeySecret
|
||||
? await getTokenFromConfig(config.accessKeySecret, orgId, prisma)
|
||||
: env.AWS_SECRET_ACCESS_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: aws(modelId),
|
||||
};
|
||||
}
|
||||
case 'anthropic': {
|
||||
const anthropic = createAnthropic({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.ANTHROPIC_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: anthropic(modelId),
|
||||
providerOptions: {
|
||||
anthropic: {
|
||||
thinking: {
|
||||
type: "enabled",
|
||||
budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS,
|
||||
}
|
||||
} satisfies AnthropicProviderOptions,
|
||||
},
|
||||
headers: {
|
||||
// @see: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
||||
'anthropic-beta': 'interleaved-thinking-2025-05-14',
|
||||
},
|
||||
};
|
||||
}
|
||||
case 'azure': {
|
||||
const azure = createAzure({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.AZURE_API_KEY,
|
||||
apiVersion: config.apiVersion,
|
||||
resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME,
|
||||
});
|
||||
|
||||
return {
|
||||
model: azure(modelId),
|
||||
};
|
||||
}
|
||||
case 'deepseek': {
|
||||
const deepseek = createDeepSeek({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.DEEPSEEK_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: deepseek(modelId),
|
||||
};
|
||||
}
|
||||
case 'google-generative-ai': {
|
||||
const google = createGoogleGenerativeAI({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.GOOGLE_GENERATIVE_AI_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: google(modelId),
|
||||
};
|
||||
}
|
||||
case 'google-vertex': {
|
||||
const vertex = createVertex({
|
||||
project: config.project ?? env.GOOGLE_VERTEX_PROJECT,
|
||||
location: config.region ?? env.GOOGLE_VERTEX_REGION,
|
||||
...(config.credentials ? {
|
||||
googleAuthOptions: {
|
||||
keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma),
|
||||
}
|
||||
} : {}),
|
||||
});
|
||||
|
||||
return {
|
||||
model: vertex(modelId),
|
||||
providerOptions: {
|
||||
google: {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS,
|
||||
includeThoughts: env.GOOGLE_VERTEX_INCLUDE_THOUGHTS === 'true',
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
case 'google-vertex-anthropic': {
|
||||
const vertexAnthropic = createVertexAnthropic({
|
||||
project: config.project ?? env.GOOGLE_VERTEX_PROJECT,
|
||||
location: config.region ?? env.GOOGLE_VERTEX_REGION,
|
||||
...(config.credentials ? {
|
||||
googleAuthOptions: {
|
||||
keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma),
|
||||
}
|
||||
} : {}),
|
||||
});
|
||||
|
||||
return {
|
||||
model: vertexAnthropic(modelId),
|
||||
};
|
||||
}
|
||||
case 'mistral': {
|
||||
const mistral = createMistral({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.MISTRAL_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: mistral(modelId),
|
||||
};
|
||||
}
|
||||
case 'openai': {
|
||||
const openai = createOpenAI({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.OPENAI_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: openai(modelId),
|
||||
providerOptions: {
|
||||
openai: {
|
||||
reasoningEffort: config.reasoningEffort ?? 'medium'
|
||||
} satisfies OpenAIResponsesProviderOptions,
|
||||
},
|
||||
};
|
||||
}
|
||||
case 'openai-compatible': {
|
||||
const openai = createOpenAICompatible({
|
||||
baseURL: config.baseUrl,
|
||||
name: config.displayName ?? modelId,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: undefined,
|
||||
});
|
||||
|
||||
return {
|
||||
model: openai.chatModel(modelId),
|
||||
}
|
||||
}
|
||||
case 'openrouter': {
|
||||
const openrouter = createOpenRouter({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.OPENROUTER_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: openrouter(modelId),
|
||||
};
|
||||
}
|
||||
case 'xai': {
|
||||
const xai = createXai({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.XAI_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: xai(modelId),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const errorHandler = (error: unknown) => {
|
||||
logger.error(error);
|
||||
|
|
|
|||
|
|
@ -2,17 +2,32 @@
|
|||
|
||||
import { sew, withAuth, withOrgMembership } from "@/actions";
|
||||
import { env } from "@/env.mjs";
|
||||
import { chatIsReadonly, notFound, ServiceError } from "@/lib/serviceError";
|
||||
import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants";
|
||||
import { ErrorCode } from "@/lib/errorCodes";
|
||||
import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError";
|
||||
import { prisma } from "@/prisma";
|
||||
import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock';
|
||||
import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic';
|
||||
import { createAzure } from '@ai-sdk/azure';
|
||||
import { createDeepSeek } from '@ai-sdk/deepseek';
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { createVertex } from '@ai-sdk/google-vertex';
|
||||
import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic';
|
||||
import { createMistral } from '@ai-sdk/mistral';
|
||||
import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai";
|
||||
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
||||
import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider";
|
||||
import { createXai } from '@ai-sdk/xai';
|
||||
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
|
||||
import { getTokenFromConfig } from "@sourcebot/crypto";
|
||||
import { ChatVisibility, OrgRole, Prisma } from "@sourcebot/db";
|
||||
import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type";
|
||||
import { loadConfig } from "@sourcebot/shared";
|
||||
import { generateText, JSONValue } from "ai";
|
||||
import fs from 'fs';
|
||||
import { StatusCodes } from "http-status-codes";
|
||||
import path from 'path';
|
||||
import { LanguageModelInfo, SBChatMessage } from "./types";
|
||||
import { loadConfig } from "@sourcebot/shared";
|
||||
import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type";
|
||||
import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants";
|
||||
import { StatusCodes } from "http-status-codes";
|
||||
import { ErrorCode } from "@/lib/errorCodes";
|
||||
|
||||
export const createChat = async (domain: string) => sew(() =>
|
||||
withAuth((userId) =>
|
||||
|
|
@ -170,6 +185,58 @@ export const updateChatName = async ({ chatId, name }: { chatId: string, name: s
|
|||
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
|
||||
);
|
||||
|
||||
export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() =>
|
||||
withAuth((userId) =>
|
||||
withOrgMembership(userId, domain, async ({ org }) => {
|
||||
// From the language model ID, attempt to find the
|
||||
// corresponding config in `config.json`.
|
||||
const languageModelConfig =
|
||||
(await _getConfiguredLanguageModelsFull())
|
||||
.find((model) => model.model === languageModelId);
|
||||
|
||||
if (!languageModelConfig) {
|
||||
return serviceErrorResponse({
|
||||
statusCode: StatusCodes.BAD_REQUEST,
|
||||
errorCode: ErrorCode.INVALID_REQUEST_BODY,
|
||||
message: `Language model ${languageModelId} is not configured.`,
|
||||
});
|
||||
}
|
||||
|
||||
const { model } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
|
||||
|
||||
const prompt = `Convert this question into a short topic title (max 50 characters).
|
||||
|
||||
Rules:
|
||||
- Do NOT include question words (what, where, how, why, when, which)
|
||||
- Do NOT end with a question mark
|
||||
- Capitalize the first letter of the title
|
||||
- Focus on the subject/topic being discussed
|
||||
- Make it sound like a file name or category
|
||||
|
||||
Examples:
|
||||
"Where is the authentication code?" → "Authentication Code"
|
||||
"How to setup the database?" → "Database Setup"
|
||||
"What are the API endpoints?" → "API Endpoints"
|
||||
|
||||
User question: ${message}`;
|
||||
|
||||
const result = await generateText({
|
||||
model,
|
||||
prompt,
|
||||
});
|
||||
|
||||
await updateChatName({
|
||||
chatId,
|
||||
name: result.text,
|
||||
}, domain);
|
||||
|
||||
return {
|
||||
success: true,
|
||||
}
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() =>
|
||||
withAuth((userId) =>
|
||||
withOrgMembership(userId, domain, async ({ org }) => {
|
||||
|
|
@ -303,3 +370,193 @@ export const _getConfiguredLanguageModelsFull = async (): Promise<LanguageModel[
|
|||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export const _getAISDKLanguageModelAndOptions = async (config: LanguageModel, orgId: number): Promise<{
|
||||
model: AISDKLanguageModelV2,
|
||||
providerOptions?: Record<string, Record<string, JSONValue>>,
|
||||
headers?: Record<string, string>,
|
||||
}> => {
|
||||
const { provider, model: modelId } = config;
|
||||
|
||||
switch (provider) {
|
||||
case 'amazon-bedrock': {
|
||||
const aws = createAmazonBedrock({
|
||||
baseURL: config.baseUrl,
|
||||
region: config.region ?? env.AWS_REGION,
|
||||
accessKeyId: config.accessKeyId
|
||||
? await getTokenFromConfig(config.accessKeyId, orgId, prisma)
|
||||
: env.AWS_ACCESS_KEY_ID,
|
||||
secretAccessKey: config.accessKeySecret
|
||||
? await getTokenFromConfig(config.accessKeySecret, orgId, prisma)
|
||||
: env.AWS_SECRET_ACCESS_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: aws(modelId),
|
||||
};
|
||||
}
|
||||
case 'anthropic': {
|
||||
const anthropic = createAnthropic({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.ANTHROPIC_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: anthropic(modelId),
|
||||
providerOptions: {
|
||||
anthropic: {
|
||||
thinking: {
|
||||
type: "enabled",
|
||||
budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS,
|
||||
}
|
||||
} satisfies AnthropicProviderOptions,
|
||||
},
|
||||
headers: {
|
||||
// @see: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
||||
'anthropic-beta': 'interleaved-thinking-2025-05-14',
|
||||
},
|
||||
};
|
||||
}
|
||||
case 'azure': {
|
||||
const azure = createAzure({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.AZURE_API_KEY,
|
||||
apiVersion: config.apiVersion,
|
||||
resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME,
|
||||
});
|
||||
|
||||
return {
|
||||
model: azure(modelId),
|
||||
};
|
||||
}
|
||||
case 'deepseek': {
|
||||
const deepseek = createDeepSeek({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.DEEPSEEK_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: deepseek(modelId),
|
||||
};
|
||||
}
|
||||
case 'google-generative-ai': {
|
||||
const google = createGoogleGenerativeAI({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.GOOGLE_GENERATIVE_AI_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: google(modelId),
|
||||
};
|
||||
}
|
||||
case 'google-vertex': {
|
||||
const vertex = createVertex({
|
||||
project: config.project ?? env.GOOGLE_VERTEX_PROJECT,
|
||||
location: config.region ?? env.GOOGLE_VERTEX_REGION,
|
||||
...(config.credentials ? {
|
||||
googleAuthOptions: {
|
||||
keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma),
|
||||
}
|
||||
} : {}),
|
||||
});
|
||||
|
||||
return {
|
||||
model: vertex(modelId),
|
||||
providerOptions: {
|
||||
google: {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS,
|
||||
includeThoughts: env.GOOGLE_VERTEX_INCLUDE_THOUGHTS === 'true',
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
case 'google-vertex-anthropic': {
|
||||
const vertexAnthropic = createVertexAnthropic({
|
||||
project: config.project ?? env.GOOGLE_VERTEX_PROJECT,
|
||||
location: config.region ?? env.GOOGLE_VERTEX_REGION,
|
||||
...(config.credentials ? {
|
||||
googleAuthOptions: {
|
||||
keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma),
|
||||
}
|
||||
} : {}),
|
||||
});
|
||||
|
||||
return {
|
||||
model: vertexAnthropic(modelId),
|
||||
};
|
||||
}
|
||||
case 'mistral': {
|
||||
const mistral = createMistral({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.MISTRAL_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: mistral(modelId),
|
||||
};
|
||||
}
|
||||
case 'openai': {
|
||||
const openai = createOpenAI({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.OPENAI_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: openai(modelId),
|
||||
providerOptions: {
|
||||
openai: {
|
||||
reasoningEffort: config.reasoningEffort ?? 'medium',
|
||||
} satisfies OpenAIResponsesProviderOptions,
|
||||
},
|
||||
};
|
||||
}
|
||||
case 'openai-compatible': {
|
||||
const openai = createOpenAICompatible({
|
||||
baseURL: config.baseUrl,
|
||||
name: config.displayName ?? modelId,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: undefined,
|
||||
});
|
||||
|
||||
return {
|
||||
model: openai.chatModel(modelId),
|
||||
}
|
||||
}
|
||||
case 'openrouter': {
|
||||
const openrouter = createOpenRouter({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.OPENROUTER_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: openrouter(modelId),
|
||||
};
|
||||
}
|
||||
case 'xai': {
|
||||
const xai = createXai({
|
||||
baseURL: config.baseUrl,
|
||||
apiKey: config.token
|
||||
? await getTokenFromConfig(config.token, orgId, prisma)
|
||||
: env.XAI_API_KEY,
|
||||
});
|
||||
|
||||
return {
|
||||
model: xai(modelId),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -23,6 +23,8 @@ import { ErrorBanner } from './errorBanner';
|
|||
import { useRouter } from 'next/navigation';
|
||||
import { usePrevious } from '@uidotdev/usehooks';
|
||||
import { RepositoryQuery, SearchContextQuery } from '@/lib/types';
|
||||
import { generateAndUpdateChatNameFromMessage } from '../../actions';
|
||||
import { isServiceError } from '@/lib/utils';
|
||||
|
||||
type ChatHistoryState = {
|
||||
scrollOffset?: number;
|
||||
|
|
@ -119,14 +121,57 @@ export const ChatThread = ({
|
|||
languageModelId: selectedLanguageModel.model,
|
||||
} satisfies AdditionalChatRequestParams,
|
||||
});
|
||||
}, [_sendMessage, selectedLanguageModel, toast, selectedSearchScopes]);
|
||||
|
||||
if (
|
||||
messages.length === 0 &&
|
||||
message.parts.length > 0 &&
|
||||
message.parts[0].type === 'text'
|
||||
) {
|
||||
generateAndUpdateChatNameFromMessage(
|
||||
{
|
||||
chatId,
|
||||
languageModelId: selectedLanguageModel.model,
|
||||
message: message.parts[0].text,
|
||||
},
|
||||
domain
|
||||
).then((response) => {
|
||||
if (isServiceError(response)) {
|
||||
toast({
|
||||
description: `❌ Failed to generate chat name. Reason: ${response.message}`,
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
// Refresh the page to update the chat name.
|
||||
router.refresh();
|
||||
});
|
||||
}
|
||||
}, [
|
||||
selectedLanguageModel,
|
||||
_sendMessage,
|
||||
selectedSearchScopes,
|
||||
messages.length,
|
||||
toast,
|
||||
chatId,
|
||||
domain,
|
||||
router,
|
||||
]);
|
||||
|
||||
|
||||
const messagePairs = useMessagePairs(messages);
|
||||
|
||||
useNavigationGuard({
|
||||
enabled: status === "streaming" || status === "submitted",
|
||||
confirm: () => window.confirm("You have unsaved changes that will be lost.")
|
||||
enabled: ({ type }) => {
|
||||
// @note: a "refresh" in this context means we have triggered a client side
|
||||
// refresh via `router.refresh()`, and not the user pressing "CMD+R"
|
||||
// (that would be a "beforeunload" event). We can safely peform refreshes
|
||||
// without loosing any unsaved changes.
|
||||
if (type === "refresh") {
|
||||
return false;
|
||||
}
|
||||
|
||||
return status === "streaming" || status === "submitted";
|
||||
},
|
||||
confirm: () => window.confirm("You have unsaved changes that will be lost."),
|
||||
});
|
||||
|
||||
// When the chat is finished, refresh the page to update the chat history.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { CreateUIMessage, TextUIPart, UIMessagePart } from "ai"
|
||||
import { Descendant, Editor, Point, Range, Transforms } from "slate"
|
||||
import { ANSWER_TAG, FILE_REFERENCE_PREFIX, FILE_REFERENCE_REGEX } from "./constants"
|
||||
import { CreateUIMessage, TextUIPart, UIMessagePart } from "ai";
|
||||
import { Descendant, Editor, Point, Range, Transforms } from "slate";
|
||||
import { ANSWER_TAG, FILE_REFERENCE_PREFIX, FILE_REFERENCE_REGEX } from "./constants";
|
||||
import {
|
||||
CustomEditor,
|
||||
CustomText,
|
||||
|
|
@ -14,7 +14,7 @@ import {
|
|||
SBChatMessageToolTypes,
|
||||
SearchScope,
|
||||
Source,
|
||||
} from "./types"
|
||||
} from "./types";
|
||||
|
||||
export const insertMention = (editor: CustomEditor, data: MentionData, target?: Range | null) => {
|
||||
const mention: MentionElement = {
|
||||
|
|
|
|||
Loading…
Reference in a new issue