Merge branch 'main' into fix-ai-no-response-error

This commit is contained in:
Aditya Raj Prasad 2025-11-28 15:26:58 +05:30 committed by GitHub
commit fa833b3574
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 537 additions and 536 deletions

View file

@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- Fixed issue where single quotes could not be used in search queries. [#629](https://github.com/sourcebot-dev/sourcebot/pull/629) - Fixed issue where single quotes could not be used in search queries. [#629](https://github.com/sourcebot-dev/sourcebot/pull/629)
- Fixed issue where files with special characters would fail to load. [#636](https://github.com/sourcebot-dev/sourcebot/issues/636)
- Fixed Ask performance issues. [#632](https://github.com/sourcebot-dev/sourcebot/pull/632)
## [4.10.0] - 2025-11-24 ## [4.10.0] - 2025-11-24

View file

@ -137,6 +137,7 @@
"embla-carousel-auto-scroll": "^8.3.0", "embla-carousel-auto-scroll": "^8.3.0",
"embla-carousel-react": "^8.3.0", "embla-carousel-react": "^8.3.0",
"escape-string-regexp": "^5.0.0", "escape-string-regexp": "^5.0.0",
"fast-deep-equal": "^3.1.3",
"fuse.js": "^7.0.0", "fuse.js": "^7.0.0",
"google-auth-library": "^10.1.0", "google-auth-library": "^10.1.0",
"graphql": "^16.9.0", "graphql": "^16.9.0",

View file

@ -19,7 +19,6 @@ import { useBrowseState } from "../../hooks/useBrowseState";
import { rangeHighlightingExtension } from "./rangeHighlightingExtension"; import { rangeHighlightingExtension } from "./rangeHighlightingExtension";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
import { createAuditAction } from "@/ee/features/audit/actions"; import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
interface PureCodePreviewPanelProps { interface PureCodePreviewPanelProps {
path: string; path: string;
@ -43,7 +42,6 @@ export const PureCodePreviewPanel = ({
const hasCodeNavEntitlement = useHasEntitlement("code-nav"); const hasCodeNavEntitlement = useHasEntitlement("code-nav");
const { updateBrowseState } = useBrowseState(); const { updateBrowseState } = useBrowseState();
const { navigateToPath } = useBrowseNavigation(); const { navigateToPath } = useBrowseNavigation();
const domain = useDomain();
const captureEvent = useCaptureEvent(); const captureEvent = useCaptureEvent();
const highlightRangeQuery = useNonEmptyQueryParam(HIGHLIGHT_RANGE_QUERY_PARAM); const highlightRangeQuery = useNonEmptyQueryParam(HIGHLIGHT_RANGE_QUERY_PARAM);
@ -145,7 +143,7 @@ export const PureCodePreviewPanel = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain) })
updateBrowseState({ updateBrowseState({
selectedSymbolInfo: { selectedSymbolInfo: {
@ -157,7 +155,7 @@ export const PureCodePreviewPanel = ({
isBottomPanelCollapsed: false, isBottomPanelCollapsed: false,
activeExploreMenuTab: "references", activeExploreMenuTab: "references",
}) })
}, [captureEvent, updateBrowseState, repoName, revisionName, language, domain]); }, [captureEvent, updateBrowseState, repoName, revisionName, language]);
// If we resolve multiple matches, instead of navigating to the first match, we should // If we resolve multiple matches, instead of navigating to the first match, we should
@ -171,7 +169,7 @@ export const PureCodePreviewPanel = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain) })
if (symbolDefinitions.length === 0) { if (symbolDefinitions.length === 0) {
return; return;
@ -200,7 +198,7 @@ export const PureCodePreviewPanel = ({
isBottomPanelCollapsed: false, isBottomPanelCollapsed: false,
}) })
} }
}, [captureEvent, navigateToPath, revisionName, updateBrowseState, repoName, language, domain]); }, [captureEvent, navigateToPath, revisionName, updateBrowseState, repoName, language]);
const theme = useCodeMirrorTheme(); const theme = useCodeMirrorTheme();

View file

