wip on perf optimizations. Also changed some actions to withAuthV2

This commit is contained in:
bkellam 2025-11-25 18:19:15 -08:00
parent cbe381ad0c
commit 3863f6dd81
21 changed files with 464 additions and 514 deletions

View file

@ -19,7 +19,6 @@ import { useBrowseState } from "../../hooks/useBrowseState";
import { rangeHighlightingExtension } from "./rangeHighlightingExtension"; import { rangeHighlightingExtension } from "./rangeHighlightingExtension";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
import { createAuditAction } from "@/ee/features/audit/actions"; import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
interface PureCodePreviewPanelProps { interface PureCodePreviewPanelProps {
path: string; path: string;
@ -43,7 +42,6 @@ export const PureCodePreviewPanel = ({
const hasCodeNavEntitlement = useHasEntitlement("code-nav"); const hasCodeNavEntitlement = useHasEntitlement("code-nav");
const { updateBrowseState } = useBrowseState(); const { updateBrowseState } = useBrowseState();
const { navigateToPath } = useBrowseNavigation(); const { navigateToPath } = useBrowseNavigation();
const domain = useDomain();
const captureEvent = useCaptureEvent(); const captureEvent = useCaptureEvent();
const highlightRangeQuery = useNonEmptyQueryParam(HIGHLIGHT_RANGE_QUERY_PARAM); const highlightRangeQuery = useNonEmptyQueryParam(HIGHLIGHT_RANGE_QUERY_PARAM);
@ -145,7 +143,7 @@ export const PureCodePreviewPanel = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain) })
updateBrowseState({ updateBrowseState({
selectedSymbolInfo: { selectedSymbolInfo: {
@ -157,7 +155,7 @@ export const PureCodePreviewPanel = ({
isBottomPanelCollapsed: false, isBottomPanelCollapsed: false,
activeExploreMenuTab: "references", activeExploreMenuTab: "references",
}) })
}, [captureEvent, updateBrowseState, repoName, revisionName, language, domain]); }, [captureEvent, updateBrowseState, repoName, revisionName, language]);
// If we resolve multiple matches, instead of navigating to the first match, we should // If we resolve multiple matches, instead of navigating to the first match, we should
@ -171,7 +169,7 @@ export const PureCodePreviewPanel = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain) })
if (symbolDefinitions.length === 0) { if (symbolDefinitions.length === 0) {
return; return;
@ -200,7 +198,7 @@ export const PureCodePreviewPanel = ({
isBottomPanelCollapsed: false, isBottomPanelCollapsed: false,
}) })
} }
}, [captureEvent, navigateToPath, revisionName, updateBrowseState, repoName, language, domain]); }, [captureEvent, navigateToPath, revisionName, updateBrowseState, repoName, language]);
const theme = useCodeMirrorTheme(); const theme = useCodeMirrorTheme();

View file

@ -24,9 +24,9 @@ export default async function Page(props: PageProps) {
const languageModels = await getConfiguredLanguageModelsInfo(); const languageModels = await getConfiguredLanguageModelsInfo();
const repos = await getRepos(); const repos = await getRepos();
const searchContexts = await getSearchContexts(params.domain); const searchContexts = await getSearchContexts(params.domain);
const chatInfo = await getChatInfo({ chatId: params.id }, params.domain); const chatInfo = await getChatInfo({ chatId: params.id });
const session = await auth(); const session = await auth();
const chatHistory = session ? await getUserChatHistory(params.domain) : []; const chatHistory = session ? await getUserChatHistory() : [];
if (isServiceError(chatHistory)) { if (isServiceError(chatHistory)) {
throw new ServiceErrorException(chatHistory); throw new ServiceErrorException(chatHistory);

View file

@ -4,7 +4,6 @@ import { useToast } from "@/components/hooks/use-toast";
import { Badge } from "@/components/ui/badge"; import { Badge } from "@/components/ui/badge";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { updateChatName } from "@/features/chat/actions"; import { updateChatName } from "@/features/chat/actions";
import { useDomain } from "@/hooks/useDomain";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { GlobeIcon } from "@radix-ui/react-icons"; import { GlobeIcon } from "@radix-ui/react-icons";
import { ChatVisibility } from "@sourcebot/db"; import { ChatVisibility } from "@sourcebot/db";
@ -23,7 +22,6 @@ interface ChatNameProps {
export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) => { export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) => {
const [isRenameDialogOpen, setIsRenameDialogOpen] = useState(false); const [isRenameDialogOpen, setIsRenameDialogOpen] = useState(false);
const { toast } = useToast(); const { toast } = useToast();
const domain = useDomain();
const router = useRouter(); const router = useRouter();
const onRenameChat = useCallback(async (name: string) => { const onRenameChat = useCallback(async (name: string) => {
@ -31,7 +29,7 @@ export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) =>
const response = await updateChatName({ const response = await updateChatName({
chatId: id, chatId: id,
name: name, name: name,
}, domain); });
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -43,7 +41,7 @@ export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) =>
}); });
router.refresh(); router.refresh();
} }
}, [id, domain, toast, router]); }, [id, toast, router]);
return ( return (
<> <>

View file

@ -9,7 +9,6 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { deleteChat, updateChatName } from "@/features/chat/actions"; import { deleteChat, updateChatName } from "@/features/chat/actions";
import { useDomain } from "@/hooks/useDomain";
import { cn, isServiceError } from "@/lib/utils"; import { cn, isServiceError } from "@/lib/utils";
import { CirclePlusIcon, EllipsisIcon, PencilIcon, TrashIcon } from "lucide-react"; import { CirclePlusIcon, EllipsisIcon, PencilIcon, TrashIcon } from "lucide-react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
@ -23,6 +22,7 @@ import { useChatId } from "../useChatId";
import { RenameChatDialog } from "./renameChatDialog"; import { RenameChatDialog } from "./renameChatDialog";
import { DeleteChatDialog } from "./deleteChatDialog"; import { DeleteChatDialog } from "./deleteChatDialog";
import Link from "next/link"; import Link from "next/link";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
interface ChatSidePanelProps { interface ChatSidePanelProps {
order: number; order: number;
@ -41,7 +41,6 @@ export const ChatSidePanel = ({
isAuthenticated, isAuthenticated,
isCollapsedInitially, isCollapsedInitially,
}: ChatSidePanelProps) => { }: ChatSidePanelProps) => {
const domain = useDomain();
const [isCollapsed, setIsCollapsed] = useState(isCollapsedInitially); const [isCollapsed, setIsCollapsed] = useState(isCollapsedInitially);
const sidePanelRef = useRef<ImperativePanelHandle>(null); const sidePanelRef = useRef<ImperativePanelHandle>(null);
const router = useRouter(); const router = useRouter();
@ -72,7 +71,7 @@ export const ChatSidePanel = ({
const response = await updateChatName({ const response = await updateChatName({
chatId, chatId,
name: name, name: name,
}, domain); });
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -84,14 +83,14 @@ export const ChatSidePanel = ({
}); });
router.refresh(); router.refresh();
} }
}, [router, toast, domain]); }, [router, toast]);
const onDeleteChat = useCallback(async (chatIdToDelete: string) => { const onDeleteChat = useCallback(async (chatIdToDelete: string) => {
if (!chatIdToDelete) { if (!chatIdToDelete) {
return; return;
} }
const response = await deleteChat({ chatId: chatIdToDelete }, domain); const response = await deleteChat({ chatId: chatIdToDelete });
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -104,12 +103,12 @@ export const ChatSidePanel = ({
// If we just deleted the current chat, navigate to new chat // If we just deleted the current chat, navigate to new chat
if (chatIdToDelete === chatId) { if (chatIdToDelete === chatId) {
router.push(`/${domain}/chat`); router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`);
} }
router.refresh(); router.refresh();
} }
}, [chatId, router, toast, domain]); }, [chatId, router, toast]);
return ( return (
<> <>
@ -131,7 +130,7 @@ export const ChatSidePanel = ({
size="sm" size="sm"
className="w-full" className="w-full"
onClick={() => { onClick={() => {
router.push(`/${domain}/chat`); router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`);
}} }}
> >
<CirclePlusIcon className="w-4 h-4 mr-1" /> <CirclePlusIcon className="w-4 h-4 mr-1" />
@ -145,7 +144,7 @@ export const ChatSidePanel = ({
<div className="flex flex-col"> <div className="flex flex-col">
<p className="text-sm text-muted-foreground mb-4"> <p className="text-sm text-muted-foreground mb-4">
<Link <Link
href={`/login?callbackUrl=${encodeURIComponent(`/${domain}/chat`)}`} href={`/login?callbackUrl=${encodeURIComponent(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`)}`}
className="text-sm text-link hover:underline cursor-pointer" className="text-sm text-link hover:underline cursor-pointer"
> >
Sign in Sign in
@ -163,7 +162,7 @@ export const ChatSidePanel = ({
chat.id === chatId && "bg-muted" chat.id === chatId && "bg-muted"
)} )}
onClick={() => { onClick={() => {
router.push(`/${domain}/chat/${chat.id}`); router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat/${chat.id}`);
}} }}
> >
<span className="text-sm truncate">{chat.name ?? 'Untitled chat'}</span> <span className="text-sm truncate">{chat.name ?? 'Untitled chat'}</span>

