diff --git a/packages/web/src/app/[domain]/browse/[...path]/components/pureCodePreviewPanel.tsx b/packages/web/src/app/[domain]/browse/[...path]/components/pureCodePreviewPanel.tsx index 49f08be8..bdf6f878 100644 --- a/packages/web/src/app/[domain]/browse/[...path]/components/pureCodePreviewPanel.tsx +++ b/packages/web/src/app/[domain]/browse/[...path]/components/pureCodePreviewPanel.tsx @@ -19,7 +19,6 @@ import { useBrowseState } from "../../hooks/useBrowseState"; import { rangeHighlightingExtension } from "./rangeHighlightingExtension"; import useCaptureEvent from "@/hooks/useCaptureEvent"; import { createAuditAction } from "@/ee/features/audit/actions"; -import { useDomain } from "@/hooks/useDomain"; interface PureCodePreviewPanelProps { path: string; @@ -43,7 +42,6 @@ export const PureCodePreviewPanel = ({ const hasCodeNavEntitlement = useHasEntitlement("code-nav"); const { updateBrowseState } = useBrowseState(); const { navigateToPath } = useBrowseNavigation(); - const domain = useDomain(); const captureEvent = useCaptureEvent(); const highlightRangeQuery = useNonEmptyQueryParam(HIGHLIGHT_RANGE_QUERY_PARAM); @@ -145,7 +143,7 @@ export const PureCodePreviewPanel = ({ metadata: { message: symbolName, }, - }, domain) + }) updateBrowseState({ selectedSymbolInfo: { @@ -157,7 +155,7 @@ export const PureCodePreviewPanel = ({ isBottomPanelCollapsed: false, activeExploreMenuTab: "references", }) - }, [captureEvent, updateBrowseState, repoName, revisionName, language, domain]); + }, [captureEvent, updateBrowseState, repoName, revisionName, language]); // If we resolve multiple matches, instead of navigating to the first match, we should @@ -171,7 +169,7 @@ export const PureCodePreviewPanel = ({ metadata: { message: symbolName, }, - }, domain) + }) if (symbolDefinitions.length === 0) { return; @@ -200,7 +198,7 @@ export const PureCodePreviewPanel = ({ isBottomPanelCollapsed: false, }) } - }, [captureEvent, navigateToPath, revisionName, updateBrowseState, repoName, language, domain]); + }, [captureEvent, navigateToPath, revisionName, updateBrowseState, repoName, language]); const theme = useCodeMirrorTheme(); diff --git a/packages/web/src/app/[domain]/chat/[id]/page.tsx b/packages/web/src/app/[domain]/chat/[id]/page.tsx index 24cd94e9..c22cbd0c 100644 --- a/packages/web/src/app/[domain]/chat/[id]/page.tsx +++ b/packages/web/src/app/[domain]/chat/[id]/page.tsx @@ -24,9 +24,9 @@ export default async function Page(props: PageProps) { const languageModels = await getConfiguredLanguageModelsInfo(); const repos = await getRepos(); const searchContexts = await getSearchContexts(params.domain); - const chatInfo = await getChatInfo({ chatId: params.id }, params.domain); + const chatInfo = await getChatInfo({ chatId: params.id }); const session = await auth(); - const chatHistory = session ? await getUserChatHistory(params.domain) : []; + const chatHistory = session ? await getUserChatHistory() : []; if (isServiceError(chatHistory)) { throw new ServiceErrorException(chatHistory); diff --git a/packages/web/src/app/[domain]/chat/components/chatName.tsx b/packages/web/src/app/[domain]/chat/components/chatName.tsx index b305d3d7..147fbcf5 100644 --- a/packages/web/src/app/[domain]/chat/components/chatName.tsx +++ b/packages/web/src/app/[domain]/chat/components/chatName.tsx @@ -4,7 +4,6 @@ import { useToast } from "@/components/hooks/use-toast"; import { Badge } from "@/components/ui/badge"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { updateChatName } from "@/features/chat/actions"; -import { useDomain } from "@/hooks/useDomain"; import { isServiceError } from "@/lib/utils"; import { GlobeIcon } from "@radix-ui/react-icons"; import { ChatVisibility } from "@sourcebot/db"; @@ -23,7 +22,6 @@ interface ChatNameProps { export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) => { const [isRenameDialogOpen, setIsRenameDialogOpen] = useState(false); const { toast } = useToast(); - const domain = useDomain(); const router = useRouter(); const onRenameChat = useCallback(async (name: string) => { @@ -31,7 +29,7 @@ export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) => const response = await updateChatName({ chatId: id, name: name, - }, domain); + }); if (isServiceError(response)) { toast({ @@ -43,7 +41,7 @@ export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) => }); router.refresh(); } - }, [id, domain, toast, router]); + }, [id, toast, router]); return ( <> diff --git a/packages/web/src/app/[domain]/chat/components/chatSidePanel.tsx b/packages/web/src/app/[domain]/chat/components/chatSidePanel.tsx index 16e955f8..ec2fce51 100644 --- a/packages/web/src/app/[domain]/chat/components/chatSidePanel.tsx +++ b/packages/web/src/app/[domain]/chat/components/chatSidePanel.tsx @@ -9,7 +9,6 @@ import { ScrollArea } from "@/components/ui/scroll-area"; import { Separator } from "@/components/ui/separator"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { deleteChat, updateChatName } from "@/features/chat/actions"; -import { useDomain } from "@/hooks/useDomain"; import { cn, isServiceError } from "@/lib/utils"; import { CirclePlusIcon, EllipsisIcon, PencilIcon, TrashIcon } from "lucide-react"; import { useRouter } from "next/navigation"; @@ -23,6 +22,7 @@ import { useChatId } from "../useChatId"; import { RenameChatDialog } from "./renameChatDialog"; import { DeleteChatDialog } from "./deleteChatDialog"; import Link from "next/link"; +import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants"; interface ChatSidePanelProps { order: number; @@ -41,7 +41,6 @@ export const ChatSidePanel = ({ isAuthenticated, isCollapsedInitially, }: ChatSidePanelProps) => { - const domain = useDomain(); const [isCollapsed, setIsCollapsed] = useState(isCollapsedInitially); const sidePanelRef = useRef(null); const router = useRouter(); @@ -72,7 +71,7 @@ export const ChatSidePanel = ({ const response = await updateChatName({ chatId, name: name, - }, domain); + }); if (isServiceError(response)) { toast({ @@ -84,14 +83,14 @@ export const ChatSidePanel = ({ }); router.refresh(); } - }, [router, toast, domain]); + }, [router, toast]); const onDeleteChat = useCallback(async (chatIdToDelete: string) => { if (!chatIdToDelete) { return; } - const response = await deleteChat({ chatId: chatIdToDelete }, domain); + const response = await deleteChat({ chatId: chatIdToDelete }); if (isServiceError(response)) { toast({ @@ -104,12 +103,12 @@ export const ChatSidePanel = ({ // If we just deleted the current chat, navigate to new chat if (chatIdToDelete === chatId) { - router.push(`/${domain}/chat`); + router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`); } router.refresh(); } - }, [chatId, router, toast, domain]); + }, [chatId, router, toast]); return ( <> @@ -131,7 +130,7 @@ export const ChatSidePanel = ({ size="sm" className="w-full" onClick={() => { - router.push(`/${domain}/chat`); + router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`); }} > @@ -145,7 +144,7 @@ export const ChatSidePanel = ({

