wip on perf optimizations. Also changed some actions to withAuthV2

This commit is contained in:
bkellam 2025-11-25 18:19:15 -08:00
parent cbe381ad0c
commit 3863f6dd81
21 changed files with 464 additions and 514 deletions

View file

@ -19,7 +19,6 @@ import { useBrowseState } from "../../hooks/useBrowseState";
import { rangeHighlightingExtension } from "./rangeHighlightingExtension";
import useCaptureEvent from "@/hooks/useCaptureEvent";
import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
interface PureCodePreviewPanelProps {
path: string;
@ -43,7 +42,6 @@ export const PureCodePreviewPanel = ({
const hasCodeNavEntitlement = useHasEntitlement("code-nav");
const { updateBrowseState } = useBrowseState();
const { navigateToPath } = useBrowseNavigation();
const domain = useDomain();
const captureEvent = useCaptureEvent();
const highlightRangeQuery = useNonEmptyQueryParam(HIGHLIGHT_RANGE_QUERY_PARAM);
@ -145,7 +143,7 @@ export const PureCodePreviewPanel = ({
metadata: {
message: symbolName,
},
}, domain)
})
updateBrowseState({
selectedSymbolInfo: {
@ -157,7 +155,7 @@ export const PureCodePreviewPanel = ({
isBottomPanelCollapsed: false,
activeExploreMenuTab: "references",
})
}, [captureEvent, updateBrowseState, repoName, revisionName, language, domain]);
}, [captureEvent, updateBrowseState, repoName, revisionName, language]);
// If we resolve multiple matches, instead of navigating to the first match, we should
@ -171,7 +169,7 @@ export const PureCodePreviewPanel = ({
metadata: {
message: symbolName,
},
}, domain)
})
if (symbolDefinitions.length === 0) {
return;
@ -200,7 +198,7 @@ export const PureCodePreviewPanel = ({
isBottomPanelCollapsed: false,
})
}
}, [captureEvent, navigateToPath, revisionName, updateBrowseState, repoName, language, domain]);
}, [captureEvent, navigateToPath, revisionName, updateBrowseState, repoName, language]);
const theme = useCodeMirrorTheme();

View file