@ -24,9 +24,9 @@ export default async function Page(props: PageProps) {
const languageModels = await getConfiguredLanguageModelsInfo(); const languageModels = await getConfiguredLanguageModelsInfo();
const repos = await getRepos(); const repos = await getRepos();
const searchContexts = await getSearchContexts(params.domain); const searchContexts = await getSearchContexts(params.domain);
const chatInfo = await getChatInfo({ chatId: params.id }, params.domain); const chatInfo = await getChatInfo({ chatId: params.id });
const session = await auth(); const session = await auth();
const chatHistory = session ? await getUserChatHistory(params.domain) : []; const chatHistory = session ? await getUserChatHistory() : [];
if (isServiceError(chatHistory)) { if (isServiceError(chatHistory)) {
throw new ServiceErrorException(chatHistory); throw new ServiceErrorException(chatHistory);

View file

@ -4,7 +4,6 @@ import { useToast } from "@/components/hooks/use-toast";
import { Badge } from "@/components/ui/badge"; import { Badge } from "@/components/ui/badge";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { updateChatName } from "@/features/chat/actions"; import { updateChatName } from "@/features/chat/actions";
import { useDomain } from "@/hooks/useDomain";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { GlobeIcon } from "@radix-ui/react-icons"; import { GlobeIcon } from "@radix-ui/react-icons";
import { ChatVisibility } from "@sourcebot/db"; import { ChatVisibility } from "@sourcebot/db";
@ -23,7 +22,6 @@ interface ChatNameProps {
export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) => { export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) => {
const [isRenameDialogOpen, setIsRenameDialogOpen] = useState(false); const [isRenameDialogOpen, setIsRenameDialogOpen] = useState(false);
const { toast } = useToast(); const { toast } = useToast();
const domain = useDomain();
const router = useRouter(); const router = useRouter();
const onRenameChat = useCallback(async (name: string) => { const onRenameChat = useCallback(async (name: string) => {
@ -31,7 +29,7 @@ export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) =>
const response = await updateChatName({ const response = await updateChatName({
chatId: id, chatId: id,
name: name, name: name,
}, domain); });
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -43,7 +41,7 @@ export const ChatName = ({ name, visibility, id, isReadonly }: ChatNameProps) =>
}); });
router.refresh(); router.refresh();
} }
}, [id, domain, toast, router]); }, [id, toast, router]);
return ( return (
<> <>

View file

@ -9,7 +9,6 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { deleteChat, updateChatName } from "@/features/chat/actions"; import { deleteChat, updateChatName } from "@/features/chat/actions";
import { useDomain } from "@/hooks/useDomain";
import { cn, isServiceError } from "@/lib/utils"; import { cn, isServiceError } from "@/lib/utils";
import { CirclePlusIcon, EllipsisIcon, PencilIcon, TrashIcon } from "lucide-react"; import { CirclePlusIcon, EllipsisIcon, PencilIcon, TrashIcon } from "lucide-react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
@ -23,6 +22,7 @@ import { useChatId } from "../useChatId";
import { RenameChatDialog } from "./renameChatDialog"; import { RenameChatDialog } from "./renameChatDialog";
import { DeleteChatDialog } from "./deleteChatDialog"; import { DeleteChatDialog } from "./deleteChatDialog";
import Link from "next/link"; import Link from "next/link";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
interface ChatSidePanelProps { interface ChatSidePanelProps {
order: number; order: number;
@ -41,7 +41,6 @@ export const ChatSidePanel = ({
isAuthenticated, isAuthenticated,
isCollapsedInitially, isCollapsedInitially,
}: ChatSidePanelProps) => { }: ChatSidePanelProps) => {
const domain = useDomain();
const [isCollapsed, setIsCollapsed] = useState(isCollapsedInitially); const [isCollapsed, setIsCollapsed] = useState(isCollapsedInitially);
const sidePanelRef = useRef<ImperativePanelHandle>(null); const sidePanelRef = useRef<ImperativePanelHandle>(null);
const router = useRouter(); const router = useRouter();
@ -72,7 +71,7 @@ export const ChatSidePanel = ({
const response = await updateChatName({ const response = await updateChatName({
chatId, chatId,
name: name, name: name,
}, domain); });
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -84,14 +83,14 @@ export const ChatSidePanel = ({
}); });
router.refresh(); router.refresh();
} }
}, [router, toast, domain]); }, [router, toast]);
const onDeleteChat = useCallback(async (chatIdToDelete: string) => { const onDeleteChat = useCallback(async (chatIdToDelete: string) => {
if (!chatIdToDelete) { if (!chatIdToDelete) {
return; return;
} }
const response = await deleteChat({ chatId: chatIdToDelete }, domain); const response = await deleteChat({ chatId: chatIdToDelete });
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -104,12 +103,12 @@ export const ChatSidePanel = ({
// If we just deleted the current chat, navigate to new chat // If we just deleted the current chat, navigate to new chat
if (chatIdToDelete === chatId) { if (chatIdToDelete === chatId) {
router.push(`/${domain}/chat`); router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`);
} }
router.refresh(); router.refresh();
} }
}, [chatId, router, toast, domain]); }, [chatId, router, toast]);
return ( return (
<> <>
@ -131,7 +130,7 @@ export const ChatSidePanel = ({
size="sm" size="sm"
className="w-full" className="w-full"
onClick={() => { onClick={() => {
router.push(`/${domain}/chat`); router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`);
}} }}
> >
<CirclePlusIcon className="w-4 h-4 mr-1" /> <CirclePlusIcon className="w-4 h-4 mr-1" />
@ -145,7 +144,7 @@ export const ChatSidePanel = ({
<div className="flex flex-col"> <div className="flex flex-col">
<p className="text-sm text-muted-foreground mb-4"> <p className="text-sm text-muted-foreground mb-4">
<Link <Link
href={`/login?callbackUrl=${encodeURIComponent(`/${domain}/chat`)}`} href={`/login?callbackUrl=${encodeURIComponent(`/${SINGLE_TENANT_ORG_DOMAIN}/chat`)}`}
className="text-sm text-link hover:underline cursor-pointer" className="text-sm text-link hover:underline cursor-pointer"
> >
Sign in Sign in
@ -163,7 +162,7 @@ export const ChatSidePanel = ({
chat.id === chatId && "bg-muted" chat.id === chatId && "bg-muted"
)} )}
onClick={() => { onClick={() => {
router.push(`/${domain}/chat/${chat.id}`); router.push(`/${SINGLE_TENANT_ORG_DOMAIN}/chat/${chat.id}`);
}} }}
> >
<span className="text-sm truncate">{chat.name ?? 'Untitled chat'}</span> <span className="text-sm truncate">{chat.name ?? 'Untitled chat'}</span>

View file

@ -221,7 +221,7 @@ export const SearchBar = ({
metadata: { metadata: {
message: query, message: query,
}, },
}, domain) })
const url = createPathWithQueryParams(`/${domain}/search`, const url = createPathWithQueryParams(`/${domain}/search`,
[SearchQueryParams.query, query], [SearchQueryParams.query, query],

View file

@ -22,8 +22,6 @@ import { symbolHoverTargetsExtension } from "@/ee/features/codeNav/components/sy
import { useHasEntitlement } from "@/features/entitlements/useHasEntitlement"; import { useHasEntitlement } from "@/features/entitlements/useHasEntitlement";
import { SymbolDefinition } from "@/ee/features/codeNav/components/symbolHoverPopup/useHoveredOverSymbolInfo"; import { SymbolDefinition } from "@/ee/features/codeNav/components/symbolHoverPopup/useHoveredOverSymbolInfo";
import { createAuditAction } from "@/ee/features/audit/actions"; import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
export interface CodePreviewFile { export interface CodePreviewFile {
@ -53,7 +51,6 @@ export const CodePreview = ({
const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null); const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null);
const { navigateToPath } = useBrowseNavigation(); const { navigateToPath } = useBrowseNavigation();
const hasCodeNavEntitlement = useHasEntitlement("code-nav"); const hasCodeNavEntitlement = useHasEntitlement("code-nav");
const domain = useDomain();
const [gutterWidth, setGutterWidth] = useState(0); const [gutterWidth, setGutterWidth] = useState(0);
const theme = useCodeMirrorTheme(); const theme = useCodeMirrorTheme();
@ -127,7 +124,7 @@ export const CodePreview = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain) })
if (symbolDefinitions.length === 0) { if (symbolDefinitions.length === 0) {
return; return;
@ -162,7 +159,7 @@ export const CodePreview = ({
} }
}); });
} }
}, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName, domain]); }, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName]);
const onFindReferences = useCallback((symbolName: string) => { const onFindReferences = useCallback((symbolName: string) => {
captureEvent('wa_find_references_pressed', { captureEvent('wa_find_references_pressed', {
@ -173,7 +170,7 @@ export const CodePreview = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain) })
navigateToPath({ navigateToPath({
repoName, repoName,
@ -191,7 +188,7 @@ export const CodePreview = ({
isBottomPanelCollapsed: false, isBottomPanelCollapsed: false,
} }
}) })
}, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName, domain]); }, [captureEvent, file.filepath, file.language, file.revision, navigateToPath, repoName]);
return ( return (
<div className="flex flex-col h-full"> <div className="flex flex-col h-full">

View file

@ -1,4 +1,4 @@
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/actions";
import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions"; import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions";
import { createAgentStream } from "@/features/chat/agent"; import { createAgentStream } from "@/features/chat/agent";
import { additionalChatRequestParamsSchema, LanguageModelInfo, SBChatMessage, SearchScope } from "@/features/chat/types"; import { additionalChatRequestParamsSchema, LanguageModelInfo, SBChatMessage, SearchScope } from "@/features/chat/types";
@ -6,10 +6,10 @@ import { getAnswerPartFromAssistantMessage, getLanguageModelKey } from "@/featur
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError"; import { notFound, schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { prisma } from "@/prisma"; import { withOptionalAuthV2 } from "@/withAuthV2";
import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider"; import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider";
import * as Sentry from "@sentry/nextjs"; import * as Sentry from "@sentry/nextjs";
import { OrgRole } from "@sourcebot/db"; import { PrismaClient } from "@sourcebot/db";
import { createLogger } from "@sourcebot/shared"; import { createLogger } from "@sourcebot/shared";
import { import {
createUIMessageStream, createUIMessageStream,
@ -34,15 +34,6 @@ const chatRequestSchema = z.object({
}) })
export async function POST(req: Request) { export async function POST(req: Request) {
const domain = req.headers.get("X-Org-Domain");
if (!domain) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER,
message: "Missing X-Org-Domain header",
});
}
const requestBody = await req.json(); const requestBody = await req.json();
const parsed = await chatRequestSchema.safeParseAsync(requestBody); const parsed = await chatRequestSchema.safeParseAsync(requestBody);
if (!parsed.success) { if (!parsed.success) {
@ -56,56 +47,54 @@ export async function POST(req: Request) {
const languageModel = _languageModel as LanguageModelInfo; const languageModel = _languageModel as LanguageModelInfo;
const response = await sew(() => const response = await sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { // Validate that the chat exists and is not readonly.
// Validate that the chat exists and is not readonly. const chat = await prisma.chat.findUnique({
const chat = await prisma.chat.findUnique({ where: {
where: {
orgId: org.id,
id,
},
});
if (!chat) {
return notFound();
}
if (chat.isReadonly) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: "Chat is readonly and cannot be edited.",
});
}
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => getLanguageModelKey(model) === getLanguageModelKey(languageModel));
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModel.model} is not configured.`,
});
}
const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig);
return createMessageStreamResponse({
messages,
id,
selectedSearchScopes,
model,
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
modelProviderOptions: providerOptions,
domain,
orgId: org.id, orgId: org.id,
id,
},
});
if (!chat) {
return notFound();
}
if (chat.isReadonly) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: "Chat is readonly and cannot be edited.",
}); });
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true }
)
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => getLanguageModelKey(model) === getLanguageModelKey(languageModel));
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModel.model} is not configured.`,
});
}
const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig);
return createMessageStreamResponse({
messages,
id,
selectedSearchScopes,
model,
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
modelProviderOptions: providerOptions,
orgId: org.id,
prisma,
});
})
) )
if (isServiceError(response)) { if (isServiceError(response)) {
@ -132,8 +121,8 @@ interface CreateMessageStreamResponseProps {
model: AISDKLanguageModelV2; model: AISDKLanguageModelV2;
modelName: string; modelName: string;
modelProviderOptions?: Record<string, Record<string, JSONValue>>; modelProviderOptions?: Record<string, Record<string, JSONValue>>;
domain: string;
orgId: number; orgId: number;
prisma: PrismaClient;
} }
const createMessageStreamResponse = async ({ const createMessageStreamResponse = async ({
@ -143,8 +132,8 @@ const createMessageStreamResponse = async ({
model, model,
modelName, modelName,
modelProviderOptions, modelProviderOptions,
domain,
orgId, orgId,
prisma,
}: CreateMessageStreamResponseProps) => { }: CreateMessageStreamResponseProps) => {
const latestMessage = messages[messages.length - 1]; const latestMessage = messages[messages.length - 1];
const sources = latestMessage.parts const sources = latestMessage.parts
@ -254,7 +243,7 @@ const createMessageStreamResponse = async ({
await updateChatMessages({ await updateChatMessages({
chatId: id, chatId: id,
messages messages
}, domain); });
}, },
}); });

View file

@ -1,46 +1,25 @@
'use server'; 'use server';
import { NextRequest } from "next/server";
import { fetchAuditRecords } from "@/ee/features/audit/actions"; import { fetchAuditRecords } from "@/ee/features/audit/actions";
import { isServiceError } from "@/lib/utils";
import { serviceErrorResponse } from "@/lib/serviceError";
import { StatusCodes } from "http-status-codes";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { env } from "@sourcebot/shared"; import { serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils";
import { getEntitlements } from "@sourcebot/shared"; import { getEntitlements } from "@sourcebot/shared";
import { StatusCodes } from "http-status-codes";
export const GET = async (request: NextRequest) => { export const GET = async () => {
const domain = request.headers.get("X-Org-Domain"); const entitlements = getEntitlements();
const apiKey = request.headers.get("X-Sourcebot-Api-Key") ?? undefined; if (!entitlements.includes('audit')) {
return serviceErrorResponse({
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.NOT_FOUND,
message: "Audit logging is not enabled for your license",
});
}
if (!domain) { const result = await fetchAuditRecords();
return serviceErrorResponse({ if (isServiceError(result)) {
statusCode: StatusCodes.BAD_REQUEST, return serviceErrorResponse(result);
errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER, }
message: "Missing X-Org-Domain header", return Response.json(result);
});
}
if (env.SOURCEBOT_EE_AUDIT_LOGGING_ENABLED === 'false') {
return serviceErrorResponse({
statusCode: StatusCodes.NOT_FOUND,
errorCode: ErrorCode.NOT_FOUND,
message: "Audit logging is not enabled",
});
}
const entitlements = getEntitlements();
if (!entitlements.includes('audit')) {
return serviceErrorResponse({
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.NOT_FOUND,
message: "Audit logging is not enabled for your license",
});
}
const result = await fetchAuditRecords(domain, apiKey);
if (isServiceError(result)) {
return serviceErrorResponse(result);
}
return Response.json(result);
}; };

View file

@ -1,59 +1,64 @@
"use server"; "use server";
import { prisma } from "@/prisma"; import { sew } from "@/actions";
import { ErrorCode } from "@/lib/errorCodes";
import { StatusCodes } from "http-status-codes";
import { sew, withAuth, withOrgMembership } from "@/actions";
import { OrgRole } from "@sourcebot/db";
import { createLogger } from "@sourcebot/shared";
import { ServiceError } from "@/lib/serviceError";
import { getAuditService } from "@/ee/features/audit/factory"; import { getAuditService } from "@/ee/features/audit/factory";
import { ErrorCode } from "@/lib/errorCodes";
import { ServiceError } from "@/lib/serviceError";
import { prisma } from "@/prisma";
import { withAuthV2, withMinimumOrgRole } from "@/withAuthV2";
import { createLogger } from "@sourcebot/shared";
import { StatusCodes } from "http-status-codes";
import { AuditEvent } from "./types"; import { AuditEvent } from "./types";
import { OrgRole } from "@sourcebot/db";
const auditService = getAuditService(); const auditService = getAuditService();
const logger = createLogger('audit-utils'); const logger = createLogger('audit-utils');
export const createAuditAction = async (event: Omit<AuditEvent, 'sourcebotVersion' | 'orgId' | 'actor' | 'target'>, domain: string) => sew(async () => export const createAuditAction = async (event: Omit<AuditEvent, 'sourcebotVersion' | 'orgId' | 'actor' | 'target'>) => sew(async () =>
withAuth((userId) => withAuthV2(async ({ user, org }) => {
withOrgMembership(userId, domain, async ({ org }) => {
await auditService.createAudit({ ...event, orgId: org.id, actor: { id: userId, type: "user" }, target: { id: org.id.toString(), type: "org" } })
}, /* minRequiredRole = */ OrgRole.MEMBER), /* allowAnonymousAccess = */ true)
);
export const fetchAuditRecords = async (domain: string, apiKey: string | undefined = undefined) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
try {
const auditRecords = await prisma.audit.findMany({
where: {
orgId: org.id,
},
orderBy: {
timestamp: 'desc'
}
});
await auditService.createAudit({ await auditService.createAudit({
action: "audit.fetch", ...event,
actor: { orgId: org.id,
id: userId, actor: { id: user.id, type: "user" },
type: "user" target: { id: org.id.toString(), type: "org" },
},
target: {
id: org.id.toString(),
type: "org"
},
orgId: org.id
}) })
})
return auditRecords; );
} catch (error) {
logger.error('Error fetching audit logs', { error }); export const fetchAuditRecords = async () => sew(() =>
return { withAuthV2(async ({ user, org, role }) =>
statusCode: StatusCodes.INTERNAL_SERVER_ERROR, withMinimumOrgRole(role, OrgRole.OWNER, async () => {
errorCode: ErrorCode.UNEXPECTED_ERROR, try {
message: "Failed to fetch audit logs", const auditRecords = await prisma.audit.findMany({
} satisfies ServiceError; where: {
} orgId: org.id,
}, /* minRequiredRole = */ OrgRole.OWNER), /* allowAnonymousAccess = */ true, apiKey ? { apiKey, domain } : undefined) },
orderBy: {
timestamp: 'desc'
}
});
await auditService.createAudit({
action: "audit.fetch",
actor: {
id: user.id,
type: "user"
},
target: {
id: org.id.toString(),
type: "org"
},
orgId: org.id
})
return auditRecords;
} catch (error) {
logger.error('Error fetching audit logs', { error });
return {
statusCode: StatusCodes.INTERNAL_SERVER_ERROR,
errorCode: ErrorCode.UNEXPECTED_ERROR,
message: "Failed to fetch audit logs",
} satisfies ServiceError;
}
}))
); );