Sign in @@ -163,7 +162,7 @@ export const ChatSidePanel = ({ chat.id === chatId && "bg-muted" )} onClick={() => { - router.push(`/${domain}/chat/${chat.id}`); + router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat/${chat.id}`); }} > {chat.name ?? 'Untitled chat'} diff --git a/packages/web/src/app/[domain]/components/searchBar/searchBar.tsx b/packages/web/src/app/[domain]/components/searchBar/searchBar.tsx index 9ec5f664..dda7ab2a 100644 --- a/packages/web/src/app/[domain]/components/searchBar/searchBar.tsx +++ b/packages/web/src/app/[domain]/components/searchBar/searchBar.tsx @@ -221,7 +221,7 @@ export const SearchBar = ({ metadata: { message: query, }, - }, domain) + }) const url = createPathWithQueryParams(`/${domain}/search`, [SearchQueryParams.query, query], diff --git a/packages/web/src/app/[domain]/search/components/codePreviewPanel/codePreview.tsx b/packages/web/src/app/[domain]/search/components/codePreviewPanel/codePreview.tsx index e62917ce..d500bbe9 100644 --- a/packages/web/src/app/[domain]/search/components/codePreviewPanel/codePreview.tsx +++ b/packages/web/src/app/[domain]/search/components/codePreviewPanel/codePreview.tsx @@ -22,8 +22,6 @@ import { symbolHoverTargetsExtension } from "@/ee/features/codeNav/components/sy import { useHasEntitlement } from "@/features/entitlements/useHasEntitlement"; import { SymbolDefinition } from "@/ee/features/codeNav/components/symbolHoverPopup/useHoveredOverSymbolInfo"; import { createAuditAction } from "@/ee/features/audit/actions"; -import { useDomain } from "@/hooks/useDomain"; - import useCaptureEvent from "@/hooks/useCaptureEvent"; export interface CodePreviewFile { @@ -53,7 +51,6 @@ export const CodePreview = ({ const [editorRef, setEditorRef] = useState(null); const { navigateToPath } = useBrowseNavigation(); const hasCodeNavEntitlement = useHasEntitlement("code-nav"); - const domain = useDomain(); const [gutterWidth, setGutterWidth] = useState(0); const theme = useCodeMirrorTheme(); @@ -127,7 +124,7 @@ export const CodePreview = ({ metadata: { message: symbolName, }, - }, domain) + }) if (symbolDefinitions.length === 0) { return; @@ -162,7 +159,7 @@ export const CodePreview = ({ } }); } - }, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName, domain]); + }, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName]); const onFindReferences = useCallback((symbolName: string) => { captureEvent('wa_find_references_pressed', { @@ -173,7 +170,7 @@ export const CodePreview = ({ metadata: { message: symbolName, }, - }, domain) + }) navigateToPath({ repoName, @@ -191,7 +188,7 @@ export const CodePreview = ({ isBottomPanelCollapsed: false, } }) - }, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName, domain]); + }, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName]); return (

diff --git a/packages/web/src/app/api/(server)/chat/route.ts b/packages/web/src/app/api/(server)/chat/route.ts index 1abc1a1c..e0fc2b8b 100644 --- a/packages/web/src/app/api/(server)/chat/route.ts +++ b/packages/web/src/app/api/(server)/chat/route.ts @@ -1,4 +1,4 @@ -import { sew, withAuth, withOrgMembership } from "@/actions"; +import { sew } from "@/actions"; import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions"; import { createAgentStream } from "@/features/chat/agent"; import { additionalChatRequestParamsSchema, LanguageModelInfo, SBChatMessage, SearchScope } from "@/features/chat/types"; @@ -6,10 +6,10 @@ import { getAnswerPartFromAssistantMessage, getLanguageModelKey } from "@/featur import { ErrorCode } from "@/lib/errorCodes"; import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError"; import { isServiceError } from "@/lib/utils"; -import { prisma } from "@/prisma"; +import { withOptionalAuthV2 } from "@/withAuthV2"; import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider"; import * as Sentry from "@sentry/nextjs"; -import { OrgRole } from "@sourcebot/db"; +import { PrismaClient } from "@sourcebot/db"; import { createLogger } from "@sourcebot/shared"; import { createUIMessageStream, @@ -34,15 +34,6 @@ const chatRequestSchema = z.object({ }) export async function POST(req: Request) { - const domain = req.headers.get("X-Org-Domain"); - if (!domain) { - return serviceErrorResponse({ - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER, - message: "Missing X-Org-Domain header", - }); - } - const requestBody = await req.json(); const parsed = await chatRequestSchema.safeParseAsync(requestBody); if (!parsed.success) { @@ -56,56 +47,54 @@ export async function POST(req: Request) { const languageModel = _languageModel as LanguageModelInfo; const response = await sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - // Validate that the chat exists and is not readonly. - const chat = await prisma.chat.findUnique({ - where: { - orgId: org.id, - id, - }, - }); - - if (!chat) { - return notFound(); - } - - if (chat.isReadonly) { - return serviceErrorResponse({ - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.INVALID_REQUEST_BODY, - message: "Chat is readonly and cannot be edited.", - }); - } - - // From the language model ID, attempt to find the - // corresponding config in `config.json`. - const languageModelConfig = - (await _getConfiguredLanguageModelsFull()) - .find((model) => getLanguageModelKey(model) === getLanguageModelKey(languageModel)); - - if (!languageModelConfig) { - return serviceErrorResponse({ - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.INVALID_REQUEST_BODY, - message: `Language model ${languageModel.model} is not configured.`, - }); - } - - const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig); - - return createMessageStreamResponse({ - messages, - id, - selectedSearchScopes, - model, - modelName: languageModelConfig.displayName ?? languageModelConfig.model, - modelProviderOptions: providerOptions, - domain, + withOptionalAuthV2(async ({ org, prisma }) => { + // Validate that the chat exists and is not readonly. + const chat = await prisma.chat.findUnique({ + where: { orgId: org.id, + id, + }, + }); + + if (!chat) { + return notFound(); + } + + if (chat.isReadonly) { + return serviceErrorResponse({ + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: "Chat is readonly and cannot be edited.", }); - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true - ) + } + + // From the language model ID, attempt to find the + // corresponding config in `config.json`. + const languageModelConfig = + (await _getConfiguredLanguageModelsFull()) + .find((model) => getLanguageModelKey(model) === getLanguageModelKey(languageModel)); + + if (!languageModelConfig) { + return serviceErrorResponse({ + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: `Language model ${languageModel.model} is not configured.`, + }); + } + + const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig); + + return createMessageStreamResponse({ + messages, + id, + selectedSearchScopes, + model, + modelName: languageModelConfig.displayName ?? languageModelConfig.model, + modelProviderOptions: providerOptions, + orgId: org.id, + prisma, + }); + }) ) if (isServiceError(response)) { @@ -132,8 +121,8 @@ interface CreateMessageStreamResponseProps { model: AISDKLanguageModelV2; modelName: string; modelProviderOptions?: Record>; - domain: string; orgId: number; + prisma: PrismaClient; } const createMessageStreamResponse = async ({ @@ -143,8 +132,8 @@ const createMessageStreamResponse = async ({ model, modelName, modelProviderOptions, - domain, orgId, + prisma, }: CreateMessageStreamResponseProps) => { const latestMessage = messages[messages.length - 1]; const sources = latestMessage.parts @@ -254,7 +243,7 @@ const createMessageStreamResponse = async ({ await updateChatMessages({ chatId: id, messages - }, domain); + }); }, }); diff --git a/packages/web/src/app/api/(server)/ee/audit/route.ts b/packages/web/src/app/api/(server)/ee/audit/route.ts index c1b8c86c..84d89f26 100644 --- a/packages/web/src/app/api/(server)/ee/audit/route.ts +++ b/packages/web/src/app/api/(server)/ee/audit/route.ts @@ -1,46 +1,25 @@ 'use server'; -import { NextRequest } from "next/server"; import { fetchAuditRecords } from "@/ee/features/audit/actions"; -import { isServiceError } from "@/lib/utils"; -import { serviceErrorResponse } from "@/lib/serviceError"; -import { StatusCodes } from "http-status-codes"; import { ErrorCode } from "@/lib/errorCodes"; -import { env } from "@sourcebot/shared"; +import { serviceErrorResponse } from "@/lib/serviceError"; +import { isServiceError } from "@/lib/utils"; import { getEntitlements } from "@sourcebot/shared"; +import { StatusCodes } from "http-status-codes"; -export const GET = async (request: NextRequest) => { - const domain = request.headers.get("X-Org-Domain"); - const apiKey = request.headers.get("X-Sourcebot-Api-Key") ?? undefined; +export const GET = async () => { + const entitlements = getEntitlements(); + if (!entitlements.includes('audit')) { + return serviceErrorResponse({ + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.NOT_FOUND, + message: "Audit logging is not enabled for your license", + }); + } - if (!domain) { - return serviceErrorResponse({ - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER, - message: "Missing X-Org-Domain header", - }); - } - - if (env.SOURCEBOT_EE_AUDIT_LOGGING_ENABLED === 'false') { - return serviceErrorResponse({ - statusCode: StatusCodes.NOT_FOUND, - errorCode: ErrorCode.NOT_FOUND, - message: "Audit logging is not enabled", - }); - } - - const entitlements = getEntitlements(); - if (!entitlements.includes('audit')) { - return serviceErrorResponse({ - statusCode: StatusCodes.FORBIDDEN, - errorCode: ErrorCode.NOT_FOUND, - message: "Audit logging is not enabled for your license", - }); - } - - const result = await fetchAuditRecords(domain, apiKey); - if (isServiceError(result)) { - return serviceErrorResponse(result); - } - return Response.json(result); + const result = await fetchAuditRecords(); + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + return Response.json(result); }; \ No newline at end of file diff --git a/packages/web/src/ee/features/audit/actions.ts b/packages/web/src/ee/features/audit/actions.ts index 60814946..519f6555 100644 --- a/packages/web/src/ee/features/audit/actions.ts +++ b/packages/web/src/ee/features/audit/actions.ts @@ -1,59 +1,62 @@ "use server"; -import { prisma } from "@/prisma"; -import { ErrorCode } from "@/lib/errorCodes"; -import { StatusCodes } from "http-status-codes"; -import { sew, withAuth, withOrgMembership } from "@/actions"; -import { OrgRole } from "@sourcebot/db"; -import { createLogger } from "@sourcebot/shared"; -import { ServiceError } from "@/lib/serviceError"; +import { sew } from "@/actions"; import { getAuditService } from "@/ee/features/audit/factory"; +import { ErrorCode } from "@/lib/errorCodes"; +import { ServiceError } from "@/lib/serviceError"; +import { prisma } from "@/prisma"; +import { withAuthV2 } from "@/withAuthV2"; +import { createLogger } from "@sourcebot/shared"; +import { StatusCodes } from "http-status-codes"; import { AuditEvent } from "./types"; const auditService = getAuditService(); const logger = createLogger('audit-utils'); -export const createAuditAction = async (event: Omit, domain: string) => sew(async () => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - await auditService.createAudit({ ...event, orgId: org.id, actor: { id: userId, type: "user" }, target: { id: org.id.toString(), type: "org" } }) - }, /* minRequiredRole = */ OrgRole.MEMBER), /* allowAnonymousAccess = */ true) -); - -export const fetchAuditRecords = async (domain: string, apiKey: string | undefined = undefined) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - try { - const auditRecords = await prisma.audit.findMany({ - where: { - orgId: org.id, - }, - orderBy: { - timestamp: 'desc' - } - }); - +export const createAuditAction = async (event: Omit) => sew(async () => + withAuthV2(async ({ user, org }) => { await auditService.createAudit({ - action: "audit.fetch", - actor: { - id: userId, - type: "user" - }, - target: { - id: org.id.toString(), - type: "org" - }, - orgId: org.id + ...event, + orgId: org.id, + actor: { id: user.id, type: "user" }, + target: { id: org.id.toString(), type: "org" }, }) - - return auditRecords; - } catch (error) { - logger.error('Error fetching audit logs', { error }); - return { - statusCode: StatusCodes.INTERNAL_SERVER_ERROR, - errorCode: ErrorCode.UNEXPECTED_ERROR, - message: "Failed to fetch audit logs", - } satisfies ServiceError; - } - }, /* minRequiredRole = */ OrgRole.OWNER), /* allowAnonymousAccess = */ true, apiKey ? { apiKey, domain } : undefined) + }) +); + +export const fetchAuditRecords = async () => sew(() => + withAuthV2(async ({ user, org }) => { + try { + const auditRecords = await prisma.audit.findMany({ + where: { + orgId: org.id, + }, + orderBy: { + timestamp: 'desc' + } + }); + + await auditService.createAudit({ + action: "audit.fetch", + actor: { + id: user.id, + type: "user" + }, + target: { + id: org.id.toString(), + type: "org" + }, + orgId: org.id + }) + + return auditRecords; + } catch (error) { + logger.error('Error fetching audit logs', { error }); + return { + statusCode: StatusCodes.INTERNAL_SERVER_ERROR, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: "Failed to fetch audit logs", + } satisfies ServiceError; + } + }) ); diff --git a/packages/web/src/features/chat/actions.ts b/packages/web/src/features/chat/actions.ts index 86f86301..a987dbbc 100644 --- a/packages/web/src/features/chat/actions.ts +++ b/packages/web/src/features/chat/actions.ts @@ -1,10 +1,9 @@ 'use server'; -import { sew, withAuth, withOrgMembership } from "@/actions"; +import { sew } from "@/actions"; import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants"; import { ErrorCode } from "@/lib/errorCodes"; import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError"; -import { prisma } from "@/prisma"; import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; import { createAzure } from '@ai-sdk/azure'; @@ -20,7 +19,7 @@ import { createXai } from '@ai-sdk/xai'; import { fromNodeProviderChain } from '@aws-sdk/credential-providers'; import { createOpenRouter } from '@openrouter/ai-sdk-provider'; import { getTokenFromConfig, createLogger, env } from "@sourcebot/shared"; -import { ChatVisibility, OrgRole, Prisma } from "@sourcebot/db"; +import { ChatVisibility, Prisma } from "@sourcebot/db"; import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type"; import { Token } from "@sourcebot/schemas/v3/shared.type"; import { generateText, JSONValue, extractReasoningMiddleware, wrapLanguageModel } from "ai"; @@ -29,168 +28,161 @@ import fs from 'fs'; import { StatusCodes } from "http-status-codes"; import path from 'path'; import { LanguageModelInfo, SBChatMessage } from "./types"; +import { withAuthV2, withOptionalAuthV2 } from "@/withAuthV2"; const logger = createLogger('chat-actions'); -export const createChat = async (domain: string) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { +export const createChat = async () => sew(() => + withOptionalAuthV2(async ({ org, user, prisma }) => { + const isGuestUser = user?.id === SOURCEBOT_GUEST_USER_ID; - const isGuestUser = userId === SOURCEBOT_GUEST_USER_ID; + const chat = await prisma.chat.create({ + data: { + orgId: org.id, + messages: [] as unknown as Prisma.InputJsonValue, + createdById: user?.id ?? SOURCEBOT_GUEST_USER_ID, + visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE, + }, + }); - const chat = await prisma.chat.create({ - data: { - orgId: org.id, - messages: [] as unknown as Prisma.InputJsonValue, - createdById: userId, - visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE, - }, - }); - - return { - id: chat.id, - } - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) + return { + id: chat.id, + } + }) ); -export const getChatInfo = async ({ chatId }: { chatId: string }, domain: string) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - const chat = await prisma.chat.findUnique({ - where: { - id: chatId, - orgId: org.id, - }, - }); +export const getChatInfo = async ({ chatId }: { chatId: string }) => sew(() => + withOptionalAuthV2(async ({ org, user, prisma }) => { + const chat = await prisma.chat.findUnique({ + where: { + id: chatId, + orgId: org.id, + }, + }); - if (!chat) { - return notFound(); - } + if (!chat) { + return notFound(); + } - if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { - return notFound(); - } + if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) { + return notFound(); + } - return { - messages: chat.messages as unknown as SBChatMessage[], - visibility: chat.visibility, - name: chat.name, - isReadonly: chat.isReadonly, - }; - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) + return { + messages: chat.messages as unknown as SBChatMessage[], + visibility: chat.visibility, + name: chat.name, + isReadonly: chat.isReadonly, + }; + }) ); -export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }, domain: string) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - const chat = await prisma.chat.findUnique({ - where: { - id: chatId, - orgId: org.id, - }, - }); +export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }) => sew(() => + withOptionalAuthV2(async ({ org, user, prisma }) => { + const chat = await prisma.chat.findUnique({ + where: { + id: chatId, + orgId: org.id, + }, + }); - if (!chat) { - return notFound(); + if (!chat) { + return notFound(); + } + + if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) { + return notFound(); + } + + if (chat.isReadonly) { + return chatIsReadonly(); + } + + await prisma.chat.update({ + where: { + id: chatId, + }, + data: { + messages: messages as unknown as Prisma.InputJsonValue, + }, + }); + + if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) { + const chatDir = path.join(env.DATA_CACHE_DIR, 'chats'); + if (!fs.existsSync(chatDir)) { + fs.mkdirSync(chatDir, { recursive: true }); } - if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { - return notFound(); - } + const chatFile = path.join(chatDir, `${chatId}.json`); + fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2)); + } - if (chat.isReadonly) { - return chatIsReadonly(); - } - - await prisma.chat.update({ - where: { - id: chatId, - }, - data: { - messages: messages as unknown as Prisma.InputJsonValue, - }, - }); - - if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) { - const chatDir = path.join(env.DATA_CACHE_DIR, 'chats'); - if (!fs.existsSync(chatDir)) { - fs.mkdirSync(chatDir, { recursive: true }); - } - - const chatFile = path.join(chatDir, `${chatId}.json`); - fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2)); - } - - return { - success: true, - } - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) + return { + success: true, + } + }) ); -export const getUserChatHistory = async (domain: string) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - const chats = await prisma.chat.findMany({ - where: { - orgId: org.id, - createdById: userId, - }, - orderBy: { - updatedAt: 'desc', - }, - }); +export const getUserChatHistory = async () => sew(() => + withAuthV2(async ({ org, user, prisma }) => { + const chats = await prisma.chat.findMany({ + where: { + orgId: org.id, + createdById: user.id, + }, + orderBy: { + updatedAt: 'desc', + }, + }); - return chats.map((chat) => ({ - id: chat.id, - createdAt: chat.createdAt, - name: chat.name, - visibility: chat.visibility, - })) - }) - ) + return chats.map((chat) => ({ + id: chat.id, + createdAt: chat.createdAt, + name: chat.name, + visibility: chat.visibility, + })) + }) ); -export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }, domain: string) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - const chat = await prisma.chat.findUnique({ - where: { - id: chatId, - orgId: org.id, - }, - }); +export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }) => sew(() => + withOptionalAuthV2(async ({ org, user, prisma }) => { + const chat = await prisma.chat.findUnique({ + where: { + id: chatId, + orgId: org.id, + }, + }); - if (!chat) { - return notFound(); - } + if (!chat) { + return notFound(); + } - if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { - return notFound(); - } + if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) { + return notFound(); + } - if (chat.isReadonly) { - return chatIsReadonly(); - } + if (chat.isReadonly) { + return chatIsReadonly(); + } - await prisma.chat.update({ - where: { - id: chatId, - orgId: org.id, - }, - data: { - name, - }, - }); + await prisma.chat.update({ + where: { + id: chatId, + orgId: org.id, + }, + data: { + name, + }, + }); - return { - success: true, - } - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) + return { + success: true, + } + }) ); -export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async () => { +export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }) => sew(() => + withOptionalAuthV2(async () => { // From the language model ID, attempt to find the // corresponding config in `config.json`. const languageModelConfig = @@ -231,48 +223,6 @@ User question: ${message}`; await updateChatName({ chatId, name: result.text, - }, domain); - - return { - success: true, - } - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true - ) -); - -export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - const chat = await prisma.chat.findUnique({ - where: { - id: chatId, - orgId: org.id, - }, - }); - - if (!chat) { - return notFound(); - } - - // Public chats cannot be deleted. - if (chat.visibility === ChatVisibility.PUBLIC) { - return { - statusCode: StatusCodes.FORBIDDEN, - errorCode: ErrorCode.UNEXPECTED_ERROR, - message: 'You are not allowed to delete this chat.', - } satisfies ServiceError; - } - - // Only the creator of a chat can delete it. - if (chat.createdById !== userId) { - return notFound(); - } - - await prisma.chat.delete({ - where: { - id: chatId, - orgId: org.id, - }, }); return { @@ -280,6 +230,45 @@ export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) } }) ) + +export const deleteChat = async ({ chatId }: { chatId: string }) => sew(() => + withAuthV2(async ({ org, user, prisma }) => { + const chat = await prisma.chat.findUnique({ + where: { + id: chatId, + orgId: org.id, + }, + }); + + if (!chat) { + return notFound(); + } + + // Public chats cannot be deleted. + if (chat.visibility === ChatVisibility.PUBLIC) { + return { + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'You are not allowed to delete this chat.', + } satisfies ServiceError; + } + + // Only the creator of a chat can delete it. + if (chat.createdById !== user.id) { + return notFound(); + } + + await prisma.chat.delete({ + where: { + id: chatId, + orgId: org.id, + }, + }); + + return { + success: true, + } + }) ); export const submitFeedback = async ({ @@ -290,56 +279,55 @@ export const submitFeedback = async ({ chatId: string, messageId: string, feedbackType: 'like' | 'dislike' -}, domain: string) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - const chat = await prisma.chat.findUnique({ - where: { - id: chatId, - orgId: org.id, - }, - }); +}) => sew(() => + withOptionalAuthV2(async ({ org, user, prisma }) => { + const chat = await prisma.chat.findUnique({ + where: { + id: chatId, + orgId: org.id, + }, + }); - if (!chat) { - return notFound(); + if (!chat) { + return notFound(); + } + + // When a chat is private, only the creator can submit feedback. + if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) { + return notFound(); + } + + const messages = chat.messages as unknown as SBChatMessage[]; + const updatedMessages = messages.map(message => { + if (message.id === messageId && message.role === 'assistant') { + return { + ...message, + metadata: { + ...message.metadata, + feedback: [ + ...(message.metadata?.feedback ?? []), + { + type: feedbackType, + timestamp: new Date().toISOString(), + userId: user?.id, + } + ] + } + } satisfies SBChatMessage; } + return message; + }); - // When a chat is private, only the creator can submit feedback. - if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { - return notFound(); - } + await prisma.chat.update({ + where: { id: chatId }, + data: { + messages: updatedMessages as unknown as Prisma.InputJsonValue, + }, + }); - const messages = chat.messages as unknown as SBChatMessage[]; - const updatedMessages = messages.map(message => { - if (message.id === messageId && message.role === 'assistant') { - return { - ...message, - metadata: { - ...message.metadata, - feedback: [ - ...(message.metadata?.feedback ?? []), - { - type: feedbackType, - timestamp: new Date().toISOString(), - userId: userId, - } - ] - } - } satisfies SBChatMessage; - } - return message; - }); - - await prisma.chat.update({ - where: { id: chatId }, - data: { - messages: updatedMessages as unknown as Prisma.InputJsonValue, - }, - }); - - return { success: true }; - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) -); + return { success: true }; + }) +) /** * Returns the subset of information about the configured language models diff --git a/packages/web/src/features/chat/components/chatThread/answerCard.tsx b/packages/web/src/features/chat/components/chatThread/answerCard.tsx index e3f89f62..5feebd32 100644 --- a/packages/web/src/features/chat/components/chatThread/answerCard.tsx +++ b/packages/web/src/features/chat/components/chatThread/answerCard.tsx @@ -6,7 +6,7 @@ import { Button } from "@/components/ui/button"; import { TableOfContentsIcon, ThumbsDown, ThumbsUp } from "lucide-react"; import { Separator } from "@/components/ui/separator"; import { MarkdownRenderer } from "./markdownRenderer"; -import { forwardRef, useCallback, useImperativeHandle, useRef, useState } from "react"; +import { forwardRef, memo, useCallback, useImperativeHandle, useRef, useState } from "react"; import { Toggle } from "@/components/ui/toggle"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { CopyIconButton } from "@/app/[domain]/components/copyIconButton"; @@ -14,7 +14,6 @@ import { useToast } from "@/components/hooks/use-toast"; import { convertLLMOutputToPortableMarkdown } from "../../utils"; import { submitFeedback } from "../../actions"; import { isServiceError } from "@/lib/utils"; -import { useDomain } from "@/hooks/useDomain"; import useCaptureEvent from "@/hooks/useCaptureEvent"; import { LangfuseWeb } from "langfuse"; import { env } from "@sourcebot/shared/client"; @@ -31,7 +30,7 @@ const langfuseWeb = (env.NEXT_PUBLIC_SOURCEBOT_CLOUD_ENVIRONMENT !== undefined & baseUrl: env.NEXT_PUBLIC_LANGFUSE_BASE_URL, }) : null; -export const AnswerCard = forwardRef(({ +const AnswerCardComponent = forwardRef(({ answerText, messageId, chatId, @@ -41,7 +40,6 @@ export const AnswerCard = forwardRef(({ const { tocItems, activeId } = useExtractTOCItems({ target: markdownRendererRef.current }); const [isTOCButtonToggled, setIsTOCButtonToggled] = useState(false); const { toast } = useToast(); - const domain = useDomain(); const [isSubmittingFeedback, setIsSubmittingFeedback] = useState(false); const [feedback, setFeedback] = useState<'like' | 'dislike' | undefined>(undefined); const captureEvent = useCaptureEvent(); @@ -67,7 +65,7 @@ export const AnswerCard = forwardRef(({ chatId, messageId, feedbackType - }, domain); + }); if (isServiceError(response)) { toast({ @@ -93,7 +91,7 @@ export const AnswerCard = forwardRef(({ } setIsSubmittingFeedback(false); - }, [chatId, messageId, domain, toast, captureEvent, traceId]); + }, [chatId, messageId, toast, captureEvent, traceId]); return (
@@ -178,4 +176,6 @@ export const AnswerCard = forwardRef(({ ) }) -AnswerCard.displayName = 'AnswerCard'; \ No newline at end of file +AnswerCardComponent.displayName = 'AnswerCard'; + +export const AnswerCard = memo(AnswerCardComponent); \ No newline at end of file diff --git a/packages/web/src/features/chat/components/chatThread/chatThread.tsx b/packages/web/src/features/chat/components/chatThread/chatThread.tsx index e7a49ba9..05ea70e0 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThread.tsx +++ b/packages/web/src/features/chat/components/chatThread/chatThread.tsx @@ -7,7 +7,6 @@ import { Separator } from '@/components/ui/separator'; import { CustomSlateEditor } from '@/features/chat/customSlateEditor'; import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types'; import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils'; -import { useDomain } from '@/hooks/useDomain'; import { useChat } from '@ai-sdk/react'; import { CreateUIMessage, DefaultChatTransport } from 'ai'; import { ArrowDownIcon } from 'lucide-react'; @@ -54,7 +53,6 @@ export const ChatThread = ({ onSelectedSearchScopesChange, isChatReadonly, }: ChatThreadProps) => { - const domain = useDomain(); const [isErrorBannerVisible, setIsErrorBannerVisible] = useState(false); const scrollAreaRef = useRef(null); const latestMessagePairRef = useRef(null); @@ -89,9 +87,6 @@ export const ChatThread = ({ messages: initialMessages, transport: new DefaultChatTransport({ api: '/api/chat', - headers: { - "X-Org-Domain": domain, - } }), onData: (dataPart) => { // Keeps sources added by the assistant in sync. @@ -134,7 +129,6 @@ export const ChatThread = ({ languageModelId: selectedLanguageModel.model, message: message.parts[0].text, }, - domain ).then((response) => { if (isServiceError(response)) { toast({ @@ -153,7 +147,6 @@ export const ChatThread = ({ messages.length, toast, chatId, - domain, router, ]); @@ -196,46 +189,47 @@ export const ChatThread = ({ hasSubmittedInputMessage.current = true; }, [inputMessage, sendMessage]); + // @todo: this need to be optimized to avoid excessive re-renders // Track scroll position changes. - useEffect(() => { - const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement; - if (!scrollElement) return; + // useEffect(() => { + // const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement; + // if (!scrollElement) return; - let timeout: NodeJS.Timeout | null = null; + // let timeout: NodeJS.Timeout | null = null; - const handleScroll = () => { - const scrollOffset = scrollElement.scrollTop; + // const handleScroll = () => { + // const scrollOffset = scrollElement.scrollTop; - const threshold = 50; // pixels from bottom to consider "at bottom" - const { scrollHeight, clientHeight } = scrollElement; - const isAtBottom = scrollHeight - scrollOffset - clientHeight <= threshold; - setIsAutoScrollEnabled(isAtBottom); + // const threshold = 50; // pixels from bottom to consider "at bottom" + // const { scrollHeight, clientHeight } = scrollElement; + // const isAtBottom = scrollHeight - scrollOffset - clientHeight <= threshold; + // setIsAutoScrollEnabled(isAtBottom); - // Debounce the history state update - if (timeout) { - clearTimeout(timeout); - } + // // Debounce the history state update + // if (timeout) { + // clearTimeout(timeout); + // } - timeout = setTimeout(() => { - history.replaceState( - { - scrollOffset, - } satisfies ChatHistoryState, - '', - window.location.href - ); - }, 300); - }; + // timeout = setTimeout(() => { + // history.replaceState( + // { + // scrollOffset, + // } satisfies ChatHistoryState, + // '', + // window.location.href + // ); + // }, 300); + // }; - scrollElement.addEventListener('scroll', handleScroll, { passive: true }); + // scrollElement.addEventListener('scroll', handleScroll, { passive: true }); - return () => { - scrollElement.removeEventListener('scroll', handleScroll); - if (timeout) { - clearTimeout(timeout); - } - }; - }, []); + // return () => { + // scrollElement.removeEventListener('scroll', handleScroll); + // if (timeout) { + // clearTimeout(timeout); + // } + // }; + // }, []); useEffect(() => { const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement; @@ -313,9 +307,11 @@ export const ChatThread = ({ {messagePairs.map(([userMessage, assistantMessage], index) => { const isLastPair = index === messagePairs.length - 1; const isStreaming = isLastPair && (status === "streaming" || status === "submitted"); + // Use a stable key based on user message ID + const key = userMessage.id; return ( - + (({ +export const ChatThreadListItemComponent = forwardRef(({ userMessage, assistantMessage: _assistantMessage, isStreaming, @@ -32,6 +32,7 @@ export const ChatThreadListItem = forwardRef { + console.log(`re-rendering chat thread list item`, index); const leftPanelRef = useRef(null); const [leftPanelHeight, setLeftPanelHeight] = useState(null); const answerRef = useRef(null); @@ -393,7 +394,12 @@ export const ChatThreadListItem = forwardRef !nextProps.isStreaming); // Finds the nearest reference element to the viewport center. const getNearestReferenceElement = (referenceElements: Element[]) => { diff --git a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx b/packages/web/src/features/chat/components/chatThread/detailsCard.tsx index 0fe18a64..5cd6f335 100644 --- a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx +++ b/packages/web/src/features/chat/components/chatThread/detailsCard.tsx @@ -7,6 +7,7 @@ import { Skeleton } from '@/components/ui/skeleton'; import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'; import { cn } from '@/lib/utils'; import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, List, ScanSearchIcon, Zap } from 'lucide-react'; +import { memo } from 'react'; import { MarkdownRenderer } from './markdownRenderer'; import { FindSymbolDefinitionsToolComponent } from './tools/findSymbolDefinitionsToolComponent'; import { FindSymbolReferencesToolComponent } from './tools/findSymbolReferencesToolComponent'; @@ -27,7 +28,7 @@ interface DetailsCardProps { metadata?: SBChatMessageMetadata; } -export const DetailsCard = ({ +const DetailsCardComponent = ({ isExpanded, onExpandedChanged, isThinking, @@ -209,4 +210,6 @@ export const DetailsCard = ({ ) -} \ No newline at end of file +} + +export const DetailsCard = memo(DetailsCardComponent); \ No newline at end of file diff --git a/packages/web/src/features/chat/components/chatThread/markdownRenderer.tsx b/packages/web/src/features/chat/components/chatThread/markdownRenderer.tsx index a4992aa0..b69145f0 100644 --- a/packages/web/src/features/chat/components/chatThread/markdownRenderer.tsx +++ b/packages/web/src/features/chat/components/chatThread/markdownRenderer.tsx @@ -1,7 +1,6 @@ 'use client'; import { CodeSnippet } from '@/app/components/codeSnippet'; -import { useDomain } from '@/hooks/useDomain'; import { SearchQueryParams } from '@/lib/types'; import { cn, createPathWithQueryParams } from '@/lib/utils'; import type { Element, Root } from "hast"; @@ -10,7 +9,7 @@ import { CopyIcon, SearchIcon } from 'lucide-react'; import type { Heading, Nodes } from "mdast"; import { findAndReplace } from 'mdast-util-find-and-replace'; import { useRouter } from 'next/navigation'; -import React, { useCallback, useMemo, forwardRef } from 'react'; +import React, { useCallback, useMemo, forwardRef, memo } from 'react'; import Markdown from 'react-markdown'; import rehypeRaw from 'rehype-raw'; import rehypeSanitize, { defaultSchema } from 'rehype-sanitize'; @@ -20,6 +19,7 @@ import { visit } from 'unist-util-visit'; import { CodeBlock } from './codeBlock'; import { FILE_REFERENCE_REGEX } from '@/features/chat/constants'; import { createFileReference } from '@/features/chat/utils'; +import { SINGLE_TENANT_ORG_DOMAIN } from '@/lib/constants'; export const REFERENCE_PAYLOAD_ATTRIBUTE = 'data-reference-payload'; @@ -102,8 +102,7 @@ interface MarkdownRendererProps { className?: string; } -export const MarkdownRenderer = forwardRef(({ content, className }, ref) => { - const domain = useDomain(); +const MarkdownRendererComponent = forwardRef(({ content, className }, ref) => { const router = useRouter(); const remarkPlugins = useMemo((): PluggableList => { @@ -176,7 +175,7 @@ export const MarkdownRenderer = forwardRef { e.preventDefault(); e.stopPropagation(); - const url = createPathWithQueryParams(`/${domain}/search`, [SearchQueryParams.query, `"${text}"`]) + const url = createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/search`, [SearchQueryParams.query, `"${text}"`]) router.push(url); }} title="Search for snippet" @@ -199,7 +198,7 @@ export const MarkdownRenderer = forwardRef ) - }, [domain, router]); + }, [router]); return (
(null); const captureEvent = useCaptureEvent(); - const domain = useDomain(); useImperativeHandle( forwardedRef, @@ -124,6 +122,8 @@ const ReferencedFileSourceListItem = ({ return createCodeFoldingExtension(references, 3); }, [references]); + // console.log(`re-renderign for file ${fileName}`); + const extensions = useMemo(() => { return [ languageExtension, @@ -231,7 +231,7 @@ const ReferencedFileSourceListItem = ({ metadata: { message: symbolName, }, - }, domain); + }); if (symbolDefinitions.length === 1) { const symbolDefinition = symbolDefinitions[0]; @@ -263,7 +263,7 @@ const ReferencedFileSourceListItem = ({ }); } - }, [captureEvent, domain, navigateToPath, revision, repoName, fileName, language]); + }, [captureEvent, navigateToPath, revision, repoName, fileName, language]); const onFindReferences = useCallback((symbolName: string) => { captureEvent('wa_find_references_pressed', { @@ -274,7 +274,7 @@ const ReferencedFileSourceListItem = ({ metadata: { message: symbolName, }, - }, domain); + }); navigateToPath({ repoName, @@ -293,7 +293,7 @@ const ReferencedFileSourceListItem = ({ } }) - }, [captureEvent, domain, fileName, language, navigateToPath, repoName, revision]); + }, [captureEvent, fileName, language, navigateToPath, repoName, revision]); const ExpandCollapseIcon = useMemo(() => { return isExpanded ? ChevronDown : ChevronRight; @@ -355,6 +355,6 @@ const ReferencedFileSourceListItem = ({ ) } -export default forwardRef(ReferencedFileSourceListItem) as ( +export default memo(forwardRef(ReferencedFileSourceListItem)) as ( props: ReferencedFileSourceListItemProps & { ref?: Ref }, ) => ReturnType; diff --git a/packages/web/src/features/chat/components/chatThread/referencedSourcesListView.tsx b/packages/web/src/features/chat/components/chatThread/referencedSourcesListView.tsx index b24085d8..7792bb3a 100644 --- a/packages/web/src/features/chat/components/chatThread/referencedSourcesListView.tsx +++ b/packages/web/src/features/chat/components/chatThread/referencedSourcesListView.tsx @@ -4,11 +4,10 @@ import { getFileSource } from "@/app/api/(client)/client"; import { VscodeFileIcon } from "@/app/components/vscodeFileIcon"; import { ScrollArea } from "@/components/ui/scroll-area"; import { Skeleton } from "@/components/ui/skeleton"; -import { useDomain } from "@/hooks/useDomain"; import { isServiceError, unwrapServiceError } from "@/lib/utils"; import { useQueries } from "@tanstack/react-query"; import { ReactCodeMirrorRef } from '@uiw/react-codemirror'; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { memo, useCallback, useEffect, useMemo, useRef, useState } from "react"; import scrollIntoView from 'scroll-into-view-if-needed'; import { FileReference, FileSource, Reference, Source } from "../../types"; import ReferencedFileSourceListItem from "./referencedFileSourceListItem"; @@ -31,7 +30,7 @@ const resolveFileReference = (reference: FileReference, sources: FileSource[]): ); } -export const ReferencedSourcesListView = ({ +const ReferencedSourcesListViewComponent = ({ references, sources, index, @@ -43,7 +42,6 @@ export const ReferencedSourcesListView = ({ }: ReferencedSourcesListViewProps) => { const scrollAreaRef = useRef(null); const editorRefsMap = useRef>(new Map()); - const domain = useDomain(); const [collapsedFileIds, setCollapsedFileIds] = useState([]); const getFileId = useCallback((fileSource: FileSource) => { @@ -98,7 +96,7 @@ export const ReferencedSourcesListView = ({ const fileSourceQueries = useQueries({ queries: referencedFileSources.map((file) => ({ - queryKey: ['fileSource', file.path, file.repo, file.revision, domain], + queryKey: ['fileSource', file.path, file.repo, file.revision], queryFn: () => unwrapServiceError(getFileSource({ fileName: file.path, repository: file.repo, @@ -183,6 +181,25 @@ export const ReferencedSourcesListView = ({ } }, [getFileId, referencedFileSources, selectedReference]); + const onExpandedChanged = useCallback((fileId: string, isExpanded: boolean) => { + if (isExpanded) { + setCollapsedFileIds(collapsedFileIds => collapsedFileIds.filter((id) => id !== fileId)); + } else { + setCollapsedFileIds(collapsedFileIds => [...collapsedFileIds, fileId]); + } + + if (!isExpanded) { + const fileSourceStart = document.getElementById(`${fileId}-start`); + if (fileSourceStart) { + scrollIntoView(fileSourceStart, { + scrollMode: 'if-needed', + block: 'start', + behavior: 'instant', + }); + } + } + }, []); + if (referencedFileSources.length === 0) { return (
@@ -253,30 +270,7 @@ export const ReferencedSourcesListView = ({ selectedReference={selectedReference} hoveredReference={hoveredReference} isExpanded={!collapsedFileIds.includes(fileId)} - onExpandedChanged={(isExpanded) => { - if (isExpanded) { - setCollapsedFileIds(collapsedFileIds.filter((id) => id !== fileId)); - } else { - setCollapsedFileIds([...collapsedFileIds, fileId]); - } - - // When collapsing a file when you are deep in a scroll, it's a better - // experience to have the scroll automatically restored to the top of the file - // s.t., header is still sticky to the top of the scroll area. - if (!isExpanded) { - const fileSourceStart = document.getElementById(`${fileId}-start`); - if (!fileSourceStart) { - return; - } - - scrollIntoView(fileSourceStart, { - scrollMode: 'if-needed', - block: 'start', - behavior: 'instant', - }); - } - } - } + onExpandedChanged={(isExpanded) => onExpandedChanged(fileId, isExpanded)} /> ); })} @@ -284,3 +278,6 @@ export const ReferencedSourcesListView = ({ ); } + +// Memoize to prevent unnecessary re-renders +export const ReferencedSourcesListView = memo(ReferencedSourcesListViewComponent); diff --git a/packages/web/src/features/chat/components/chatThread/tools/searchCodeToolComponent.tsx b/packages/web/src/features/chat/components/chatThread/tools/searchCodeToolComponent.tsx index 9131d8cc..53287b88 100644 --- a/packages/web/src/features/chat/components/chatThread/tools/searchCodeToolComponent.tsx +++ b/packages/web/src/features/chat/components/chatThread/tools/searchCodeToolComponent.tsx @@ -1,7 +1,6 @@ 'use client'; import { SearchCodeToolUIPart } from "@/features/chat/tools"; -import { useDomain } from "@/hooks/useDomain"; import { createPathWithQueryParams, isServiceError } from "@/lib/utils"; import { useMemo, useState } from "react"; import { FileListItem, ToolHeader, TreeList } from "./shared"; @@ -12,10 +11,10 @@ import Link from "next/link"; import { SearchQueryParams } from "@/lib/types"; import { PlayIcon } from "@radix-ui/react-icons"; import { buildSearchQuery } from "@/features/chat/utils"; +import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants"; export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }) => { const [isExpanded, setIsExpanded] = useState(false); - const domain = useDomain(); const displayQuery = useMemo(() => { if (part.state !== 'input-available' && part.state !== 'output-available') { @@ -78,7 +77,7 @@ export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart } )} { - const domain = useDomain(); - return (
@@ -28,7 +26,7 @@ export const FileListItem = ({ repoName, revisionName: 'HEAD', path, - domain, + domain: SINGLE_TENANT_ORG_DOMAIN, pathType: 'blob', })} > diff --git a/packages/web/src/features/chat/types.ts b/packages/web/src/features/chat/types.ts index 8d543b1a..71ee959a 100644 --- a/packages/web/src/features/chat/types.ts +++ b/packages/web/src/features/chat/types.ts @@ -70,7 +70,7 @@ export const sbChatMessageMetadataSchema = z.object({ feedback: z.array(z.object({ type: z.enum(['like', 'dislike']), timestamp: z.string(), // ISO date string - userId: z.string(), + userId: z.string().optional(), })).optional(), selectedSearchScopes: z.array(searchScopeSchema).optional(), traceId: z.string().optional(), diff --git a/packages/web/src/features/chat/useCreateNewChatThread.ts b/packages/web/src/features/chat/useCreateNewChatThread.ts index 25c67b22..b6db9342 100644 --- a/packages/web/src/features/chat/useCreateNewChatThread.ts +++ b/packages/web/src/features/chat/useCreateNewChatThread.ts @@ -4,7 +4,6 @@ import { useCallback, useState } from "react"; import { Descendant } from "slate"; import { createUIMessage, getAllMentionElements } from "./utils"; import { slateContentToString } from "./utils"; -import { useDomain } from "@/hooks/useDomain"; import { useToast } from "@/components/hooks/use-toast"; import { useRouter } from "next/navigation"; import { createChat } from "./actions"; @@ -13,9 +12,9 @@ import { createPathWithQueryParams } from "@/lib/utils"; import { SearchScope, SET_CHAT_STATE_SESSION_STORAGE_KEY, SetChatStatePayload } from "./types"; import { useSessionStorage } from "usehooks-ts"; import useCaptureEvent from "@/hooks/useCaptureEvent"; +import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants"; export const useCreateNewChatThread = () => { - const domain = useDomain(); const [isLoading, setIsLoading] = useState(false); const { toast } = useToast(); const router = useRouter(); @@ -31,7 +30,7 @@ export const useCreateNewChatThread = () => { const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes); setIsLoading(true); - const response = await createChat(domain); + const response = await createChat(); if (isServiceError(response)) { toast({ description: `❌ Failed to create chat. Reason: ${response.message}` @@ -47,11 +46,11 @@ export const useCreateNewChatThread = () => { selectedSearchScopes, }); - const url = createPathWithQueryParams(`/${domain}/chat/${response.id}`); + const url = createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/chat/${response.id}`); router.push(url); router.refresh(); - }, [domain, router, toast, setChatState, captureEvent]); + }, [router, toast, setChatState, captureEvent]); return { createNewChatThread,