@ -24,9 +24,9 @@ export default async function Page(props: PageProps) {
const languageModels = await getConfiguredLanguageModelsInfo();
const repos = await getRepos();
const searchContexts = await getSearchContexts(params.domain);
const chatInfo = await getChatInfo({ chatId: params.id }, params.domain);
const chatInfo = await getChatInfo({ chatId: params.id });
const session = await auth();
const chatHistory = session ? await getUserChatHistory(params.domain) : [];
const chatHistory = session ? await getUserChatHistory() : [];
if (isServiceError(chatHistory)) {
throw new ServiceErrorException(chatHistory);

View file

@ -4,7 +4,6 @@ import { useToast } from "@/components/hooks/use-toast";
import { Badge } from "@/components/ui/badge";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { updateChatName } from "@/features/chat/actions";
import { useDomain } from "@/hooks/useDomain";
import { isServiceError } from "@/lib/utils";
import { GlobeIcon } from "@radix-ui/react-icons";
import { ChatVisibility } from "@sourcebot/db";
@ -23,7 +22,6 @@ interface ChatNameProps {
export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) => {
const [isRenameDialogOpen, setIsRenameDialogOpen] = useState(false);
const { toast } = useToast();
const domain = useDomain();
const router = useRouter();
const onRenameChat = useCallback(async (name: string) => {
@ -31,7 +29,7 @@ export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) =>
const response = await updateChatName({
chatId: id,
name: name,
}, domain);
});
if (isServiceError(response)) {
toast({
@ -43,7 +41,7 @@ export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) =>
});
router.refresh();
}
}, [id, domain, toast, router]);
}, [id, toast, router]);
return (
<>

View file

@ -9,7 +9,6 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { Separator } from "@/components/ui/separator";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { deleteChat, updateChatName } from "@/features/chat/actions";
import { useDomain } from "@/hooks/useDomain";
import { cn, isServiceError } from "@/lib/utils";
import { CirclePlusIcon, EllipsisIcon, PencilIcon, TrashIcon } from "lucide-react";
import { useRouter } from "next/navigation";
@ -23,6 +22,7 @@ import { useChatId } from "../useChatId";
import { RenameChatDialog } from "./renameChatDialog";
import { DeleteChatDialog } from "./deleteChatDialog";
import Link from "next/link";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
interface ChatSidePanelProps {
order: number;
@ -41,7 +41,6 @@ export const ChatSidePanel = ({
isAuthenticated,
isCollapsedInitially,
}: ChatSidePanelProps) => {
const domain = useDomain();
const [isCollapsed, setIsCollapsed] = useState(isCollapsedInitially);
const sidePanelRef = useRef<ImperativePanelHandle>(null);
const router = useRouter();
@ -72,7 +71,7 @@ export const ChatSidePanel = ({
const response = await updateChatName({
chatId,
name: name,
}, domain);
});
if (isServiceError(response)) {
toast({
@ -84,14 +83,14 @@ export const ChatSidePanel = ({
});
router.refresh();
}
}, [router, toast, domain]);
}, [router, toast]);
const onDeleteChat = useCallback(async (chatIdToDelete: string) => {
if (!chatIdToDelete) {
return;
}
const response = await deleteChat({ chatId: chatIdToDelete }, domain);
const response = await deleteChat({ chatId: chatIdToDelete });
if (isServiceError(response)) {
toast({
@ -104,12 +103,12 @@ export const ChatSidePanel = ({
// If we just deleted the current chat, navigate to new chat
if (chatIdToDelete === chatId) {
router.push(`/${domain}/chat`);
router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`);
}
router.refresh();
}
}, [chatId, router, toast, domain]);
}, [chatId, router, toast]);
return (
<>
@ -131,7 +130,7 @@ export const ChatSidePanel = ({
size="sm"
className="w-full"
onClick={() => {
router.push(`/${domain}/chat`);
router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`);
}}
>
<CirclePlusIcon className="w-4 h-4 mr-1" />
@ -145,7 +144,7 @@ export const ChatSidePanel = ({
<div className="flex flex-col">
<p className="text-sm text-muted-foreground mb-4">
<Link
href={`/login?callbackUrl=${encodeURIComponent(`/${domain}/chat`)}`}
href={`/login?callbackUrl=${encodeURIComponent(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`)}`}
className="text-sm text-link hover:underline cursor-pointer"
>
Sign in
@ -163,7 +162,7 @@ export const ChatSidePanel = ({
chat.id === chatId && "bg-muted"
)}
onClick={() => {
router.push(`/${domain}/chat/${chat.id}`);
router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat/${chat.id}`);
}}
>
<span className="text-sm truncate">{chat.name ?? 'Untitled chat'}</span>

View file

@ -221,7 +221,7 @@ export const SearchBar = ({
metadata: {
message: query,
},
}, domain)
})
const url = createPathWithQueryParams(`/${domain}/search`,
[SearchQueryParams.query, query],

View file

@ -22,8 +22,6 @@ import { symbolHoverTargetsExtension } from "@/ee/features/codeNav/components/sy
import { useHasEntitlement } from "@/features/entitlements/useHasEntitlement";
import { SymbolDefinition } from "@/ee/features/codeNav/components/symbolHoverPopup/useHoveredOverSymbolInfo";
import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
import useCaptureEvent from "@/hooks/useCaptureEvent";
export interface CodePreviewFile {
@ -53,7 +51,6 @@ export const CodePreview = ({
const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null);
const { navigateToPath } = useBrowseNavigation();
const hasCodeNavEntitlement = useHasEntitlement("code-nav");
const domain = useDomain();
const [gutterWidth, setGutterWidth] = useState(0);
const theme = useCodeMirrorTheme();
@ -127,7 +124,7 @@ export const CodePreview = ({
metadata: {
message: symbolName,
},
}, domain)
})
if (symbolDefinitions.length === 0) {
return;
@ -162,7 +159,7 @@ export const CodePreview = ({
}
});
}
}, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName, domain]);
}, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName]);
const onFindReferences = useCallback((symbolName: string) => {
captureEvent('wa_find_references_pressed', {
@ -173,7 +170,7 @@ export const CodePreview = ({
metadata: {
message: symbolName,
},
}, domain)
})
navigateToPath({
repoName,
@ -191,7 +188,7 @@ export const CodePreview = ({
isBottomPanelCollapsed: false,
}
})
}, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName, domain]);
}, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName]);
return (
<div className="flex flex-col h-full">

View file

@ -1,4 +1,4 @@
import { sew, withAuth, withOrgMembership } from "@/actions";
import { sew } from "@/actions";
import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions";
import { createAgentStream } from "@/features/chat/agent";
import { additionalChatRequestParamsSchema, LanguageModelInfo, SBChatMessage, SearchScope } from "@/features/chat/types";
@ -6,10 +6,10 @@ import { getAnswerPartFromAssistantMessage, getLanguageModelKey } from "@/featur
import { ErrorCode } from "@/lib/errorCodes";
import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils";
import { prisma } from "@/prisma";
import { withOptionalAuthV2 } from "@/withAuthV2";
import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider";
import * as Sentry from "@sentry/nextjs";
import { OrgRole } from "@sourcebot/db";
import { PrismaClient } from "@sourcebot/db";
import { createLogger } from "@sourcebot/shared";
import {
createUIMessageStream,
@ -34,15 +34,6 @@ const chatRequestSchema = z.object({
})
export async function POST(req: Request) {
const domain = req.headers.get("X-Org-Domain");
if (!domain) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER,
message: "Missing X-Org-Domain header",
});
}
const requestBody = await req.json();
const parsed = await chatRequestSchema.safeParseAsync(requestBody);
if (!parsed.success) {
@ -56,56 +47,54 @@ export async function POST(req: Request) {
const languageModel = _languageModel as LanguageModelInfo;
const response = await sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
// Validate that the chat exists and is not readonly.
const chat = await prisma.chat.findUnique({
where: {
orgId: org.id,
id,
},
});
if (!chat) {
return notFound();
}
if (chat.isReadonly) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: "Chat is readonly and cannot be edited.",
});
}
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => getLanguageModelKey(model) === getLanguageModelKey(languageModel));
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModel.model} is not configured.`,
});
}
const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig);
return createMessageStreamResponse({
messages,
id,
selectedSearchScopes,
model,
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
modelProviderOptions: providerOptions,
domain,
withOptionalAuthV2(async ({ org, prisma }) => {
// Validate that the chat exists and is not readonly.
const chat = await prisma.chat.findUnique({
where: {
orgId: org.id,
id,
},
});
if (!chat) {
return notFound();
}
if (chat.isReadonly) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: "Chat is readonly and cannot be edited.",
});
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true
)
}
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => getLanguageModelKey(model) === getLanguageModelKey(languageModel));
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModel.model} is not configured.`,
});
}
const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig);
return createMessageStreamResponse({
messages,
id,
selectedSearchScopes,
model,
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
modelProviderOptions: providerOptions,
orgId: org.id,
prisma,
});
})
)
if (isServiceError(response)) {
@ -132,8 +121,8 @@ interface CreateMessageStreamResponseProps {
model: AISDKLanguageModelV2;
modelName: string;
modelProviderOptions?: Record<string, Record<string, JSONValue>>;
domain: string;
orgId: number;
prisma: PrismaClient;
}
const createMessageStreamResponse = async ({
@ -143,8 +132,8 @@ const createMessageStreamResponse = async ({
model,
modelName,
modelProviderOptions,
domain,
orgId,
prisma,
}: CreateMessageStreamResponseProps) => {
const latestMessage = messages[messages.length - 1];
const sources = latestMessage.parts
@ -254,7 +243,7 @@ const createMessageStreamResponse = async ({
await updateChatMessages({
chatId: id,
messages
}, domain);
});
},
});

View file

@ -1,46 +1,25 @@
'use server';
import { NextRequest } from "next/server";
import { fetchAuditRecords } from "@/ee/features/audit/actions";
import { isServiceError } from "@/lib/utils";
import { serviceErrorResponse } from "@/lib/serviceError";
import { StatusCodes } from "http-status-codes";
import { ErrorCode } from "@/lib/errorCodes";
import { env } from "@sourcebot/shared";
import { serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils";
import { getEntitlements } from "@sourcebot/shared";
import { StatusCodes } from "http-status-codes";
export const GET = async (request: NextRequest) => {
const domain = request.headers.get("X-Org-Domain");
const apiKey = request.headers.get("X-Sourcebot-Api-Key") ?? undefined;
export const GET = async () => {
const entitlements = getEntitlements();
if (!entitlements.includes('audit')) {
return serviceErrorResponse({
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.NOT_FOUND,
message: "Audit logging is not enabled for your license",
});
}
if (!domain) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER,
message: "Missing X-Org-Domain header",
});
}
if (env.SOURCEBOT_EE_AUDIT_LOGGING_ENABLED === 'false') {
return serviceErrorResponse({
statusCode: StatusCodes.NOT_FOUND,
errorCode: ErrorCode.NOT_FOUND,
message: "Audit logging is not enabled",
});
}
const entitlements = getEntitlements();
if (!entitlements.includes('audit')) {
return serviceErrorResponse({
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.NOT_FOUND,
message: "Audit logging is not enabled for your license",
});
}
const result = await fetchAuditRecords(domain, apiKey);
if (isServiceError(result)) {
return serviceErrorResponse(result);
}
return Response.json(result);
const result = await fetchAuditRecords();
if (isServiceError(result)) {
return serviceErrorResponse(result);
}
return Response.json(result);
};

