fix(ask_sb): Fix long generation times on first message in thread (#447)

This commit is contained in:
Brendan Kellam 2025-08-07 21:56:56 -07:00 committed by GitHub
parent 0773399392
commit 4f2644daa2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 489 additions and 408 deletions

View file

@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Removed prefix from structured log output. [#443](https://github.com/sourcebot-dev/sourcebot/pull/443)
- [ask sb] Fixed long generation times for first message in a chat thread. [#447](https://github.com/sourcebot-dev/sourcebot/pull/447)
### Changed
- Bumped AI SDK and associated packages version. [#444](https://github.com/sourcebot-dev/sourcebot/pull/444)

View file

@ -1,6 +1,5 @@
import { sew, withAuth, withOrgMembership } from "@/actions";
import { env } from "@/env.mjs";
import { _getConfiguredLanguageModelsFull, updateChatMessages, updateChatName } from "@/features/chat/actions";
import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions";
import { createAgentStream } from "@/features/chat/agent";
import { additionalChatRequestParamsSchema, SBChatMessage, SearchScope } from "@/features/chat/types";
import { getAnswerPartFromAssistantMessage } from "@/features/chat/utils";
@ -8,33 +7,18 @@ import { ErrorCode } from "@/lib/errorCodes";
import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils";
import { prisma } from "@/prisma";
import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock';
import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic';
import { createAzure } from '@ai-sdk/azure';
import { createDeepSeek } from '@ai-sdk/deepseek';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { createVertex } from '@ai-sdk/google-vertex';
import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic';
import { createMistral } from '@ai-sdk/mistral';
import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider";
import { createXai } from '@ai-sdk/xai';
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import * as Sentry from "@sentry/nextjs";
import { getTokenFromConfig } from "@sourcebot/crypto";
import { OrgRole } from "@sourcebot/db";
import { createLogger } from "@sourcebot/logger";
import { LanguageModel } from "@sourcebot/schemas/v3/index.type";
import {
createUIMessageStream,
createUIMessageStreamResponse,
generateText,
JSONValue,
ModelMessage,
StreamTextResult,
UIMessageStreamOptions,
UIMessageStreamWriter,
UIMessageStreamWriter
} from "ai";
import { randomUUID } from "crypto";
import { StatusCodes } from "http-status-codes";
@ -66,12 +50,60 @@ export async function POST(req: Request) {
}
const { messages, id, selectedSearchScopes, languageModelId } = parsed.data;
const response = await chatHandler({
messages,
id,
selectedSearchScopes,
languageModelId,
}, domain);
const response = await sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
// Validate that the chat exists and is not readonly.
const chat = await prisma.chat.findUnique({
where: {
orgId: org.id,
id,
},
});
if (!chat) {
return notFound();
}
if (chat.isReadonly) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: "Chat is readonly and cannot be edited.",
});
}
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => model.model === languageModelId);
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModelId} is not configured.`,
});
}
const { model, providerOptions, headers } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
return createMessageStreamResponse({
messages,
id,
selectedSearchScopes,
model,
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
modelProviderOptions: providerOptions,
modelHeaders: headers,
domain,
orgId: org.id,
});
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true
)
)
if (isServiceError(response)) {
return serviceErrorResponse(response);
@ -90,400 +122,146 @@ const mergeStreamAsync = async (stream: StreamTextResult<any, any>, writer: UIMe
})));
}
interface ChatHandlerProps {
interface CreateMessageStreamResponseProps {
messages: SBChatMessage[];
id: string;
selectedSearchScopes: SearchScope[];
languageModelId: string;
model: AISDKLanguageModelV2;
modelName: string;
modelProviderOptions?: Record<string, Record<string, JSONValue>>;
modelHeaders?: Record<string, string>;
domain: string;
orgId: number;
}
const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: ChatHandlerProps, domain: string) => sew(async () =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({
where: {
orgId: org.id,
id,
},
const createMessageStreamResponse = async ({
messages,
id,
selectedSearchScopes,
model,
modelName,
modelProviderOptions,
modelHeaders,
domain,
orgId,
}: CreateMessageStreamResponseProps) => {
const latestMessage = messages[messages.length - 1];
const sources = latestMessage.parts
.filter((part) => part.type === 'data-source')
.map((part) => part.data);
const traceId = randomUUID();
// Extract user messages and assistant answers.
// We will use this as the context we carry between messages.
const messageHistory =
messages.map((message): ModelMessage | undefined => {
if (message.role === 'user') {
return {
role: 'user',
content: message.parts[0].type === 'text' ? message.parts[0].text : '',
};
}
if (message.role === 'assistant') {
const answerPart = getAnswerPartFromAssistantMessage(message, false);
if (answerPart) {
return {
role: 'assistant',
content: [answerPart]
}
}
}
}).filter(message => message !== undefined);
const stream = createUIMessageStream<SBChatMessage>({
execute: async ({ writer }) => {
writer.write({
type: 'start',
});
if (!chat) {
return notFound();
}
const startTime = new Date();
if (chat.isReadonly) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: "Chat is readonly and cannot be edited.",
});
}
const expandedReposArrays = await Promise.all(selectedSearchScopes.map(async (scope) => {
if (scope.type === 'repo') {
return [scope.value];
}
const latestMessage = messages[messages.length - 1];
const sources = latestMessage.parts
.filter((part) => part.type === 'data-source')
.map((part) => part.data);
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => model.model === languageModelId);
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModelId} is not configured.`,
});
}
const { model, providerOptions, headers } = await getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
if (
messages.length === 1 &&
messages[0].role === "user" &&
messages[0].parts.length >= 1 &&
messages[0].parts[0].type === 'text'
) {
const content = messages[0].parts[0].text;
const title = await generateChatTitle(content, model);
await updateChatName({
chatId: id,
name: title,
}, domain);
}
const traceId = randomUUID();
// Extract user messages and assistant answers.
// We will use this as the context we carry between messages.
const messageHistory =
messages.map((message): ModelMessage | undefined => {
if (message.role === 'user') {
return {
role: 'user',
content: message.parts[0].type === 'text' ? message.parts[0].text : '',
};
}
if (message.role === 'assistant') {
const answerPart = getAnswerPartFromAssistantMessage(message, false);
if (answerPart) {
return {
role: 'assistant',
content: [answerPart]
}
}
}
}).filter(message => message !== undefined);
const stream = createUIMessageStream<SBChatMessage>({
execute: async ({ writer }) => {
writer.write({
type: 'start',
});
const startTime = new Date();
const expandedReposArrays = await Promise.all(selectedSearchScopes.map(async (scope) => {
if (scope.type === 'repo') {
return [scope.value];
}
if (scope.type === 'reposet') {
const reposet = await prisma.searchContext.findFirst({
where: {
orgId: org.id,
name: scope.value
},
include: {
repos: true
}
});
if (reposet) {
return reposet.repos.map(repo => repo.name);
}
}
return [];
}));
const expandedRepos = expandedReposArrays.flat();
const researchStream = await createAgentStream({
model,
providerOptions,
headers,
inputMessages: messageHistory,
inputSources: sources,
searchScopeRepoNames: expandedRepos,
onWriteSource: (source) => {
writer.write({
type: 'data-source',
data: source,
});
if (scope.type === 'reposet') {
const reposet = await prisma.searchContext.findFirst({
where: {
orgId,
name: scope.value
},
traceId,
});
await mergeStreamAsync(researchStream, writer, {
sendReasoning: true,
sendStart: false,
sendFinish: false,
});
const totalUsage = await researchStream.totalUsage;
writer.write({
type: 'message-metadata',
messageMetadata: {
totalTokens: totalUsage.totalTokens,
totalInputTokens: totalUsage.inputTokens,
totalOutputTokens: totalUsage.outputTokens,
totalResponseTimeMs: new Date().getTime() - startTime.getTime(),
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
selectedSearchScopes,
traceId,
include: {
repos: true
}
})
});
if (reposet) {
return reposet.repos.map(repo => repo.name);
}
}
return [];
}));
const expandedRepos = expandedReposArrays.flat();
const researchStream = await createAgentStream({
model,
providerOptions: modelProviderOptions,
headers: modelHeaders,
inputMessages: messageHistory,
inputSources: sources,
searchScopeRepoNames: expandedRepos,
onWriteSource: (source) => {
writer.write({
type: 'finish',
type: 'data-source',
data: source,
});
},
onError: errorHandler,
originalMessages: messages,
onFinish: async ({ messages }) => {
await updateChatMessages({
chatId: id,
messages
}, domain);
},
traceId,
});
return createUIMessageStreamResponse({
stream,
await mergeStreamAsync(researchStream, writer, {
sendReasoning: true,
sendStart: false,
sendFinish: false,
});
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true
));
const generateChatTitle = async (message: string, model: AISDKLanguageModelV2) => {
const prompt = `Convert this question into a short topic title (max 50 characters).
const totalUsage = await researchStream.totalUsage;
Rules:
- Do NOT include question words (what, where, how, why, when, which)
- Do NOT end with a question mark
- Capitalize the first letter of the title
- Focus on the subject/topic being discussed
- Make it sound like a file name or category
writer.write({
type: 'message-metadata',
messageMetadata: {
totalTokens: totalUsage.totalTokens,
totalInputTokens: totalUsage.inputTokens,
totalOutputTokens: totalUsage.outputTokens,
totalResponseTimeMs: new Date().getTime() - startTime.getTime(),
modelName,
selectedSearchScopes,
traceId,
}
});
Examples:
"Where is the authentication code?" "Authentication Code"
"How to setup the database?" "Database Setup"
"What are the API endpoints?" "API Endpoints"
User question: ${message}`;
const result = await generateText({
model,
prompt,
writer.write({
type: 'finish',
});
},
onError: errorHandler,
originalMessages: messages,
onFinish: async ({ messages }) => {
await updateChatMessages({
chatId: id,
messages
}, domain);
},
});
return result.text;
}
const getAISDKLanguageModelAndOptions = async (config: LanguageModel, orgId: number): Promise<{
model: AISDKLanguageModelV2,
providerOptions?: Record<string, Record<string, JSONValue>>,
headers?: Record<string, string>,
}> => {
const { provider, model: modelId } = config;
switch (provider) {
case 'amazon-bedrock': {
const aws = createAmazonBedrock({
baseURL: config.baseUrl,
region: config.region ?? env.AWS_REGION,
accessKeyId: config.accessKeyId
? await getTokenFromConfig(config.accessKeyId, orgId, prisma)
: env.AWS_ACCESS_KEY_ID,
secretAccessKey: config.accessKeySecret
? await getTokenFromConfig(config.accessKeySecret, orgId, prisma)
: env.AWS_SECRET_ACCESS_KEY,
});
return {
model: aws(modelId),
};
}
case 'anthropic': {
const anthropic = createAnthropic({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.ANTHROPIC_API_KEY,
});
return {
model: anthropic(modelId),
providerOptions: {
anthropic: {
thinking: {
type: "enabled",
budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS,
}
} satisfies AnthropicProviderOptions,
},
headers: {
// @see: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
'anthropic-beta': 'interleaved-thinking-2025-05-14',
},
};
}
case 'azure': {
const azure = createAzure({
baseURL: config.baseUrl,
apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.AZURE_API_KEY,
apiVersion: config.apiVersion,
resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME,
});
return {
model: azure(modelId),
};
}
case 'deepseek': {
const deepseek = createDeepSeek({
baseURL: config.baseUrl,
apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.DEEPSEEK_API_KEY,
});
return {
model: deepseek(modelId),
};
}
case 'google-generative-ai': {
const google = createGoogleGenerativeAI({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.GOOGLE_GENERATIVE_AI_API_KEY,
});
return {
model: google(modelId),
};
}
case 'google-vertex': {
const vertex = createVertex({
project: config.project ?? env.GOOGLE_VERTEX_PROJECT,
location: config.region ?? env.GOOGLE_VERTEX_REGION,
...(config.credentials ? {
googleAuthOptions: {
keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma),
}
} : {}),
});
return {
model: vertex(modelId),
providerOptions: {
google: {
thinkingConfig: {
thinkingBudget: env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS,
includeThoughts: env.GOOGLE_VERTEX_INCLUDE_THOUGHTS === 'true',
}
}
},
};
}
case 'google-vertex-anthropic': {
const vertexAnthropic = createVertexAnthropic({
project: config.project ?? env.GOOGLE_VERTEX_PROJECT,
location: config.region ?? env.GOOGLE_VERTEX_REGION,
...(config.credentials ? {
googleAuthOptions: {
keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma),
}
} : {}),
});
return {
model: vertexAnthropic(modelId),
};
}
case 'mistral': {
const mistral = createMistral({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.MISTRAL_API_KEY,
});
return {
model: mistral(modelId),
};
}
case 'openai': {
const openai = createOpenAI({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.OPENAI_API_KEY,
});
return {
model: openai(modelId),
providerOptions: {
openai: {
reasoningEffort: config.reasoningEffort ?? 'medium'
} satisfies OpenAIResponsesProviderOptions,
},
};
}
case 'openai-compatible': {
const openai = createOpenAICompatible({
baseURL: config.baseUrl,
name: config.displayName ?? modelId,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: undefined,
});
return {
model: openai.chatModel(modelId),
}
}
case 'openrouter': {
const openrouter = createOpenRouter({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.OPENROUTER_API_KEY,
});
return {
model: openrouter(modelId),
};
}
case 'xai': {
const xai = createXai({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.XAI_API_KEY,
});
return {
model: xai(modelId),
};
}
}
}
return createUIMessageStreamResponse({
stream,
});
};
const errorHandler = (error: unknown) => {
logger.error(error);

View file

@ -2,17 +2,32 @@
import { sew, withAuth, withOrgMembership } from "@/actions";
import { env } from "@/env.mjs";
import { chatIsReadonly, notFound, ServiceError } from "@/lib/serviceError";
import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants";
import { ErrorCode } from "@/lib/errorCodes";
import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError";
import { prisma } from "@/prisma";
import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock';
import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic';
import { createAzure } from '@ai-sdk/azure';
import { createDeepSeek } from '@ai-sdk/deepseek';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { createVertex } from '@ai-sdk/google-vertex';
import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic';
import { createMistral } from '@ai-sdk/mistral';
import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider";
import { createXai } from '@ai-sdk/xai';
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { getTokenFromConfig } from "@sourcebot/crypto";
import { ChatVisibility, OrgRole, Prisma } from "@sourcebot/db";
import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type";
import { loadConfig } from "@sourcebot/shared";
import { generateText, JSONValue } from "ai";
import fs from 'fs';
import { StatusCodes } from "http-status-codes";
import path from 'path';
import { LanguageModelInfo, SBChatMessage } from "./types";
import { loadConfig } from "@sourcebot/shared";
import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type";
import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants";
import { StatusCodes } from "http-status-codes";
import { ErrorCode } from "@/lib/errorCodes";
export const createChat = async (domain: string) => sew(() =>
withAuth((userId) =>
@ -170,6 +185,58 @@ export const updateChatName = async ({ chatId, name }: { chatId: string, name: s
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
);
export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => model.model === languageModelId);
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModelId} is not configured.`,
});
}
const { model } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
const prompt = `Convert this question into a short topic title (max 50 characters).
Rules:
- Do NOT include question words (what, where, how, why, when, which)
- Do NOT end with a question mark
- Capitalize the first letter of the title
- Focus on the subject/topic being discussed
- Make it sound like a file name or category
Examples:
"Where is the authentication code?" "Authentication Code"
"How to setup the database?" "Database Setup"
"What are the API endpoints?" "API Endpoints"
User question: ${message}`;
const result = await generateText({
model,
prompt,
});
await updateChatName({
chatId,
name: result.text,
}, domain);
return {
success: true,
}
})
)
);
export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
@ -303,3 +370,193 @@ export const _getConfiguredLanguageModelsFull = async (): Promise<LanguageModel[
return [];
}
}
export const _getAISDKLanguageModelAndOptions = async (config: LanguageModel, orgId: number): Promise<{
model: AISDKLanguageModelV2,
providerOptions?: Record<string, Record<string, JSONValue>>,
headers?: Record<string, string>,
}> => {
const { provider, model: modelId } = config;
switch (provider) {
case 'amazon-bedrock': {
const aws = createAmazonBedrock({
baseURL: config.baseUrl,
region: config.region ?? env.AWS_REGION,
accessKeyId: config.accessKeyId
? await getTokenFromConfig(config.accessKeyId, orgId, prisma)
: env.AWS_ACCESS_KEY_ID,
secretAccessKey: config.accessKeySecret
? await getTokenFromConfig(config.accessKeySecret, orgId, prisma)
: env.AWS_SECRET_ACCESS_KEY,
});
return {
model: aws(modelId),
};
}
case 'anthropic': {
const anthropic = createAnthropic({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.ANTHROPIC_API_KEY,
});
return {
model: anthropic(modelId),
providerOptions: {
anthropic: {
thinking: {
type: "enabled",
budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS,
}
} satisfies AnthropicProviderOptions,
},
headers: {
// @see: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
'anthropic-beta': 'interleaved-thinking-2025-05-14',
},
};
}
case 'azure': {
const azure = createAzure({
baseURL: config.baseUrl,
apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.AZURE_API_KEY,
apiVersion: config.apiVersion,
resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME,
});
return {
model: azure(modelId),
};
}
case 'deepseek': {
const deepseek = createDeepSeek({
baseURL: config.baseUrl,
apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.DEEPSEEK_API_KEY,
});
return {
model: deepseek(modelId),
};
}
case 'google-generative-ai': {
const google = createGoogleGenerativeAI({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.GOOGLE_GENERATIVE_AI_API_KEY,
});
return {
model: google(modelId),
};
}
case 'google-vertex': {
const vertex = createVertex({
project: config.project ?? env.GOOGLE_VERTEX_PROJECT,
location: config.region ?? env.GOOGLE_VERTEX_REGION,
...(config.credentials ? {
googleAuthOptions: {
keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma),
}
} : {}),
});
return {
model: vertex(modelId),
providerOptions: {
google: {
thinkingConfig: {
thinkingBudget: env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS,
includeThoughts: env.GOOGLE_VERTEX_INCLUDE_THOUGHTS === 'true',
}
}
},
};
}
case 'google-vertex-anthropic': {
const vertexAnthropic = createVertexAnthropic({
project: config.project ?? env.GOOGLE_VERTEX_PROJECT,
location: config.region ?? env.GOOGLE_VERTEX_REGION,
...(config.credentials ? {
googleAuthOptions: {
keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma),
}
} : {}),
});
return {
model: vertexAnthropic(modelId),
};
}
case 'mistral': {
const mistral = createMistral({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.MISTRAL_API_KEY,
});
return {
model: mistral(modelId),
};
}
case 'openai': {
const openai = createOpenAI({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.OPENAI_API_KEY,
});
return {
model: openai(modelId),
providerOptions: {
openai: {
reasoningEffort: config.reasoningEffort ?? 'medium',
} satisfies OpenAIResponsesProviderOptions,
},
};
}
case 'openai-compatible': {
const openai = createOpenAICompatible({
baseURL: config.baseUrl,
name: config.displayName ?? modelId,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: undefined,
});
return {
model: openai.chatModel(modelId),
}
}
case 'openrouter': {
const openrouter = createOpenRouter({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.OPENROUTER_API_KEY,
});
return {
model: openrouter(modelId),
};
}
case 'xai': {
const xai = createXai({
baseURL: config.baseUrl,
apiKey: config.token
? await getTokenFromConfig(config.token, orgId, prisma)
: env.XAI_API_KEY,
});
return {
model: xai(modelId),
};
}
}
}

View file

@ -23,6 +23,8 @@ import { ErrorBanner } from './errorBanner';
import { useRouter } from 'next/navigation';
import { usePrevious } from '@uidotdev/usehooks';
import { RepositoryQuery, SearchContextQuery } from '@/lib/types';
import { generateAndUpdateChatNameFromMessage } from '../../actions';
import { isServiceError } from '@/lib/utils';
type ChatHistoryState = {
scrollOffset?: number;
@ -118,15 +120,58 @@ export const ChatThread = ({
selectedSearchScopes,
languageModelId: selectedLanguageModel.model,
} satisfies AdditionalChatRequestParams,
});
}, [_sendMessage, selectedLanguageModel, toast, selectedSearchScopes]);
});
if (
messages.length === 0 &&
message.parts.length > 0 &&
message.parts[0].type === 'text'
) {
generateAndUpdateChatNameFromMessage(
{
chatId,
languageModelId: selectedLanguageModel.model,
message: message.parts[0].text,
},
domain
).then((response) => {
if (isServiceError(response)) {
toast({
description: `❌ Failed to generate chat name. Reason: ${response.message}`,
variant: "destructive",
});
}
// Refresh the page to update the chat name.
router.refresh();
});
}
}, [
selectedLanguageModel,
_sendMessage,
selectedSearchScopes,
messages.length,
toast,
chatId,
domain,
router,
]);
const messagePairs = useMessagePairs(messages);
useNavigationGuard({
enabled: status === "streaming" || status === "submitted",
confirm: () => window.confirm("You have unsaved changes that will be lost.")
enabled: ({ type }) => {
// @note: a "refresh" in this context means we have triggered a client side
// refresh via `router.refresh()`, and not the user pressing "CMD+R"
// (that would be a "beforeunload" event). We can safely peform refreshes
// without loosing any unsaved changes.
if (type === "refresh") {
return false;
}
return status === "streaming" || status === "submitted";
},
confirm: () => window.confirm("You have unsaved changes that will be lost."),
});
// When the chat is finished, refresh the page to update the chat history.

View file

@ -1,6 +1,6 @@
import { CreateUIMessage, TextUIPart, UIMessagePart } from "ai"
import { Descendant, Editor, Point, Range, Transforms } from "slate"
import { ANSWER_TAG, FILE_REFERENCE_PREFIX, FILE_REFERENCE_REGEX } from "./constants"
import { CreateUIMessage, TextUIPart, UIMessagePart } from "ai";
import { Descendant, Editor, Point, Range, Transforms } from "slate";
import { ANSWER_TAG, FILE_REFERENCE_PREFIX, FILE_REFERENCE_REGEX } from "./constants";
import {
CustomEditor,
CustomText,
@ -14,7 +14,7 @@ import {
SBChatMessageToolTypes,
SearchScope,
Source,
} from "./types"
} from "./types";
export const insertMention = (editor: CustomEditor, data: MentionData, target?: Range | null) => {
const mention: MentionElement = {