add repo_sets filter for repositories a user has access to

This commit is contained in:
bkellam 2025-11-20 12:40:20 -08:00
parent aad3507cad
commit 3fd5f49045
4 changed files with 84 additions and 34 deletions

View file

@ -1 +1,3 @@
import type { User, Account } from ".prisma/client";
export type UserWithAccounts = User & { accounts: Account[] };
export * from ".prisma/client"; export * from ".prisma/client";

View file

@ -12,37 +12,45 @@ import { withOptionalAuthV2 } from "@/withAuthV2";
import * as grpc from '@grpc/grpc-js'; import * as grpc from '@grpc/grpc-js';
import * as protoLoader from '@grpc/proto-loader'; import * as protoLoader from '@grpc/proto-loader';
import * as Sentry from '@sentry/nextjs'; import * as Sentry from '@sentry/nextjs';
import { PrismaClient, Repo } from "@sourcebot/db"; import { PrismaClient, Repo, UserWithAccounts } from "@sourcebot/db";
import { createLogger, env } from "@sourcebot/shared"; import { createLogger, env, hasEntitlement } from "@sourcebot/shared";
import path from 'path'; import path from 'path';
import { parseQueryIntoLezerTree, transformLezerTreeToZoektGrpcQuery } from './query'; import { parseQueryIntoLezerTree, transformLezerTreeToZoektGrpcQuery } from './query';
import { RepositoryInfo, SearchRequest, SearchResponse, SearchResultFile, SearchStats, SourceRange, StreamedSearchResponse } from "./types"; import { RepositoryInfo, SearchRequest, SearchResponse, SearchResultFile, SearchStats, SourceRange, StreamedSearchResponse } from "./types";
import { FlushReason as ZoektFlushReason } from "@/proto/zoekt/webserver/v1/FlushReason"; import { FlushReason as ZoektFlushReason } from "@/proto/zoekt/webserver/v1/FlushReason";
import { RevisionExpr } from "@sourcebot/query-language"; import { RevisionExpr } from "@sourcebot/query-language";
import { getCodeHostBrowseFileAtBranchUrl } from "@/lib/utils"; import { getCodeHostBrowseFileAtBranchUrl } from "@/lib/utils";
import { getRepoPermissionFilterForUser } from "@/prisma";
const logger = createLogger("searchApi"); const logger = createLogger("searchApi");
export const search = (searchRequest: SearchRequest) => sew(() => export const search = (searchRequest: SearchRequest) => sew(() =>
withOptionalAuthV2(async ({ prisma }) => { withOptionalAuthV2(async ({ prisma, user }) => {
const repoSearchScope = await getAccessibleRepoNamesForUser({ user, prisma });
const zoektSearchRequest = await createZoektSearchRequest({ const zoektSearchRequest = await createZoektSearchRequest({
searchRequest, searchRequest,
prisma, prisma,
repoSearchScope,
}); });
logger.debug('zoektSearchRequest:', JSON.stringify(zoektSearchRequest, null, 2));
logger.debug(`zoektSearchRequest:\n${JSON.stringify(zoektSearchRequest, null, 2)}`);
return zoektSearch(zoektSearchRequest, prisma); return zoektSearch(zoektSearchRequest, prisma);
})); }));
export const streamSearch = (searchRequest: SearchRequest) => sew(() => export const streamSearch = (searchRequest: SearchRequest) => sew(() =>
withOptionalAuthV2(async ({ prisma }) => { withOptionalAuthV2(async ({ prisma, user }) => {
const repoSearchScope = await getAccessibleRepoNamesForUser({ user, prisma });
const zoektSearchRequest = await createZoektSearchRequest({ const zoektSearchRequest = await createZoektSearchRequest({
searchRequest, searchRequest,
prisma, prisma,
repoSearchScope,
}); });
logger.debug('zoektStreamSearchRequest:', JSON.stringify(zoektSearchRequest, null, 2)); console.log(`zoektStreamSearchRequest:\n${JSON.stringify(zoektSearchRequest, null, 2)}`);
return zoektStreamSearch(zoektSearchRequest, prisma); return zoektStreamSearch(zoektSearchRequest, prisma);
})); }));
@ -296,9 +304,9 @@ const transformZoektSearchResponse = async (response: ZoektGrpcSearchResponse, r
const repoId = getRepoIdForFile(file); const repoId = getRepoIdForFile(file);
const repo = reposMapCache.get(repoId); const repo = reposMapCache.get(repoId);
// This can happen if the user doesn't have access to the repository. // This should never happen.
if (!repo) { if (!repo) {
return undefined; throw new Error(`Repository not found for file: ${file.file_name}`);
} }
// @todo: address "file_name might not be a valid UTF-8 string" warning. // @todo: address "file_name might not be a valid UTF-8 string" warning.
@ -432,9 +440,12 @@ const getRepoIdForFile = (file: ZoektGrpcFileMatch): string | number => {
const createZoektSearchRequest = async ({ const createZoektSearchRequest = async ({
searchRequest, searchRequest,
prisma, prisma,
repoSearchScope,
}: { }: {
searchRequest: SearchRequest; searchRequest: SearchRequest;
prisma: PrismaClient; prisma: PrismaClient;
// Allows the caller to scope the search to a specific set of repositories.
repoSearchScope?: string[];
}) => { }) => {
const tree = parseQueryIntoLezerTree(searchRequest.query); const tree = parseQueryIntoLezerTree(searchRequest.query);
const zoektQuery = await transformLezerTreeToZoektGrpcQuery({ const zoektQuery = await transformLezerTreeToZoektGrpcQuery({
@ -487,6 +498,14 @@ const createZoektSearchRequest = async ({
exact: true, exact: true,
} }
}] : []), }] : []),
...(repoSearchScope ? [{
repo_set: {
set: repoSearchScope.reduce((acc, repo) => {
acc[repo] = true;
return acc;
}, {} as Record<string, boolean>)
}
}] : []),
] ]
} }
}, },
@ -542,6 +561,27 @@ const createZoektSearchRequest = async ({
return zoektSearchRequest; return zoektSearchRequest;
} }
/**
* Returns a list of repository names that the user has access to.
* If permission syncing is disabled, returns undefined.
*/
const getAccessibleRepoNamesForUser = async ({ user, prisma }: { user?: UserWithAccounts, prisma: PrismaClient }) => {
if (
env.EXPERIMENT_EE_PERMISSION_SYNC_ENABLED !== 'true' ||
!hasEntitlement('permission-syncing')
) {
return undefined;
}
const accessibleRepos = await prisma.repo.findMany({
where: getRepoPermissionFilterForUser(user),
select: {
name: true,
}
});
return accessibleRepos.map(repo => repo.name);
}
const createGrpcClient = (): WebserverServiceClient => { const createGrpcClient = (): WebserverServiceClient => {
// Path to proto files - these should match your monorepo structure // Path to proto files - these should match your monorepo structure
const protoBasePath = path.join(process.cwd(), '../../vendor/zoekt/grpc/protos'); const protoBasePath = path.join(process.cwd(), '../../vendor/zoekt/grpc/protos');

View file

@ -1,6 +1,6 @@
import 'server-only'; import 'server-only';
import { env, getDBConnectionString } from "@sourcebot/shared"; import { env, getDBConnectionString } from "@sourcebot/shared";
import { Prisma, PrismaClient } from "@sourcebot/db"; import { Prisma, PrismaClient, UserWithAccounts } from "@sourcebot/db";
import { hasEntitlement } from "@sourcebot/shared"; import { hasEntitlement } from "@sourcebot/shared";
// @see: https://authjs.dev/getting-started/adapters/prisma // @see: https://authjs.dev/getting-started/adapters/prisma
@ -32,7 +32,7 @@ if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma
* Creates a prisma client extension that scopes queries to striclty information * Creates a prisma client extension that scopes queries to striclty information
* a given user should be able to access. * a given user should be able to access.
*/ */
export const userScopedPrismaClientExtension = (accountIds?: string[]) => { export const userScopedPrismaClientExtension = (user?: UserWithAccounts) => {
return Prisma.defineExtension( return Prisma.defineExtension(
(prisma) => { (prisma) => {
return prisma.$extends({ return prisma.$extends({
@ -46,24 +46,7 @@ export const userScopedPrismaClientExtension = (accountIds?: string[]) => {
argsWithWhere.where = { argsWithWhere.where = {
...(argsWithWhere.where || {}), ...(argsWithWhere.where || {}),
OR: [ ...getRepoPermissionFilterForUser(user),
// Only include repos that are permitted to the user
...(accountIds ? [
{
permittedAccounts: {
some: {
accountId: {
in: accountIds,
}
}
}
},
] : []),
// or are public.
{
isPublic: true,
}
]
}; };
return query(args); return query(args);
@ -74,3 +57,29 @@ export const userScopedPrismaClientExtension = (accountIds?: string[]) => {
}) })
}) })
} }
/**
* Returns a filter for repositories that the user has access to.
*/
export const getRepoPermissionFilterForUser = (user?: UserWithAccounts): Prisma.RepoWhereInput => {
return {
OR: [
// Only include repos that are permitted to the user
...((user && user.accounts.length > 0) ? [
{
permittedAccounts: {
some: {
accountId: {
in: user.accounts.map(account => account.id),
}
}
}
},
] : []),
// or are public.
{
isPublic: true,
}
]
}
}

View file

@ -1,6 +1,6 @@
import { prisma as __unsafePrisma, userScopedPrismaClientExtension } from "@/prisma"; import { prisma as __unsafePrisma, userScopedPrismaClientExtension } from "@/prisma";
import { hashSecret } from "@sourcebot/shared"; import { hashSecret } from "@sourcebot/shared";
import { ApiKey, Org, OrgRole, PrismaClient, User } from "@sourcebot/db"; import { ApiKey, Org, OrgRole, PrismaClient, UserWithAccounts } from "@sourcebot/db";
import { headers } from "next/headers"; import { headers } from "next/headers";
import { auth } from "./auth"; import { auth } from "./auth";
import { notAuthenticated, notFound, ServiceError } from "./lib/serviceError"; import { notAuthenticated, notFound, ServiceError } from "./lib/serviceError";
@ -11,14 +11,14 @@ import { getOrgMetadata, isServiceError } from "./lib/utils";
import { hasEntitlement } from "@sourcebot/shared"; import { hasEntitlement } from "@sourcebot/shared";
interface OptionalAuthContext { interface OptionalAuthContext {
user?: User; user?: UserWithAccounts;
org: Org; org: Org;
role: OrgRole; role: OrgRole;
prisma: PrismaClient; prisma: PrismaClient;
} }
interface RequiredAuthContext { interface RequiredAuthContext {
user: User; user: UserWithAccounts;
org: Org; org: Org;
role: Exclude<OrgRole, 'GUEST'>; role: Exclude<OrgRole, 'GUEST'>;
prisma: PrismaClient; prisma: PrismaClient;
@ -88,8 +88,7 @@ export const getAuthContext = async (): Promise<OptionalAuthContext | ServiceErr
}, },
}) : null; }) : null;
const accountIds = user?.accounts.map(account => account.id); const prisma = __unsafePrisma.$extends(userScopedPrismaClientExtension(user)) as PrismaClient;
const prisma = __unsafePrisma.$extends(userScopedPrismaClientExtension(accountIds)) as PrismaClient;
return { return {
user: user ?? undefined, user: user ?? undefined,