diff --git a/packages/db/src/index.ts b/packages/db/src/index.ts index e7cb7554..245206d9 100644 --- a/packages/db/src/index.ts +++ b/packages/db/src/index.ts @@ -1 +1,3 @@ +import type { User, Account } from ".prisma/client"; +export type UserWithAccounts = User & { accounts: Account[] }; export * from ".prisma/client"; \ No newline at end of file diff --git a/packages/web/src/features/search/searchApi.ts b/packages/web/src/features/search/searchApi.ts index 696d3a10..e0b48955 100644 --- a/packages/web/src/features/search/searchApi.ts +++ b/packages/web/src/features/search/searchApi.ts @@ -12,37 +12,45 @@ import { withOptionalAuthV2 } from "@/withAuthV2"; import * as grpc from '@grpc/grpc-js'; import * as protoLoader from '@grpc/proto-loader'; import * as Sentry from '@sentry/nextjs'; -import { PrismaClient, Repo } from "@sourcebot/db"; -import { createLogger, env } from "@sourcebot/shared"; +import { PrismaClient, Repo, UserWithAccounts } from "@sourcebot/db"; +import { createLogger, env, hasEntitlement } from "@sourcebot/shared"; import path from 'path'; import { parseQueryIntoLezerTree, transformLezerTreeToZoektGrpcQuery } from './query'; import { RepositoryInfo, SearchRequest, SearchResponse, SearchResultFile, SearchStats, SourceRange, StreamedSearchResponse } from "./types"; import { FlushReason as ZoektFlushReason } from "@/proto/zoekt/webserver/v1/FlushReason"; import { RevisionExpr } from "@sourcebot/query-language"; import { getCodeHostBrowseFileAtBranchUrl } from "@/lib/utils"; +import { getRepoPermissionFilterForUser } from "@/prisma"; const logger = createLogger("searchApi"); export const search = (searchRequest: SearchRequest) => sew(() => - withOptionalAuthV2(async ({ prisma }) => { + withOptionalAuthV2(async ({ prisma, user }) => { + const repoSearchScope = await getAccessibleRepoNamesForUser({ user, prisma }); + const zoektSearchRequest = await createZoektSearchRequest({ searchRequest, prisma, + repoSearchScope, }); - logger.debug('zoektSearchRequest:', JSON.stringify(zoektSearchRequest, null, 2)); + + logger.debug(`zoektSearchRequest:\n${JSON.stringify(zoektSearchRequest, null, 2)}`); return zoektSearch(zoektSearchRequest, prisma); })); export const streamSearch = (searchRequest: SearchRequest) => sew(() => - withOptionalAuthV2(async ({ prisma }) => { + withOptionalAuthV2(async ({ prisma, user }) => { + const repoSearchScope = await getAccessibleRepoNamesForUser({ user, prisma }); + const zoektSearchRequest = await createZoektSearchRequest({ searchRequest, prisma, + repoSearchScope, }); - logger.debug('zoektStreamSearchRequest:', JSON.stringify(zoektSearchRequest, null, 2)); + console.log(`zoektStreamSearchRequest:\n${JSON.stringify(zoektSearchRequest, null, 2)}`); return zoektStreamSearch(zoektSearchRequest, prisma); })); @@ -296,9 +304,9 @@ const transformZoektSearchResponse = async (response: ZoektGrpcSearchResponse, r const repoId = getRepoIdForFile(file); const repo = reposMapCache.get(repoId); - // This can happen if the user doesn't have access to the repository. + // This should never happen. if (!repo) { - return undefined; + throw new Error(`Repository not found for file: ${file.file_name}`); } // @todo: address "file_name might not be a valid UTF-8 string" warning. @@ -432,9 +440,12 @@ const getRepoIdForFile = (file: ZoektGrpcFileMatch): string | number => { const createZoektSearchRequest = async ({ searchRequest, prisma, + repoSearchScope, }: { searchRequest: SearchRequest; prisma: PrismaClient; + // Allows the caller to scope the search to a specific set of repositories. + repoSearchScope?: string[]; }) => { const tree = parseQueryIntoLezerTree(searchRequest.query); const zoektQuery = await transformLezerTreeToZoektGrpcQuery({ @@ -487,6 +498,14 @@ const createZoektSearchRequest = async ({ exact: true, } }] : []), + ...(repoSearchScope ? [{ + repo_set: { + set: repoSearchScope.reduce((acc, repo) => { + acc[repo] = true; + return acc; + }, {} as Record) + } + }] : []), ] } }, @@ -542,6 +561,27 @@ const createZoektSearchRequest = async ({ return zoektSearchRequest; } +/** + * Returns a list of repository names that the user has access to. + * If permission syncing is disabled, returns undefined. + */ +const getAccessibleRepoNamesForUser = async ({ user, prisma }: { user?: UserWithAccounts, prisma: PrismaClient }) => { + if ( + env.EXPERIMENT_EE_PERMISSION_SYNC_ENABLED !== 'true' || + !hasEntitlement('permission-syncing') + ) { + return undefined; + } + + const accessibleRepos = await prisma.repo.findMany({ + where: getRepoPermissionFilterForUser(user), + select: { + name: true, + } + }); + return accessibleRepos.map(repo => repo.name); +} + const createGrpcClient = (): WebserverServiceClient => { // Path to proto files - these should match your monorepo structure const protoBasePath = path.join(process.cwd(), '../../vendor/zoekt/grpc/protos'); diff --git a/packages/web/src/prisma.ts b/packages/web/src/prisma.ts index 0d520de7..1de13668 100644 --- a/packages/web/src/prisma.ts +++ b/packages/web/src/prisma.ts @@ -1,6 +1,6 @@ import 'server-only'; import { env, getDBConnectionString } from "@sourcebot/shared"; -import { Prisma, PrismaClient } from "@sourcebot/db"; +import { Prisma, PrismaClient, UserWithAccounts } from "@sourcebot/db"; import { hasEntitlement } from "@sourcebot/shared"; // @see: https://authjs.dev/getting-started/adapters/prisma @@ -24,7 +24,7 @@ export const prisma = globalForPrisma.prisma || new PrismaClient({ url: dbConnectionString, }, } - }: {}), + } : {}), }) if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma @@ -32,7 +32,7 @@ if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma * Creates a prisma client extension that scopes queries to striclty information * a given user should be able to access. */ -export const userScopedPrismaClientExtension = (accountIds?: string[]) => { +export const userScopedPrismaClientExtension = (user?: UserWithAccounts) => { return Prisma.defineExtension( (prisma) => { return prisma.$extends({ @@ -46,24 +46,7 @@ export const userScopedPrismaClientExtension = (accountIds?: string[]) => { argsWithWhere.where = { ...(argsWithWhere.where || {}), - OR: [ - // Only include repos that are permitted to the user - ...(accountIds ? [ - { - permittedAccounts: { - some: { - accountId: { - in: accountIds, - } - } - } - }, - ] : []), - // or are public. - { - isPublic: true, - } - ] + ...getRepoPermissionFilterForUser(user), }; return query(args); @@ -74,3 +57,29 @@ export const userScopedPrismaClientExtension = (accountIds?: string[]) => { }) }) } + +/** + * Returns a filter for repositories that the user has access to. + */ +export const getRepoPermissionFilterForUser = (user?: UserWithAccounts): Prisma.RepoWhereInput => { + return { + OR: [ + // Only include repos that are permitted to the user + ...((user && user.accounts.length > 0) ? [ + { + permittedAccounts: { + some: { + accountId: { + in: user.accounts.map(account => account.id), + } + } + } + }, + ] : []), + // or are public. + { + isPublic: true, + } + ] + } +} diff --git a/packages/web/src/withAuthV2.ts b/packages/web/src/withAuthV2.ts index 1b055533..f1e22962 100644 --- a/packages/web/src/withAuthV2.ts +++ b/packages/web/src/withAuthV2.ts @@ -1,6 +1,6 @@ import { prisma as __unsafePrisma, userScopedPrismaClientExtension } from "@/prisma"; import { hashSecret } from "@sourcebot/shared"; -import { ApiKey, Org, OrgRole, PrismaClient, User } from "@sourcebot/db"; +import { ApiKey, Org, OrgRole, PrismaClient, UserWithAccounts } from "@sourcebot/db"; import { headers } from "next/headers"; import { auth } from "./auth"; import { notAuthenticated, notFound, ServiceError } from "./lib/serviceError"; @@ -11,14 +11,14 @@ import { getOrgMetadata, isServiceError } from "./lib/utils"; import { hasEntitlement } from "@sourcebot/shared"; interface OptionalAuthContext { - user?: User; + user?: UserWithAccounts; org: Org; role: OrgRole; prisma: PrismaClient; } interface RequiredAuthContext { - user: User; + user: UserWithAccounts; org: Org; role: Exclude; prisma: PrismaClient; @@ -88,8 +88,7 @@ export const getAuthContext = async (): Promise account.id); - const prisma = __unsafePrisma.$extends(userScopedPrismaClientExtension(accountIds)) as PrismaClient; + const prisma = __unsafePrisma.$extends(userScopedPrismaClientExtension(user)) as PrismaClient; return { user: user ?? undefined,