diff --git a/CHANGELOG.md b/CHANGELOG.md index e5e45747..edf2bf00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Removed prefix from structured log output. [#443](https://github.com/sourcebot-dev/sourcebot/pull/443) +- [ask sb] Fixed long generation times for first message in a chat thread. [#447](https://github.com/sourcebot-dev/sourcebot/pull/447) ### Changed - Bumped AI SDK and associated packages version. [#444](https://github.com/sourcebot-dev/sourcebot/pull/444) diff --git a/packages/web/src/app/api/(server)/chat/route.ts b/packages/web/src/app/api/(server)/chat/route.ts index ae789272..a14140a0 100644 --- a/packages/web/src/app/api/(server)/chat/route.ts +++ b/packages/web/src/app/api/(server)/chat/route.ts @@ -1,6 +1,5 @@ import { sew, withAuth, withOrgMembership } from "@/actions"; -import { env } from "@/env.mjs"; -import { _getConfiguredLanguageModelsFull, updateChatMessages, updateChatName } from "@/features/chat/actions"; +import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions"; import { createAgentStream } from "@/features/chat/agent"; import { additionalChatRequestParamsSchema, SBChatMessage, SearchScope } from "@/features/chat/types"; import { getAnswerPartFromAssistantMessage } from "@/features/chat/utils"; @@ -8,33 +7,18 @@ import { ErrorCode } from "@/lib/errorCodes"; import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError"; import { isServiceError } from "@/lib/utils"; import { prisma } from "@/prisma"; -import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; -import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; -import { createAzure } from '@ai-sdk/azure'; -import { createDeepSeek } from '@ai-sdk/deepseek'; -import { createGoogleGenerativeAI } from '@ai-sdk/google'; -import { createVertex } from '@ai-sdk/google-vertex'; -import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic'; -import { createMistral } from '@ai-sdk/mistral'; -import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai"; -import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider"; -import { createXai } from '@ai-sdk/xai'; -import { createOpenRouter } from '@openrouter/ai-sdk-provider'; import * as Sentry from "@sentry/nextjs"; -import { getTokenFromConfig } from "@sourcebot/crypto"; import { OrgRole } from "@sourcebot/db"; import { createLogger } from "@sourcebot/logger"; -import { LanguageModel } from "@sourcebot/schemas/v3/index.type"; import { createUIMessageStream, createUIMessageStreamResponse, - generateText, JSONValue, ModelMessage, StreamTextResult, UIMessageStreamOptions, - UIMessageStreamWriter, + UIMessageStreamWriter } from "ai"; import { randomUUID } from "crypto"; import { StatusCodes } from "http-status-codes"; @@ -66,12 +50,60 @@ export async function POST(req: Request) { } const { messages, id, selectedSearchScopes, languageModelId } = parsed.data; - const response = await chatHandler({ - messages, - id, - selectedSearchScopes, - languageModelId, - }, domain); + + const response = await sew(() => + withAuth((userId) => + withOrgMembership(userId, domain, async ({ org }) => { + // Validate that the chat exists and is not readonly. + const chat = await prisma.chat.findUnique({ + where: { + orgId: org.id, + id, + }, + }); + + if (!chat) { + return notFound(); + } + + if (chat.isReadonly) { + return serviceErrorResponse({ + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: "Chat is readonly and cannot be edited.", + }); + } + + // From the language model ID, attempt to find the + // corresponding config in `config.json`. + const languageModelConfig = + (await _getConfiguredLanguageModelsFull()) + .find((model) => model.model === languageModelId); + + if (!languageModelConfig) { + return serviceErrorResponse({ + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: `Language model ${languageModelId} is not configured.`, + }); + } + + const { model, providerOptions, headers } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id); + + return createMessageStreamResponse({ + messages, + id, + selectedSearchScopes, + model, + modelName: languageModelConfig.displayName ?? languageModelConfig.model, + modelProviderOptions: providerOptions, + modelHeaders: headers, + domain, + orgId: org.id, + }); + }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true + ) + ) if (isServiceError(response)) { return serviceErrorResponse(response); @@ -90,400 +122,146 @@ const mergeStreamAsync = async (stream: StreamTextResult, writer: UIMe }))); } -interface ChatHandlerProps { +interface CreateMessageStreamResponseProps { messages: SBChatMessage[]; id: string; selectedSearchScopes: SearchScope[]; - languageModelId: string; + model: AISDKLanguageModelV2; + modelName: string; + modelProviderOptions?: Record>; + modelHeaders?: Record; + domain: string; + orgId: number; } -const chatHandler = ({ messages, id, selectedSearchScopes, languageModelId }: ChatHandlerProps, domain: string) => sew(async () => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - const chat = await prisma.chat.findUnique({ - where: { - orgId: org.id, - id, - }, +const createMessageStreamResponse = async ({ + messages, + id, + selectedSearchScopes, + model, + modelName, + modelProviderOptions, + modelHeaders, + domain, + orgId, +}: CreateMessageStreamResponseProps) => { + const latestMessage = messages[messages.length - 1]; + const sources = latestMessage.parts + .filter((part) => part.type === 'data-source') + .map((part) => part.data); + + const traceId = randomUUID(); + + // Extract user messages and assistant answers. + // We will use this as the context we carry between messages. + const messageHistory = + messages.map((message): ModelMessage | undefined => { + if (message.role === 'user') { + return { + role: 'user', + content: message.parts[0].type === 'text' ? message.parts[0].text : '', + }; + } + + if (message.role === 'assistant') { + const answerPart = getAnswerPartFromAssistantMessage(message, false); + if (answerPart) { + return { + role: 'assistant', + content: [answerPart] + } + } + } + }).filter(message => message !== undefined); + + const stream = createUIMessageStream({ + execute: async ({ writer }) => { + writer.write({ + type: 'start', }); - if (!chat) { - return notFound(); - } + const startTime = new Date(); - if (chat.isReadonly) { - return serviceErrorResponse({ - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.INVALID_REQUEST_BODY, - message: "Chat is readonly and cannot be edited.", - }); - } + const expandedReposArrays = await Promise.all(selectedSearchScopes.map(async (scope) => { + if (scope.type === 'repo') { + return [scope.value]; + } - const latestMessage = messages[messages.length - 1]; - const sources = latestMessage.parts - .filter((part) => part.type === 'data-source') - .map((part) => part.data); - - // From the language model ID, attempt to find the - // corresponding config in `config.json`. - const languageModelConfig = - (await _getConfiguredLanguageModelsFull()) - .find((model) => model.model === languageModelId); - - if (!languageModelConfig) { - return serviceErrorResponse({ - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.INVALID_REQUEST_BODY, - message: `Language model ${languageModelId} is not configured.`, - }); - } - - const { model, providerOptions, headers } = await getAISDKLanguageModelAndOptions(languageModelConfig, org.id); - - if ( - messages.length === 1 && - messages[0].role === "user" && - messages[0].parts.length >= 1 && - messages[0].parts[0].type === 'text' - ) { - const content = messages[0].parts[0].text; - - const title = await generateChatTitle(content, model); - await updateChatName({ - chatId: id, - name: title, - }, domain); - } - - const traceId = randomUUID(); - - // Extract user messages and assistant answers. - // We will use this as the context we carry between messages. - const messageHistory = - messages.map((message): ModelMessage | undefined => { - if (message.role === 'user') { - return { - role: 'user', - content: message.parts[0].type === 'text' ? message.parts[0].text : '', - }; - } - - if (message.role === 'assistant') { - const answerPart = getAnswerPartFromAssistantMessage(message, false); - if (answerPart) { - return { - role: 'assistant', - content: [answerPart] - } - } - } - }).filter(message => message !== undefined); - - const stream = createUIMessageStream({ - execute: async ({ writer }) => { - writer.write({ - type: 'start', - }); - - const startTime = new Date(); - - const expandedReposArrays = await Promise.all(selectedSearchScopes.map(async (scope) => { - if (scope.type === 'repo') { - return [scope.value]; - } - - if (scope.type === 'reposet') { - const reposet = await prisma.searchContext.findFirst({ - where: { - orgId: org.id, - name: scope.value - }, - include: { - repos: true - } - }); - - if (reposet) { - return reposet.repos.map(repo => repo.name); - } - } - - return []; - })); - const expandedRepos = expandedReposArrays.flat(); - - const researchStream = await createAgentStream({ - model, - providerOptions, - headers, - inputMessages: messageHistory, - inputSources: sources, - searchScopeRepoNames: expandedRepos, - onWriteSource: (source) => { - writer.write({ - type: 'data-source', - data: source, - }); + if (scope.type === 'reposet') { + const reposet = await prisma.searchContext.findFirst({ + where: { + orgId, + name: scope.value }, - traceId, - }); - - await mergeStreamAsync(researchStream, writer, { - sendReasoning: true, - sendStart: false, - sendFinish: false, - }); - - const totalUsage = await researchStream.totalUsage; - - writer.write({ - type: 'message-metadata', - messageMetadata: { - totalTokens: totalUsage.totalTokens, - totalInputTokens: totalUsage.inputTokens, - totalOutputTokens: totalUsage.outputTokens, - totalResponseTimeMs: new Date().getTime() - startTime.getTime(), - modelName: languageModelConfig.displayName ?? languageModelConfig.model, - selectedSearchScopes, - traceId, + include: { + repos: true } - }) + }); + if (reposet) { + return reposet.repos.map(repo => repo.name); + } + } + return []; + })); + const expandedRepos = expandedReposArrays.flat(); + + const researchStream = await createAgentStream({ + model, + providerOptions: modelProviderOptions, + headers: modelHeaders, + inputMessages: messageHistory, + inputSources: sources, + searchScopeRepoNames: expandedRepos, + onWriteSource: (source) => { writer.write({ - type: 'finish', + type: 'data-source', + data: source, }); }, - onError: errorHandler, - originalMessages: messages, - onFinish: async ({ messages }) => { - await updateChatMessages({ - chatId: id, - messages - }, domain); - }, + traceId, }); - return createUIMessageStreamResponse({ - stream, + await mergeStreamAsync(researchStream, writer, { + sendReasoning: true, + sendStart: false, + sendFinish: false, }); - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true - )); -const generateChatTitle = async (message: string, model: AISDKLanguageModelV2) => { - const prompt = `Convert this question into a short topic title (max 50 characters). + const totalUsage = await researchStream.totalUsage; -Rules: -- Do NOT include question words (what, where, how, why, when, which) -- Do NOT end with a question mark -- Capitalize the first letter of the title -- Focus on the subject/topic being discussed -- Make it sound like a file name or category + writer.write({ + type: 'message-metadata', + messageMetadata: { + totalTokens: totalUsage.totalTokens, + totalInputTokens: totalUsage.inputTokens, + totalOutputTokens: totalUsage.outputTokens, + totalResponseTimeMs: new Date().getTime() - startTime.getTime(), + modelName, + selectedSearchScopes, + traceId, + } + }); -Examples: -"Where is the authentication code?" → "Authentication Code" -"How to setup the database?" → "Database Setup" -"What are the API endpoints?" → "API Endpoints" - -User question: ${message}`; - - const result = await generateText({ - model, - prompt, + writer.write({ + type: 'finish', + }); + }, + onError: errorHandler, + originalMessages: messages, + onFinish: async ({ messages }) => { + await updateChatMessages({ + chatId: id, + messages + }, domain); + }, }); - return result.text; -} - -const getAISDKLanguageModelAndOptions = async (config: LanguageModel, orgId: number): Promise<{ - model: AISDKLanguageModelV2, - providerOptions?: Record>, - headers?: Record, -}> => { - - const { provider, model: modelId } = config; - - switch (provider) { - case 'amazon-bedrock': { - const aws = createAmazonBedrock({ - baseURL: config.baseUrl, - region: config.region ?? env.AWS_REGION, - accessKeyId: config.accessKeyId - ? await getTokenFromConfig(config.accessKeyId, orgId, prisma) - : env.AWS_ACCESS_KEY_ID, - secretAccessKey: config.accessKeySecret - ? await getTokenFromConfig(config.accessKeySecret, orgId, prisma) - : env.AWS_SECRET_ACCESS_KEY, - }); - - return { - model: aws(modelId), - }; - } - case 'anthropic': { - const anthropic = createAnthropic({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token, orgId, prisma) - : env.ANTHROPIC_API_KEY, - }); - - return { - model: anthropic(modelId), - providerOptions: { - anthropic: { - thinking: { - type: "enabled", - budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS, - } - } satisfies AnthropicProviderOptions, - }, - headers: { - // @see: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking - 'anthropic-beta': 'interleaved-thinking-2025-05-14', - }, - }; - } - case 'azure': { - const azure = createAzure({ - baseURL: config.baseUrl, - apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.AZURE_API_KEY, - apiVersion: config.apiVersion, - resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME, - }); - - return { - model: azure(modelId), - }; - } - case 'deepseek': { - const deepseek = createDeepSeek({ - baseURL: config.baseUrl, - apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.DEEPSEEK_API_KEY, - }); - - return { - model: deepseek(modelId), - }; - } - case 'google-generative-ai': { - const google = createGoogleGenerativeAI({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token, orgId, prisma) - : env.GOOGLE_GENERATIVE_AI_API_KEY, - }); - - return { - model: google(modelId), - }; - } - case 'google-vertex': { - const vertex = createVertex({ - project: config.project ?? env.GOOGLE_VERTEX_PROJECT, - location: config.region ?? env.GOOGLE_VERTEX_REGION, - ...(config.credentials ? { - googleAuthOptions: { - keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma), - } - } : {}), - }); - - return { - model: vertex(modelId), - providerOptions: { - google: { - thinkingConfig: { - thinkingBudget: env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS, - includeThoughts: env.GOOGLE_VERTEX_INCLUDE_THOUGHTS === 'true', - } - } - }, - }; - } - case 'google-vertex-anthropic': { - const vertexAnthropic = createVertexAnthropic({ - project: config.project ?? env.GOOGLE_VERTEX_PROJECT, - location: config.region ?? env.GOOGLE_VERTEX_REGION, - ...(config.credentials ? { - googleAuthOptions: { - keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma), - } - } : {}), - }); - - return { - model: vertexAnthropic(modelId), - }; - } - case 'mistral': { - const mistral = createMistral({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token, orgId, prisma) - : env.MISTRAL_API_KEY, - }); - - return { - model: mistral(modelId), - }; - } - case 'openai': { - const openai = createOpenAI({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token, orgId, prisma) - : env.OPENAI_API_KEY, - }); - - return { - model: openai(modelId), - providerOptions: { - openai: { - reasoningEffort: config.reasoningEffort ?? 'medium' - } satisfies OpenAIResponsesProviderOptions, - }, - }; - } - case 'openai-compatible': { - const openai = createOpenAICompatible({ - baseURL: config.baseUrl, - name: config.displayName ?? modelId, - apiKey: config.token - ? await getTokenFromConfig(config.token, orgId, prisma) - : undefined, - }); - - return { - model: openai.chatModel(modelId), - } - } - case 'openrouter': { - const openrouter = createOpenRouter({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token, orgId, prisma) - : env.OPENROUTER_API_KEY, - }); - - return { - model: openrouter(modelId), - }; - } - case 'xai': { - const xai = createXai({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token, orgId, prisma) - : env.XAI_API_KEY, - }); - - return { - model: xai(modelId), - }; - } - } -} + return createUIMessageStreamResponse({ + stream, + }); +}; const errorHandler = (error: unknown) => { logger.error(error); diff --git a/packages/web/src/features/chat/actions.ts b/packages/web/src/features/chat/actions.ts index 17a1335c..ac73ae6e 100644 --- a/packages/web/src/features/chat/actions.ts +++ b/packages/web/src/features/chat/actions.ts @@ -2,17 +2,32 @@ import { sew, withAuth, withOrgMembership } from "@/actions"; import { env } from "@/env.mjs"; -import { chatIsReadonly, notFound, ServiceError } from "@/lib/serviceError"; +import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants"; +import { ErrorCode } from "@/lib/errorCodes"; +import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError"; import { prisma } from "@/prisma"; +import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; +import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; +import { createAzure } from '@ai-sdk/azure'; +import { createDeepSeek } from '@ai-sdk/deepseek'; +import { createGoogleGenerativeAI } from '@ai-sdk/google'; +import { createVertex } from '@ai-sdk/google-vertex'; +import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic'; +import { createMistral } from '@ai-sdk/mistral'; +import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai"; +import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; +import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider"; +import { createXai } from '@ai-sdk/xai'; +import { createOpenRouter } from '@openrouter/ai-sdk-provider'; +import { getTokenFromConfig } from "@sourcebot/crypto"; import { ChatVisibility, OrgRole, Prisma } from "@sourcebot/db"; +import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type"; +import { loadConfig } from "@sourcebot/shared"; +import { generateText, JSONValue } from "ai"; import fs from 'fs'; +import { StatusCodes } from "http-status-codes"; import path from 'path'; import { LanguageModelInfo, SBChatMessage } from "./types"; -import { loadConfig } from "@sourcebot/shared"; -import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type"; -import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants"; -import { StatusCodes } from "http-status-codes"; -import { ErrorCode } from "@/lib/errorCodes"; export const createChat = async (domain: string) => sew(() => withAuth((userId) => @@ -170,6 +185,58 @@ export const updateChatName = async ({ chatId, name }: { chatId: string, name: s }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) ); +export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() => + withAuth((userId) => + withOrgMembership(userId, domain, async ({ org }) => { + // From the language model ID, attempt to find the + // corresponding config in `config.json`. + const languageModelConfig = + (await _getConfiguredLanguageModelsFull()) + .find((model) => model.model === languageModelId); + + if (!languageModelConfig) { + return serviceErrorResponse({ + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: `Language model ${languageModelId} is not configured.`, + }); + } + + const { model } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id); + + const prompt = `Convert this question into a short topic title (max 50 characters). + +Rules: +- Do NOT include question words (what, where, how, why, when, which) +- Do NOT end with a question mark +- Capitalize the first letter of the title +- Focus on the subject/topic being discussed +- Make it sound like a file name or category + +Examples: +"Where is the authentication code?" → "Authentication Code" +"How to setup the database?" → "Database Setup" +"What are the API endpoints?" → "API Endpoints" + +User question: ${message}`; + + const result = await generateText({ + model, + prompt, + }); + + await updateChatName({ + chatId, + name: result.text, + }, domain); + + return { + success: true, + } + }) + ) +); + export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() => withAuth((userId) => withOrgMembership(userId, domain, async ({ org }) => { @@ -303,3 +370,193 @@ export const _getConfiguredLanguageModelsFull = async (): Promise>, + headers?: Record, +}> => { + const { provider, model: modelId } = config; + + switch (provider) { + case 'amazon-bedrock': { + const aws = createAmazonBedrock({ + baseURL: config.baseUrl, + region: config.region ?? env.AWS_REGION, + accessKeyId: config.accessKeyId + ? await getTokenFromConfig(config.accessKeyId, orgId, prisma) + : env.AWS_ACCESS_KEY_ID, + secretAccessKey: config.accessKeySecret + ? await getTokenFromConfig(config.accessKeySecret, orgId, prisma) + : env.AWS_SECRET_ACCESS_KEY, + }); + + return { + model: aws(modelId), + }; + } + case 'anthropic': { + const anthropic = createAnthropic({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token, orgId, prisma) + : env.ANTHROPIC_API_KEY, + }); + + return { + model: anthropic(modelId), + providerOptions: { + anthropic: { + thinking: { + type: "enabled", + budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS, + } + } satisfies AnthropicProviderOptions, + }, + headers: { + // @see: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking + 'anthropic-beta': 'interleaved-thinking-2025-05-14', + }, + }; + } + case 'azure': { + const azure = createAzure({ + baseURL: config.baseUrl, + apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.AZURE_API_KEY, + apiVersion: config.apiVersion, + resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME, + }); + + return { + model: azure(modelId), + }; + } + case 'deepseek': { + const deepseek = createDeepSeek({ + baseURL: config.baseUrl, + apiKey: config.token ? (await getTokenFromConfig(config.token, orgId, prisma)) : env.DEEPSEEK_API_KEY, + }); + + return { + model: deepseek(modelId), + }; + } + case 'google-generative-ai': { + const google = createGoogleGenerativeAI({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token, orgId, prisma) + : env.GOOGLE_GENERATIVE_AI_API_KEY, + }); + + return { + model: google(modelId), + }; + } + case 'google-vertex': { + const vertex = createVertex({ + project: config.project ?? env.GOOGLE_VERTEX_PROJECT, + location: config.region ?? env.GOOGLE_VERTEX_REGION, + ...(config.credentials ? { + googleAuthOptions: { + keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma), + } + } : {}), + }); + + return { + model: vertex(modelId), + providerOptions: { + google: { + thinkingConfig: { + thinkingBudget: env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS, + includeThoughts: env.GOOGLE_VERTEX_INCLUDE_THOUGHTS === 'true', + } + } + }, + }; + } + case 'google-vertex-anthropic': { + const vertexAnthropic = createVertexAnthropic({ + project: config.project ?? env.GOOGLE_VERTEX_PROJECT, + location: config.region ?? env.GOOGLE_VERTEX_REGION, + ...(config.credentials ? { + googleAuthOptions: { + keyFilename: await getTokenFromConfig(config.credentials, orgId, prisma), + } + } : {}), + }); + + return { + model: vertexAnthropic(modelId), + }; + } + case 'mistral': { + const mistral = createMistral({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token, orgId, prisma) + : env.MISTRAL_API_KEY, + }); + + return { + model: mistral(modelId), + }; + } + case 'openai': { + const openai = createOpenAI({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token, orgId, prisma) + : env.OPENAI_API_KEY, + }); + + return { + model: openai(modelId), + providerOptions: { + openai: { + reasoningEffort: config.reasoningEffort ?? 'medium', + } satisfies OpenAIResponsesProviderOptions, + }, + }; + } + case 'openai-compatible': { + const openai = createOpenAICompatible({ + baseURL: config.baseUrl, + name: config.displayName ?? modelId, + apiKey: config.token + ? await getTokenFromConfig(config.token, orgId, prisma) + : undefined, + }); + + return { + model: openai.chatModel(modelId), + } + } + case 'openrouter': { + const openrouter = createOpenRouter({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token, orgId, prisma) + : env.OPENROUTER_API_KEY, + }); + + return { + model: openrouter(modelId), + }; + } + case 'xai': { + const xai = createXai({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token, orgId, prisma) + : env.XAI_API_KEY, + }); + + return { + model: xai(modelId), + }; + } + } +} \ No newline at end of file diff --git a/packages/web/src/features/chat/components/chatThread/chatThread.tsx b/packages/web/src/features/chat/components/chatThread/chatThread.tsx index 62e066ac..0c58ae67 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThread.tsx +++ b/packages/web/src/features/chat/components/chatThread/chatThread.tsx @@ -23,6 +23,8 @@ import { ErrorBanner } from './errorBanner'; import { useRouter } from 'next/navigation'; import { usePrevious } from '@uidotdev/usehooks'; import { RepositoryQuery, SearchContextQuery } from '@/lib/types'; +import { generateAndUpdateChatNameFromMessage } from '../../actions'; +import { isServiceError } from '@/lib/utils'; type ChatHistoryState = { scrollOffset?: number; @@ -118,15 +120,58 @@ export const ChatThread = ({ selectedSearchScopes, languageModelId: selectedLanguageModel.model, } satisfies AdditionalChatRequestParams, - }); - }, [_sendMessage, selectedLanguageModel, toast, selectedSearchScopes]); + }); + + if ( + messages.length === 0 && + message.parts.length > 0 && + message.parts[0].type === 'text' + ) { + generateAndUpdateChatNameFromMessage( + { + chatId, + languageModelId: selectedLanguageModel.model, + message: message.parts[0].text, + }, + domain + ).then((response) => { + if (isServiceError(response)) { + toast({ + description: `❌ Failed to generate chat name. Reason: ${response.message}`, + variant: "destructive", + }); + } + // Refresh the page to update the chat name. + router.refresh(); + }); + } + }, [ + selectedLanguageModel, + _sendMessage, + selectedSearchScopes, + messages.length, + toast, + chatId, + domain, + router, + ]); const messagePairs = useMessagePairs(messages); useNavigationGuard({ - enabled: status === "streaming" || status === "submitted", - confirm: () => window.confirm("You have unsaved changes that will be lost.") + enabled: ({ type }) => { + // @note: a "refresh" in this context means we have triggered a client side + // refresh via `router.refresh()`, and not the user pressing "CMD+R" + // (that would be a "beforeunload" event). We can safely peform refreshes + // without loosing any unsaved changes. + if (type === "refresh") { + return false; + } + + return status === "streaming" || status === "submitted"; + }, + confirm: () => window.confirm("You have unsaved changes that will be lost."), }); // When the chat is finished, refresh the page to update the chat history. diff --git a/packages/web/src/features/chat/utils.ts b/packages/web/src/features/chat/utils.ts index 6207cb0e..c57f3fd5 100644 --- a/packages/web/src/features/chat/utils.ts +++ b/packages/web/src/features/chat/utils.ts @@ -1,6 +1,6 @@ -import { CreateUIMessage, TextUIPart, UIMessagePart } from "ai" -import { Descendant, Editor, Point, Range, Transforms } from "slate" -import { ANSWER_TAG, FILE_REFERENCE_PREFIX, FILE_REFERENCE_REGEX } from "./constants" +import { CreateUIMessage, TextUIPart, UIMessagePart } from "ai"; +import { Descendant, Editor, Point, Range, Transforms } from "slate"; +import { ANSWER_TAG, FILE_REFERENCE_PREFIX, FILE_REFERENCE_REGEX } from "./constants"; import { CustomEditor, CustomText, @@ -14,7 +14,7 @@ import { SBChatMessageToolTypes, SearchScope, Source, -} from "./types" +} from "./types"; export const insertMention = (editor: CustomEditor, data: MentionData, target?: Range | null) => { const mention: MentionElement = {