View file

@ -1,59 +1,62 @@
"use server";
import { prisma } from "@/prisma";
import { ErrorCode } from "@/lib/errorCodes";
import { StatusCodes } from "http-status-codes";
import { sew, withAuth, withOrgMembership } from "@/actions";
import { OrgRole } from "@sourcebot/db";
import { createLogger } from "@sourcebot/shared";
import { ServiceError } from "@/lib/serviceError";
import { sew } from "@/actions";
import { getAuditService } from "@/ee/features/audit/factory";
import { ErrorCode } from "@/lib/errorCodes";
import { ServiceError } from "@/lib/serviceError";
import { prisma } from "@/prisma";
import { withAuthV2 } from "@/withAuthV2";
import { createLogger } from "@sourcebot/shared";
import { StatusCodes } from "http-status-codes";
import { AuditEvent } from "./types";
const auditService = getAuditService();
const logger = createLogger('audit-utils');
export const createAuditAction = async (event: Omit<AuditEvent, 'sourcebotVersion' | 'orgId' | 'actor' | 'target'>, domain: string) => sew(async () =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
await auditService.createAudit({ ...event, orgId: org.id, actor: { id: userId, type: "user" }, target: { id: org.id.toString(), type: "org" } })
}, /* minRequiredRole = */ OrgRole.MEMBER), /* allowAnonymousAccess = */ true)
);
export const fetchAuditRecords = async (domain: string, apiKey: string | undefined = undefined) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
try {
const auditRecords = await prisma.audit.findMany({
where: {
orgId: org.id,
},
orderBy: {
timestamp: 'desc'
}
});
export const createAuditAction = async (event: Omit<AuditEvent, 'sourcebotVersion' | 'orgId' | 'actor' | 'target'>) => sew(async () =>
withAuthV2(async ({ user, org }) => {
await auditService.createAudit({
action: "audit.fetch",
actor: {
id: userId,
type: "user"
},
target: {
id: org.id.toString(),
type: "org"
},
orgId: org.id
...event,
orgId: org.id,
actor: { id: user.id, type: "user" },
target: { id: org.id.toString(), type: "org" },
})
return auditRecords;
} catch (error) {
logger.error('Error fetching audit logs', { error });
return {
statusCode: StatusCodes.INTERNAL_SERVER_ERROR,
errorCode: ErrorCode.UNEXPECTED_ERROR,
message: "Failed to fetch audit logs",
} satisfies ServiceError;
}
}, /* minRequiredRole = */ OrgRole.OWNER), /* allowAnonymousAccess = */ true, apiKey ? { apiKey, domain } : undefined)
})
);
export const fetchAuditRecords = async () => sew(() =>
withAuthV2(async ({ user, org }) => {
try {
const auditRecords = await prisma.audit.findMany({
where: {
orgId: org.id,
},
orderBy: {
timestamp: 'desc'
}
});
await auditService.createAudit({
action: "audit.fetch",
actor: {
id: user.id,
type: "user"
},
target: {
id: org.id.toString(),
type: "org"
},
orgId: org.id
})
return auditRecords;
} catch (error) {
logger.error('Error fetching audit logs', { error });
return {
statusCode: StatusCodes.INTERNAL_SERVER_ERROR,
errorCode: ErrorCode.UNEXPECTED_ERROR,
message: "Failed to fetch audit logs",
} satisfies ServiceError;
}
})
);

View file