View file

@ -1,10 +1,9 @@
'use server'; 'use server';
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/actions";
import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants"; import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError"; import { chatIsReadonly, notFound, ServiceError, serviceErrorResponse } from "@/lib/serviceError";
import { prisma } from "@/prisma";
import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock';
import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic';
import { createAzure } from '@ai-sdk/azure'; import { createAzure } from '@ai-sdk/azure';
@ -20,7 +19,7 @@ import { createXai } from '@ai-sdk/xai';
import { fromNodeProviderChain } from '@aws-sdk/credential-providers'; import { fromNodeProviderChain } from '@aws-sdk/credential-providers';
import { createOpenRouter } from '@openrouter/ai-sdk-provider'; import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { getTokenFromConfig, createLogger, env } from "@sourcebot/shared"; import { getTokenFromConfig, createLogger, env } from "@sourcebot/shared";
import { ChatVisibility, OrgRole, Prisma } from "@sourcebot/db"; import { ChatVisibility, Prisma } from "@sourcebot/db";
import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type"; import { LanguageModel } from "@sourcebot/schemas/v3/languageModel.type";
import { Token } from "@sourcebot/schemas/v3/shared.type"; import { Token } from "@sourcebot/schemas/v3/shared.type";
import { generateText, JSONValue, extractReasoningMiddleware, wrapLanguageModel } from "ai"; import { generateText, JSONValue, extractReasoningMiddleware, wrapLanguageModel } from "ai";
@ -29,168 +28,161 @@ import fs from 'fs';
import { StatusCodes } from "http-status-codes"; import { StatusCodes } from "http-status-codes";
import path from 'path'; import path from 'path';
import { LanguageModelInfo, SBChatMessage } from "./types"; import { LanguageModelInfo, SBChatMessage } from "./types";
import { withAuthV2, withOptionalAuthV2 } from "@/withAuthV2";
const logger = createLogger('chat-actions'); const logger = createLogger('chat-actions');
export const createChat = async (domain: string) => sew(() => export const createChat = async () => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const isGuestUser = user?.id === SOURCEBOT_GUEST_USER_ID;
const isGuestUser = userId === SOURCEBOT_GUEST_USER_ID; const chat = await prisma.chat.create({
data: {
orgId: org.id,
messages: [] as unknown as Prisma.InputJsonValue,
createdById: user?.id ?? SOURCEBOT_GUEST_USER_ID,
visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE,
},
});
const chat = await prisma.chat.create({ return {
data: { id: chat.id,
orgId: org.id, }
messages: [] as unknown as Prisma.InputJsonValue, })
createdById: userId,
visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE,
},
});
return {
id: chat.id,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
); );
export const getChatInfo = async ({ chatId }: { chatId: string }, domain: string) => sew(() => export const getChatInfo = async ({ chatId }: { chatId: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const chat = await prisma.chat.findUnique({
const chat = await prisma.chat.findUnique({ where: {
where: { id: chatId,
id: chatId, orgId: org.id,
orgId: org.id, },
}, });
});
if (!chat) { if (!chat) {
return notFound(); return notFound();
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound(); return notFound();
} }
return { return {
messages: chat.messages as unknown as SBChatMessage[], messages: chat.messages as unknown as SBChatMessage[],
visibility: chat.visibility, visibility: chat.visibility,
name: chat.name, name: chat.name,
isReadonly: chat.isReadonly, isReadonly: chat.isReadonly,
}; };
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); );
export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }, domain: string) => sew(() => export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const chat = await prisma.chat.findUnique({
const chat = await prisma.chat.findUnique({ where: {
where: { id: chatId,
id: chatId, orgId: org.id,
orgId: org.id, },
}, });
});
if (!chat) { if (!chat) {
return notFound(); return notFound();
}
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound();
}
if (chat.isReadonly) {
return chatIsReadonly();
}
await prisma.chat.update({
where: {
id: chatId,
},
data: {
messages: messages as unknown as Prisma.InputJsonValue,
},
});
if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) {
const chatDir = path.join(env.DATA_CACHE_DIR, 'chats');
if (!fs.existsSync(chatDir)) {
fs.mkdirSync(chatDir, { recursive: true });
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { const chatFile = path.join(chatDir, `${chatId}.json`);
return notFound(); fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2));
} }
if (chat.isReadonly) { return {
return chatIsReadonly(); success: true,
} }
})
await prisma.chat.update({
where: {
id: chatId,
},
data: {
messages: messages as unknown as Prisma.InputJsonValue,
},
});
if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) {
const chatDir = path.join(env.DATA_CACHE_DIR, 'chats');
if (!fs.existsSync(chatDir)) {
fs.mkdirSync(chatDir, { recursive: true });
}
const chatFile = path.join(chatDir, `${chatId}.json`);
fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2));
}
return {
success: true,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
); );
export const getUserChatHistory = async (domain: string) => sew(() => export const getUserChatHistory = async () => sew(() =>
withAuth((userId) => withAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const chats = await prisma.chat.findMany({
const chats = await prisma.chat.findMany({ where: {
where: { orgId: org.id,
orgId: org.id, createdById: user.id,
createdById: userId, },
}, orderBy: {
orderBy: { updatedAt: 'desc',
updatedAt: 'desc', },
}, });
});
return chats.map((chat) => ({ return chats.map((chat) => ({
id: chat.id, id: chat.id,
createdAt: chat.createdAt, createdAt: chat.createdAt,
name: chat.name, name: chat.name,
visibility: chat.visibility, visibility: chat.visibility,
})) }))
}) })
)
); );
export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }, domain: string) => sew(() => export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const chat = await prisma.chat.findUnique({
const chat = await prisma.chat.findUnique({ where: {
where: { id: chatId,
id: chatId, orgId: org.id,
orgId: org.id, },
}, });
});
if (!chat) { if (!chat) {
return notFound(); return notFound();
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound(); return notFound();
} }
if (chat.isReadonly) { if (chat.isReadonly) {
return chatIsReadonly(); return chatIsReadonly();
} }
await prisma.chat.update({ await prisma.chat.update({
where: { where: {
id: chatId, id: chatId,
orgId: org.id, orgId: org.id,
}, },
data: { data: {
name, name,
}, },
}); });
return { return {
success: true, success: true,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); );
export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() => export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async () => {
withOrgMembership(userId, domain, async () => {
// From the language model ID, attempt to find the // From the language model ID, attempt to find the
// corresponding config in `config.json`. // corresponding config in `config.json`.
const languageModelConfig = const languageModelConfig =
@ -231,48 +223,6 @@ User question: ${message}`;
await updateChatName({ await updateChatName({
chatId, chatId,
name: result.text, name: result.text,
}, domain);
return {
success: true,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true
)
);
export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() =>
withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
if (!chat) {
return notFound();
}
// Public chats cannot be deleted.
if (chat.visibility === ChatVisibility.PUBLIC) {
return {
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.UNEXPECTED_ERROR,
message: 'You are not allowed to delete this chat.',
} satisfies ServiceError;
}
// Only the creator of a chat can delete it.
if (chat.createdById !== userId) {
return notFound();
}
await prisma.chat.delete({
where: {
id: chatId,
orgId: org.id,
},
}); });
return { return {
@ -280,6 +230,45 @@ export const deleteChat = async ({ chatId }: { chatId: string }, domain: string)
} }
}) })
) )
export const deleteChat = async ({ chatId }: { chatId: string }) => sew(() =>
withAuthV2(async ({ org, user, prisma }) => {
const chat = await prisma.chat.findUnique({
where: {
id: chatId,
orgId: org.id,
},
});
if (!chat) {
return notFound();
}
// Public chats cannot be deleted.
if (chat.visibility === ChatVisibility.PUBLIC) {
return {
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.UNEXPECTED_ERROR,
message: 'You are not allowed to delete this chat.',
} satisfies ServiceError;
}
// Only the creator of a chat can delete it.
if (chat.createdById !== user.id) {
return notFound();
}
await prisma.chat.delete({
where: {
id: chatId,
orgId: org.id,
},
});
return {
success: true,
}
})
); );
export const submitFeedback = async ({ export const submitFeedback = async ({
@ -290,56 +279,55 @@ export const submitFeedback = async ({
chatId: string, chatId: string,
messageId: string, messageId: string,
feedbackType: 'like' | 'dislike' feedbackType: 'like' | 'dislike'
}, domain: string) => sew(() => }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, user, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const chat = await prisma.chat.findUnique({
const chat = await prisma.chat.findUnique({ where: {
where: { id: chatId,
id: chatId, orgId: org.id,
orgId: org.id, },
}, });
});
if (!chat) { if (!chat) {
return notFound(); return notFound();
}
// When a chat is private, only the creator can submit feedback.
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== user?.id) {
return notFound();
}
const messages = chat.messages as unknown as SBChatMessage[];
const updatedMessages = messages.map(message => {
if (message.id === messageId && message.role === 'assistant') {
return {
...message,
metadata: {
...message.metadata,
feedback: [
...(message.metadata?.feedback ?? []),
{
type: feedbackType,
timestamp: new Date().toISOString(),
userId: user?.id,
}
]
}
} satisfies SBChatMessage;
} }
return message;
});
// When a chat is private, only the creator can submit feedback. await prisma.chat.update({
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { where: { id: chatId },
return notFound(); data: {
} messages: updatedMessages as unknown as Prisma.InputJsonValue,
},
});
const messages = chat.messages as unknown as SBChatMessage[]; return { success: true };
const updatedMessages = messages.map(message => { })
if (message.id === messageId && message.role === 'assistant') { )
return {
...message,
metadata: {
...message.metadata,
feedback: [
...(message.metadata?.feedback ?? []),
{
type: feedbackType,
timestamp: new Date().toISOString(),
userId: userId,
}
]
}
} satisfies SBChatMessage;
}
return message;
});
await prisma.chat.update({
where: { id: chatId },
data: {
messages: updatedMessages as unknown as Prisma.InputJsonValue,
},
});
return { success: true };
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
);
/** /**
* Returns the subset of information about the configured language models * Returns the subset of information about the configured language models

View file

@ -8,7 +8,7 @@ import { insertMention, slateContentToString } from "@/features/chat/utils";
import { cn, IS_MAC } from "@/lib/utils"; import { cn, IS_MAC } from "@/lib/utils";
import { computePosition, flip, offset, shift, VirtualElement } from "@floating-ui/react"; import { computePosition, flip, offset, shift, VirtualElement } from "@floating-ui/react";
import { ArrowUp, Loader2, StopCircleIcon, TriangleAlertIcon } from "lucide-react"; import { ArrowUp, Loader2, StopCircleIcon, TriangleAlertIcon } from "lucide-react";
import { Fragment, KeyboardEvent, useCallback, useEffect, useMemo, useRef, useState } from "react"; import { Fragment, KeyboardEvent, memo, useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useHotkeys } from "react-hotkeys-hook"; import { useHotkeys } from "react-hotkeys-hook";
import { Descendant, insertText } from "slate"; import { Descendant, insertText } from "slate";
import { Editable, ReactEditor, RenderElementProps, RenderLeafProps, useFocused, useSelected, useSlate } from "slate-react"; import { Editable, ReactEditor, RenderElementProps, RenderLeafProps, useFocused, useSelected, useSlate } from "slate-react";
@ -19,6 +19,7 @@ import { useSuggestionModeAndQuery } from "./useSuggestionModeAndQuery";
import { useSuggestionsData } from "./useSuggestionsData"; import { useSuggestionsData } from "./useSuggestionsData";
import { useToast } from "@/components/hooks/use-toast"; import { useToast } from "@/components/hooks/use-toast";
import { SearchContextQuery } from "@/lib/types"; import { SearchContextQuery } from "@/lib/types";
import isEqual from "fast-deep-equal/react";
interface ChatBoxProps { interface ChatBoxProps {
onSubmit: (children: Descendant[], editor: CustomEditor) => void; onSubmit: (children: Descendant[], editor: CustomEditor) => void;
@ -34,7 +35,7 @@ interface ChatBoxProps {
onContextSelectorOpenChanged: (isOpen: boolean) => void; onContextSelectorOpenChanged: (isOpen: boolean) => void;
} }
export const ChatBox = ({ const ChatBoxComponent = ({
onSubmit: _onSubmit, onSubmit: _onSubmit,
onStop, onStop,
preferredSuggestionsBoxPlacement = "bottom-start", preferredSuggestionsBoxPlacement = "bottom-start",
@ -368,6 +369,8 @@ export const ChatBox = ({
) )
} }
export const ChatBox = memo(ChatBoxComponent, isEqual);
const DefaultElement = (props: RenderElementProps) => { const DefaultElement = (props: RenderElementProps) => {
return <p {...props.attributes}>{props.children}</p> return <p {...props.attributes}>{props.children}</p>
} }

View file

@ -32,13 +32,19 @@ export const useSuggestionsData = ({
const { data: fileSuggestions, isLoading: _isLoadingFileSuggestions } = useQuery({ const { data: fileSuggestions, isLoading: _isLoadingFileSuggestions } = useQuery({
queryKey: ["fileSuggestions-agentic", suggestionQuery, domain, selectedRepos], queryKey: ["fileSuggestions-agentic", suggestionQuery, domain, selectedRepos],
queryFn: () => { queryFn: () => {
let query = `file:${suggestionQuery}`; const query = [];
if (suggestionQuery.length > 0) {
query.push(`file:${suggestionQuery}`);
} else {
query.push('file:.*');
}
if (selectedRepos.length > 0) { if (selectedRepos.length > 0) {
query += ` reposet:${selectedRepos.join(',')}`; query.push(`reposet:${selectedRepos.join(',')}`);
} }
return unwrapServiceError(search({ return unwrapServiceError(search({
query, query: query.join(' '),
matches: 10, matches: 10,
contextLines: 1, contextLines: 1,
})) }))

View file

@ -6,7 +6,7 @@ import { Button } from "@/components/ui/button";
import { TableOfContentsIcon, ThumbsDown, ThumbsUp } from "lucide-react"; import { TableOfContentsIcon, ThumbsDown, ThumbsUp } from "lucide-react";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { MarkdownRenderer } from "./markdownRenderer"; import { MarkdownRenderer } from "./markdownRenderer";
import { forwardRef, useCallback, useImperativeHandle, useRef, useState } from "react"; import { forwardRef, memo, useCallback, useImperativeHandle, useRef, useState } from "react";
import { Toggle } from "@/components/ui/toggle"; import { Toggle } from "@/components/ui/toggle";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { CopyIconButton } from "@/app/[domain]/components/copyIconButton"; import { CopyIconButton } from "@/app/[domain]/components/copyIconButton";
@ -14,10 +14,10 @@ import { useToast } from "@/components/hooks/use-toast";
import { convertLLMOutputToPortableMarkdown } from "../../utils"; import { convertLLMOutputToPortableMarkdown } from "../../utils";
import { submitFeedback } from "../../actions"; import { submitFeedback } from "../../actions";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { useDomain } from "@/hooks/useDomain";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
import { LangfuseWeb } from "langfuse"; import { LangfuseWeb } from "langfuse";
import { env } from "@sourcebot/shared/client"; import { env } from "@sourcebot/shared/client";
import isEqual from "fast-deep-equal/react";
interface AnswerCardProps { interface AnswerCardProps {
answerText: string; answerText: string;
@ -31,7 +31,7 @@ const langfuseWeb = (env.NEXT_PUBLIC_SOURCEBOT_CLOUD_ENVIRONMENT !== undefined &
baseUrl: env.NEXT_PUBLIC_LANGFUSE_BASE_URL, baseUrl: env.NEXT_PUBLIC_LANGFUSE_BASE_URL,
}) : null; }) : null;
export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({ const AnswerCardComponent = forwardRef<HTMLDivElement, AnswerCardProps>(({
answerText, answerText,
messageId, messageId,
chatId, chatId,
@ -41,7 +41,6 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
const { tocItems, activeId } = useExtractTOCItems({ target: markdownRendererRef.current }); const { tocItems, activeId } = useExtractTOCItems({ target: markdownRendererRef.current });
const [isTOCButtonToggled, setIsTOCButtonToggled] = useState(false); const [isTOCButtonToggled, setIsTOCButtonToggled] = useState(false);
const { toast } = useToast(); const { toast } = useToast();
const domain = useDomain();
const [isSubmittingFeedback, setIsSubmittingFeedback] = useState(false); const [isSubmittingFeedback, setIsSubmittingFeedback] = useState(false);
const [feedback, setFeedback] = useState<'like' | 'dislike' | undefined>(undefined); const [feedback, setFeedback] = useState<'like' | 'dislike' | undefined>(undefined);
const captureEvent = useCaptureEvent(); const captureEvent = useCaptureEvent();
@ -67,7 +66,7 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
chatId, chatId,
messageId, messageId,
feedbackType feedbackType
}, domain); });
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -93,7 +92,7 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
} }
setIsSubmittingFeedback(false); setIsSubmittingFeedback(false);
}, [chatId, messageId, domain, toast, captureEvent, traceId]); }, [chatId, messageId, toast, captureEvent, traceId]);
return ( return (
<div className="flex flex-row w-full relative scroll-mt-16"> <div className="flex flex-row w-full relative scroll-mt-16">
@ -178,4 +177,6 @@ export const AnswerCard = forwardRef<HTMLDivElement, AnswerCardProps>(({
) )
}) })
AnswerCard.displayName = 'AnswerCard'; AnswerCardComponent.displayName = 'AnswerCard';
export const AnswerCard = memo(AnswerCardComponent, isEqual);

View file

@ -7,7 +7,6 @@ import { Separator } from '@/components/ui/separator';
import { CustomSlateEditor } from '@/features/chat/customSlateEditor'; import { CustomSlateEditor } from '@/features/chat/customSlateEditor';
import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types'; import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types';
import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils'; import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils';
import { useDomain } from '@/hooks/useDomain';
import { useChat } from '@ai-sdk/react'; import { useChat } from '@ai-sdk/react';
import { CreateUIMessage, DefaultChatTransport } from 'ai'; import { CreateUIMessage, DefaultChatTransport } from 'ai';
import { ArrowDownIcon } from 'lucide-react'; import { ArrowDownIcon } from 'lucide-react';
@ -54,7 +53,6 @@ export const ChatThread = ({
onSelectedSearchScopesChange, onSelectedSearchScopesChange,
isChatReadonly, isChatReadonly,
}: ChatThreadProps) => { }: ChatThreadProps) => {
const domain = useDomain();
const [isErrorBannerVisible, setIsErrorBannerVisible] = useState(false); const [isErrorBannerVisible, setIsErrorBannerVisible] = useState(false);
const scrollAreaRef = useRef<HTMLDivElement>(null); const scrollAreaRef = useRef<HTMLDivElement>(null);
const latestMessagePairRef = useRef<HTMLDivElement>(null); const latestMessagePairRef = useRef<HTMLDivElement>(null);
@ -89,9 +87,6 @@ export const ChatThread = ({
messages: initialMessages, messages: initialMessages,
transport: new DefaultChatTransport({ transport: new DefaultChatTransport({
api: '/api/chat', api: '/api/chat',
headers: {
"X-Org-Domain": domain,
}
}), }),
onData: (dataPart) => { onData: (dataPart) => {
// Keeps sources added by the assistant in sync. // Keeps sources added by the assistant in sync.
@ -134,7 +129,6 @@ export const ChatThread = ({
languageModelId: selectedLanguageModel.model, languageModelId: selectedLanguageModel.model,
message: message.parts[0].text, message: message.parts[0].text,
}, },
domain
).then((response) => { ).then((response) => {
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
@ -153,7 +147,6 @@ export const ChatThread = ({
messages.length, messages.length,
toast, toast,
chatId, chatId,
domain,
router, router,
]); ]);
@ -224,7 +217,7 @@ export const ChatThread = ({
'', '',
window.location.href window.location.href
); );
}, 300); }, 500);
}; };
scrollElement.addEventListener('scroll', handleScroll, { passive: true }); scrollElement.addEventListener('scroll', handleScroll, { passive: true });
@ -243,11 +236,17 @@ export const ChatThread = ({
return; return;
} }
const { scrollOffset } = (history.state ?? {}) as ChatHistoryState; // @hack: without this setTimeout, the scroll position would not be restored
scrollElement.scrollTo({ // at the correct position (it was slightly too high). The theory is that the
top: scrollOffset ?? 0, // content hasn't fully rendered yet, so restoring the scroll position too
behavior: 'instant', // early results in weirdness. Waiting 10ms seems to fix the issue.
}); setTimeout(() => {
const { scrollOffset } = (history.state ?? {}) as ChatHistoryState;
scrollElement.scrollTo({
top: scrollOffset ?? 0,
behavior: 'instant',
});
}, 10);
}, []); }, []);
// When messages are being streamed, scroll to the latest message // When messages are being streamed, scroll to the latest message
@ -313,9 +312,11 @@ export const ChatThread = ({
{messagePairs.map(([userMessage, assistantMessage], index) => { {messagePairs.map(([userMessage, assistantMessage], index) => {
const isLastPair = index === messagePairs.length - 1; const isLastPair = index === messagePairs.length - 1;
const isStreaming = isLastPair && (status === "streaming" || status === "submitted"); const isStreaming = isLastPair && (status === "streaming" || status === "submitted");
// Use a stable key based on user message ID
const key = userMessage.id;
return ( return (
<Fragment key={index}> <Fragment key={key}>
<ChatThreadListItem <ChatThreadListItem
index={index} index={index}
chatId={chatId} chatId={chatId}

View file

@ -4,16 +4,17 @@ import { AnimatedResizableHandle } from '@/components/ui/animatedResizableHandle
import { ResizablePanel, ResizablePanelGroup } from '@/components/ui/resizable'; import { ResizablePanel, ResizablePanelGroup } from '@/components/ui/resizable';
import { Skeleton } from '@/components/ui/skeleton'; import { Skeleton } from '@/components/ui/skeleton';
import { CheckCircle, Loader2 } from 'lucide-react'; import { CheckCircle, Loader2 } from 'lucide-react';
import { CSSProperties, forwardRef, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { CSSProperties, forwardRef, memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import scrollIntoView from 'scroll-into-view-if-needed'; import scrollIntoView from 'scroll-into-view-if-needed';
import { Reference, referenceSchema, SBChatMessage, Source } from "../../types"; import { Reference, referenceSchema, SBChatMessage, Source } from "../../types";
import { useExtractReferences } from '../../useExtractReferences'; import { useExtractReferences } from '../../useExtractReferences';
import { getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences } from '../../utils'; import { getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences, tryResolveFileReference } from '../../utils';
import { AnswerCard } from './answerCard'; import { AnswerCard } from './answerCard';
import { DetailsCard } from './detailsCard'; import { DetailsCard } from './detailsCard';
import { MarkdownRenderer, REFERENCE_PAYLOAD_ATTRIBUTE } from './markdownRenderer'; import { MarkdownRenderer, REFERENCE_PAYLOAD_ATTRIBUTE } from './markdownRenderer';
import { ReferencedSourcesListView } from './referencedSourcesListView'; import { ReferencedSourcesListView } from './referencedSourcesListView';
import { uiVisiblePartTypes } from '../../constants'; import { uiVisiblePartTypes } from '../../constants';
import isEqual from "fast-deep-equal/react";
interface ChatThreadListItemProps { interface ChatThreadListItemProps {
userMessage: SBChatMessage; userMessage: SBChatMessage;
@ -24,7 +25,7 @@ interface ChatThreadListItemProps {
index: number; index: number;
} }
export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemProps>(({ const ChatThreadListItemComponent = forwardRef<HTMLDivElement, ChatThreadListItemProps>(({
userMessage, userMessage,
assistantMessage: _assistantMessage, assistantMessage: _assistantMessage,
isStreaming, isStreaming,
@ -80,7 +81,6 @@ export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemP
return getAnswerPartFromAssistantMessage(assistantMessage, isStreaming); return getAnswerPartFromAssistantMessage(assistantMessage, isStreaming);
}, [assistantMessage, isStreaming]); }, [assistantMessage, isStreaming]);
const references = useExtractReferences(answerPart);
// Groups parts into steps that are associated with thinking steps that // Groups parts into steps that are associated with thinking steps that
// should be visible to the user. By "steps", we mean parts that originated // should be visible to the user. By "steps", we mean parts that originated
@ -279,6 +279,26 @@ export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemP
}; };
}, [hoveredReference]); }, [hoveredReference]);
const references = useExtractReferences(answerPart);
// Extract the file sources that are referenced by the answer part.
const referencedFileSources = useMemo(() => {
const fileSources = sources.filter((source) => source.type === 'file');
return references
.filter((reference) => reference.type === 'file')
.map((reference) => tryResolveFileReference(reference, fileSources))
.filter((file) => file !== undefined)
// de-duplicate files
.filter((file, index, self) =>
index === self.findIndex((t) =>
t?.path === file?.path
&& t?.repo === file?.repo
&& t?.revision === file?.revision
)
);
}, [references, sources]);
return ( return (
<div <div
@ -364,11 +384,11 @@ export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemP
<div <div
className="sticky top-0" className="sticky top-0"
> >
{references.length > 0 ? ( {referencedFileSources.length > 0 ? (
<ReferencedSourcesListView <ReferencedSourcesListView
index={index} index={index}
references={references} references={references}
sources={sources} sources={referencedFileSources}
hoveredReference={hoveredReference} hoveredReference={hoveredReference}
selectedReference={selectedReference} selectedReference={selectedReference}
onSelectedReferenceChanged={setSelectedReference} onSelectedReferenceChanged={setSelectedReference}
@ -393,7 +413,34 @@ export const ChatThreadListItem = forwardRef<HTMLDivElement, ChatThreadListItemP
) )
}); });
ChatThreadListItem.displayName = 'ChatThreadListItem'; ChatThreadListItemComponent.displayName = 'ChatThreadListItem';
// Custom comparison function that handles the known issue where useChat mutates
// message objects in place during streaming, causing fast-deep-equal to return
// true even when content changes (because it checks reference equality first).
// See: https://github.com/vercel/ai/issues/6466
const arePropsEqual = (
prevProps: ChatThreadListItemProps,
nextProps: ChatThreadListItemProps
): boolean => {
// Always re-render if streaming status changes
if (prevProps.isStreaming !== nextProps.isStreaming) {
return false;
}
// If currently streaming, always allow re-render
// This bypasses the fast-deep-equal reference check issue when useChat
// mutates message objects in place during token streaming
if (nextProps.isStreaming) {
return false;
}
// For non-streaming messages, use deep equality
// At this point, useChat should have finished and created final objects
return isEqual(prevProps, nextProps);
};
export const ChatThreadListItem = memo(ChatThreadListItemComponent, arePropsEqual);
// Finds the nearest reference element to the viewport center. // Finds the nearest reference element to the viewport center.
const getNearestReferenceElement = (referenceElements: Element[]) => { const getNearestReferenceElement = (referenceElements: Element[]) => {

View file

@ -7,6 +7,7 @@ import { Skeleton } from '@/components/ui/skeleton';
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'; import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, List, ScanSearchIcon, Zap } from 'lucide-react'; import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, List, ScanSearchIcon, Zap } from 'lucide-react';
import { memo } from 'react';
import { MarkdownRenderer } from './markdownRenderer'; import { MarkdownRenderer } from './markdownRenderer';
import { FindSymbolDefinitionsToolComponent } from './tools/findSymbolDefinitionsToolComponent'; import { FindSymbolDefinitionsToolComponent } from './tools/findSymbolDefinitionsToolComponent';
import { FindSymbolReferencesToolComponent } from './tools/findSymbolReferencesToolComponent'; import { FindSymbolReferencesToolComponent } from './tools/findSymbolReferencesToolComponent';
@ -16,6 +17,7 @@ import { SearchReposToolComponent } from './tools/searchReposToolComponent';
import { ListAllReposToolComponent } from './tools/listAllReposToolComponent'; import { ListAllReposToolComponent } from './tools/listAllReposToolComponent';
import { SBChatMessageMetadata, SBChatMessagePart } from '../../types'; import { SBChatMessageMetadata, SBChatMessagePart } from '../../types';
import { SearchScopeIcon } from '../searchScopeIcon'; import { SearchScopeIcon } from '../searchScopeIcon';
import isEqual from "fast-deep-equal/react";
interface DetailsCardProps { interface DetailsCardProps {
@ -27,7 +29,7 @@ interface DetailsCardProps {
metadata?: SBChatMessageMetadata; metadata?: SBChatMessageMetadata;
} }
export const DetailsCard = ({ const DetailsCardComponent = ({
isExpanded, isExpanded,
onExpandedChanged, onExpandedChanged,
isThinking, isThinking,
@ -35,7 +37,6 @@ export const DetailsCard = ({
metadata, metadata,
thinkingSteps, thinkingSteps,
}: DetailsCardProps) => { }: DetailsCardProps) => {
return ( return (
<Card className="mb-4"> <Card className="mb-4">
<Collapsible open={isExpanded} onOpenChange={onExpandedChanged}> <Collapsible open={isExpanded} onOpenChange={onExpandedChanged}>
@ -209,4 +210,6 @@ export const DetailsCard = ({
</Collapsible> </Collapsible>
</Card> </Card>
) )
} }
export const DetailsCard = memo(DetailsCardComponent, isEqual);

View file

@ -1,7 +1,6 @@
'use client'; 'use client';
import { CodeSnippet } from '@/app/components/codeSnippet'; import { CodeSnippet } from '@/app/components/codeSnippet';
import { useDomain } from '@/hooks/useDomain';
import { SearchQueryParams } from '@/lib/types'; import { SearchQueryParams } from '@/lib/types';
import { cn, createPathWithQueryParams } from '@/lib/utils'; import { cn, createPathWithQueryParams } from '@/lib/utils';
import type { Element, Root } from "hast"; import type { Element, Root } from "hast";
@ -10,7 +9,7 @@ import { CopyIcon, SearchIcon } from 'lucide-react';
import type { Heading, Nodes } from "mdast"; import type { Heading, Nodes } from "mdast";
import { findAndReplace } from 'mdast-util-find-and-replace'; import { findAndReplace } from 'mdast-util-find-and-replace';
import { useRouter } from 'next/navigation'; import { useRouter } from 'next/navigation';
import React, { useCallback, useMemo, forwardRef } from 'react'; import React, { useCallback, useMemo, forwardRef, memo } from 'react';
import Markdown from 'react-markdown'; import Markdown from 'react-markdown';
import rehypeRaw from 'rehype-raw'; import rehypeRaw from 'rehype-raw';
import rehypeSanitize, { defaultSchema } from 'rehype-sanitize'; import rehypeSanitize, { defaultSchema } from 'rehype-sanitize';
@ -20,6 +19,8 @@ import { visit } from 'unist-util-visit';
import { CodeBlock } from './codeBlock'; import { CodeBlock } from './codeBlock';
import { FILE_REFERENCE_REGEX } from '@/features/chat/constants'; import { FILE_REFERENCE_REGEX } from '@/features/chat/constants';
import { createFileReference } from '@/features/chat/utils'; import { createFileReference } from '@/features/chat/utils';
import { SINGLE_TENANT_ORG_DOMAIN } from '@/lib/constants';
import isEqual from "fast-deep-equal/react";
export const REFERENCE_PAYLOAD_ATTRIBUTE = 'data-reference-payload'; export const REFERENCE_PAYLOAD_ATTRIBUTE = 'data-reference-payload';
@ -102,8 +103,7 @@ interface MarkdownRendererProps {
className?: string; className?: string;
} }
export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps>(({ content, className }, ref) => { const MarkdownRendererComponent = forwardRef<HTMLDivElement, MarkdownRendererProps>(({ content, className }, ref) => {
const domain = useDomain();
const router = useRouter(); const router = useRouter();
const remarkPlugins = useMemo((): PluggableList => { const remarkPlugins = useMemo((): PluggableList => {
@ -176,7 +176,7 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
onClick={(e) => { onClick={(e) => {
e.preventDefault(); e.preventDefault();
e.stopPropagation(); e.stopPropagation();
const url = createPathWithQueryParams(`/${domain}/search`, [SearchQueryParams.query, `"${text}"`]) const url = createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/search`, [SearchQueryParams.query, `"${text}"`])
router.push(url); router.push(url);
}} }}
title="Search for snippet" title="Search for snippet"
@ -199,7 +199,7 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
</span> </span>
) )
}, [domain, router]); }, [router]);
return ( return (
<div <div
@ -220,4 +220,6 @@ export const MarkdownRenderer = forwardRef<HTMLDivElement, MarkdownRendererProps
); );
}); });
MarkdownRenderer.displayName = 'MarkdownRenderer'; MarkdownRendererComponent.displayName = 'MarkdownRenderer';
export const MarkdownRenderer = memo(MarkdownRendererComponent, isEqual);