View file

@ -221,7 +221,7 @@ export const SearchBar = ({
metadata: { metadata: {
message: query, message: query,
}, },
}, domain) })
const url = createPathWithQueryParams(`/${domain}/search`, const url = createPathWithQueryParams(`/${domain}/search`,
[SearchQueryParams.query, query], [SearchQueryParams.query, query],

View file

@ -22,8 +22,6 @@ import { symbolHoverTargetsExtension } from "@/ee/features/codeNav/components/sy
import { useHasEntitlement } from "@/features/entitlements/useHasEntitlement"; import { useHasEntitlement } from "@/features/entitlements/useHasEntitlement";
import { SymbolDefinition } from "@/ee/features/codeNav/components/symbolHoverPopup/useHoveredOverSymbolInfo"; import { SymbolDefinition } from "@/ee/features/codeNav/components/symbolHoverPopup/useHoveredOverSymbolInfo";
import { createAuditAction } from "@/ee/features/audit/actions"; import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
export interface CodePreviewFile { export interface CodePreviewFile {
@ -53,7 +51,6 @@ export const CodePreview = ({
const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null); const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null);
const { navigateToPath } = useBrowseNavigation(); const { navigateToPath } = useBrowseNavigation();
const hasCodeNavEntitlement = useHasEntitlement("code-nav"); const hasCodeNavEntitlement = useHasEntitlement("code-nav");
const domain = useDomain();
const [gutterWidth, setGutterWidth] = useState(0); const [gutterWidth, setGutterWidth] = useState(0);
const theme = useCodeMirrorTheme(); const theme = useCodeMirrorTheme();
@ -127,7 +124,7 @@ export const CodePreview = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain) })
if (symbolDefinitions.length === 0) { if (symbolDefinitions.length === 0) {
return; return;
@ -162,7 +159,7 @@ export const CodePreview = ({
} }
}); });
} }
}, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName, domain]); }, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName]);
const onFindReferences = useCallback((symbolName: string) => { const onFindReferences = useCallback((symbolName: string) => {
captureEvent('wa_find_references_pressed', { captureEvent('wa_find_references_pressed', {
@ -173,7 +170,7 @@ export const CodePreview = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain) })
navigateToPath({ navigateToPath({
repoName, repoName,
@ -191,7 +188,7 @@ export const CodePreview = ({
isBottomPanelCollapsed: false, isBottomPanelCollapsed: false,
} }
}) })
}, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName, domain]); }, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName]);
return ( return (
<div className="flex flex-col h-full"> <div className="flex flex-col h-full">

View file

@ -1,4 +1,4 @@
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/actions";
import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions"; import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions";
import { createAgentStream } from "@/features/chat/agent"; import { createAgentStream } from "@/features/chat/agent";
import { additionalChatRequestParamsSchema, LanguageModelInfo, SBChatMessage, SearchScope } from "@/features/chat/types"; import { additionalChatRequestParamsSchema, LanguageModelInfo, SBChatMessage, SearchScope } from "@/features/chat/types";
@ -6,10 +6,10 @@ import { getAnswerPartFromAssistantMessage, getLanguageModelKey } from "@/featur
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError"; import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { prisma } from "@/prisma"; import { withOptionalAuthV2 } from "@/withAuthV2";
import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider"; import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider";
import * as Sentry from "@sentry/nextjs"; import * as Sentry from "@sentry/nextjs";
import { OrgRole } from "@sourcebot/db"; import { PrismaClient } from "@sourcebot/db";
import { createLogger } from "@sourcebot/shared"; import { createLogger } from "@sourcebot/shared";
import { import {
createUIMessageStream, createUIMessageStream,
@ -34,15 +34,6 @@ const chatRequestSchema = z.object({
}) })
export async function POST(req: Request) { export async function POST(req: Request) {
const domain = req.headers.get("X-Org-Domain");
if (!domain) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER,
message: "Missing X-Org-Domain header",
});
}
const requestBody = await req.json(); const requestBody = await req.json();
const parsed = await chatRequestSchema.safeParseAsync(requestBody); const parsed = await chatRequestSchema.safeParseAsync(requestBody);
if (!parsed.success) { if (!parsed.success) {
@ -56,8 +47,7 @@ export async function POST(req: Request) {
const languageModel = _languageModel as LanguageModelInfo; const languageModel = _languageModel as LanguageModelInfo;
const response = await sew(() => const response = await sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
// Validate that the chat exists and is not readonly. // Validate that the chat exists and is not readonly.
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
@ -101,11 +91,10 @@ export async function POST(req: Request) {
model, model,
modelName: languageModelConfig.displayName ?? languageModelConfig.model, modelName: languageModelConfig.displayName ?? languageModelConfig.model,
modelProviderOptions: providerOptions, modelProviderOptions: providerOptions,
domain,
orgId: org.id, orgId: org.id,
prisma,
}); });
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true })
)
) )
if (isServiceError(response)) { if (isServiceError(response)) {
@ -132,8 +121,8 @@ interface CreateMessageStreamResponseProps {
model: AISDKLanguageModelV2; model: AISDKLanguageModelV2;
modelName: string; modelName: string;
modelProviderOptions?: Record<string, Record<string, JSONValue>>; modelProviderOptions?: Record<string, Record<string, JSONValue>>;
domain: string;
orgId: number; orgId: number;
prisma: PrismaClient;
} }
const createMessageStreamResponse = async ({ const createMessageStreamResponse = async ({
@ -143,8 +132,8 @@ const createMessageStreamResponse = async ({
model, model,
modelName, modelName,
modelProviderOptions, modelProviderOptions,
domain,
orgId, orgId,
prisma,
}: CreateMessageStreamResponseProps) => { }: CreateMessageStreamResponseProps) => {
const latestMessage = messages[messages.length - 1]; const latestMessage = messages[messages.length - 1];
const sources = latestMessage.parts const sources = latestMessage.parts
@ -254,7 +243,7 @@ const createMessageStreamResponse = async ({
await updateChatMessages({ await updateChatMessages({
chatId: id, chatId: id,
messages messages
}, domain); });
}, },
}); });

View file

@ -1,34 +1,13 @@
'use server'; 'use server';
import { NextRequest } from "next/server";
import { fetchAuditRecords } from "@/ee/features/audit/actions"; import { fetchAuditRecords } from "@/ee/features/audit/actions";
import { isServiceError } from "@/lib/utils";
import { serviceErrorResponse } from "@/lib/serviceError";
import { StatusCodes } from "http-status-codes";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { env } from "@sourcebot/shared"; import { serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils";
import { getEntitlements } from "@sourcebot/shared"; import { getEntitlements } from "@sourcebot/shared";
import { StatusCodes } from "http-status-codes";
export const GET = async (request: NextRequest) => { export const GET = async () => {
const domain = request.headers.get("X-Org-Domain");
const apiKey = request.headers.get("X-Sourcebot-Api-Key") ?? undefined;
if (!domain) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER,
message: "Missing X-Org-Domain header",
});
}
if (env.SOURCEBOT_EE_AUDIT_LOGGING_ENABLED === 'false') {
return serviceErrorResponse({
statusCode: StatusCodes.NOT_FOUND,
errorCode: ErrorCode.NOT_FOUND,
message: "Audit logging is not enabled",
});
}
const entitlements = getEntitlements(); const entitlements = getEntitlements();
if (!entitlements.includes('audit')) { if (!entitlements.includes('audit')) {
return serviceErrorResponse({ return serviceErrorResponse({
@ -38,7 +17,7 @@ export const GET = async (request: NextRequest) => {
}); });
} }
const result = await fetchAuditRecords(domain, apiKey); const result = await fetchAuditRecords();
if (isServiceError(result)) { if (isServiceError(result)) {
return serviceErrorResponse(result); return serviceErrorResponse(result);
} }

View file

@ -1,28 +1,31 @@
"use server"; "use server";
import { prisma } from "@/prisma"; import { sew } from "@/actions";
import { ErrorCode } from "@/lib/errorCodes";
import { StatusCodes } from "http-status-codes";
import { sew, withAuth, withOrgMembership } from "@/actions";
import { OrgRole } from "@sourcebot/db";
import { createLogger } from "@sourcebot/shared";
import { ServiceError } from "@/lib/serviceError";
import { getAuditService } from "@/ee/features/audit/factory"; import { getAuditService } from "@/ee/features/audit/factory";
import { ErrorCode } from "@/lib/errorCodes";
import { ServiceError } from "@/lib/serviceError";
import { prisma } from "@/prisma";
import { withAuthV2 } from "@/withAuthV2";
import { createLogger } from "@sourcebot/shared";
import { StatusCodes } from "http-status-codes";
import { AuditEvent } from "./types"; import { AuditEvent } from "./types";
const auditService = getAuditService(); const auditService = getAuditService();
const logger = createLogger('audit-utils'); const logger = createLogger('audit-utils');
export const createAuditAction = async (event: Omit<AuditEvent, 'sourcebotVersion' | 'orgId' | 'actor' | 'target'>, domain: string) => sew(async () => export const createAuditAction = async (event: Omit<AuditEvent, 'sourcebotVersion' | 'orgId' | 'actor' | 'target'>) => sew(async () =>
withAuth((userId) => withAuthV2(async ({ user, org }) => {
withOrgMembership(userId, domain, async ({ org }) => { await auditService.createAudit({
await auditService.createAudit({ ...event, orgId: org.id, actor: { id: userId, type: "user" }, target: { id: org.id.toString(), type: "org" } }) ...event,
}, /* minRequiredRole = */ OrgRole.MEMBER), /* allowAnonymousAccess = */ true) orgId: org.id,
actor: { id: user.id, type: "user" },
target: { id: org.id.toString(), type: "org" },
})
})
); );
export const fetchAuditRecords = async (domain: string, apiKey: string | undefined = undefined) => sew(() => export const fetchAuditRecords = async () => sew(() =>
withAuth((userId) => withAuthV2(async ({ user, org }) => {
withOrgMembership(userId, domain, async ({ org }) => {
try { try {
const auditRecords = await prisma.audit.findMany({ const auditRecords = await prisma.audit.findMany({
where: { where: {
@ -36,7 +39,7 @@ export const fetchAuditRecords = async (domain: string, apiKey: string | undefin
await auditService.createAudit({ await auditService.createAudit({
action: "audit.fetch", action: "audit.fetch",
actor: { actor: {
id: userId, id: user.id,
type: "user" type: "user"
}, },
target: { target: {
@ -55,5 +58,5 @@ export const fetchAuditRecords = async (domain: string, apiKey: string | undefin
message: "Failed to fetch audit logs", message: "Failed to fetch audit logs",
} satisfies ServiceError; } satisfies ServiceError;
} }
}, /* minRequiredRole = */ OrgRole.OWNER), /* allowAnonymousAccess = */ true, apiKey ? { apiKey, domain } : undefined) })
); );

View file

@ -1,10 +1,9 @@
'use server'; 'use server';
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/actions";
import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants"; import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError"; import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError";
import { prisma } from "@/prisma";
import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock';
import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic';
import { createAzure } from '@ai-sdk/azure'; import { createAzure } from '@ai-sdk/azure';
@ -20,7 +19,7 @@ import { createXai } from '@ai-sdk/xai';
import { fromNodeProviderChain } from '@aws-sdk/credential-providers'; import { fromNodeProviderChain } from '@aws-sdk/credential-providers';
import { createOpenRouter } from '@openrouter/ai-sdk-provider'; import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { getTokenFromConfig, createLogger, env } from "@sourcebot/shared"; import { getTokenFromConfig, createLogger, env } from "@sourcebot/shared";
import { ChatVisibility, OrgRole, Prisma } from "@sourcebot/db"; import { ChatVisibility, Prisma } from "@sourcebot/db";
import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type"; import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type";
import { Token } from "@sourcebot/schemas/v3/shared.type"; import { Token } from "@sourcebot/schemas/v3/shared.type";
import { generateText, JSONValue, extractReasoningMiddleware, wrapLanguageModel } from "ai"; import { generateText, JSONValue, extractReasoningMiddleware, wrapLanguageModel } from "ai";
@ -29,20 +28,19 @@ import fs from 'fs';
import { StatusCodes } from "http-status-codes"; import { StatusCodes } from "http-status-codes";
import path from 'path'; import path from 'path';
import { LanguageModelInfo, SBChatMessage } from "./types"; import { LanguageModelInfo, SBChatMessage } from "./types";
import { withAuthV2, withOptionalAuthV2 } from "@/withAuthV2";
const logger = createLogger('chat-actions'); const logger = createLogger('chat-actions');
export const createChat = async (domain: string) => sew(() => export const createChat = async () => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const isGuestUser = user?.id === SOURCEBOT_GUEST_USER_ID;
const isGuestUser = userId === SOURCEBOT_GUEST_USER_ID;
const chat = await prisma.chat.create({ const chat = await prisma.chat.create({
data: { data: {
orgId: org.id, orgId: org.id,
messages: [] as unknown as Prisma.InputJsonValue, messages: [] as unknown as Prisma.InputJsonValue,
createdById: userId, createdById: user?.id ?? SOURCEBOT_GUEST_USER_ID,
visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE, visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE,
}, },
}); });
@ -50,12 +48,11 @@ export const createChat = async (domain: string) => sew(() =>
return { return {
id: chat.id, id: chat.id,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); );
export const getChatInfo = async ({ chatId }: { chatId: string }, domain: string) => sew(() => export const getChatInfo = async ({ chatId }: { chatId: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
@ -67,7 +64,7 @@ export const getChatInfo = async ({ chatId }: { chatId: string }, domain: string
return notFound(); return notFound();
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound(); return notFound();
} }
@ -77,12 +74,11 @@ export const getChatInfo = async ({ chatId }: { chatId: string }, domain: string
name: chat.name, name: chat.name,
isReadonly: chat.isReadonly, isReadonly: chat.isReadonly,
}; };
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); );
export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }, domain: string) => sew(() => export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
@ -94,7 +90,7 @@ export const updateChatMessages = async ({ chatId, messages }: { chatId: string,
return notFound(); return notFound();
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound(); return notFound();
} }
@ -124,16 +120,15 @@ export const updateChatMessages = async ({ chatId, messages }: { chatId: string,
return { return {
success: true, success: true,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); );
export const getUserChatHistory = async (domain: string) => sew(() => export const getUserChatHistory = async () => sew(() =>
withAuth((userId) => withAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const chats = await prisma.chat.findMany({ const chats = await prisma.chat.findMany({
where: { where: {
orgId: org.id, orgId: org.id,
createdById: userId, createdById: user.id,
}, },
orderBy: { orderBy: {
updatedAt: 'desc', updatedAt: 'desc',
@ -147,12 +142,10 @@ export const getUserChatHistory = async (domain: string) => sew(() =>
visibility: chat.visibility, visibility: chat.visibility,
})) }))
}) })
)
); );
export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }, domain: string) => sew(() => export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
@ -164,7 +157,7 @@ export const updateChatName = async ({ chatId, name }: { chatId: string, name: s
return notFound(); return notFound();
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound(); return notFound();
} }
@ -185,12 +178,11 @@ export const updateChatName = async ({ chatId, name }: { chatId: string, name: s
return { return {
success: true, success: true,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); );
export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() => export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async () => {
withOrgMembership(userId, domain, async () => {
// From the language model ID, attempt to find the // From the language model ID, attempt to find the
// corresponding config in `config.json`. // corresponding config in `config.json`.
const languageModelConfig = const languageModelConfig =
@ -231,18 +223,16 @@ User question: ${message}`;
await updateChatName({ await updateChatName({
chatId, chatId,
name: result.text, name: result.text,
}, domain); });
return { return {
success: true, success: true,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true })
) )
);
export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() => export const deleteChat = async ({ chatId }: { chatId: string }) => sew(() =>
withAuth((userId) => withAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
@ -264,7 +254,7 @@ export const deleteChat = async ({ chatId }: { chatId: string }, domain: string)
} }
// Only the creator of a chat can delete it. // Only the creator of a chat can delete it.
if (chat.createdById !== userId) { if (chat.createdById !== user.id) {
return notFound(); return notFound();
} }
@ -279,7 +269,6 @@ export const deleteChat = async ({ chatId }: { chatId: string }, domain: string)
success: true, success: true,
} }
}) })
)
); );
export const submitFeedback = async ({ export const submitFeedback = async ({
@ -290,9 +279,8 @@ export const submitFeedback = async ({
chatId: string, chatId: string,
messageId: string, messageId: string,
feedbackType: 'like' | 'dislike' feedbackType: 'like' | 'dislike'
}, domain: string) => sew(() => }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
@ -305,7 +293,7 @@ export const submitFeedback = async ({
} }
// When a chat is private, only the creator can submit feedback. // When a chat is private, only the creator can submit feedback.
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound(); return notFound();
} }
@ -321,7 +309,7 @@ export const submitFeedback = async ({
{ {
type: feedbackType, type: feedbackType,
timestamp: new Date().toISOString(), timestamp: new Date().toISOString(),
userId: userId, userId: user?.id,
} }
] ]
} }
@ -338,8 +326,8 @@ export const submitFeedback = async ({
}); });
return { success: true }; return { success: true };
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); )
/** /**
* Returns the subset of information about the configured language models * Returns the subset of information about the configured language models

View file

@ -6,7 +6,7 @@ import { Button } from "@/components/ui/button";
import { TableOfContentsIcon, ThumbsDown, ThumbsUp } from "lucide-react"; import { TableOfContentsIcon, ThumbsDown, ThumbsUp } from "lucide-react";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { MarkdownRenderer } from "./markdownRenderer"; import { MarkdownRenderer } from "./markdownRenderer";
import { forwardRef, useCallback, useImperativeHandle, useRef, useState } from "react"; import { forwardRef, memo, useCallback, useImperativeHandle, useRef, useState } from "react";
import { Toggle } from "@/components/ui/toggle"; import { Toggle } from "@/components/ui/toggle";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { CopyIconButton } from "@/app/[domain]/components/copyIconButton"; import { CopyIconButton } from "@/app/[domain]/components/copyIconButton";
@ -14,7 +14,6 @@ import { useToast } from "@/components/hooks/use-toast";
import { convertLLMOutputToPortableMarkdown } from "../../utils"; import { convertLLMOutputToPortableMarkdown } from "../../utils";
import { submitFeedback } from "../../actions"; import { submitFeedback } from "../../actions";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { useDomain } from "@/hooks/useDomain";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
import { LangfuseWeb } from "langfuse"; import { LangfuseWeb } from "langfuse";
import { env } from "@sourcebot/shared/client"; import { env } from "@sourcebot/shared/client";
@ -31,7 +30,7 @@ const langfuseWeb = (env.NEXT_PUBLIC_SOURCEBOT_CLOUD_ENVIRONMENT !== undefined &
baseUrl: env.NEXT_PUBLIC_LANGFUSE_BASE_URL, baseUrl: env.NEXT_PUBLIC_LANGFUSE_BASE_URL,
}) : null; }) : null;
export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({ const AnswerCardComponent = forwardRef<HTMLDivElement, AnswerCardProps>(({
answerText, answerText,
messageId, messageId,
chatId, chatId,
@ -41,7 +40,6 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
const { tocItems, activeId } = useExtractTOCItems({ target: markdownRendererRef.current }); const { tocItems, activeId } = useExtractTOCItems({ target: markdownRendererRef.current });
const [isTOCButtonToggled, setIsTOCButtonToggled] = useState(false); const [isTOCButtonToggled, setIsTOCButtonToggled] = useState(false);
const { toast } = useToast(); const { toast } = useToast();
const domain = useDomain();
const [isSubmittingFeedback, setIsSubmittingFeedback] = useState(false); const [isSubmittingFeedback, setIsSubmittingFeedback] = useState(false);
const [feedback, setFeedback] = useState<'like' | 'dislike' | undefined>(undefined); const [feedback, setFeedback] = useState<'like' | 'dislike' | undefined>(undefined);
const captureEvent = useCaptureEvent(); const captureEvent = useCaptureEvent();
@ -67,7 +65,7 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
chatId, chatId,
messageId, messageId,
feedbackType feedbackType
}, domain); });
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -93,7 +91,7 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
} }
setIsSubmittingFeedback(false); setIsSubmittingFeedback(false);
}, [chatId, messageId, domain, toast, captureEvent, traceId]); }, [chatId, messageId, toast, captureEvent, traceId]);
return ( return (
<div className="flex flex-row w-full relative scroll-mt-16"> <div className="flex flex-row w-full relative scroll-mt-16">
@ -178,4 +176,6 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
) )
}) })
AnswerCard.displayName = 'AnswerCard'; AnswerCardComponent.displayName = 'AnswerCard';
export const AnswerCard = memo(AnswerCardComponent);

View file

@ -7,7 +7,6 @@ import { Separator } from '@/components/ui/separator';
import { CustomSlateEditor } from '@/features/chat/customSlateEditor'; import { CustomSlateEditor } from '@/features/chat/customSlateEditor';
import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types'; import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types';
import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils'; import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils';
import { useDomain } from '@/hooks/useDomain';
import { useChat } from '@ai-sdk/react'; import { useChat } from '@ai-sdk/react';
import { CreateUIMessage, DefaultChatTransport } from 'ai'; import { CreateUIMessage, DefaultChatTransport } from 'ai';
import { ArrowDownIcon } from 'lucide-react'; import { ArrowDownIcon } from 'lucide-react';
@ -54,7 +53,6 @@ export const ChatThread = ({
onSelectedSearchScopesChange, onSelectedSearchScopesChange,
isChatReadonly, isChatReadonly,
}: ChatThreadProps) => { }: ChatThreadProps) => {
const domain = useDomain();
const [isErrorBannerVisible, setIsErrorBannerVisible] = useState(false); const [isErrorBannerVisible, setIsErrorBannerVisible] = useState(false);
const scrollAreaRef = useRef<HTMLDivElement>(null); const scrollAreaRef = useRef<HTMLDivElement>(null);
const latestMessagePairRef = useRef<HTMLDivElement>(null); const latestMessagePairRef = useRef<HTMLDivElement>(null);
@ -89,9 +87,6 @@ export const ChatThread = ({
messages: initialMessages, messages: initialMessages,
transport: new DefaultChatTransport({ transport: new DefaultChatTransport({
api: '/api/chat', api: '/api/chat',
headers: {
"X-Org-Domain": domain,
}
}), }),
onData: (dataPart) => { onData: (dataPart) => {
// Keeps sources added by the assistant in sync. // Keeps sources added by the assistant in sync.
@ -134,7 +129,6 @@ export const ChatThread = ({
languageModelId: selectedLanguageModel.model, languageModelId: selectedLanguageModel.model,
message: message.parts[0].text, message: message.parts[0].text,
}, },
domain
).then((response) => { ).then((response) => {
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -153,7 +147,6 @@ export const ChatThread = ({
messages.length, messages.length,
toast, toast,
chatId, chatId,
domain,
router, router,
]); ]);
@ -196,46 +189,47 @@ export const ChatThread = ({
hasSubmittedInputMessage.current = true; hasSubmittedInputMessage.current = true;
}, [inputMessage, sendMessage]); }, [inputMessage, sendMessage]);
// @todo: this need to be optimized to avoid excessive re-renders
// Track scroll position changes. // Track scroll position changes.
useEffect(() => { // useEffect(() => {
const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement; // const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement;
if (!scrollElement) return; // if (!scrollElement) return;
let timeout: NodeJS.Timeout | null = null; // let timeout: NodeJS.Timeout | null = null;
const handleScroll = () => { // const handleScroll = () => {
const scrollOffset = scrollElement.scrollTop; // const scrollOffset = scrollElement.scrollTop;
const threshold = 50; // pixels from bottom to consider "at bottom" // const threshold = 50; // pixels from bottom to consider "at bottom"
const { scrollHeight, clientHeight } = scrollElement; // const { scrollHeight, clientHeight } = scrollElement;
const isAtBottom = scrollHeight - scrollOffset - clientHeight <= threshold; // const isAtBottom = scrollHeight - scrollOffset - clientHeight <= threshold;
setIsAutoScrollEnabled(isAtBottom); // setIsAutoScrollEnabled(isAtBottom);
// Debounce the history state update // // Debounce the history state update
if (timeout) { // if (timeout) {
clearTimeout(timeout); // clearTimeout(timeout);
} // }
timeout = setTimeout(() => { // timeout = setTimeout(() => {
history.replaceState( // history.replaceState(
{ // {
scrollOffset, // scrollOffset,
} satisfies ChatHistoryState, // } satisfies ChatHistoryState,
'', // '',
window.location.href // window.location.href
); // );
}, 300); // }, 300);
}; // };
scrollElement.addEventListener('scroll', handleScroll, { passive: true }); // scrollElement.addEventListener('scroll', handleScroll, { passive: true });
return () => { // return () => {
scrollElement.removeEventListener('scroll', handleScroll); // scrollElement.removeEventListener('scroll', handleScroll);
if (timeout) { // if (timeout) {
clearTimeout(timeout); // clearTimeout(timeout);
} // }
}; // };
}, []); // }, []);
useEffect(() => { useEffect(() => {
const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement; const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement;
@ -313,9 +307,11 @@ export const ChatThread = ({
{messagePairs.map(([userMessage, assistantMessage], index) => { {messagePairs.map(([userMessage, assistantMessage], index) => {
const isLastPair = index === messagePairs.length - 1; const isLastPair = index === messagePairs.length - 1;
const isStreaming = isLastPair && (status === "streaming" || status === "submitted"); const isStreaming = isLastPair && (status === "streaming" || status === "submitted");
// Use a stable key based on user message ID
const key = userMessage.id;
return ( return (
<Fragment key={index}> <Fragment key={key}>
<ChatThreadListItem <ChatThreadListItem
index={index} index={index}
chatId={chatId} chatId={chatId}

View file

@ -4,7 +4,7 @@ import { AnimatedResizableHandle } from '@/components/ui/animatedResizableHandle
import { ResizablePanel, ResizablePanelGroup } from '@/components/ui/resizable'; import { ResizablePanel, ResizablePanelGroup } from '@/components/ui/resizable';
import { Skeleton } from '@/components/ui/skeleton'; import { Skeleton } from '@/components/ui/skeleton';
import { CheckCircle, Loader2 } from 'lucide-react'; import { CheckCircle, Loader2 } from 'lucide-react';
import { CSSProperties, forwardRef, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { CSSProperties, forwardRef, memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import scrollIntoView from 'scroll-into-view-if-needed'; import scrollIntoView from 'scroll-into-view-if-needed';
import { Reference, referenceSchema, SBChatMessage, Source } from "../../types"; import { Reference, referenceSchema, SBChatMessage, Source } from "../../types";
import { useExtractReferences } from '../../useExtractReferences'; import { useExtractReferences } from '../../useExtractReferences';
@ -24,7 +24,7 @@ interface ChatThreadListItemProps {
index: number; index: number;
} }
export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemProps>(({ export const ChatThreadListItemComponent = forwardRef<HTMLDivElement, ChatThreadListItemProps>(({
userMessage, userMessage,
assistantMessage: _assistantMessage, assistantMessage: _assistantMessage,
isStreaming, isStreaming,
@ -32,6 +32,7 @@ export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemP
chatId, chatId,
index, index,
}, ref) => { }, ref) => {
console.log(`re-rendering chat thread list item`, index);
const leftPanelRef = useRef<HTMLDivElement>(null); const leftPanelRef = useRef<HTMLDivElement>(null);
const [leftPanelHeight, setLeftPanelHeight] = useState<number | null>(null); const [leftPanelHeight, setLeftPanelHeight] = useState<number | null>(null);
const answerRef = useRef<HTMLDivElement>(null); const answerRef = useRef<HTMLDivElement>(null);
@ -393,7 +394,12 @@ export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemP
) )
}); });
ChatThreadListItem.displayName = 'ChatThreadListItem'; ChatThreadListItemComponent.displayName = 'ChatThreadListItem';
// Only allow re-rendering when the message _is_ streaming.
// This is a performance optimizations to prevent unnecessary
// re-renders for chatThreadListItems that are not streaming.
export const ChatThreadListItem = memo(ChatThreadListItemComponent, (_, nextProps) => !nextProps.isStreaming);
// Finds the nearest reference element to the viewport center. // Finds the nearest reference element to the viewport center.
const getNearestReferenceElement = (referenceElements: Element[]) => { const getNearestReferenceElement = (referenceElements: Element[]) => {

View file

@ -7,6 +7,7 @@ import { Skeleton } from '@/components/ui/skeleton';
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'; import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, List, ScanSearchIcon, Zap } from 'lucide-react'; import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, List, ScanSearchIcon, Zap } from 'lucide-react';
import { memo } from 'react';
import { MarkdownRenderer } from './markdownRenderer'; import { MarkdownRenderer } from './markdownRenderer';
import { FindSymbolDefinitionsToolComponent } from './tools/findSymbolDefinitionsToolComponent'; import { FindSymbolDefinitionsToolComponent } from './tools/findSymbolDefinitionsToolComponent';
import { FindSymbolReferencesToolComponent } from './tools/findSymbolReferencesToolComponent'; import { FindSymbolReferencesToolComponent } from './tools/findSymbolReferencesToolComponent';
@ -27,7 +28,7 @@ interface DetailsCardProps {
metadata?: SBChatMessageMetadata; metadata?: SBChatMessageMetadata;
} }
export const DetailsCard = ({ const DetailsCardComponent = ({
isExpanded, isExpanded,
onExpandedChanged, onExpandedChanged,
isThinking, isThinking,
@ -210,3 +211,5 @@ export const DetailsCard = ({
</Card> </Card>
) )
} }
export const DetailsCard = memo(DetailsCardComponent);

View file

@ -1,7 +1,6 @@
'use client'; 'use client';
import { CodeSnippet } from '@/app/components/codeSnippet'; import { CodeSnippet } from '@/app/components/codeSnippet';
import { useDomain } from '@/hooks/useDomain';
import { SearchQueryParams } from '@/lib/types'; import { SearchQueryParams } from '@/lib/types';
import { cn, createPathWithQueryParams } from '@/lib/utils'; import { cn, createPathWithQueryParams } from '@/lib/utils';
import type { Element, Root } from "hast"; import type { Element, Root } from "hast";
@ -10,7 +9,7 @@ import { CopyIcon, SearchIcon } from 'lucide-react';
import type { Heading, Nodes } from "mdast"; import type { Heading, Nodes } from "mdast";
import { findAndReplace } from 'mdast-util-find-and-replace'; import { findAndReplace } from 'mdast-util-find-and-replace';
import { useRouter } from 'next/navigation'; import { useRouter } from 'next/navigation';
import React, { useCallback, useMemo, forwardRef } from 'react'; import React, { useCallback, useMemo, forwardRef, memo } from 'react';
import Markdown from 'react-markdown'; import Markdown from 'react-markdown';
import rehypeRaw from 'rehype-raw'; import rehypeRaw from 'rehype-raw';
import rehypeSanitize, { defaultSchema } from 'rehype-sanitize'; import rehypeSanitize, { defaultSchema } from 'rehype-sanitize';
@ -20,6 +19,7 @@ import { visit } from 'unist-util-visit';
import { CodeBlock } from './codeBlock'; import { CodeBlock } from './codeBlock';
import { FILE_REFERENCE_REGEX } from '@/features/chat/constants'; import { FILE_REFERENCE_REGEX } from '@/features/chat/constants';
import { createFileReference } from '@/features/chat/utils'; import { createFileReference } from '@/features/chat/utils';
import { SINGLE_TENANT_ORG_DOMAIN } from '@/lib/constants';
export const REFERENCE_PAYLOAD_ATTRIBUTE = 'data-reference-payload'; export const REFERENCE_PAYLOAD_ATTRIBUTE = 'data-reference-payload';
@ -102,8 +102,7 @@ interface MarkdownRendererProps {
className?: string; className?: string;
} }
export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps>(({ content, className }, ref) => { const MarkdownRendererComponent = forwardRef<HTMLDivElement, MarkdownRendererProps>(({ content, className }, ref) => {
const domain = useDomain();
const router = useRouter(); const router = useRouter();
const remarkPlugins = useMemo((): PluggableList => { const remarkPlugins = useMemo((): PluggableList => {
@ -176,7 +175,7 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
onClick={(e) => { onClick={(e) => {
e.preventDefault(); e.preventDefault();
e.stopPropagation(); e.stopPropagation();
const url = createPathWithQueryParams(`/${domain}/search`, [SearchQueryParams.query, `"${text}"`]) const url = createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/search`, [SearchQueryParams.query, `"${text}"`])
router.push(url); router.push(url);
}} }}
title="Search for snippet" title="Search for snippet"
@ -199,7 +198,7 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
</span> </span>
) )
}, [domain, router]); }, [router]);
return ( return (
<div <div
@ -220,4 +219,6 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
); );
}); });
MarkdownRenderer.displayName = 'MarkdownRenderer'; MarkdownRendererComponent.displayName = 'MarkdownRenderer';
export const MarkdownRenderer = memo(MarkdownRendererComponent);

View file

@ -14,13 +14,12 @@ import { Range } from "@codemirror/state";
import { Decoration, DecorationSet, EditorView } from '@codemirror/view'; import { Decoration, DecorationSet, EditorView } from '@codemirror/view';
import CodeMirror, { ReactCodeMirrorRef, StateField } from '@uiw/react-codemirror'; import CodeMirror, { ReactCodeMirrorRef, StateField } from '@uiw/react-codemirror';
import { ChevronDown, ChevronRight } from "lucide-react"; import { ChevronDown, ChevronRight } from "lucide-react";
import { forwardRef, Ref, useCallback, useImperativeHandle, useMemo, useState } from "react"; import { forwardRef, memo, Ref, useCallback, useImperativeHandle, useMemo, useState } from "react";
import { FileReference } from "../../types"; import { FileReference } from "../../types";
import { createCodeFoldingExtension } from "./codeFoldingExtension"; import { createCodeFoldingExtension } from "./codeFoldingExtension";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
import { CodeHostType } from "@sourcebot/db"; import { CodeHostType } from "@sourcebot/db";
import { createAuditAction } from "@/ee/features/audit/actions";
const lineDecoration = Decoration.line({ const lineDecoration = Decoration.line({
attributes: { class: "cm-range-border-radius chat-lineHighlight" }, attributes: { class: "cm-range-border-radius chat-lineHighlight" },
@ -75,7 +74,6 @@ const ReferencedFileSourceListItem = ({
const theme = useCodeMirrorTheme(); const theme = useCodeMirrorTheme();
const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null); const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null);
const captureEvent = useCaptureEvent(); const captureEvent = useCaptureEvent();
const domain = useDomain();
useImperativeHandle( useImperativeHandle(
forwardedRef, forwardedRef,
@ -124,6 +122,8 @@ const ReferencedFileSourceListItem = ({
return createCodeFoldingExtension(references, 3); return createCodeFoldingExtension(references, 3);
}, [references]); }, [references]);
// console.log(`re-renderign for file ${fileName}`);
const extensions = useMemo(() => { const extensions = useMemo(() => {
return [ return [
languageExtension, languageExtension,
@ -231,7 +231,7 @@ const ReferencedFileSourceListItem = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain); });
if (symbolDefinitions.length === 1) { if (symbolDefinitions.length === 1) {
const symbolDefinition = symbolDefinitions[0]; const symbolDefinition = symbolDefinitions[0];
@ -263,7 +263,7 @@ const ReferencedFileSourceListItem = ({
}); });
} }
}, [captureEvent, domain, navigateToPath, revision, repoName, fileName, language]); }, [captureEvent, navigateToPath, revision, repoName, fileName, language]);
const onFindReferences = useCallback((symbolName: string) => { const onFindReferences = useCallback((symbolName: string) => {
captureEvent('wa_find_references_pressed', { captureEvent('wa_find_references_pressed', {
@ -274,7 +274,7 @@ const ReferencedFileSourceListItem = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain); });
navigateToPath({ navigateToPath({
repoName, repoName,
@ -293,7 +293,7 @@ const ReferencedFileSourceListItem = ({
} }
}) })
}, [captureEvent, domain, fileName, language, navigateToPath, repoName, revision]); }, [captureEvent, fileName, language, navigateToPath, repoName, revision]);
const ExpandCollapseIcon = useMemo(() => { const ExpandCollapseIcon = useMemo(() => {
return isExpanded ? ChevronDown : ChevronRight; return isExpanded ? ChevronDown : ChevronRight;
@ -355,6 +355,6 @@ const ReferencedFileSourceListItem = ({
) )
} }
export default forwardRef(ReferencedFileSourceListItem) as ( export default memo(forwardRef(ReferencedFileSourceListItem)) as (
props: ReferencedFileSourceListItemProps & { ref?: Ref<ReactCodeMirrorRef> }, props: ReferencedFileSourceListItemProps & { ref?: Ref<ReactCodeMirrorRef> },
) => ReturnType<typeof ReferencedFileSourceListItem>; ) => ReturnType<typeof ReferencedFileSourceListItem>;

View file

@ -4,11 +4,10 @@ import { getFileSource } from "@/app/api/(client)/client";
import { VscodeFileIcon } from "@/app/components/vscodeFileIcon"; import { VscodeFileIcon } from "@/app/components/vscodeFileIcon";
import { ScrollArea } from "@/components/ui/scroll-area"; import { ScrollArea } from "@/components/ui/scroll-area";
import { Skeleton } from "@/components/ui/skeleton"; import { Skeleton } from "@/components/ui/skeleton";
import { useDomain } from "@/hooks/useDomain";
import { isServiceError, unwrapServiceError } from "@/lib/utils"; import { isServiceError, unwrapServiceError } from "@/lib/utils";
import { useQueries } from "@tanstack/react-query"; import { useQueries } from "@tanstack/react-query";
import { ReactCodeMirrorRef } from '@uiw/react-codemirror'; import { ReactCodeMirrorRef } from '@uiw/react-codemirror';
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { memo, useCallback, useEffect, useMemo, useRef, useState } from "react";
import scrollIntoView from 'scroll-into-view-if-needed'; import scrollIntoView from 'scroll-into-view-if-needed';
import { FileReference, FileSource, Reference, Source } from "../../types"; import { FileReference, FileSource, Reference, Source } from "../../types";
import ReferencedFileSourceListItem from "./referencedFileSourceListItem"; import ReferencedFileSourceListItem from "./referencedFileSourceListItem";
@ -31,7 +30,7 @@ const resolveFileReference = (reference: FileReference, sources: FileSource[]):
); );
} }
export const ReferencedSourcesListView = ({ const ReferencedSourcesListViewComponent = ({
references, references,
sources, sources,
index, index,
@ -43,7 +42,6 @@ export const ReferencedSourcesListView = ({
}: ReferencedSourcesListViewProps) => { }: ReferencedSourcesListViewProps) => {
const scrollAreaRef = useRef<HTMLDivElement>(null); const scrollAreaRef = useRef<HTMLDivElement>(null);
const editorRefsMap = useRef<Map<string, ReactCodeMirrorRef>>(new Map()); const editorRefsMap = useRef<Map<string, ReactCodeMirrorRef>>(new Map());
const domain = useDomain();
const [collapsedFileIds, setCollapsedFileIds] = useState<string[]>([]); const [collapsedFileIds, setCollapsedFileIds] = useState<string[]>([]);
const getFileId = useCallback((fileSource: FileSource) => { const getFileId = useCallback((fileSource: FileSource) => {
@ -98,7 +96,7 @@ export const ReferencedSourcesListView = ({
const fileSourceQueries = useQueries({ const fileSourceQueries = useQueries({
queries: referencedFileSources.map((file) => ({ queries: referencedFileSources.map((file) => ({
queryKey: ['fileSource', file.path, file.repo, file.revision, domain], queryKey: ['fileSource', file.path, file.repo, file.revision],
queryFn: () => unwrapServiceError(getFileSource({ queryFn: () => unwrapServiceError(getFileSource({
fileName: file.path, fileName: file.path,
repository: file.repo, repository: file.repo,
@ -183,6 +181,25 @@ export const ReferencedSourcesListView = ({
} }
}, [getFileId, referencedFileSources, selectedReference]); }, [getFileId, referencedFileSources, selectedReference]);
const onExpandedChanged = useCallback((fileId: string, isExpanded: boolean) => {
if (isExpanded) {
setCollapsedFileIds(collapsedFileIds => collapsedFileIds.filter((id) => id !== fileId));
} else {
setCollapsedFileIds(collapsedFileIds => [...collapsedFileIds, fileId]);
}
if (!isExpanded) {
const fileSourceStart = document.getElementById(`${fileId}-start`);
if (fileSourceStart) {
scrollIntoView(fileSourceStart, {
scrollMode: 'if-needed',
block: 'start',
behavior: 'instant',
});
}
}
}, []);
if (referencedFileSources.length === 0) { if (referencedFileSources.length === 0) {
return ( return (
<div className="p-4 text-center text-muted-foreground text-sm"> <div className="p-4 text-center text-muted-foreground text-sm">
@ -253,30 +270,7 @@ export const ReferencedSourcesListView = ({
selectedReference={selectedReference} selectedReference={selectedReference}
hoveredReference={hoveredReference} hoveredReference={hoveredReference}
isExpanded={!collapsedFileIds.includes(fileId)} isExpanded={!collapsedFileIds.includes(fileId)}
onExpandedChanged={(isExpanded) => { onExpandedChanged={(isExpanded) => onExpandedChanged(fileId, isExpanded)}
if (isExpanded) {
setCollapsedFileIds(collapsedFileIds.filter((id) => id !== fileId));
} else {
setCollapsedFileIds([...collapsedFileIds, fileId]);
}
// When collapsing a file when you are deep in a scroll, it's a better
// experience to have the scroll automatically restored to the top of the file
// s.t., header is still sticky to the top of the scroll area.
if (!isExpanded) {
const fileSourceStart = document.getElementById(`${fileId}-start`);
if (!fileSourceStart) {
return;
}
scrollIntoView(fileSourceStart, {
scrollMode: 'if-needed',
block: 'start',
behavior: 'instant',
});
}
}
}
/> />
); );
})} })}
@ -284,3 +278,6 @@ export const ReferencedSourcesListView = ({
</ScrollArea> </ScrollArea>
); );
} }
// Memoize to prevent unnecessary re-renders
export const ReferencedSourcesListView = memo(ReferencedSourcesListViewComponent);

View file

@ -1,7 +1,6 @@
'use client'; 'use client';
import { SearchCodeToolUIPart } from "@/features/chat/tools"; import { SearchCodeToolUIPart } from "@/features/chat/tools";
import { useDomain } from "@/hooks/useDomain";
import { createPathWithQueryParams, isServiceError } from "@/lib/utils"; import { createPathWithQueryParams, isServiceError } from "@/lib/utils";
import { useMemo, useState } from "react"; import { useMemo, useState } from "react";
import { FileListItem, ToolHeader, TreeList } from "./shared"; import { FileListItem, ToolHeader, TreeList } from "./shared";
@ -12,10 +11,10 @@ import Link from "next/link";
import { SearchQueryParams } from "@/lib/types"; import { SearchQueryParams } from "@/lib/types";
import { PlayIcon } from "@radix-ui/react-icons"; import { PlayIcon } from "@radix-ui/react-icons";
import { buildSearchQuery } from "@/features/chat/utils"; import { buildSearchQuery } from "@/features/chat/utils";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }) => { export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }) => {
const [isExpanded, setIsExpanded] = useState(false); const [isExpanded, setIsExpanded] = useState(false);
const domain = useDomain();
const displayQuery = useMemo(() => { const displayQuery = useMemo(() => {
if (part.state !== 'input-available' && part.state !== 'output-available') { if (part.state !== 'input-available' && part.state !== 'output-available') {
@ -78,7 +77,7 @@ export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }
</TreeList> </TreeList>
)} )}
<Link <Link
href={createPathWithQueryParams(`/${domain}/search`, href={createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/search`,
[SearchQueryParams.query, part.output.query], [SearchQueryParams.query, part.output.query],
)} )}
className='flex flex-row items-center gap-2 text-sm text-muted-foreground mt-2 ml-auto w-fit hover:text-foreground' className='flex flex-row items-center gap-2 text-sm text-muted-foreground mt-2 ml-auto w-fit hover:text-foreground'

View file

@ -2,12 +2,12 @@
import { VscodeFileIcon } from '@/app/components/vscodeFileIcon'; import { VscodeFileIcon } from '@/app/components/vscodeFileIcon';
import { ScrollArea } from '@/components/ui/scroll-area'; import { ScrollArea } from '@/components/ui/scroll-area';
import { useDomain } from '@/hooks/useDomain';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { ChevronDown, ChevronRight, Loader2 } from 'lucide-react'; import { ChevronDown, ChevronRight, Loader2 } from 'lucide-react';
import Link from 'next/link'; import Link from 'next/link';
import React from 'react'; import React from 'react';
import { getBrowsePath } from "@/app/[domain]/browse/hooks/utils"; import { getBrowsePath } from "@/app/[domain]/browse/hooks/utils";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const FileListItem = ({ export const FileListItem = ({
@ -17,8 +17,6 @@ export const FileListItem = ({
path: string, path: string,
repoName: string, repoName: string,
}) => { }) => {
const domain = useDomain();
return ( return (
<div key={path} className="flex flex-row items-center overflow-hidden hover:bg-accent hover:text-accent-foreground rounded-sm cursor-pointer p-0.5"> <div key={path} className="flex flex-row items-center overflow-hidden hover:bg-accent hover:text-accent-foreground rounded-sm cursor-pointer p-0.5">
<VscodeFileIcon fileName={path} className="mr-1 flex-shrink-0" /> <VscodeFileIcon fileName={path} className="mr-1 flex-shrink-0" />
@ -28,7 +26,7 @@ export const FileListItem = ({
repoName, repoName,
revisionName: 'HEAD', revisionName: 'HEAD',
path, path,
domain, domain: SINGLE_TENANT_ORG_DOMAIN,
pathType: 'blob', pathType: 'blob',
})} })}
> >

View file

@ -70,7 +70,7 @@ export const sbChatMessageMetadataSchema = z.object({
feedback: z.array(z.object({ feedback: z.array(z.object({
type: z.enum(['like', 'dislike']), type: z.enum(['like', 'dislike']),
timestamp: z.string(), // ISO date string timestamp: z.string(), // ISO date string
userId: z.string(), userId: z.string().optional(),
})).optional(), })).optional(),
selectedSearchScopes: z.array(searchScopeSchema).optional(), selectedSearchScopes: z.array(searchScopeSchema).optional(),
traceId: z.string().optional(), traceId: z.string().optional(),

View file

@ -4,7 +4,6 @@ import { useCallback, useState } from "react";
import { Descendant } from "slate"; import { Descendant } from "slate";
import { createUIMessage, getAllMentionElements } from "./utils"; import { createUIMessage, getAllMentionElements } from "./utils";
import { slateContentToString } from "./utils"; import { slateContentToString } from "./utils";
import { useDomain } from "@/hooks/useDomain";
import { useToast } from "@/components/hooks/use-toast"; import { useToast } from "@/components/hooks/use-toast";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { createChat } from "./actions"; import { createChat } from "./actions";
@ -13,9 +12,9 @@ import { createPathWithQueryParams } from "@/lib/utils";
import { SearchScope, SET_CHAT_STATE_SESSION_STORAGE_KEY, SetChatStatePayload } from "./types"; import { SearchScope, SET_CHAT_STATE_SESSION_STORAGE_KEY, SetChatStatePayload } from "./types";
import { useSessionStorage } from "usehooks-ts"; import { useSessionStorage } from "usehooks-ts";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const useCreateNewChatThread = () => { export const useCreateNewChatThread = () => {
const domain = useDomain();
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const { toast } = useToast(); const { toast } = useToast();
const router = useRouter(); const router = useRouter();
@ -31,7 +30,7 @@ export const useCreateNewChatThread = () => {
const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes); const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes);
setIsLoading(true); setIsLoading(true);
const response = await createChat(domain); const response = await createChat();
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
description: `❌ Failed to create chat. Reason: ${response.message}` description: `❌ Failed to create chat. Reason: ${response.message}`
@ -47,11 +46,11 @@ export const useCreateNewChatThread = () => {
selectedSearchScopes, selectedSearchScopes,
}); });
const url = createPathWithQueryParams(`/${domain}/chat/${response.id}`); const url = createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/chat/${response.id}`);
router.push(url); router.push(url);
router.refresh(); router.refresh();
}, [domain, router, toast, setChatState, captureEvent]); }, [router, toast, setChatState, captureEvent]);
return { return {
createNewChatThread, createNewChatThread,