@ -1,10 +1,9 @@
'use server';
import { sew, withAuth, withOrgMembership } from "@/actions";
import { sew } from "@/actions";
import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants";
import { ErrorCode } from "@/lib/errorCodes";
import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError";
import { prisma } from "@/prisma";
import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock';
import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic';
import { createAzure } from '@ai-sdk/azure';
@ -20,7 +19,7 @@ import { createXai } from '@ai-sdk/xai';
import { fromNodeProviderChain } from '@aws-sdk/credential-providers';
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { getTokenFromConfig, createLogger, env } from "@sourcebot/shared";
import { ChatVisibility, OrgRole, Prisma } from "@sourcebot/db";
import { ChatVisibility, Prisma } from "@sourcebot/db";
import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type";
import { Token } from "@sourcebot/schemas/v3/shared.type";
import { generateText, JSONValue, extractReasoningMiddleware, wrapLanguageModel } from "ai";
@ -29,168 +28,161 @@ import fs from 'fs';
import { StatusCodes } from "http-status-codes";
import path from 'path';
import { LanguageModelInfo, SBChatMessage } from "./types";
import { withAuthV2, withOptionalAuthV2 } from "@/withAuthV2";
const logger = createLogger('chat-actions');
export const createChat = async (domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
export const createChat = async () => sew(() =>
withOptionalAuthV2(async ({ org, user, prisma }) => {
const isGuestUser = user?.id === SOURCEBOT_GUEST_USER_ID;
const isGuestUser = userId === SOURCEBOT_GUEST_USER_ID;
const chat = await prisma.chat.create({
data: {
orgId: org.id,
messages: [] as unknown as Prisma.InputJsonValue,
createdById: user?.id ?? SOURCEBOT_GUEST_USER_ID,
visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE,
},
});
const chat = await prisma.chat.create({
data: {
orgId: org.id,
messages: [] as unknown as Prisma.InputJsonValue,
createdById: userId,
visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE,
},
});
return {
id: chat.id,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
return {
id: chat.id,
}
})
);
export const getChatInfo = async ({ chatId }: { chatId: string }, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
export const getChatInfo = async ({ chatId }: { chatId: string }) => sew(() =>
withOptionalAuthV2(async ({ org, user, prisma }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
if (!chat) {
return notFound();
}
if (!chat) {
return notFound();
}
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) {
return notFound();
}
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound();
}
return {
messages: chat.messages as unknown as SBChatMessage[],
visibility: chat.visibility,
name: chat.name,
isReadonly: chat.isReadonly,
};
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
return {
messages: chat.messages as unknown as SBChatMessage[],
visibility: chat.visibility,
name: chat.name,
isReadonly: chat.isReadonly,
};
})
);
export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }) => sew(() =>
withOptionalAuthV2(async ({ org, user, prisma }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
if (!chat) {
return notFound();
if (!chat) {
return notFound();
}
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound();
}
if (chat.isReadonly) {
return chatIsReadonly();
}
await prisma.chat.update({
where: {
id: chatId,
},
data: {
messages: messages as unknown as Prisma.InputJsonValue,
},
});
if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) {
const chatDir = path.join(env.DATA_CACHE_DIR, 'chats');
if (!fs.existsSync(chatDir)) {
fs.mkdirSync(chatDir, { recursive: true });
}
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) {
return notFound();
}
const chatFile = path.join(chatDir, `${chatId}.json`);
fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2));
}
if (chat.isReadonly) {
return chatIsReadonly();
}
await prisma.chat.update({
where: {
id: chatId,
},
data: {
messages: messages as unknown as Prisma.InputJsonValue,
},
});
if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) {
const chatDir = path.join(env.DATA_CACHE_DIR, 'chats');
if (!fs.existsSync(chatDir)) {
fs.mkdirSync(chatDir, { recursive: true });
}
const chatFile = path.join(chatDir, `${chatId}.json`);
fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2));
}
return {
success: true,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
return {
success: true,
}
})
);
export const getUserChatHistory = async (domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
const chats = await prisma.chat.findMany({
where: {
orgId: org.id,
createdById: userId,
},
orderBy: {
updatedAt: 'desc',
},
});
export const getUserChatHistory = async () => sew(() =>
withAuthV2(async ({ org, user, prisma }) => {
const chats = await prisma.chat.findMany({
where: {
orgId: org.id,
createdById: user.id,
},
orderBy: {
updatedAt: 'desc',
},
});
return chats.map((chat) => ({
id: chat.id,
createdAt: chat.createdAt,
name: chat.name,
visibility: chat.visibility,
}))
})
)
return chats.map((chat) => ({
id: chat.id,
createdAt: chat.createdAt,
name: chat.name,
visibility: chat.visibility,
}))
})
);
export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }) => sew(() =>
withOptionalAuthV2(async ({ org, user, prisma }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
if (!chat) {
return notFound();
}
if (!chat) {
return notFound();
}
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) {
return notFound();
}
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound();
}
if (chat.isReadonly) {
return chatIsReadonly();
}
if (chat.isReadonly) {
return chatIsReadonly();
}
await prisma.chat.update({
where: {
id: chatId,
orgId: org.id,
},
data: {
name,
},
});
await prisma.chat.update({
where: {
id: chatId,
orgId: org.id,
},
data: {
name,
},
});
return {
success: true,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
return {
success: true,
}
})
);
export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async () => {
export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }) => sew(() =>
withOptionalAuthV2(async () => {
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
@ -231,48 +223,6 @@ User question: ${message}`;
await updateChatName({
chatId,
name: result.text,
}, domain);
return {
success: true,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true
)
);
export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
if (!chat) {
return notFound();
}
// Public chats cannot be deleted.
if (chat.visibility === ChatVisibility.PUBLIC) {
return {
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.UNEXPECTED_ERROR,
message: 'You are not allowed to delete this chat.',
} satisfies ServiceError;
}
// Only the creator of a chat can delete it.
if (chat.createdById !== userId) {
return notFound();
}
await prisma.chat.delete({
where: {
id: chatId,
orgId: org.id,
},
});
return {
@ -280,6 +230,45 @@ export const deleteChat = async ({ chatId }: { chatId: string }, domain: string)
}
})
)
export const deleteChat = async ({ chatId }: { chatId: string }) => sew(() =>
withAuthV2(async ({ org, user, prisma }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
if (!chat) {
return notFound();
}
// Public chats cannot be deleted.
if (chat.visibility === ChatVisibility.PUBLIC) {
return {
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.UNEXPECTED_ERROR,
message: 'You are not allowed to delete this chat.',
} satisfies ServiceError;
}
// Only the creator of a chat can delete it.
if (chat.createdById !== user.id) {
return notFound();
}
await prisma.chat.delete({
where: {
id: chatId,
orgId: org.id,
},
});
return {
success: true,
}
})
);
export const submitFeedback = async ({
@ -290,56 +279,55 @@ export const submitFeedback = async ({
chatId: string,
messageId: string,
feedbackType: 'like' | 'dislike'
}, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
}) => sew(() =>
withOptionalAuthV2(async ({ org, user, prisma }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
if (!chat) {
return notFound();
if (!chat) {
return notFound();
}
// When a chat is private, only the creator can submit feedback.
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound();
}
const messages = chat.messages as unknown as SBChatMessage[];
const updatedMessages = messages.map(message => {
if (message.id === messageId && message.role === 'assistant') {
return {
...message,
metadata: {
...message.metadata,
feedback: [
...(message.metadata?.feedback ?? []),
{
type: feedbackType,
timestamp: new Date().toISOString(),
userId: user?.id,
}
]
}
} satisfies SBChatMessage;
}
return message;
});
// When a chat is private, only the creator can submit feedback.
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) {
return notFound();
}
await prisma.chat.update({
where: { id: chatId },
data: {
messages: updatedMessages as unknown as Prisma.InputJsonValue,
},
});
const messages = chat.messages as unknown as SBChatMessage[];
const updatedMessages = messages.map(message => {
if (message.id === messageId && message.role === 'assistant') {
return {
...message,
metadata: {
...message.metadata,
feedback: [
...(message.metadata?.feedback ?? []),
{
type: feedbackType,
timestamp: new Date().toISOString(),
userId: userId,
}
]
}
} satisfies SBChatMessage;
}
return message;
});
await prisma.chat.update({
where: { id: chatId },
data: {
messages: updatedMessages as unknown as Prisma.InputJsonValue,
},
});
return { success: true };
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
);
return { success: true };
})
)
/**
* Returns the subset of information about the configured language models

View file

@ -6,7 +6,7 @@ import { Button } from "@/components/ui/button";
import { TableOfContentsIcon, ThumbsDown, ThumbsUp } from "lucide-react";
import { Separator } from "@/components/ui/separator";
import { MarkdownRenderer } from "./markdownRenderer";
import { forwardRef, useCallback, useImperativeHandle, useRef, useState } from "react";
import { forwardRef, memo, useCallback, useImperativeHandle, useRef, useState } from "react";
import { Toggle } from "@/components/ui/toggle";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { CopyIconButton } from "@/app/[domain]/components/copyIconButton";
@ -14,7 +14,6 @@ import { useToast } from "@/components/hooks/use-toast";
import { convertLLMOutputToPortableMarkdown } from "../../utils";
import { submitFeedback } from "../../actions";
import { isServiceError } from "@/lib/utils";
import { useDomain } from "@/hooks/useDomain";
import useCaptureEvent from "@/hooks/useCaptureEvent";
import { LangfuseWeb } from "langfuse";
import { env } from "@sourcebot/shared/client";
@ -31,7 +30,7 @@ const langfuseWeb = (env.NEXT_PUBLIC_SOURCEBOT_CLOUD_ENVIRONMENT !== undefined &
baseUrl: env.NEXT_PUBLIC_LANGFUSE_BASE_URL,
}) : null;
export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
const AnswerCardComponent = forwardRef<HTMLDivElement, AnswerCardProps>(({
answerText,
messageId,
chatId,
@ -41,7 +40,6 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
const { tocItems, activeId } = useExtractTOCItems({ target: markdownRendererRef.current });
const [isTOCButtonToggled, setIsTOCButtonToggled] = useState(false);
const { toast } = useToast();
const domain = useDomain();
const [isSubmittingFeedback, setIsSubmittingFeedback] = useState(false);
const [feedback, setFeedback] = useState<'like' | 'dislike' | undefined>(undefined);
const captureEvent = useCaptureEvent();
@ -67,7 +65,7 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
chatId,
messageId,
feedbackType
}, domain);
});
if (isServiceError(response)) {
toast({
@ -93,7 +91,7 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
}
setIsSubmittingFeedback(false);
}, [chatId, messageId, domain, toast, captureEvent, traceId]);
}, [chatId, messageId, toast, captureEvent, traceId]);
return (
<div className="flex flex-row w-full relative scroll-mt-16">
@ -178,4 +176,6 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
)
})
AnswerCard.displayName = 'AnswerCard';
AnswerCardComponent.displayName = 'AnswerCard';
export const AnswerCard = memo(AnswerCardComponent);

View file

@ -7,7 +7,6 @@ import { Separator } from '@/components/ui/separator';
import { CustomSlateEditor } from '@/features/chat/customSlateEditor';
import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types';
import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils';
import { useDomain } from '@/hooks/useDomain';
import { useChat } from '@ai-sdk/react';
import { CreateUIMessage, DefaultChatTransport } from 'ai';
import { ArrowDownIcon } from 'lucide-react';
@ -54,7 +53,6 @@ export const ChatThread = ({
onSelectedSearchScopesChange,
isChatReadonly,
}: ChatThreadProps) => {
const domain = useDomain();
const [isErrorBannerVisible, setIsErrorBannerVisible] = useState(false);
const scrollAreaRef = useRef<HTMLDivElement>(null);
const latestMessagePairRef = useRef<HTMLDivElement>(null);
@ -89,9 +87,6 @@ export const ChatThread = ({
messages: initialMessages,
transport: new DefaultChatTransport({
api: '/api/chat',
headers: {
"X-Org-Domain": domain,
}
}),
onData: (dataPart) => {
// Keeps sources added by the assistant in sync.
@ -134,7 +129,6 @@ export const ChatThread = ({
languageModelId: selectedLanguageModel.model,
message: message.parts[0].text,
},
domain
).then((response) => {
if (isServiceError(response)) {
toast({
@ -153,7 +147,6 @@ export const ChatThread = ({
messages.length,
toast,
chatId,
domain,
router,
]);
@ -196,46 +189,47 @@ export const ChatThread = ({
hasSubmittedInputMessage.current = true;
}, [inputMessage, sendMessage]);
// @todo: this need to be optimized to avoid excessive re-renders
// Track scroll position changes.
useEffect(() => {
const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement;
if (!scrollElement) return;
// useEffect(() => {
// const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement;
// if (!scrollElement) return;
let timeout: NodeJS.Timeout | null = null;
// let timeout: NodeJS.Timeout | null = null;
const handleScroll = () => {
const scrollOffset = scrollElement.scrollTop;
// const handleScroll = () => {
// const scrollOffset = scrollElement.scrollTop;
const threshold = 50; // pixels from bottom to consider "at bottom"
const { scrollHeight, clientHeight } = scrollElement;
const isAtBottom = scrollHeight - scrollOffset - clientHeight <= threshold;
setIsAutoScrollEnabled(isAtBottom);
// const threshold = 50; // pixels from bottom to consider "at bottom"
// const { scrollHeight, clientHeight } = scrollElement;
// const isAtBottom = scrollHeight - scrollOffset - clientHeight <= threshold;
// setIsAutoScrollEnabled(isAtBottom);
// Debounce the history state update
if (timeout) {
clearTimeout(timeout);
}
// // Debounce the history state update
// if (timeout) {
// clearTimeout(timeout);
// }
timeout = setTimeout(() => {
history.replaceState(
{
scrollOffset,
} satisfies ChatHistoryState,
'',
window.location.href
);
}, 300);
};
// timeout = setTimeout(() => {
// history.replaceState(
// {
// scrollOffset,
// } satisfies ChatHistoryState,
// '',
// window.location.href
// );
// }, 300);
// };
scrollElement.addEventListener('scroll', handleScroll, { passive: true });
// scrollElement.addEventListener('scroll', handleScroll, { passive: true });
return () => {
scrollElement.removeEventListener('scroll', handleScroll);
if (timeout) {
clearTimeout(timeout);
}
};
}, []);
// return () => {
// scrollElement.removeEventListener('scroll', handleScroll);
// if (timeout) {
// clearTimeout(timeout);
// }
// };
// }, []);
useEffect(() => {
const scrollElement = scrollAreaRef.current?.querySelector('[data-radix-scroll-area-viewport]') as HTMLElement;
@ -313,9 +307,11 @@ export const ChatThread = ({
{messagePairs.map(([userMessage, assistantMessage], index) => {
const isLastPair = index === messagePairs.length - 1;
const isStreaming = isLastPair && (status === "streaming" || status === "submitted");
// Use a stable key based on user message ID
const key = userMessage.id;
return (
<Fragment key={index}>
<Fragment key={key}>
<ChatThreadListItem
index={index}
chatId={chatId}

View file

@ -4,7 +4,7 @@ import { AnimatedResizableHandle } from '@/components/ui/animatedResizableHandle
import { ResizablePanel, ResizablePanelGroup } from '@/components/ui/resizable';
import { Skeleton } from '@/components/ui/skeleton';
import { CheckCircle, Loader2 } from 'lucide-react';
import { CSSProperties, forwardRef, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { CSSProperties, forwardRef, memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import scrollIntoView from 'scroll-into-view-if-needed';
import { Reference, referenceSchema, SBChatMessage, Source } from "../../types";
import { useExtractReferences } from '../../useExtractReferences';
@ -24,7 +24,7 @@ interface ChatThreadListItemProps {
index: number;
}
export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemProps>(({
export const ChatThreadListItemComponent = forwardRef<HTMLDivElement, ChatThreadListItemProps>(({
userMessage,
assistantMessage: _assistantMessage,
isStreaming,
@ -32,6 +32,7 @@ export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemP
chatId,
index,
}, ref) => {
console.log(`re-rendering chat thread list item`, index);
const leftPanelRef = useRef<HTMLDivElement>(null);
const [leftPanelHeight, setLeftPanelHeight] = useState<number | null>(null);
const answerRef = useRef<HTMLDivElement>(null);
@ -393,7 +394,12 @@ export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemP
)
});
ChatThreadListItem.displayName = 'ChatThreadListItem';
ChatThreadListItemComponent.displayName = 'ChatThreadListItem';
// Only allow re-rendering when the message _is_ streaming.
// This is a performance optimizations to prevent unnecessary
// re-renders for chatThreadListItems that are not streaming.
export const ChatThreadListItem = memo(ChatThreadListItemComponent, (_, nextProps) => !nextProps.isStreaming);
// Finds the nearest reference element to the viewport center.
const getNearestReferenceElement = (referenceElements: Element[]) => {

View file

@ -7,6 +7,7 @@ import { Skeleton } from '@/components/ui/skeleton';
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip';
import { cn } from '@/lib/utils';
import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, List, ScanSearchIcon, Zap } from 'lucide-react';
import { memo } from 'react';
import { MarkdownRenderer } from './markdownRenderer';
import { FindSymbolDefinitionsToolComponent } from './tools/findSymbolDefinitionsToolComponent';
import { FindSymbolReferencesToolComponent } from './tools/findSymbolReferencesToolComponent';
@ -27,7 +28,7 @@ interface DetailsCardProps {
metadata?: SBChatMessageMetadata;
}
export const DetailsCard = ({
const DetailsCardComponent = ({
isExpanded,
onExpandedChanged,
isThinking,
@ -210,3 +211,5 @@ export const DetailsCard = ({
</Card>
)
}
export const DetailsCard = memo(DetailsCardComponent);

View file

@ -1,7 +1,6 @@
'use client';
import { CodeSnippet } from '@/app/components/codeSnippet';
import { useDomain } from '@/hooks/useDomain';
import { SearchQueryParams } from '@/lib/types';
import { cn, createPathWithQueryParams } from '@/lib/utils';
import type { Element, Root } from "hast";
@ -10,7 +9,7 @@ import { CopyIcon, SearchIcon } from 'lucide-react';
import type { Heading, Nodes } from "mdast";
import { findAndReplace } from 'mdast-util-find-and-replace';
import { useRouter } from 'next/navigation';
import React, { useCallback, useMemo, forwardRef } from 'react';
import React, { useCallback, useMemo, forwardRef, memo } from 'react';
import Markdown from 'react-markdown';
import rehypeRaw from 'rehype-raw';
import rehypeSanitize, { defaultSchema } from 'rehype-sanitize';
@ -20,6 +19,7 @@ import { visit } from 'unist-util-visit';
import { CodeBlock } from './codeBlock';
import { FILE_REFERENCE_REGEX } from '@/features/chat/constants';
import { createFileReference } from '@/features/chat/utils';
import { SINGLE_TENANT_ORG_DOMAIN } from '@/lib/constants';
export const REFERENCE_PAYLOAD_ATTRIBUTE = 'data-reference-payload';
@ -102,8 +102,7 @@ interface MarkdownRendererProps {
className?: string;
}
export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps>(({ content, className }, ref) => {
const domain = useDomain();
const MarkdownRendererComponent = forwardRef<HTMLDivElement, MarkdownRendererProps>(({ content, className }, ref) => {
const router = useRouter();
const remarkPlugins = useMemo((): PluggableList => {
@ -176,7 +175,7 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
onClick={(e) => {
e.preventDefault();
e.stopPropagation();
const url = createPathWithQueryParams(`/${domain}/search`, [SearchQueryParams.query, `"${text}"`])
const url = createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/search`, [SearchQueryParams.query, `"${text}"`])
router.push(url);
}}
title="Search for snippet"
@ -199,7 +198,7 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
</span>
)
}, [domain, router]);
}, [router]);
return (
<div
@ -220,4 +219,6 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
);
});
MarkdownRenderer.displayName = 'MarkdownRenderer';
MarkdownRendererComponent.displayName = 'MarkdownRenderer';
export const MarkdownRenderer = memo(MarkdownRendererComponent);