View file

@ -14,13 +14,13 @@ import { Range } from "@codemirror/state";
import { Decoration, DecorationSet, EditorView } from '@codemirror/view'; import { Decoration, DecorationSet, EditorView } from '@codemirror/view';
import CodeMirror, { ReactCodeMirrorRef, StateField } from '@uiw/react-codemirror'; import CodeMirror, { ReactCodeMirrorRef, StateField } from '@uiw/react-codemirror';
import { ChevronDown, ChevronRight } from "lucide-react"; import { ChevronDown, ChevronRight } from "lucide-react";
import { forwardRef, Ref, useCallback, useImperativeHandle, useMemo, useState } from "react"; import { forwardRef, memo, Ref, useCallback, useImperativeHandle, useMemo, useState } from "react";
import { FileReference } from "../../types"; import { FileReference } from "../../types";
import { createCodeFoldingExtension } from "./codeFoldingExtension"; import { createCodeFoldingExtension } from "./codeFoldingExtension";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
import { createAuditAction } from "@/ee/features/audit/actions";
import { useDomain } from "@/hooks/useDomain";
import { CodeHostType } from "@sourcebot/db"; import { CodeHostType } from "@sourcebot/db";
import { createAuditAction } from "@/ee/features/audit/actions";
import isEqual from "fast-deep-equal/react";
const lineDecoration = Decoration.line({ const lineDecoration = Decoration.line({
attributes: { class: "cm-range-border-radius chat-lineHighlight" }, attributes: { class: "cm-range-border-radius chat-lineHighlight" },
@ -75,7 +75,6 @@ const ReferencedFileSourceListItem = ({
const theme = useCodeMirrorTheme(); const theme = useCodeMirrorTheme();
const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null); const [editorRef, setEditorRef] = useState<ReactCodeMirrorRef | null>(null);
const captureEvent = useCaptureEvent(); const captureEvent = useCaptureEvent();
const domain = useDomain();
useImperativeHandle( useImperativeHandle(
forwardedRef, forwardedRef,
@ -231,7 +230,7 @@ const ReferencedFileSourceListItem = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain); });
if (symbolDefinitions.length === 1) { if (symbolDefinitions.length === 1) {
const symbolDefinition = symbolDefinitions[0]; const symbolDefinition = symbolDefinitions[0];
@ -263,7 +262,7 @@ const ReferencedFileSourceListItem = ({
}); });
} }
}, [captureEvent, domain, navigateToPath, revision, repoName, fileName, language]); }, [captureEvent, navigateToPath, revision, repoName, fileName, language]);
const onFindReferences = useCallback((symbolName: string) => { const onFindReferences = useCallback((symbolName: string) => {
captureEvent('wa_find_references_pressed', { captureEvent('wa_find_references_pressed', {
@ -274,7 +273,7 @@ const ReferencedFileSourceListItem = ({
metadata: { metadata: {
message: symbolName, message: symbolName,
}, },
}, domain); });
navigateToPath({ navigateToPath({
repoName, repoName,
@ -293,7 +292,7 @@ const ReferencedFileSourceListItem = ({
} }
}) })
}, [captureEvent, domain, fileName, language, navigateToPath, repoName, revision]); }, [captureEvent, fileName, language, navigateToPath, repoName, revision]);
const ExpandCollapseIcon = useMemo(() => { const ExpandCollapseIcon = useMemo(() => {
return isExpanded ? ChevronDown : ChevronRight; return isExpanded ? ChevronDown : ChevronRight;
@ -355,6 +354,6 @@ const ReferencedFileSourceListItem = ({
) )
} }
export default forwardRef(ReferencedFileSourceListItem) as ( export default memo(forwardRef(ReferencedFileSourceListItem), isEqual) as (
props: ReferencedFileSourceListItemProps & { ref?: Ref<ReactCodeMirrorRef> }, props: ReferencedFileSourceListItemProps & { ref?: Ref<ReactCodeMirrorRef> },
) => ReturnType<typeof ReferencedFileSourceListItem>; ) => ReturnType<typeof ReferencedFileSourceListItem>;

View file

@ -4,18 +4,19 @@ import { getFileSource } from "@/app/api/(client)/client";
import { VscodeFileIcon } from "@/app/components/vscodeFileIcon"; import { VscodeFileIcon } from "@/app/components/vscodeFileIcon";
import { ScrollArea } from "@/components/ui/scroll-area"; import { ScrollArea } from "@/components/ui/scroll-area";
import { Skeleton } from "@/components/ui/skeleton"; import { Skeleton } from "@/components/ui/skeleton";
import { useDomain } from "@/hooks/useDomain";
import { isServiceError, unwrapServiceError } from "@/lib/utils"; import { isServiceError, unwrapServiceError } from "@/lib/utils";
import { useQueries } from "@tanstack/react-query"; import { useQueries } from "@tanstack/react-query";
import { ReactCodeMirrorRef } from '@uiw/react-codemirror'; import { ReactCodeMirrorRef } from '@uiw/react-codemirror';
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { memo, useCallback, useEffect, useMemo, useRef, useState } from "react";
import scrollIntoView from 'scroll-into-view-if-needed'; import scrollIntoView from 'scroll-into-view-if-needed';
import { FileReference, FileSource, Reference, Source } from "../../types"; import { FileReference, FileSource, Reference } from "../../types";
import { tryResolveFileReference } from '../../utils';
import ReferencedFileSourceListItem from "./referencedFileSourceListItem"; import ReferencedFileSourceListItem from "./referencedFileSourceListItem";
import isEqual from 'fast-deep-equal/react';
interface ReferencedSourcesListViewProps { interface ReferencedSourcesListViewProps {
references: FileReference[]; references: FileReference[];
sources: Source[]; sources: FileSource[];
index: number; index: number;
hoveredReference?: Reference; hoveredReference?: Reference;
onHoveredReferenceChanged: (reference?: Reference) => void; onHoveredReferenceChanged: (reference?: Reference) => void;
@ -24,14 +25,7 @@ interface ReferencedSourcesListViewProps {
style: React.CSSProperties; style: React.CSSProperties;
} }
const resolveFileReference = (reference: FileReference, sources: FileSource[]): FileSource | undefined => { const ReferencedSourcesListViewComponent = ({
return sources.find(
(source) => source.repo.endsWith(reference.repo) &&
source.path.endsWith(reference.path)
);
}
export const ReferencedSourcesListView = ({
references, references,
sources, sources,
index, index,
@ -43,7 +37,6 @@ export const ReferencedSourcesListView = ({
}: ReferencedSourcesListViewProps) => { }: ReferencedSourcesListViewProps) => {
const scrollAreaRef = useRef<HTMLDivElement>(null); const scrollAreaRef = useRef<HTMLDivElement>(null);
const editorRefsMap = useRef<Map<string, ReactCodeMirrorRef>>(new Map()); const editorRefsMap = useRef<Map<string, ReactCodeMirrorRef>>(new Map());
const domain = useDomain();
const [collapsedFileIds, setCollapsedFileIds] = useState<string[]>([]); const [collapsedFileIds, setCollapsedFileIds] = useState<string[]>([]);
const getFileId = useCallback((fileSource: FileSource) => { const getFileId = useCallback((fileSource: FileSource) => {
@ -61,44 +54,27 @@ export const ReferencedSourcesListView = ({
} }
}, []); }, []);
const referencedFileSources = useMemo((): FileSource[] => {
const fileSources = sources.filter((source) => source.type === 'file');
return references
.filter((reference) => reference.type === 'file')
.map((reference) => resolveFileReference(reference, fileSources))
.filter((file) => file !== undefined)
// de-duplicate files
.filter((file, index, self) =>
index === self.findIndex((t) =>
t?.path === file?.path
&& t?.repo === file?.repo
&& t?.revision === file?.revision
)
);
}, [references, sources]);
// Memoize the computation of references grouped by file source // Memoize the computation of references grouped by file source
const referencesGroupedByFile = useMemo(() => { const referencesGroupedByFile = useMemo(() => {
const groupedReferences = new Map<string, FileReference[]>(); const groupedReferences = new Map<string, FileReference[]>();
for (const fileSource of referencedFileSources) { for (const fileSource of sources) {
const fileKey = getFileId(fileSource); const fileKey = getFileId(fileSource);
const referencesInFile = references.filter((reference) => { const referencesInFile = references.filter((reference) => {
if (reference.type !== 'file') { if (reference.type !== 'file') {
return false; return false;
} }
return resolveFileReference(reference, [fileSource]) !== undefined; return tryResolveFileReference(reference, [fileSource]) !== undefined;
}); });
groupedReferences.set(fileKey, referencesInFile); groupedReferences.set(fileKey, referencesInFile);
} }
return groupedReferences; return groupedReferences;
}, [references, referencedFileSources, getFileId]); }, [references, sources, getFileId]);
const fileSourceQueries = useQueries({ const fileSourceQueries = useQueries({
queries: referencedFileSources.map((file) => ({ queries: sources.map((file) => ({
queryKey: ['fileSource', file.path, file.repo, file.revision, domain], queryKey: ['fileSource', file.path, file.repo, file.revision],
queryFn: () => unwrapServiceError(getFileSource({ queryFn: () => unwrapServiceError(getFileSource({
fileName: file.path, fileName: file.path,
repository: file.repo, repository: file.repo,
@ -114,7 +90,7 @@ export const ReferencedSourcesListView = ({
return; return;
} }
const fileSource = resolveFileReference(selectedReference, referencedFileSources); const fileSource = tryResolveFileReference(selectedReference, sources);
if (!fileSource) { if (!fileSource) {
return; return;
} }
@ -181,9 +157,28 @@ export const ReferencedSourcesListView = ({
behavior: 'smooth', behavior: 'smooth',
}); });
} }
}, [getFileId, referencedFileSources, selectedReference]); }, [getFileId, sources, selectedReference]);
if (referencedFileSources.length === 0) { const onExpandedChanged = useCallback((fileId: string, isExpanded: boolean) => {
if (isExpanded) {
setCollapsedFileIds(collapsedFileIds => collapsedFileIds.filter((id) => id !== fileId));
} else {
setCollapsedFileIds(collapsedFileIds => [...collapsedFileIds, fileId]);
}
if (!isExpanded) {
const fileSourceStart = document.getElementById(`${fileId}-start`);
if (fileSourceStart) {
scrollIntoView(fileSourceStart, {
scrollMode: 'if-needed',
block: 'start',
behavior: 'instant',
});
}
}
}, []);
if (sources.length === 0) {
return ( return (
<div className="p-4 text-center text-muted-foreground text-sm"> <div className="p-4 text-center text-muted-foreground text-sm">
No file references found No file references found
@ -198,7 +193,7 @@ export const ReferencedSourcesListView = ({
> >
<div className="space-y-4 pr-2"> <div className="space-y-4 pr-2">
{fileSourceQueries.map((query, index) => { {fileSourceQueries.map((query, index) => {
const fileSource = referencedFileSources[index]; const fileSource = sources[index];
const fileName = fileSource.path.split('/').pop() ?? fileSource.path; const fileName = fileSource.path.split('/').pop() ?? fileSource.path;
if (query.isLoading) { if (query.isLoading) {
@ -253,30 +248,7 @@ export const ReferencedSourcesListView = ({
selectedReference={selectedReference} selectedReference={selectedReference}
hoveredReference={hoveredReference} hoveredReference={hoveredReference}
isExpanded={!collapsedFileIds.includes(fileId)} isExpanded={!collapsedFileIds.includes(fileId)}
onExpandedChanged={(isExpanded) => { onExpandedChanged={(isExpanded) => onExpandedChanged(fileId, isExpanded)}
if (isExpanded) {
setCollapsedFileIds(collapsedFileIds.filter((id) => id !== fileId));
} else {
setCollapsedFileIds([...collapsedFileIds, fileId]);
}
// When collapsing a file when you are deep in a scroll, it's a better
// experience to have the scroll automatically restored to the top of the file
// s.t., header is still sticky to the top of the scroll area.
if (!isExpanded) {
const fileSourceStart = document.getElementById(`${fileId}-start`);
if (!fileSourceStart) {
return;
}
scrollIntoView(fileSourceStart, {
scrollMode: 'if-needed',
block: 'start',
behavior: 'instant',
});
}
}
}
/> />
); );
})} })}
@ -284,3 +256,6 @@ export const ReferencedSourcesListView = ({
</ScrollArea> </ScrollArea>
); );
} }
// Memoize to prevent unnecessary re-renders
export const ReferencedSourcesListView = memo(ReferencedSourcesListViewComponent, isEqual);

View file

@ -1,7 +1,6 @@
'use client'; 'use client';
import { SearchCodeToolUIPart } from "@/features/chat/tools"; import { SearchCodeToolUIPart } from "@/features/chat/tools";
import { useDomain } from "@/hooks/useDomain";
import { createPathWithQueryParams, isServiceError } from "@/lib/utils"; import { createPathWithQueryParams, isServiceError } from "@/lib/utils";
import { useMemo, useState } from "react"; import { useMemo, useState } from "react";
import { FileListItem, ToolHeader, TreeList } from "./shared"; import { FileListItem, ToolHeader, TreeList } from "./shared";
@ -12,10 +11,10 @@ import Link from "next/link";
import { SearchQueryParams } from "@/lib/types"; import { SearchQueryParams } from "@/lib/types";
import { PlayIcon } from "@radix-ui/react-icons"; import { PlayIcon } from "@radix-ui/react-icons";
import { buildSearchQuery } from "@/features/chat/utils"; import { buildSearchQuery } from "@/features/chat/utils";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }) => { export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }) => {
const [isExpanded, setIsExpanded] = useState(false); const [isExpanded, setIsExpanded] = useState(false);
const domain = useDomain();
const displayQuery = useMemo(() => { const displayQuery = useMemo(() => {
if (part.state !== 'input-available' && part.state !== 'output-available') { if (part.state !== 'input-available' && part.state !== 'output-available') {
@ -78,7 +77,7 @@ export const SearchCodeToolComponent = ({ part }: { part: SearchCodeToolUIPart }
</TreeList> </TreeList>
)} )}
<Link <Link
href={createPathWithQueryParams(`/${domain}/search`, href={createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/search`,
[SearchQueryParams.query, part.output.query], [SearchQueryParams.query, part.output.query],
)} )}
className='flex flex-row items-center gap-2 text-sm text-muted-foreground mt-2 ml-auto w-fit hover:text-foreground' className='flex flex-row items-center gap-2 text-sm text-muted-foreground mt-2 ml-auto w-fit hover:text-foreground'

View file

@ -2,12 +2,12 @@
import { VscodeFileIcon } from '@/app/components/vscodeFileIcon'; import { VscodeFileIcon } from '@/app/components/vscodeFileIcon';
import { ScrollArea } from '@/components/ui/scroll-area'; import { ScrollArea } from '@/components/ui/scroll-area';
import { useDomain } from '@/hooks/useDomain';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { ChevronDown, ChevronRight, Loader2 } from 'lucide-react'; import { ChevronDown, ChevronRight, Loader2 } from 'lucide-react';
import Link from 'next/link'; import Link from 'next/link';
import React from 'react'; import React from 'react';
import { getBrowsePath } from "@/app/[domain]/browse/hooks/utils"; import { getBrowsePath } from "@/app/[domain]/browse/hooks/utils";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const FileListItem = ({ export const FileListItem = ({
@ -17,8 +17,6 @@ export const FileListItem = ({
path: string, path: string,
repoName: string, repoName: string,
}) => { }) => {
const domain = useDomain();
return ( return (
<div key={path} className="flex flex-row items-center overflow-hidden hover:bg-accent hover:text-accent-foreground rounded-sm cursor-pointer p-0.5"> <div key={path} className="flex flex-row items-center overflow-hidden hover:bg-accent hover:text-accent-foreground rounded-sm cursor-pointer p-0.5">
<VscodeFileIcon fileName={path} className="mr-1 flex-shrink-0" /> <VscodeFileIcon fileName={path} className="mr-1 flex-shrink-0" />
@ -28,7 +26,7 @@ export const FileListItem = ({
repoName, repoName,
revisionName: 'HEAD', revisionName: 'HEAD',
path, path,
domain, domain: SINGLE_TENANT_ORG_DOMAIN,
pathType: 'blob', pathType: 'blob',
})} })}
> >

View file

@ -70,7 +70,7 @@ export const sbChatMessageMetadataSchema = z.object({
feedback: z.array(z.object({ feedback: z.array(z.object({
type: z.enum(['like', 'dislike']), type: z.enum(['like', 'dislike']),
timestamp: z.string(), // ISO date string timestamp: z.string(), // ISO date string
userId: z.string(), userId: z.string().optional(),
})).optional(), })).optional(),
selectedSearchScopes: z.array(searchScopeSchema).optional(), selectedSearchScopes: z.array(searchScopeSchema).optional(),
traceId: z.string().optional(), traceId: z.string().optional(),

View file

@ -4,7 +4,6 @@ import { useCallback, useState } from "react";
import { Descendant } from "slate"; import { Descendant } from "slate";
import { createUIMessage, getAllMentionElements } from "./utils"; import { createUIMessage, getAllMentionElements } from "./utils";
import { slateContentToString } from "./utils"; import { slateContentToString } from "./utils";
import { useDomain } from "@/hooks/useDomain";
import { useToast } from "@/components/hooks/use-toast"; import { useToast } from "@/components/hooks/use-toast";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { createChat } from "./actions"; import { createChat } from "./actions";
@ -13,9 +12,9 @@ import { createPathWithQueryParams } from "@/lib/utils";
import { SearchScope, SET_CHAT_STATE_SESSION_STORAGE_KEY, SetChatStatePayload } from "./types"; import { SearchScope, SET_CHAT_STATE_SESSION_STORAGE_KEY, SetChatStatePayload } from "./types";
import { useSessionStorage } from "usehooks-ts"; import { useSessionStorage } from "usehooks-ts";
import useCaptureEvent from "@/hooks/useCaptureEvent"; import useCaptureEvent from "@/hooks/useCaptureEvent";
import { SINGLE_TENANT_ORG_DOMAIN } from "@/lib/constants";
export const useCreateNewChatThread = () => { export const useCreateNewChatThread = () => {
const domain = useDomain();
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const { toast } = useToast(); const { toast } = useToast();
const router = useRouter(); const router = useRouter();
@ -31,7 +30,7 @@ export const useCreateNewChatThread = () => {
const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes); const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes);
setIsLoading(true); setIsLoading(true);
const response = await createChat(domain); const response = await createChat();
if (isServiceError(response)) { if (isServiceError(response)) {
toast({ toast({
description: `❌ Failed to create chat. Reason: ${response.message}` description: `❌ Failed to create chat. Reason: ${response.message}`
@ -47,11 +46,11 @@ export const useCreateNewChatThread = () => {
selectedSearchScopes, selectedSearchScopes,
}); });
const url = createPathWithQueryParams(`/${domain}/chat/${response.id}`); const url = createPathWithQueryParams(`/${SINGLE_TENANT_ORG_DOMAIN}/chat/${response.id}`);
router.push(url); router.push(url);
router.refresh(); router.refresh();
}, [domain, router, toast, setChatState, captureEvent]); }, [router, toast, setChatState, captureEvent]);
return { return {
createNewChatThread, createNewChatThread,

View file

@ -374,3 +374,13 @@ export const buildSearchQuery = (options: {
export const getLanguageModelKey = (model: LanguageModelInfo) => { export const getLanguageModelKey = (model: LanguageModelInfo) => {
return `${model.provider}-${model.model}-${model.displayName}`; return `${model.provider}-${model.model}-${model.displayName}`;
} }
/**
* Given a file reference and a list of file sources, attempts to resolve the file source that the reference points to.
*/
export const tryResolveFileReference = (reference: FileReference, sources: FileSource[]): FileSource | undefined => {
return sources.find(
(source) => source.repo.endsWith(reference.repo) &&
source.path.endsWith(reference.path)
);
}

View file

@ -6,6 +6,7 @@ import { search } from "./searchApi";
import { sew } from "@/actions"; import { sew } from "@/actions";
import { withOptionalAuthV2 } from "@/withAuthV2"; import { withOptionalAuthV2 } from "@/withAuthV2";
import { QueryIR } from './ir'; import { QueryIR } from './ir';
// @todo (bkellam) #574 : We should really be using `git show <hash>:<path>` to fetch file contents here. // @todo (bkellam) #574 : We should really be using `git show <hash>:<path>` to fetch file contents here.
// This will allow us to support permalinks to files at a specific revision that may not be indexed // This will allow us to support permalinks to files at a specific revision that may not be indexed
// by zoekt. We should also refactor this out of the /search folder. // by zoekt. We should also refactor this out of the /search folder.
@ -21,12 +22,12 @@ export const getFileSource = async ({ fileName, repository, branch }: FileSource
}, },
}, },
{ {
regexp: { substring: {
regexp: fileName, pattern: fileName,
case_sensitive: true, case_sensitive: true,
file_name: true, file_name: true,
content: false content: false,
}, }
}, },
...(branch ? [{ ...(branch ? [{
branch: { branch: {

View file

@ -8197,6 +8197,7 @@ __metadata:
eslint-config-next: "npm:15.5.0" eslint-config-next: "npm:15.5.0"
eslint-plugin-react: "npm:^7.37.5" eslint-plugin-react: "npm:^7.37.5"
eslint-plugin-react-hooks: "npm:^5.2.0" eslint-plugin-react-hooks: "npm:^5.2.0"
fast-deep-equal: "npm:^3.1.3"
fuse.js: "npm:^7.0.0" fuse.js: "npm:^7.0.0"
google-auth-library: "npm:^10.1.0" google-auth-library: "npm:^10.1.0"
graphql: "npm:^16.9.0" graphql: "npm:^16.9.0"