From c9e864d53a5df96675ad3159ac22f8b3bb121866 Mon Sep 17 00:00:00 2001 From: Brendan Kellam Date: Mon, 15 Sep 2025 23:13:29 -0700 Subject: [PATCH] fix(web): Fix carousel perf issue + improvements to withAuth middleware (#507) --- CHANGELOG.md | 1 + packages/mcp/src/index.ts | 4 +- packages/mcp/src/schemas.ts | 32 +- packages/mcp/src/types.ts | 1 - packages/web/package.json | 5 +- packages/web/src/__mocks__/prisma.ts | 48 ++ packages/web/src/actions.ts | 75 +- .../web/src/app/[domain]/chat/[id]/page.tsx | 2 +- packages/web/src/app/[domain]/chat/page.tsx | 2 +- .../[domain]/components/errorNavIndicator.tsx | 4 +- .../homepage/repositorySnapshot.tsx | 10 +- .../components/progressNavIndicator.tsx | 6 +- .../searchBar/useSuggestionsData.ts | 10 +- .../connections/[id]/components/repoList.tsx | 2 +- packages/web/src/app/[domain]/page.tsx | 2 +- .../app/[domain]/repos/repositoryTable.tsx | 6 +- packages/web/src/app/api/(client)/client.ts | 11 +- .../web/src/app/api/(server)/repos/route.ts | 23 +- packages/web/src/features/chat/tools.ts | 4 +- .../web/src/features/search/listReposApi.ts | 49 -- packages/web/src/features/search/schemas.ts | 11 - packages/web/src/features/search/types.ts | 4 - packages/web/src/lib/schemas.ts | 5 +- packages/web/src/lib/types.ts | 5 +- packages/web/src/withAuthV2.test.ts | 733 ++++++++++++++++++ packages/web/src/withAuthV2.ts | 196 +++++ yarn.lock | 25 + 27 files changed, 1112 insertions(+), 164 deletions(-) create mode 100644 packages/web/src/__mocks__/prisma.ts delete mode 100644 packages/web/src/features/search/listReposApi.ts create mode 100644 packages/web/src/withAuthV2.test.ts create mode 100644 packages/web/src/withAuthV2.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index d75c99e2..13defaca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed Bitbucket Cloud pagination not working beyond first page. [#295](https://github.com/sourcebot-dev/sourcebot/issues/295) - Fixed search bar line wrapping. [#501](https://github.com/sourcebot-dev/sourcebot/pull/501) +- Fixed carousel perf issues. [#507](https://github.com/sourcebot-dev/sourcebot/pull/507) ## [4.6.7] - 2025-09-08 diff --git a/packages/mcp/src/index.ts b/packages/mcp/src/index.ts index 21c52b78..4411b580 100644 --- a/packages/mcp/src/index.ts +++ b/packages/mcp/src/index.ts @@ -161,10 +161,10 @@ server.tool( }; } - const content: TextContent[] = response.repos.map(repo => { + const content: TextContent[] = response.map(repo => { return { type: "text", - text: `id: ${repo.name}\nurl: ${repo.webUrl}`, + text: `id: ${repo.repoName}\nurl: ${repo.webUrl}`, } }); diff --git a/packages/mcp/src/schemas.ts b/packages/mcp/src/schemas.ts index ad94b7dc..75413cce 100644 --- a/packages/mcp/src/schemas.ts +++ b/packages/mcp/src/schemas.ts @@ -92,16 +92,34 @@ export const searchResponseSchema = z.object({ isBranchFilteringEnabled: z.boolean(), }); -export const repositorySchema = z.object({ - name: z.string(), - branches: z.array(z.string()), +enum RepoIndexingStatus { + NEW = 'NEW', + IN_INDEX_QUEUE = 'IN_INDEX_QUEUE', + INDEXING = 'INDEXING', + INDEXED = 'INDEXED', + FAILED = 'FAILED', + IN_GC_QUEUE = 'IN_GC_QUEUE', + GARBAGE_COLLECTING = 'GARBAGE_COLLECTING', + GARBAGE_COLLECTION_FAILED = 'GARBAGE_COLLECTION_FAILED' +} + +export const repositoryQuerySchema = z.object({ + codeHostType: z.string(), + repoId: z.number(), + repoName: z.string(), + repoDisplayName: z.string().optional(), + repoCloneUrl: z.string(), webUrl: z.string().optional(), - rawConfig: z.record(z.string(), z.string()).optional(), + linkedConnections: z.array(z.object({ + id: z.number(), + name: z.string(), + })), + imageUrl: z.string().optional(), + indexedAt: z.coerce.date().optional(), + repoIndexingStatus: z.nativeEnum(RepoIndexingStatus), }); -export const listRepositoriesResponseSchema = z.object({ - repos: z.array(repositorySchema), -}); +export const listRepositoriesResponseSchema = repositoryQuerySchema.array(); export const fileSourceRequestSchema = z.object({ fileName: z.string(), diff --git a/packages/mcp/src/types.ts b/packages/mcp/src/types.ts index f789c8c1..9c858fe5 100644 --- a/packages/mcp/src/types.ts +++ b/packages/mcp/src/types.ts @@ -22,7 +22,6 @@ export type SearchResultChunk = SearchResultFile["chunks"][number]; export type SearchSymbol = z.infer; export type ListRepositoriesResponse = z.infer; -export type Repository = ListRepositoriesResponse["repos"][number]; export type FileSourceRequest = z.infer; export type FileSourceResponse = z.infer; diff --git a/packages/web/package.json b/packages/web/package.json index 279ffa42..ac9e9d95 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -7,7 +7,7 @@ "build": "cross-env SKIP_ENV_VALIDATION=1 next build", "start": "next start", "lint": "cross-env SKIP_ENV_VALIDATION=1 eslint .", - "test": "vitest", + "test": "cross-env SKIP_ENV_VALIDATION=1 vitest", "dev:emails": "email dev --dir ./src/emails", "stripe:listen": "stripe listen --forward-to http://localhost:3000/api/stripe" }, @@ -212,7 +212,8 @@ "tsx": "^4.19.2", "typescript": "^5", "vite-tsconfig-paths": "^5.1.3", - "vitest": "^2.1.5" + "vitest": "^2.1.5", + "vitest-mock-extended": "^3.1.0" }, "resolutions": { "@types/react": "19.1.10", diff --git a/packages/web/src/__mocks__/prisma.ts b/packages/web/src/__mocks__/prisma.ts new file mode 100644 index 00000000..66470017 --- /dev/null +++ b/packages/web/src/__mocks__/prisma.ts @@ -0,0 +1,48 @@ +import { SINGLE_TENANT_ORG_DOMAIN, SINGLE_TENANT_ORG_ID, SINGLE_TENANT_ORG_NAME } from '@/lib/constants'; +import { ApiKey, Org, PrismaClient, User } from '@prisma/client'; +import { beforeEach } from 'vitest'; +import { mockDeep, mockReset } from 'vitest-mock-extended'; + +beforeEach(() => { + mockReset(prisma); +}); + +export const prisma = mockDeep(); + +export const MOCK_ORG: Org = { + id: SINGLE_TENANT_ORG_ID, + name: SINGLE_TENANT_ORG_NAME, + domain: SINGLE_TENANT_ORG_DOMAIN, + createdAt: new Date(), + updatedAt: new Date(), + isOnboarded: true, + imageUrl: null, + metadata: null, + memberApprovalRequired: false, + stripeCustomerId: null, + stripeSubscriptionStatus: null, + stripeLastUpdatedAt: null, + inviteLinkEnabled: false, + inviteLinkId: null +} + +export const MOCK_API_KEY: ApiKey = { + name: 'Test API Key', + hash: 'apikey', + createdAt: new Date(), + lastUsedAt: new Date(), + orgId: 1, + createdById: '1', +} + +export const MOCK_USER: User = { + id: '1', + name: 'Test User', + email: 'test@test.com', + createdAt: new Date(), + updatedAt: new Date(), + hashedPassword: null, + emailVerified: null, + image: null +} + diff --git a/packages/web/src/actions.ts b/packages/web/src/actions.ts index b23f3ae2..80094989 100644 --- a/packages/web/src/actions.ts +++ b/packages/web/src/actions.ts @@ -40,6 +40,7 @@ import { getAuditService } from "@/ee/features/audit/factory"; import { addUserToOrganization, orgHasAvailability } from "@/lib/authUtils"; import { getOrgMetadata } from "@/lib/utils"; import { getOrgFromDomain } from "./data/org"; +import { withOptionalAuthV2 } from "./withAuthV2"; const ajv = new Ajv({ validateFormats: false, @@ -637,49 +638,47 @@ export const getConnectionInfo = async (connectionId: number, domain: string) => } }))); -export const getRepos = async (domain: string, filter: { status?: RepoIndexingStatus[], connectionId?: number } = {}) => sew(() => - withAuth((userId) => - withOrgMembership(userId, domain, async ({ org }) => { - const repos = await prisma.repo.findMany({ - where: { - orgId: org.id, - ...(filter.status ? { - repoIndexingStatus: { in: filter.status } - } : {}), - ...(filter.connectionId ? { - connections: { - some: { - connectionId: filter.connectionId - } - } - } : {}), - }, - include: { +export const getRepos = async (filter: { status?: RepoIndexingStatus[], connectionId?: number } = {}) => sew(() => + withOptionalAuthV2(async ({ org }) => { + const repos = await prisma.repo.findMany({ + where: { + orgId: org.id, + ...(filter.status ? { + repoIndexingStatus: { in: filter.status } + } : {}), + ...(filter.connectionId ? { connections: { - include: { - connection: true, + some: { + connectionId: filter.connectionId } } + } : {}), + }, + include: { + connections: { + include: { + connection: true, + } } - }); + } + }); - return repos.map((repo) => repositoryQuerySchema.parse({ - codeHostType: repo.external_codeHostType, - repoId: repo.id, - repoName: repo.name, - repoDisplayName: repo.displayName ?? undefined, - repoCloneUrl: repo.cloneUrl, - webUrl: repo.webUrl ?? undefined, - linkedConnections: repo.connections.map(({ connection }) => ({ - id: connection.id, - name: connection.name, - })), - imageUrl: repo.imageUrl ?? undefined, - indexedAt: repo.indexedAt ?? undefined, - repoIndexingStatus: repo.repoIndexingStatus, - })); - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true - )); + return repos.map((repo) => repositoryQuerySchema.parse({ + codeHostType: repo.external_codeHostType, + repoId: repo.id, + repoName: repo.name, + repoDisplayName: repo.displayName ?? undefined, + repoCloneUrl: repo.cloneUrl, + webUrl: repo.webUrl ?? undefined, + linkedConnections: repo.connections.map(({ connection }) => ({ + id: connection.id, + name: connection.name, + })), + imageUrl: repo.imageUrl ?? undefined, + indexedAt: repo.indexedAt ?? undefined, + repoIndexingStatus: repo.repoIndexingStatus, + })) + })); export const getRepoInfoByName = async (repoName: string, domain: string) => sew(() => withAuth((userId) => diff --git a/packages/web/src/app/[domain]/chat/[id]/page.tsx b/packages/web/src/app/[domain]/chat/[id]/page.tsx index 9c59a242..d37076cd 100644 --- a/packages/web/src/app/[domain]/chat/[id]/page.tsx +++ b/packages/web/src/app/[domain]/chat/[id]/page.tsx @@ -22,7 +22,7 @@ interface PageProps { export default async function Page(props: PageProps) { const params = await props.params; const languageModels = await getConfiguredLanguageModelsInfo(); - const repos = await getRepos(params.domain); + const repos = await getRepos(); const searchContexts = await getSearchContexts(params.domain); const chatInfo = await getChatInfo({ chatId: params.id }, params.domain); const session = await auth(); diff --git a/packages/web/src/app/[domain]/chat/page.tsx b/packages/web/src/app/[domain]/chat/page.tsx index 74950101..8bdddf9e 100644 --- a/packages/web/src/app/[domain]/chat/page.tsx +++ b/packages/web/src/app/[domain]/chat/page.tsx @@ -18,7 +18,7 @@ interface PageProps { export default async function Page(props: PageProps) { const params = await props.params; const languageModels = await getConfiguredLanguageModelsInfo(); - const repos = await getRepos(params.domain); + const repos = await getRepos(); const searchContexts = await getSearchContexts(params.domain); const session = await auth(); const chatHistory = session ? await getUserChatHistory(params.domain) : []; diff --git a/packages/web/src/app/[domain]/components/errorNavIndicator.tsx b/packages/web/src/app/[domain]/components/errorNavIndicator.tsx index 2f024a61..4c1bd6b0 100644 --- a/packages/web/src/app/[domain]/components/errorNavIndicator.tsx +++ b/packages/web/src/app/[domain]/components/errorNavIndicator.tsx @@ -10,8 +10,8 @@ import { env } from "@/env.mjs"; import { useQuery } from "@tanstack/react-query"; import { ConnectionSyncStatus, RepoIndexingStatus } from "@sourcebot/db"; import { getConnections } from "@/actions"; -import { getRepos } from "@/actions"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import { getRepos } from "@/app/api/(client)/client"; export const ErrorNavIndicator = () => { const domain = useDomain(); @@ -19,7 +19,7 @@ export const ErrorNavIndicator = () => { const { data: repos, isPending: isPendingRepos, isError: isErrorRepos } = useQuery({ queryKey: ['repos', domain], - queryFn: () => unwrapServiceError(getRepos(domain)), + queryFn: () => unwrapServiceError(getRepos()), select: (data) => data.filter(repo => repo.repoIndexingStatus === RepoIndexingStatus.FAILED), refetchInterval: env.NEXT_PUBLIC_POLLING_INTERVAL_MS, }); diff --git a/packages/web/src/app/[domain]/components/homepage/repositorySnapshot.tsx b/packages/web/src/app/[domain]/components/homepage/repositorySnapshot.tsx index 6248a659..d6845fa2 100644 --- a/packages/web/src/app/[domain]/components/homepage/repositorySnapshot.tsx +++ b/packages/web/src/app/[domain]/components/homepage/repositorySnapshot.tsx @@ -5,7 +5,7 @@ import { RepositoryCarousel } from "./repositoryCarousel"; import { useDomain } from "@/hooks/useDomain"; import { useQuery } from "@tanstack/react-query"; import { unwrapServiceError } from "@/lib/utils"; -import { getRepos } from "@/actions"; +import { getRepos } from "@/app/api/(client)/client"; import { env } from "@/env.mjs"; import { Skeleton } from "@/components/ui/skeleton"; import { @@ -22,6 +22,8 @@ interface RepositorySnapshotProps { repos: RepositoryQuery[]; } +const MAX_REPOS_TO_DISPLAY_IN_CAROUSEL = 15; + export function RepositorySnapshot({ repos: initialRepos, }: RepositorySnapshotProps) { @@ -29,7 +31,7 @@ export function RepositorySnapshot({ const { data: repos, isPending, isError } = useQuery({ queryKey: ['repos', domain], - queryFn: () => unwrapServiceError(getRepos(domain)), + queryFn: () => unwrapServiceError(getRepos()), refetchInterval: env.NEXT_PUBLIC_POLLING_INTERVAL_MS, placeholderData: initialRepos, }); @@ -78,7 +80,9 @@ export function RepositorySnapshot({ {` indexed`} - + {process.env.NEXT_PUBLIC_SOURCEBOT_CLOUD_ENVIRONMENT === "demo" && (

Interested in using Sourcebot on your code? Check out our{' '} diff --git a/packages/web/src/app/[domain]/components/progressNavIndicator.tsx b/packages/web/src/app/[domain]/components/progressNavIndicator.tsx index 7d79d7bc..68a8611e 100644 --- a/packages/web/src/app/[domain]/components/progressNavIndicator.tsx +++ b/packages/web/src/app/[domain]/components/progressNavIndicator.tsx @@ -1,6 +1,5 @@ "use client"; -import { getRepos } from "@/actions"; import { HoverCard, HoverCardContent, HoverCardTrigger } from "@/components/ui/hover-card"; import useCaptureEvent from "@/hooks/useCaptureEvent"; import { useDomain } from "@/hooks/useDomain"; @@ -10,14 +9,15 @@ import { RepoIndexingStatus } from "@prisma/client"; import { useQuery } from "@tanstack/react-query"; import { Loader2Icon } from "lucide-react"; import Link from "next/link"; +import { getRepos } from "@/app/api/(client)/client"; export const ProgressNavIndicator = () => { const domain = useDomain(); const captureEvent = useCaptureEvent(); const { data: inProgressRepos, isPending, isError } = useQuery({ - queryKey: ['repos', domain], - queryFn: () => unwrapServiceError(getRepos(domain)), + queryKey: ['repos'], + queryFn: () => unwrapServiceError(getRepos()), select: (data) => data.filter(repo => repo.repoIndexingStatus === RepoIndexingStatus.IN_INDEX_QUEUE || repo.repoIndexingStatus === RepoIndexingStatus.INDEXING), refetchInterval: env.NEXT_PUBLIC_POLLING_INTERVAL_MS, }); diff --git a/packages/web/src/app/[domain]/components/searchBar/useSuggestionsData.ts b/packages/web/src/app/[domain]/components/searchBar/useSuggestionsData.ts index 04c2514d..a1f9c453 100644 --- a/packages/web/src/app/[domain]/components/searchBar/useSuggestionsData.ts +++ b/packages/web/src/app/[domain]/components/searchBar/useSuggestionsData.ts @@ -19,7 +19,7 @@ import { VscSymbolVariable } from "react-icons/vsc"; import { useSearchHistory } from "@/hooks/useSearchHistory"; -import { getDisplayTime, isServiceError } from "@/lib/utils"; +import { getDisplayTime, isServiceError, unwrapServiceError } from "@/lib/utils"; import { useDomain } from "@/hooks/useDomain"; @@ -37,12 +37,12 @@ export const useSuggestionsData = ({ }: Props) => { const domain = useDomain(); const { data: repoSuggestions, isLoading: _isLoadingRepos } = useQuery({ - queryKey: ["repoSuggestions", domain], - queryFn: () => getRepos(domain), + queryKey: ["repoSuggestions"], + queryFn: () => unwrapServiceError(getRepos()), select: (data): Suggestion[] => { - return data.repos + return data .map(r => ({ - value: r.name, + value: r.repoName, })); }, enabled: suggestionMode === "repo", diff --git a/packages/web/src/app/[domain]/connections/[id]/components/repoList.tsx b/packages/web/src/app/[domain]/connections/[id]/components/repoList.tsx index a1e49637..3e91443e 100644 --- a/packages/web/src/app/[domain]/connections/[id]/components/repoList.tsx +++ b/packages/web/src/app/[domain]/connections/[id]/components/repoList.tsx @@ -64,7 +64,7 @@ export const RepoList = ({ connectionId }: RepoListProps) => { const { data: unfilteredRepos, isPending: isReposPending, error: reposError, refetch: refetchRepos } = useQuery({ queryKey: ['repos', domain, connectionId], queryFn: async () => { - const repos = await unwrapServiceError(getRepos(domain, { connectionId })); + const repos = await unwrapServiceError(getRepos({ connectionId })); return repos.sort((a, b) => { const priorityA = getPriority(a.repoIndexingStatus); const priorityB = getPriority(b.repoIndexingStatus); diff --git a/packages/web/src/app/[domain]/page.tsx b/packages/web/src/app/[domain]/page.tsx index 09bf65a9..5f172bc0 100644 --- a/packages/web/src/app/[domain]/page.tsx +++ b/packages/web/src/app/[domain]/page.tsx @@ -30,7 +30,7 @@ export default async function Home(props: { params: Promise<{ domain: string }> const session = await auth(); const models = await getConfiguredLanguageModelsInfo(); - const repos = await getRepos(domain); + const repos = await getRepos(); const searchContexts = await getSearchContexts(domain); const chatHistory = session ? await getUserChatHistory(domain) : []; diff --git a/packages/web/src/app/[domain]/repos/repositoryTable.tsx b/packages/web/src/app/[domain]/repos/repositoryTable.tsx index 9ccf7298..45a20100 100644 --- a/packages/web/src/app/[domain]/repos/repositoryTable.tsx +++ b/packages/web/src/app/[domain]/repos/repositoryTable.tsx @@ -3,7 +3,6 @@ import { DataTable } from "@/components/ui/data-table"; import { columns, RepositoryColumnInfo } from "./columns"; import { unwrapServiceError } from "@/lib/utils"; -import { getRepos } from "@/actions"; import { useQuery } from "@tanstack/react-query"; import { useDomain } from "@/hooks/useDomain"; import { RepoIndexingStatus } from "@sourcebot/db"; @@ -14,6 +13,7 @@ import { Button } from "@/components/ui/button"; import { PlusIcon } from "lucide-react"; import { AddRepositoryDialog } from "./components/addRepositoryDialog"; import { useState } from "react"; +import { getRepos } from "@/app/api/(client)/client"; interface RepositoryTableProps { isAddReposButtonVisible: boolean @@ -26,9 +26,9 @@ export const RepositoryTable = ({ const [isAddDialogOpen, setIsAddDialogOpen] = useState(false); const { data: repos, isLoading: reposLoading, error: reposError } = useQuery({ - queryKey: ['repos', domain], + queryKey: ['repos'], queryFn: async () => { - return await unwrapServiceError(getRepos(domain)); + return await unwrapServiceError(getRepos()); }, refetchInterval: env.NEXT_PUBLIC_POLLING_INTERVAL_MS, refetchIntervalInBackground: true, diff --git a/packages/web/src/app/api/(client)/client.ts b/packages/web/src/app/api/(client)/client.ts index c23d0cde..3238c7c5 100644 --- a/packages/web/src/app/api/(client)/client.ts +++ b/packages/web/src/app/api/(client)/client.ts @@ -1,19 +1,17 @@ 'use client'; -import { getVersionResponseSchema } from "@/lib/schemas"; +import { getVersionResponseSchema, getReposResponseSchema } from "@/lib/schemas"; import { ServiceError } from "@/lib/serviceError"; -import { GetVersionResponse } from "@/lib/types"; +import { GetVersionResponse, GetReposResponse } from "@/lib/types"; import { isServiceError } from "@/lib/utils"; import { FileSourceResponse, FileSourceRequest, - ListRepositoriesResponse, SearchRequest, SearchResponse, } from "@/features/search/types"; import { fileSourceResponseSchema, - listRepositoriesResponseSchema, searchResponseSchema, } from "@/features/search/schemas"; @@ -47,16 +45,15 @@ export const fetchFileSource = async (body: FileSourceRequest, domain: string): return fileSourceResponseSchema.parse(result); } -export const getRepos = async (domain: string): Promise => { +export const getRepos = async (): Promise => { const result = await fetch("/api/repos", { method: "GET", headers: { "Content-Type": "application/json", - "X-Org-Domain": domain, }, }).then(response => response.json()); - return listRepositoriesResponseSchema.parse(result); + return getReposResponseSchema.parse(result); } export const getVersion = async (): Promise => { diff --git a/packages/web/src/app/api/(server)/repos/route.ts b/packages/web/src/app/api/(server)/repos/route.ts index 6673f0eb..acc3f9ce 100644 --- a/packages/web/src/app/api/(server)/repos/route.ts +++ b/packages/web/src/app/api/(server)/repos/route.ts @@ -1,24 +1,11 @@ -'use server'; - -import { listRepositories } from "@/features/search/listReposApi"; -import { NextRequest } from "next/server"; -import { isServiceError } from "@/lib/utils"; +import { getRepos } from "@/actions"; import { serviceErrorResponse } from "@/lib/serviceError"; -import { StatusCodes } from "http-status-codes"; -import { ErrorCode } from "@/lib/errorCodes"; +import { isServiceError } from "@/lib/utils"; +import { GetReposResponse } from "@/lib/types"; -export const GET = async (request: NextRequest) => { - const domain = request.headers.get("X-Org-Domain"); - const apiKey = request.headers.get("X-Sourcebot-Api-Key") ?? undefined; - if (!domain) { - return serviceErrorResponse({ - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.MISSING_ORG_DOMAIN_HEADER, - message: "Missing X-Org-Domain header", - }); - } - const response = await listRepositories(domain, apiKey); +export const GET = async () => { + const response: GetReposResponse = await getRepos(); if (isServiceError(response)) { return serviceErrorResponse(response); } diff --git a/packages/web/src/features/chat/tools.ts b/packages/web/src/features/chat/tools.ts index 34cf1a23..608c1213 100644 --- a/packages/web/src/features/chat/tools.ts +++ b/packages/web/src/features/chat/tools.ts @@ -221,7 +221,7 @@ export const searchReposTool = tool({ limit: z.number().default(10).describe("Maximum number of repositories to return (default: 10)") }), execute: async ({ query, limit }) => { - const reposResponse = await getRepos(SINGLE_TENANT_ORG_DOMAIN); + const reposResponse = await getRepos(); if (isServiceError(reposResponse)) { return reposResponse; @@ -255,7 +255,7 @@ export const listAllReposTool = tool({ description: `Lists all repositories in the codebase. This provides a complete overview of all available repositories.`, inputSchema: z.object({}), execute: async () => { - const reposResponse = await getRepos(SINGLE_TENANT_ORG_DOMAIN); + const reposResponse = await getRepos(); if (isServiceError(reposResponse)) { return reposResponse; diff --git a/packages/web/src/features/search/listReposApi.ts b/packages/web/src/features/search/listReposApi.ts deleted file mode 100644 index 90a685bd..00000000 --- a/packages/web/src/features/search/listReposApi.ts +++ /dev/null @@ -1,49 +0,0 @@ -import { OrgRole } from "@sourcebot/db"; -import { invalidZoektResponse, ServiceError } from "../../lib/serviceError"; -import { ListRepositoriesResponse } from "./types"; -import { zoektFetch } from "./zoektClient"; -import { zoektListRepositoriesResponseSchema } from "./zoektSchema"; -import { sew, withAuth, withOrgMembership } from "@/actions"; - -export const listRepositories = async (domain: string, apiKey: string | undefined = undefined): Promise => sew(() => - withAuth((userId, _apiKeyHash) => - withOrgMembership(userId, domain, async ({ org }) => { - const body = JSON.stringify({ - opts: { - Field: 0, - } - }); - - let header: Record = {}; - header = { - "X-Tenant-ID": org.id.toString() - }; - - const listResponse = await zoektFetch({ - path: "/api/list", - body, - header, - method: "POST", - cache: "no-store", - }); - - if (!listResponse.ok) { - return invalidZoektResponse(listResponse); - } - - const listBody = await listResponse.json(); - - const parser = zoektListRepositoriesResponseSchema.transform(({ List }) => ({ - repos: List.Repos.map((repo) => ({ - name: repo.Repository.Name, - webUrl: repo.Repository.URL.length > 0 ? repo.Repository.URL : undefined, - branches: repo.Repository.Branches?.map((branch) => branch.Name) ?? [], - rawConfig: repo.Repository.RawConfig ?? undefined, - })) - } satisfies ListRepositoriesResponse)); - - const result = parser.parse(listBody); - - return result; - }, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true, apiKey ? { apiKey, domain } : undefined) -); diff --git a/packages/web/src/features/search/schemas.ts b/packages/web/src/features/search/schemas.ts index 7b84081e..18dfd8d4 100644 --- a/packages/web/src/features/search/schemas.ts +++ b/packages/web/src/features/search/schemas.ts @@ -94,17 +94,6 @@ export const searchResponseSchema = z.object({ isBranchFilteringEnabled: z.boolean(), }); -export const repositorySchema = z.object({ - name: z.string(), - branches: z.array(z.string()), - webUrl: z.string().optional(), - rawConfig: z.record(z.string(), z.string()).optional(), -}); - -export const listRepositoriesResponseSchema = z.object({ - repos: z.array(repositorySchema), -}); - export const fileSourceRequestSchema = z.object({ fileName: z.string(), repository: z.string(), diff --git a/packages/web/src/features/search/types.ts b/packages/web/src/features/search/types.ts index 5271b94b..f9af8dbe 100644 --- a/packages/web/src/features/search/types.ts +++ b/packages/web/src/features/search/types.ts @@ -1,7 +1,6 @@ // @NOTE : Please keep this file in sync with @sourcebot/mcp/src/types.ts import { fileSourceResponseSchema, - listRepositoriesResponseSchema, locationSchema, searchRequestSchema, searchResponseSchema, @@ -19,9 +18,6 @@ export type SearchResultFile = SearchResponse["files"][number]; export type SearchResultChunk = SearchResultFile["chunks"][number]; export type SearchSymbol = z.infer; -export type ListRepositoriesResponse = z.infer; -export type Repository = ListRepositoriesResponse["repos"][number]; - export type FileSourceRequest = z.infer; export type FileSourceResponse = z.infer; diff --git a/packages/web/src/lib/schemas.ts b/packages/web/src/lib/schemas.ts index 6c66e0fc..56af1084 100644 --- a/packages/web/src/lib/schemas.ts +++ b/packages/web/src/lib/schemas.ts @@ -2,6 +2,7 @@ import { checkIfOrgDomainExists } from "@/actions"; import { RepoIndexingStatus } from "@sourcebot/db"; import { z } from "zod"; import { isServiceError } from "./utils"; +import { serviceErrorSchema } from "./serviceError"; export const secretCreateRequestSchema = z.object({ key: z.string(), @@ -24,7 +25,7 @@ export const repositoryQuerySchema = z.object({ name: z.string(), })), imageUrl: z.string().optional(), - indexedAt: z.date().optional(), + indexedAt: z.coerce.date().optional(), repoIndexingStatus: z.nativeEnum(RepoIndexingStatus), }); @@ -74,3 +75,5 @@ export const orgDomainSchema = z.string() export const getVersionResponseSchema = z.object({ version: z.string(), }); + +export const getReposResponseSchema = z.union([repositoryQuerySchema.array(), serviceErrorSchema]); \ No newline at end of file diff --git a/packages/web/src/lib/types.ts b/packages/web/src/lib/types.ts index 043b27df..545dbbf4 100644 --- a/packages/web/src/lib/types.ts +++ b/packages/web/src/lib/types.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import { getVersionResponseSchema, repositoryQuerySchema, searchContextQuerySchema } from "./schemas"; +import { getReposResponseSchema, getVersionResponseSchema, repositoryQuerySchema, searchContextQuerySchema } from "./schemas"; import { tenancyModeSchema } from "@/env.mjs"; export type KeymapType = "default" | "vim"; @@ -26,4 +26,5 @@ export type NewsItem = { export type TenancyMode = z.infer; export type RepositoryQuery = z.infer; -export type SearchContextQuery = z.infer; \ No newline at end of file +export type SearchContextQuery = z.infer; +export type GetReposResponse = z.infer; \ No newline at end of file diff --git a/packages/web/src/withAuthV2.test.ts b/packages/web/src/withAuthV2.test.ts new file mode 100644 index 00000000..7056cbef --- /dev/null +++ b/packages/web/src/withAuthV2.test.ts @@ -0,0 +1,733 @@ +import { expect, test, vi, beforeEach, describe } from 'vitest'; +import { Session } from 'next-auth'; +import { notAuthenticated } from './lib/serviceError'; +import { getAuthContext, getAuthenticatedUser, withAuthV2, withOptionalAuthV2 } from './withAuthV2'; +import { MOCK_API_KEY, MOCK_ORG, MOCK_USER, prisma } from './__mocks__/prisma'; +import { OrgRole } from '@sourcebot/db'; + +const mocks = vi.hoisted(() => { + return { + // Defaults to a empty session. + auth: vi.fn(async (): Promise => null), + headers: vi.fn(async (): Promise => new Headers()), + hasEntitlement: vi.fn((_entitlement: string) => false), + } +}); + +vi.mock('./auth', () => ({ + auth: mocks.auth, +})); + +vi.mock('@/env.mjs', () => ({ + env: {} +})); + +vi.mock('next/headers', () => ({ + headers: mocks.headers, +})); + +vi.mock('@/env.mjs', () => ({ + env: {} +})); + +vi.mock('@/prisma', async () => { + // @see: https://github.com/prisma/prisma/discussions/20244#discussioncomment-7976447 + const actual = await vi.importActual('@/__mocks__/prisma'); + return { + ...actual, + }; +}); + +vi.mock('@sourcebot/crypto', () => ({ + hashSecret: vi.fn((secret: string) => secret), +})); + +vi.mock('server-only', () => ({ + default: vi.fn(), +})); + +vi.mock('@sourcebot/shared', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); + +// Test utility to set the mock session +const setMockSession = (session: Session | null) => { + mocks.auth.mockResolvedValue(session); +}; + +const setMockHeaders = (headers: Headers) => { + mocks.headers.mockResolvedValue(headers); +}; + +// Helper to create mock session objects +const createMockSession = (overrides: Partial = {}): Session => ({ + user: { + id: 'test-user-id', + email: 'test@example.com', + name: 'Test User', + image: null, + ...overrides.user, + }, + expires: '2099-01-01T00:00:00.000Z', + ...overrides, +}); + + +beforeEach(() => { + vi.clearAllMocks(); + mocks.auth.mockResolvedValue(null); + mocks.headers.mockResolvedValue(new Headers()); +}); + +describe('getAuthenticatedUser', () => { + test('should return a user object if a valid session is present', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + const user = await getAuthenticatedUser(); + expect(user).not.toBeUndefined(); + expect(user?.id).toBe(userId); + }); + + test('should return a user object if a valid api key is present', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + + setMockHeaders(new Headers({ 'X-Sourcebot-Api-Key': 'sourcebot-apikey' })); + const user = await getAuthenticatedUser(); + expect(user).not.toBeUndefined(); + expect(user?.id).toBe(userId); + expect(prisma.apiKey.update).toHaveBeenCalledWith({ + where: { + hash: 'apikey', + }, + data: { + lastUsedAt: expect.any(Date), + }, + }); + }); + + test('should return undefined if no session or api key is present', async () => { + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + }); + + test('should return undefined if a api key does not exist', async () => { + prisma.apiKey.findUnique.mockResolvedValue(null); + setMockHeaders(new Headers({ 'X-Sourcebot-Api-Key': 'sourcebot-apikey' })); + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + }); + + test('should return undefined if a api key is present but is invalid', async () => { + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'different-hash', + createdById: 'test-user-id', + }); + setMockHeaders(new Headers({ 'X-Sourcebot-Api-Key': 'sourcebot-apikey' })); + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + }); + + test('should return undefined if a valid session is present but the user is not found', async () => { + prisma.user.findUnique.mockResolvedValue(null); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + }); + + test('should return undefined if a valid api key is present but the user is not found', async () => { + prisma.user.findUnique.mockResolvedValue(null); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: 'test-user-id', + }); + setMockHeaders(new Headers({ 'X-Sourcebot-Api-Key': 'sourcebot-apikey' })); + const user = await getAuthenticatedUser(); + expect(user).toBeUndefined(); + }); +}); + +describe('getAuthContext', () => { + test('should return a auth context object if a valid session is present and the user is a member of the organization', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + const authContext = await getAuthContext(); + expect(authContext).not.toBeUndefined(); + expect(authContext).toStrictEqual({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.MEMBER, + }); + }); + + test('should return a auth context object if a valid session is present and the user is a member of the organization with OWNER role', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.OWNER, + }); + + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + const authContext = await getAuthContext(); + expect(authContext).not.toBeUndefined(); + expect(authContext).toStrictEqual({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.OWNER, + }); + }); + + test('should return a auth context object if a valid session is present and the user is not a member of the organization. The role should be GUEST.', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue(null); + + setMockSession(createMockSession({ user: { id: userId } })); + const authContext = await getAuthContext(); + expect(authContext).not.toBeUndefined(); + expect(authContext).toStrictEqual({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.GUEST, + }); + }); + + test('should return a auth context object if no auth session is present. The role should be GUEST and the user should be undefined.', async () => { + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue(null); + + const authContext = await getAuthContext(); + expect(authContext).not.toBeUndefined(); + expect(authContext).toStrictEqual({ + user: undefined, + org: MOCK_ORG, + role: OrgRole.GUEST, + }); + }); +}); + +describe('withAuthV2', () => { + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.MEMBER + }); + expect(result).toEqual(undefined); + }); + + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization with OWNER role', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.OWNER, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.OWNER + }); + expect(result).toEqual(undefined); + }); + + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization (api key)', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + setMockHeaders(new Headers({ 'X-Sourcebot-Api-Key': 'sourcebot-apikey' })); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.MEMBER + }); + expect(result).toEqual(undefined); + }); + + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization with OWNER role (api key)', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.OWNER, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + setMockHeaders(new Headers({ 'X-Sourcebot-Api-Key': 'sourcebot-apikey' })); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.OWNER + }); + expect(result).toEqual(undefined); + }); + + test('should return a service error if the user is a member of the organization but does not have a valid session', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + setMockSession(null); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).not.toHaveBeenCalled(); + expect(result).toStrictEqual(notAuthenticated()); + }); + + test('should return a service error if the user is a guest of the organization', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.GUEST, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).not.toHaveBeenCalled(); + expect(result).toStrictEqual(notAuthenticated()); + }); + + + test('should return a service error if the user is not a member of the organization (guest role)', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + // user is not a member of the organization + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withAuthV2(cb); + expect(cb).not.toHaveBeenCalled(); + expect(result).toStrictEqual(notAuthenticated()); + }); +}); + +describe('withOptionalAuthV2', () => { + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.MEMBER + }); + expect(result).toEqual(undefined); + }); + + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization with OWNER role', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.OWNER, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.OWNER + }); + expect(result).toEqual(undefined); + }); + + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization (api key)', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + setMockHeaders(new Headers({ 'X-Sourcebot-Api-Key': 'sourcebot-apikey' })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.MEMBER + }); + expect(result).toEqual(undefined); + }); + + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization with OWNER role (api key)', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.OWNER, + }); + prisma.apiKey.findUnique.mockResolvedValue({ + ...MOCK_API_KEY, + hash: 'apikey', + createdById: userId, + }); + setMockHeaders(new Headers({ 'X-Sourcebot-Api-Key': 'sourcebot-apikey' })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: MOCK_ORG, + role: OrgRole.OWNER + }); + expect(result).toEqual(undefined); + }); + + test('should return a service error if the user is a member of the organization but does not have a valid session', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + setMockSession(null); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).not.toHaveBeenCalled(); + expect(result).toStrictEqual(notAuthenticated()); + }); + + test('should return a service error if the user is a guest of the organization', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId: userId, + orgId: MOCK_ORG.id, + role: OrgRole.GUEST, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).not.toHaveBeenCalled(); + expect(result).toStrictEqual(notAuthenticated()); + }); + + + test('should return a service error if the user is not a member of the organization (guest role)', async () => { + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + // user is not a member of the organization + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).not.toHaveBeenCalled(); + expect(result).toStrictEqual(notAuthenticated()); + }); + + test('should call the callback with the auth context object if the user is a guest of the organization and the anonymous access entitlement is enabled', async () => { + mocks.hasEntitlement.mockReturnValue(true); + + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + metadata: { + anonymousAccessEnabled: true, + }, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).toHaveBeenCalledWith({ + user: { + ...MOCK_USER, + id: userId, + }, + org: { + ...MOCK_ORG, + metadata: { + anonymousAccessEnabled: true, + }, + }, + role: OrgRole.GUEST, + }); + expect(result).toEqual(undefined); + }); + + test('should return a service error when anonymousAccessEnabled is true but hasAnonymousAccessEntitlement is false', async () => { + mocks.hasEntitlement.mockReturnValue(false); + + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + metadata: { + anonymousAccessEnabled: true, + }, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).not.toHaveBeenCalled(); + expect(result).toStrictEqual(notAuthenticated()); + }); + + test('should return a service error when hasAnonymousAccessEntitlement is true but anonymousAccessEnabled is false', async () => { + mocks.hasEntitlement.mockReturnValue(true); + + const userId = 'test-user-id'; + prisma.user.findUnique.mockResolvedValue({ + ...MOCK_USER, + id: userId, + }); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + metadata: { + anonymousAccessEnabled: false, + }, + }); + setMockSession(createMockSession({ user: { id: 'test-user-id' } })); + + const cb = vi.fn(); + const result = await withOptionalAuthV2(cb); + expect(cb).not.toHaveBeenCalled(); + expect(result).toStrictEqual(notAuthenticated()); + }); +}); diff --git a/packages/web/src/withAuthV2.ts b/packages/web/src/withAuthV2.ts new file mode 100644 index 00000000..aba1a5d4 --- /dev/null +++ b/packages/web/src/withAuthV2.ts @@ -0,0 +1,196 @@ +import { prisma } from "@/prisma"; +import { hashSecret } from "@sourcebot/crypto"; +import { ApiKey, Org, OrgRole, User } from "@sourcebot/db"; +import { headers } from "next/headers"; +import { auth } from "./auth"; +import { notAuthenticated, notFound, ServiceError } from "./lib/serviceError"; +import { SINGLE_TENANT_ORG_ID } from "./lib/constants"; +import { StatusCodes } from "http-status-codes"; +import { ErrorCode } from "./lib/errorCodes"; +import { getOrgMetadata, isServiceError } from "./lib/utils"; +import { hasEntitlement } from "@sourcebot/shared"; + +interface OptionalAuthContext { + user?: User; + org: Org; + role: OrgRole; +} + +interface RequiredAuthContext { + user: User; + org: Org; + role: Omit; +} + +export const withAuthV2 = async (fn: (params: RequiredAuthContext) => Promise) => { + const authContext = await getAuthContext(); + + if (isServiceError(authContext)) { + return authContext; + } + + const { user, org, role } = authContext; + + if (!user || role === OrgRole.GUEST) { + return notAuthenticated(); + } + + return fn({ user, org, role }); +}; + +export const withOptionalAuthV2 = async (fn: (params: OptionalAuthContext) => Promise) => { + const authContext = await getAuthContext(); + if (isServiceError(authContext)) { + return authContext; + } + + const { user, org, role } = authContext; + + const hasAnonymousAccessEntitlement = hasEntitlement("anonymous-access"); + const orgMetadata = getOrgMetadata(org); + + if ( + ( + !user || + role === OrgRole.GUEST + ) && ( + !hasAnonymousAccessEntitlement || + !orgMetadata?.anonymousAccessEnabled + ) + ) { + return notAuthenticated(); + } + + return fn({ user, org, role }); +}; + +export const getAuthContext = async (): Promise => { + const user = await getAuthenticatedUser(); + + const org = await prisma.org.findUnique({ + where: { + id: SINGLE_TENANT_ORG_ID, + } + }); + + if (!org) { + return notFound("Organization not found"); + } + + const membership = user ? await prisma.userToOrg.findUnique({ + where: { + orgId_userId: { + orgId: org.id, + userId: user.id, + }, + }, + }) : null; + + return { + user: user ?? undefined, + org, + role: membership?.role ?? OrgRole.GUEST, + }; +}; + +export const getAuthenticatedUser = async () => { + // First, check if we have a valid JWT session. + const session = await auth(); + if (session) { + const userId = session.user.id; + const user = await prisma.user.findUnique({ + where: { + id: userId, + } + }); + + return user ?? undefined; + } + + // If not, check if we have a valid API key. + const apiKeyString = (await headers()).get("X-Sourcebot-Api-Key") ?? undefined; + if (apiKeyString) { + const apiKey = await getVerifiedApiObject(apiKeyString); + if (!apiKey) { + return undefined; + } + + // Attempt to find the user associated with this api key. + const user = await prisma.user.findUnique({ + where: { + id: apiKey.createdById, + }, + }); + + if (!user) { + return undefined; + } + + // Update the last used at timestamp for this api key. + await prisma.apiKey.update({ + where: { + hash: apiKey.hash, + }, + data: { + lastUsedAt: new Date(), + }, + }); + + return user; + } + + return undefined; +} + +/** + * Returns a API key object if the API key string is valid, otherwise returns undefined. + */ +const getVerifiedApiObject = async (apiKeyString: string): Promise => { + const parts = apiKeyString.split("-"); + if (parts.length !== 2 || parts[0] !== "sourcebot") { + return undefined; + } + + const hash = hashSecret(parts[1]); + const apiKey = await prisma.apiKey.findUnique({ + where: { + hash, + }, + }); + + if (!apiKey) { + return undefined; + } + + return apiKey; +} + +export const withMinimumOrgRole = async ( + userRole: OrgRole, + minRequiredRole: OrgRole = OrgRole.MEMBER, + fn: () => Promise, +) => { + + const getAuthorizationPrecedence = (role: OrgRole): number => { + switch (role) { + case OrgRole.GUEST: + return 0; + case OrgRole.MEMBER: + return 1; + case OrgRole.OWNER: + return 2; + } + }; + + if ( + getAuthorizationPrecedence(userRole) < getAuthorizationPrecedence(minRequiredRole) + ) { + return { + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: "You do not have sufficient permissions to perform this action.", + } satisfies ServiceError; + } + + return fn(); +} diff --git a/yarn.lock b/yarn.lock index 8636b807..88b65afa 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6939,6 +6939,7 @@ __metadata: usehooks-ts: "npm:^3.1.0" vite-tsconfig-paths: "npm:^5.1.3" vitest: "npm:^2.1.5" + vitest-mock-extended: "npm:^3.1.0" vscode-icons-js: "npm:^11.6.1" zod: "npm:^3.25.74" zod-to-json-schema: "npm:^3.24.5" @@ -18337,6 +18338,18 @@ __metadata: languageName: node linkType: hard +"ts-essentials@npm:>=10.0.0": + version: 10.1.1 + resolution: "ts-essentials@npm:10.1.1" + peerDependencies: + typescript: ">=4.5.0" + peerDependenciesMeta: + typescript: + optional: true + checksum: 10c0/8c59148a03eae086e7b1454fa6895e94e2f71385089ccda7e1f720a586749ede7e49ff7338e5f27e44a79f4bed740cc5dc3ad59313769bec028a85fa985685ff + languageName: node + linkType: hard + "ts-interface-checker@npm:^0.1.9": version: 0.1.13 resolution: "ts-interface-checker@npm:0.1.13" @@ -18979,6 +18992,18 @@ __metadata: languageName: node linkType: hard +"vitest-mock-extended@npm:^3.1.0": + version: 3.1.0 + resolution: "vitest-mock-extended@npm:3.1.0" + dependencies: + ts-essentials: "npm:>=10.0.0" + peerDependencies: + typescript: 3.x || 4.x || 5.x + vitest: ">=3.0.0" + checksum: 10c0/1d73c15b26c11f06ec8d1e8d3c9c2c727725fe2238e936dc260f1be919e0be67b90c92cab3ce67c79536264c0ffe77ce14d66bce60244429cf6d2e6c8273e36b + languageName: node + linkType: hard + "vitest@npm:^2.1.5, vitest@npm:^2.1.9": version: 2.1.9 resolution: "vitest@npm:2.1.9"