View file

@ -14,13 +14,12 @@ import { Range } from "@codemirror/state";
import { Decoration, DecorationSet, EditorView } from '@codemirror/view';
import CodeMirror, { ReactCodeMirrorRef, StateField } from '@uiw/react-codemirror';
import { ChevronDown, ChevronRight } from "lucide-react";
import { forwardRef, Ref, useCallback, useImperativeHandle, useMemo, useState } from "react";
import { forwardRef, memo, Ref, useCallback, useImperativeHandle, useMemo, useState } from "react";
import { FileReference } from "../../types";
import { createCodeFoldingExtension } from "./codeFoldingExtension";
import useCaptureEvent from "@/hooks/useCaptureEvent";
import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
import { CodeHostType } from "@sourcebot/db";
import { createAuditAction } from "@/ee/features/audit/actions";
const lineDecoration = Decoration.line({
attributes: { class: "cm-range-border-radius chat-lineHighlight" },
@ -75,7 +74,6 @@ const ReferencedFileSourceListItem = ({
const theme = useCodeMirrorTheme();
const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null);
const captureEvent = useCaptureEvent();
const domain = useDomain();
useImperativeHandle(
forwardedRef,
@ -124,6 +122,8 @@ const ReferencedFileSourceListItem = ({
return createCodeFoldingExtension(references, 3);
}, [references]);
// console.log(`re-renderign for file ${fileName}`);
const extensions = useMemo(() => {
return [
languageExtension,
@ -231,7 +231,7 @@ const ReferencedFileSourceListItem = ({
metadata: {
message: symbolName,
},
}, domain);
});
if (symbolDefinitions.length === 1) {
const symbolDefinition = symbolDefinitions[0];
@ -263,7 +263,7 @@ const ReferencedFileSourceListItem = ({
});
}
}, [captureEvent, domain, navigateToPath, revision, repoName, fileName, language]);
}, [captureEvent, navigateToPath, revision, repoName, fileName, language]);
const onFindReferences = useCallback((symbolName: string) => {
captureEvent('wa_find_references_pressed', {
@ -274,7 +274,7 @@ const ReferencedFileSourceListItem = ({
metadata: {
message: symbolName,
},
}, domain);
});
navigateToPath({
repoName,
@ -293,7 +293,7 @@ const ReferencedFileSourceListItem = ({
}
})
}, [captureEvent, domain, fileName, language, navigateToPath, repoName, revision]);
}, [captureEvent, fileName, language, navigateToPath, repoName, revision]);
const ExpandCollapseIcon = useMemo(() => {
return isExpanded ? ChevronDown : ChevronRight;
@ -355,6 +355,6 @@ const ReferencedFileSourceListItem = ({
)
}
export default forwardRef(ReferencedFileSourceListItem) as (
export default memo(forwardRef(ReferencedFileSourceListItem)) as (
props: ReferencedFileSourceListItemProps & { ref?: Ref<ReactCodeMirrorRef> },
) => ReturnType<typeof ReferencedFileSourceListItem>;

View file

@ -4,11 +4,10 @@ import { getFileSource } from "@/app/api/(client)/client";
import { VscodeFileIcon } from "@/app/components/vscodeFileIcon";
import { ScrollArea } from "@/components/ui/scroll-area";
import { Skeleton } from "@/components/ui/skeleton";
import { useDomain } from "@/hooks/useDomain";
import { isServiceError, unwrapServiceError } from "@/lib/utils";
import { useQueries } from "@tanstack/react-query";
import { ReactCodeMirrorRef } from '@uiw/react-codemirror';
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { memo, useCallback, useEffect, useMemo, useRef, useState } from "react";
import scrollIntoView from 'scroll-into-view-if-needed';
import { FileReference, FileSource, Reference, Source } from "../../types";
import ReferencedFileSourceListItem from "./referencedFileSourceListItem";
@ -31,7 +30,7 @@ const resolveFileReference = (reference: FileReference, sources: FileSource[]):
);
}
export const ReferencedSourcesListView = ({
const ReferencedSourcesListViewComponent = ({
references,
sources,
index,
@ -43,7 +42,6 @@ export const ReferencedSourcesListView = ({
}: ReferencedSourcesListViewProps) => {
const scrollAreaRef = useRef<HTMLDivElement>(null);
const editorRefsMap = useRef<Map<string, ReactCodeMirrorRef>>(new Map());
const domain = useDomain();
const [collapsedFileIds, setCollapsedFileIds] = useState<string[]>([]);
const getFileId = useCallback((fileSource: FileSource) => {
@ -98,7 +96,7 @@ export const ReferencedSourcesListView = ({
const fileSourceQueries = useQueries({
queries: referencedFileSources.map((file) => ({
queryKey: ['fileSource', file.path, file.repo, file.revision, domain],
queryKey: ['fileSource', file.path, file.repo, file.revision],
queryFn: () => unwrapServiceError(getFileSource({
fileName: file.path,
repository: file.repo,
@ -183,6 +181,25 @@ export const ReferencedSourcesListView = ({
}
}, [getFileId, referencedFileSources, selectedReference]);
const onExpandedChanged = useCallback((fileId: string, isExpanded: boolean) => {
if (isExpanded) {
setCollapsedFileIds(collapsedFileIds => collapsedFileIds.filter((id) => id !== fileId));
} else {
setCollapsedFileIds(collapsedFileIds => [...collapsedFileIds, fileId]);
}
if (!isExpanded) {
const fileSourceStart = document.getElementById(`${fileId}-start`);
if (fileSourceStart) {
scrollIntoView(fileSourceStart, {
scrollMode: 'if-needed',
block: 'start',
behavior: 'instant',
});
}
}
}, []);
if (referencedFileSources.length === 0) {
return (
<div className="p-4 text-center text-muted-foreground text-sm">
@ -253,30 +270,7 @@ export const ReferencedSourcesListView = ({
selectedReference={selectedReference}
hoveredReference={hoveredReference}
isExpanded={!collapsedFileIds.includes(fileId)}
onExpandedChanged={(isExpanded) => {
if (isExpanded) {
setCollapsedFileIds(collapsedFileIds.filter((id) => id !== fileId));
} else {
setCollapsedFileIds([...collapsedFileIds, fileId]);
}
// When collapsing a file when you are deep in a scroll, it's a better
// experience to have the scroll automatically restored to the top of the file
// s.t., header is still sticky to the top of the scroll area.
if (!isExpanded) {
const fileSourceStart = document.getElementById(`${fileId}-start`);
if (!fileSourceStart) {
return;
}
scrollIntoView(fileSourceStart, {
scrollMode: 'if-needed',
block: 'start',
behavior: 'instant',
});
}
}
}
onExpandedChanged={(isExpanded) => onExpandedChanged(fileId, isExpanded)}
/>
);
})}
@ -284,3 +278,6 @@ export const ReferencedSourcesListView = ({
</ScrollArea>
);
}
// Memoize to prevent unnecessary re-renders
export const ReferencedSourcesListView = memo(ReferencedSourcesListViewComponent);

View file

@ -1,7 +1,6 @@
'use client';
import { SearchCodeToolUIPart } from "@/features/chat/tools";
import { useDomain } from "@/hooks/useDomain";
import { createPathWithQueryParams, isServiceError } from "@/lib/utils";
import { useMemo, useState } from "react";
import { FileListItem, ToolHeader, TreeList } from "./shared";
@ -12,10 +11,10 @@ import Link from "next/link";
import { SearchQueryParams } from "@/lib/types";
import { PlayIcon } from "@radix-ui/react-icons";
import { buildSearchQuery } from "@/features/chat/utils";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }) => {
const [isExpanded, setIsExpanded] = useState(false);
const domain = useDomain();
const displayQuery = useMemo(() => {
if (part.state !== 'input-available' && part.state !== 'output-available') {
@ -78,7 +77,7 @@ export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }
</TreeList>
)}
<Link
href={createPathWithQueryParams(`/${domain}/search`,
href={createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/search`,
[SearchQueryParams.query, part.output.query],
)}
className='flex flex-row items-center gap-2 text-sm text-muted-foreground mt-2 ml-auto w-fit hover:text-foreground'

View file

@ -2,12 +2,12 @@
import { VscodeFileIcon } from '@/app/components/vscodeFileIcon';
import { ScrollArea } from '@/components/ui/scroll-area';
import { useDomain } from '@/hooks/useDomain';
import { cn } from '@/lib/utils';
import { ChevronDown, ChevronRight, Loader2 } from 'lucide-react';
import Link from 'next/link';
import React from 'react';
import { getBrowsePath } from "@/app/[domain]/browse/hooks/utils";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const FileListItem = ({
@ -17,8 +17,6 @@ export const FileListItem = ({
path: string,
repoName: string,
}) => {
const domain = useDomain();
return (
<div key={path} className="flex flex-row items-center overflow-hidden hover:bg-accent hover:text-accent-foreground rounded-sm cursor-pointer p-0.5">
<VscodeFileIcon fileName={path} className="mr-1 flex-shrink-0" />
@ -28,7 +26,7 @@ export const FileListItem = ({
repoName,
revisionName: 'HEAD',
path,
domain,
domain: SINGLE_TENANT_ORG_DOMAIN,
pathType: 'blob',
})}
>

View file

@ -70,7 +70,7 @@ export const sbChatMessageMetadataSchema = z.object({
feedback: z.array(z.object({
type: z.enum(['like', 'dislike']),
timestamp: z.string(), // ISO date string
userId: z.string(),
userId: z.string().optional(),
})).optional(),
selectedSearchScopes: z.array(searchScopeSchema).optional(),
traceId: z.string().optional(),

View file

@ -4,7 +4,6 @@ import { useCallback, useState } from "react";
import { Descendant } from "slate";
import { createUIMessage, getAllMentionElements } from "./utils";
import { slateContentToString } from "./utils";
import { useDomain } from "@/hooks/useDomain";
import { useToast } from "@/components/hooks/use-toast";
import { useRouter } from "next/navigation";
import { createChat } from "./actions";
@ -13,9 +12,9 @@ import { createPathWithQueryParams } from "@/lib/utils";
import { SearchScope, SET_CHAT_STATE_SESSION_STORAGE_KEY, SetChatStatePayload } from "./types";
import { useSessionStorage } from "usehooks-ts";
import useCaptureEvent from "@/hooks/useCaptureEvent";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const useCreateNewChatThread = () => {
const domain = useDomain();
const [isLoading, setIsLoading] = useState(false);
const { toast } = useToast();
const router = useRouter();
@ -31,7 +30,7 @@ export const useCreateNewChatThread = () => {
const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes);
setIsLoading(true);
const response = await createChat(domain);
const response = await createChat();
if (isServiceError(response)) {
toast({
description: `❌ Failed to create chat. Reason: ${response.message}`
@ -47,11 +46,11 @@ export const useCreateNewChatThread = () => {
selectedSearchScopes,
});
const url = createPathWithQueryParams(`/${domain}/chat/${response.id}`);
const url = createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/chat/${response.id}`);
router.push(url);
router.refresh();
}, [domain, router, toast, setChatState, captureEvent]);
}, [router, toast, setChatState, captureEvent]);
return {
createNewChatThread,