wip: move permissions check to Prisma extension

This commit is contained in:
bkellam 2025-09-17 18:31:42 -07:00
parent 0b03f94f67
commit 671fd78360
14 changed files with 697 additions and 737 deletions

View file

@ -74,7 +74,7 @@ await repoManager.validateIndexedReposHaveShards();
const connectionManagerInterval = connectionManager.startScheduler(); const connectionManagerInterval = connectionManager.startScheduler();
const repoManagerInterval = repoManager.startScheduler(); const repoManagerInterval = repoManager.startScheduler();
const permissionSyncerInterval = env.EXPERIMENT_PERMISSION_SYNC_ENABLED ? permissionSyncer.startScheduler() : null; const permissionSyncerInterval = env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? permissionSyncer.startScheduler() : null;
const cleanup = async (signal: string) => { const cleanup = async (signal: string) => {

View file

@ -59,7 +59,7 @@ export class RepoPermissionSyncer {
} }
// @todo: make this configurable // @todo: make this configurable
}, 1000 * 5); }, 1000 * 60);
} }
public dispose() { public dispose() {

View file

@ -1,46 +1,45 @@
'use server'; 'use server';
import { getAuditService } from "@/ee/features/audit/factory";
import { env } from "@/env.mjs"; import { env } from "@/env.mjs";
import { addUserToOrganization, orgHasAvailability } from "@/lib/authUtils";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { notAuthenticated, notFound, orgNotFound, secretAlreadyExists, ServiceError, ServiceErrorException, unexpectedError } from "@/lib/serviceError"; import { notAuthenticated, notFound, orgNotFound, secretAlreadyExists, ServiceError, ServiceErrorException, unexpectedError } from "@/lib/serviceError";
import { CodeHostType, isHttpError, isServiceError } from "@/lib/utils"; import { CodeHostType, getOrgMetadata, isHttpError, isServiceError } from "@/lib/utils";
import { prisma } from "@/prisma"; import { prisma } from "@/prisma";
import { render } from "@react-email/components"; import { render } from "@react-email/components";
import * as Sentry from '@sentry/nextjs'; import * as Sentry from '@sentry/nextjs';
import { decrypt, encrypt, generateApiKey, hashSecret, getTokenFromConfig } from "@sourcebot/crypto"; import { decrypt, encrypt, generateApiKey, getTokenFromConfig, hashSecret } from "@sourcebot/crypto";
import { ConnectionSyncStatus, OrgRole, Prisma, RepoIndexingStatus, StripeSubscriptionStatus, Org, ApiKey } from "@sourcebot/db"; import { ApiKey, ConnectionSyncStatus, Org, OrgRole, Prisma, RepoIndexingStatus, StripeSubscriptionStatus } from "@sourcebot/db";
import { createLogger } from "@sourcebot/logger";
import { azuredevopsSchema } from "@sourcebot/schemas/v3/azuredevops.schema";
import { bitbucketSchema } from "@sourcebot/schemas/v3/bitbucket.schema";
import { ConnectionConfig } from "@sourcebot/schemas/v3/connection.type"; import { ConnectionConfig } from "@sourcebot/schemas/v3/connection.type";
import { genericGitHostSchema } from "@sourcebot/schemas/v3/genericGitHost.schema";
import { gerritSchema } from "@sourcebot/schemas/v3/gerrit.schema"; import { gerritSchema } from "@sourcebot/schemas/v3/gerrit.schema";
import { giteaSchema } from "@sourcebot/schemas/v3/gitea.schema"; import { giteaSchema } from "@sourcebot/schemas/v3/gitea.schema";
import { githubSchema } from "@sourcebot/schemas/v3/github.schema";
import { gitlabSchema } from "@sourcebot/schemas/v3/gitlab.schema";
import { azuredevopsSchema } from "@sourcebot/schemas/v3/azuredevops.schema";
import { GithubConnectionConfig } from "@sourcebot/schemas/v3/github.type";
import { GitlabConnectionConfig } from "@sourcebot/schemas/v3/gitlab.type";
import { GiteaConnectionConfig } from "@sourcebot/schemas/v3/gitea.type"; import { GiteaConnectionConfig } from "@sourcebot/schemas/v3/gitea.type";
import { githubSchema } from "@sourcebot/schemas/v3/github.schema";
import { GithubConnectionConfig } from "@sourcebot/schemas/v3/github.type";
import { gitlabSchema } from "@sourcebot/schemas/v3/gitlab.schema";
import { GitlabConnectionConfig } from "@sourcebot/schemas/v3/gitlab.type";
import { getPlan, hasEntitlement } from "@sourcebot/shared";
import Ajv from "ajv"; import Ajv from "ajv";
import { StatusCodes } from "http-status-codes"; import { StatusCodes } from "http-status-codes";
import { cookies, headers } from "next/headers"; import { cookies, headers } from "next/headers";
import { createTransport } from "nodemailer"; import { createTransport } from "nodemailer";
import { auth } from "./auth";
import { Octokit } from "octokit"; import { Octokit } from "octokit";
import { auth } from "./auth";
import { getConnection } from "./data/connection"; import { getConnection } from "./data/connection";
import { getOrgFromDomain } from "./data/org";
import { decrementOrgSeatCount, getSubscriptionForOrg } from "./ee/features/billing/serverUtils";
import { IS_BILLING_ENABLED } from "./ee/features/billing/stripe"; import { IS_BILLING_ENABLED } from "./ee/features/billing/stripe";
import InviteUserEmail from "./emails/inviteUserEmail"; import InviteUserEmail from "./emails/inviteUserEmail";
import JoinRequestApprovedEmail from "./emails/joinRequestApprovedEmail";
import JoinRequestSubmittedEmail from "./emails/joinRequestSubmittedEmail";
import { AGENTIC_SEARCH_TUTORIAL_DISMISSED_COOKIE_NAME, MOBILE_UNSUPPORTED_SPLASH_SCREEN_DISMISSED_COOKIE_NAME, SEARCH_MODE_COOKIE_NAME, SINGLE_TENANT_ORG_DOMAIN, SOURCEBOT_GUEST_USER_ID, SOURCEBOT_SUPPORT_EMAIL } from "./lib/constants"; import { AGENTIC_SEARCH_TUTORIAL_DISMISSED_COOKIE_NAME, MOBILE_UNSUPPORTED_SPLASH_SCREEN_DISMISSED_COOKIE_NAME, SEARCH_MODE_COOKIE_NAME, SINGLE_TENANT_ORG_DOMAIN, SOURCEBOT_GUEST_USER_ID, SOURCEBOT_SUPPORT_EMAIL } from "./lib/constants";
import { orgDomainSchema, orgNameSchema, repositoryQuerySchema } from "./lib/schemas"; import { orgDomainSchema, orgNameSchema, repositoryQuerySchema } from "./lib/schemas";
import { TenancyMode, ApiKeyPayload } from "./lib/types"; import { ApiKeyPayload, TenancyMode } from "./lib/types";
import { decrementOrgSeatCount, getSubscriptionForOrg } from "./ee/features/billing/serverUtils";
import { bitbucketSchema } from "@sourcebot/schemas/v3/bitbucket.schema";
import { genericGitHostSchema } from "@sourcebot/schemas/v3/genericGitHost.schema";
import { getPlan, hasEntitlement } from "@sourcebot/shared";
import JoinRequestSubmittedEmail from "./emails/joinRequestSubmittedEmail";
import JoinRequestApprovedEmail from "./emails/joinRequestApprovedEmail";
import { createLogger } from "@sourcebot/logger";
import { getAuditService } from "@/ee/features/audit/factory";
import { addUserToOrganization, orgHasAvailability } from "@/lib/authUtils";
import { getOrgMetadata } from "@/lib/utils";
import { getOrgFromDomain } from "./data/org";
import { withOptionalAuthV2 } from "./withAuthV2"; import { withOptionalAuthV2 } from "./withAuthV2";
const ajv = new Ajv({ const ajv = new Ajv({
@ -640,7 +639,7 @@ export const getConnectionInfo = async (connectionId: number, domain: string) =>
}))); })));
export const getRepos = async (filter: { status?: RepoIndexingStatus[], connectionId?: number } = {}) => sew(() => export const getRepos = async (filter: { status?: RepoIndexingStatus[], connectionId?: number } = {}) => sew(() =>
withOptionalAuthV2(async ({ org, user }) => { withOptionalAuthV2(async ({ org, prisma }) => {
const repos = await prisma.repo.findMany({ const repos = await prisma.repo.findMany({
where: { where: {
orgId: org.id, orgId: org.id,
@ -654,13 +653,6 @@ export const getRepos = async (filter: { status?: RepoIndexingStatus[], connecti
} }
} }
} : {}), } : {}),
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId: user?.id,
}
}
} : {})
}, },
include: { include: {
connections: { connections: {
@ -688,9 +680,8 @@ export const getRepos = async (filter: { status?: RepoIndexingStatus[], connecti
})) }))
})); }));
export const getRepoInfoByName = async (repoName: string, domain: string) => sew(() => export const getRepoInfoByName = async (repoName: string) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
// @note: repo names are represented by their remote url // @note: repo names are represented by their remote url
// on the code host. E.g.,: // on the code host. E.g.,:
// - github.com/sourcebot-dev/sourcebot // - github.com/sourcebot-dev/sourcebot
@ -730,13 +721,6 @@ export const getRepoInfoByName = async (repoName: string, domain: string) => sew
where: { where: {
name: repoName, name: repoName,
orgId: org.id, orgId: org.id,
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId: userId,
}
}
} : {})
}, },
}); });
@ -754,8 +738,7 @@ export const getRepoInfoByName = async (repoName: string, domain: string) => sew
indexedAt: repo.indexedAt ?? undefined, indexedAt: repo.indexedAt ?? undefined,
repoIndexingStatus: repo.repoIndexingStatus, repoIndexingStatus: repo.repoIndexingStatus,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true }));
));
export const createConnection = async (name: string, type: CodeHostType, connectionConfig: string, domain: string): Promise<{ id: number } | ServiceError> => sew(() => export const createConnection = async (name: string, type: CodeHostType, connectionConfig: string, domain: string): Promise<{ id: number } | ServiceError> => sew(() =>
withAuth((userId) => withAuth((userId) =>
@ -805,9 +788,8 @@ export const createConnection = async (name: string, type: CodeHostType, connect
}, OrgRole.OWNER) }, OrgRole.OWNER)
)); ));
export const experimental_addGithubRepositoryByUrl = async (repositoryUrl: string, domain: string): Promise<{ connectionId: number } | ServiceError> => sew(() => export const experimental_addGithubRepositoryByUrl = async (repositoryUrl: string): Promise<{ connectionId: number } | ServiceError> => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
if (env.EXPERIMENT_SELF_SERVE_REPO_INDEXING_ENABLED !== 'true') { if (env.EXPERIMENT_SELF_SERVE_REPO_INDEXING_ENABLED !== 'true') {
return { return {
statusCode: StatusCodes.BAD_REQUEST, statusCode: StatusCodes.BAD_REQUEST,
@ -904,13 +886,6 @@ export const experimental_addGithubRepositoryByUrl = async (repositoryUrl: strin
external_id: githubRepo.id.toString(), external_id: githubRepo.id.toString(),
external_codeHostType: 'github', external_codeHostType: 'github',
external_codeHostUrl: 'https://github.com', external_codeHostUrl: 'https://github.com',
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId: userId,
}
}
} : {})
} }
}); });
@ -947,8 +922,7 @@ export const experimental_addGithubRepositoryByUrl = async (repositoryUrl: strin
return { return {
connectionId: connection.id, connectionId: connection.id,
} }
}, OrgRole.GUEST), /* allowAnonymousAccess = */ true }));
));
export const updateConnectionDisplayName = async (connectionId: number, name: string, domain: string): Promise<{ success: boolean } | ServiceError> => sew(() => export const updateConnectionDisplayName = async (connectionId: number, name: string, domain: string): Promise<{ success: boolean } | ServiceError> => sew(() =>
withAuth((userId) => withAuth((userId) =>
@ -2043,20 +2017,12 @@ export const getSearchContexts = async (domain: string) => sew(() =>
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true }, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true
)); ));
export const getRepoImage = async (repoId: number, domain: string): Promise<ArrayBuffer | ServiceError> => sew(async () => { export const getRepoImage = async (repoId: number): Promise<ArrayBuffer | ServiceError> => sew(async () => {
return await withAuth(async (userId) => { return await withOptionalAuthV2(async ({ org, prisma }) => {
return await withOrgMembership(userId, domain, async ({ org }) => {
const repo = await prisma.repo.findUnique({ const repo = await prisma.repo.findUnique({
where: { where: {
id: repoId, id: repoId,
orgId: org.id, orgId: org.id,
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId: userId,
}
}
} : {})
}, },
include: { include: {
connections: { connections: {
@ -2117,8 +2083,7 @@ export const getRepoImage = async (repoId: number, domain: string): Promise<Arra
logger.error(`Error proxying image for repo ${repoId}:`, error); logger.error(`Error proxying image for repo ${repoId}:`, error);
return notFound(); return notFound();
} }
}, /* minRequiredRole = */ OrgRole.GUEST); })
}, /* allowAnonymousAccess = */ true);
}); });
export const getAnonymousAccessStatus = async (domain: string): Promise<boolean | ServiceError> => sew(async () => { export const getAnonymousAccessStatus = async (domain: string): Promise<boolean | ServiceError> => sew(async () => {

View file

@ -20,7 +20,7 @@ export const CodePreviewPanel = async ({ path, repoName, revisionName, domain }:
repository: repoName, repository: repoName,
branch: revisionName, branch: revisionName,
}, domain), }, domain),
getRepoInfoByName(repoName, domain), getRepoInfoByName(repoName),
]); ]);
if (isServiceError(fileSourceResponse) || isServiceError(repoInfoResponse)) { if (isServiceError(fileSourceResponse) || isServiceError(repoInfoResponse)) {

View file

@ -10,17 +10,16 @@ interface TreePreviewPanelProps {
path: string; path: string;
repoName: string; repoName: string;
revisionName?: string; revisionName?: string;
domain: string;
} }
export const TreePreviewPanel = async ({ path, repoName, revisionName, domain }: TreePreviewPanelProps) => { export const TreePreviewPanel = async ({ path, repoName, revisionName }: TreePreviewPanelProps) => {
const [repoInfoResponse, folderContentsResponse] = await Promise.all([ const [repoInfoResponse, folderContentsResponse] = await Promise.all([
getRepoInfoByName(repoName, domain), getRepoInfoByName(repoName),
getFolderContents({ getFolderContents({
repoName, repoName,
revisionName: revisionName ?? 'HEAD', revisionName: revisionName ?? 'HEAD',
path, path,
}, domain) })
]); ]);
if (isServiceError(folderContentsResponse) || isServiceError(repoInfoResponse)) { if (isServiceError(folderContentsResponse) || isServiceError(repoInfoResponse)) {

View file

@ -42,7 +42,6 @@ export default async function BrowsePage(props: BrowsePageProps) {
path={path} path={path}
repoName={repoName} repoName={repoName}
revisionName={revisionName} revisionName={revisionName}
domain={domain}
/> />
)} )}
</Suspense> </Suspense>

View file

@ -6,7 +6,6 @@ import { useHotkeys } from "react-hotkeys-hook";
import { useQuery } from "@tanstack/react-query"; import { useQuery } from "@tanstack/react-query";
import { unwrapServiceError } from "@/lib/utils"; import { unwrapServiceError } from "@/lib/utils";
import { FileTreeItem, getFiles } from "@/features/fileTree/actions"; import { FileTreeItem, getFiles } from "@/features/fileTree/actions";
import { useDomain } from "@/hooks/useDomain";
import { Dialog, DialogContent, DialogDescription, DialogTitle } from "@/components/ui/dialog"; import { Dialog, DialogContent, DialogDescription, DialogTitle } from "@/components/ui/dialog";
import { useBrowseNavigation } from "../hooks/useBrowseNavigation"; import { useBrowseNavigation } from "../hooks/useBrowseNavigation";
import { useBrowseState } from "../hooks/useBrowseState"; import { useBrowseState } from "../hooks/useBrowseState";
@ -28,7 +27,6 @@ type SearchResult = {
export const FileSearchCommandDialog = () => { export const FileSearchCommandDialog = () => {
const { repoName, revisionName } = useBrowseParams(); const { repoName, revisionName } = useBrowseParams();
const domain = useDomain();
const { state: { isFileSearchOpen }, updateBrowseState } = useBrowseState(); const { state: { isFileSearchOpen }, updateBrowseState } = useBrowseState();
const commandListRef = useRef<HTMLDivElement>(null); const commandListRef = useRef<HTMLDivElement>(null);
@ -57,8 +55,8 @@ export const FileSearchCommandDialog = () => {
}, [isFileSearchOpen]); }, [isFileSearchOpen]);
const { data: files, isLoading, isError } = useQuery({ const { data: files, isLoading, isError } = useQuery({
queryKey: ['files', repoName, revisionName, domain], queryKey: ['files', repoName, revisionName],
queryFn: () => unwrapServiceError(getFiles({ repoName, revisionName: revisionName ?? 'HEAD' }, domain)), queryFn: () => unwrapServiceError(getFiles({ repoName, revisionName: revisionName ?? 'HEAD' })),
enabled: isFileSearchOpen, enabled: isFileSearchOpen,
}); });

View file

@ -8,7 +8,6 @@ import { zodResolver } from "@hookform/resolvers/zod";
import { useForm } from "react-hook-form"; import { useForm } from "react-hook-form";
import { z } from "zod"; import { z } from "zod";
import { experimental_addGithubRepositoryByUrl } from "@/actions"; import { experimental_addGithubRepositoryByUrl } from "@/actions";
import { useDomain } from "@/hooks/useDomain";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { useToast } from "@/components/hooks/use-toast"; import { useToast } from "@/components/hooks/use-toast";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
@ -37,7 +36,6 @@ const formSchema = z.object({
}); });
export const AddRepositoryDialog = ({ isOpen, onOpenChange }: AddRepositoryDialogProps) => { export const AddRepositoryDialog = ({ isOpen, onOpenChange }: AddRepositoryDialogProps) => {
const domain = useDomain();
const { toast } = useToast(); const { toast } = useToast();
const router = useRouter(); const router = useRouter();
@ -52,7 +50,7 @@ export const AddRepositoryDialog = ({ isOpen, onOpenChange }: AddRepositoryDialo
const onSubmit = async (data: z.infer<typeof formSchema>) => { const onSubmit = async (data: z.infer<typeof formSchema>) => {
const result = await experimental_addGithubRepositoryByUrl(data.repositoryUrl.trim(), domain); const result = await experimental_addGithubRepositoryByUrl(data.repositoryUrl.trim());
if (isServiceError(result)) { if (isServiceError(result)) {
toast({ toast({
title: "Error adding repository", title: "Error adding repository",

View file

@ -3,18 +3,18 @@ import { isServiceError } from "@/lib/utils";
import { NextRequest } from "next/server"; import { NextRequest } from "next/server";
export async function GET( export async function GET(
request: NextRequest, _request: NextRequest,
props: { params: Promise<{ domain: string; repoId: string }> } props: { params: Promise<{ domain: string; repoId: string }> }
) { ) {
const params = await props.params; const params = await props.params;
const { domain, repoId } = params; const { repoId } = params;
const repoIdNum = parseInt(repoId); const repoIdNum = parseInt(repoId);
if (isNaN(repoIdNum)) { if (isNaN(repoIdNum)) {
return new Response("Invalid repo ID", { status: 400 }); return new Response("Invalid repo ID", { status: 400 });
} }
const result = await getRepoImage(repoIdNum, domain); const result = await getRepoImage(repoIdNum);
if (isServiceError(result)) { if (isServiceError(result)) {
return new Response(result.message, { status: result.statusCode }); return new Response(result.message, { status: result.statusCode });
} }

View file

@ -1,13 +1,13 @@
'use server'; 'use server';
import { sew, withAuth, withOrgMembership } from '@/actions'; import { sew } from '@/actions';
import { env } from '@/env.mjs'; import { env } from '@/env.mjs';
import { OrgRole, Repo } from '@sourcebot/db';
import { prisma } from '@/prisma';
import { notFound, unexpectedError } from '@/lib/serviceError'; import { notFound, unexpectedError } from '@/lib/serviceError';
import { simpleGit } from 'simple-git'; import { withOptionalAuthV2 } from '@/withAuthV2';
import path from 'path'; import { Repo } from '@sourcebot/db';
import { createLogger } from '@sourcebot/logger'; import { createLogger } from '@sourcebot/logger';
import path from 'path';
import { simpleGit } from 'simple-git';
const logger = createLogger('file-tree'); const logger = createLogger('file-tree');
@ -25,21 +25,13 @@ export type FileTreeNode = FileTreeItem & {
* Returns the tree of files (blobs) and directories (trees) for a given repository, * Returns the tree of files (blobs) and directories (trees) for a given repository,
* at a given revision. * at a given revision.
*/ */
export const getTree = async (params: { repoName: string, revisionName: string }, domain: string) => sew(() => export const getTree = async (params: { repoName: string, revisionName: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const { repoName, revisionName } = params; const { repoName, revisionName } = params;
const repo = await prisma.repo.findFirst({ const repo = await prisma.repo.findFirst({
where: { where: {
name: repoName, name: repoName,
orgId: org.id, orgId: org.id,
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId: userId,
}
}
} : {})
}, },
}); });
@ -84,28 +76,19 @@ export const getTree = async (params: { repoName: string, revisionName: string }
tree, tree,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true) }));
);
/** /**
* Returns the contents of a folder at a given path in a given repository, * Returns the contents of a folder at a given path in a given repository,
* at a given revision. * at a given revision.
*/ */
export const getFolderContents = async (params: { repoName: string, revisionName: string, path: string }, domain: string) => sew(() => export const getFolderContents = async (params: { repoName: string, revisionName: string, path: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const { repoName, revisionName, path } = params; const { repoName, revisionName, path } = params;
const repo = await prisma.repo.findFirst({ const repo = await prisma.repo.findFirst({
where: { where: {
name: repoName, name: repoName,
orgId: org.id, orgId: org.id,
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId: userId,
}
}
} : {})
}, },
}); });
@ -168,25 +151,16 @@ export const getFolderContents = async (params: { repoName: string, revisionName
}); });
return contents; return contents;
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true) }));
);
export const getFiles = async (params: { repoName: string, revisionName: string }, domain: string) => sew(() => export const getFiles = async (params: { repoName: string, revisionName: string }) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => {
const { repoName, revisionName } = params; const { repoName, revisionName } = params;
const repo = await prisma.repo.findFirst({ const repo = await prisma.repo.findFirst({
where: { where: {
name: repoName, name: repoName,
orgId: org.id, orgId: org.id,
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId: userId,
}
}
} : {})
}, },
}); });
@ -226,8 +200,7 @@ export const getFiles = async (params: { repoName: string, revisionName: string
return files; return files;
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true) }));
);
const buildFileTree = (flatList: { type: string, path: string }[]): FileTreeNode => { const buildFileTree = (flatList: { type: string, path: string }[]): FileTreeNode => {
const root: FileTreeNode = { const root: FileTreeNode = {

View file

@ -3,7 +3,6 @@
import { getTree } from "../actions"; import { getTree } from "../actions";
import { useQuery } from "@tanstack/react-query"; import { useQuery } from "@tanstack/react-query";
import { unwrapServiceError } from "@/lib/utils"; import { unwrapServiceError } from "@/lib/utils";
import { useDomain } from "@/hooks/useDomain";
import { ResizablePanel } from "@/components/ui/resizable"; import { ResizablePanel } from "@/components/ui/resizable";
import { Skeleton } from "@/components/ui/skeleton"; import { Skeleton } from "@/components/ui/skeleton";
import { useBrowseState } from "@/app/[domain]/browse/hooks/useBrowseState"; import { useBrowseState } from "@/app/[domain]/browse/hooks/useBrowseState";
@ -41,17 +40,16 @@ export const FileTreePanel = ({ order }: FileTreePanelProps) => {
updateBrowseState, updateBrowseState,
} = useBrowseState(); } = useBrowseState();
const domain = useDomain();
const { repoName, revisionName, path } = useBrowseParams(); const { repoName, revisionName, path } = useBrowseParams();
const fileTreePanelRef = useRef<ImperativePanelHandle>(null); const fileTreePanelRef = useRef<ImperativePanelHandle>(null);
const { data, isPending, isError } = useQuery({ const { data, isPending, isError } = useQuery({
queryKey: ['tree', repoName, revisionName, domain], queryKey: ['tree', repoName, revisionName],
queryFn: () => unwrapServiceError( queryFn: () => unwrapServiceError(
getTree({ getTree({
repoName, repoName,
revisionName: revisionName ?? 'HEAD', revisionName: revisionName ?? 'HEAD',
}, domain) })
), ),
}); });

View file

@ -4,14 +4,14 @@ import { env } from "@/env.mjs";
import { invalidZoektResponse, ServiceError } from "../../lib/serviceError"; import { invalidZoektResponse, ServiceError } from "../../lib/serviceError";
import { isServiceError } from "../../lib/utils"; import { isServiceError } from "../../lib/utils";
import { zoektFetch } from "./zoektClient"; import { zoektFetch } from "./zoektClient";
import { prisma } from "@/prisma";
import { ErrorCode } from "../../lib/errorCodes"; import { ErrorCode } from "../../lib/errorCodes";
import { StatusCodes } from "http-status-codes"; import { StatusCodes } from "http-status-codes";
import { zoektSearchResponseSchema } from "./zoektSchema"; import { zoektSearchResponseSchema } from "./zoektSchema";
import { SearchRequest, SearchResponse, SourceRange } from "./types"; import { SearchRequest, SearchResponse, SourceRange } from "./types";
import { OrgRole, Repo } from "@sourcebot/db"; import { PrismaClient, Repo } from "@sourcebot/db";
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/actions";
import { base64Decode } from "@sourcebot/shared"; import { base64Decode } from "@sourcebot/shared";
import { withOptionalAuthV2 } from "@/withAuthV2";
// List of supported query prefixes in zoekt. // List of supported query prefixes in zoekt.
// @see : https://github.com/sourcebot-dev/zoekt/blob/main/query/parse.go#L417 // @see : https://github.com/sourcebot-dev/zoekt/blob/main/query/parse.go#L417
@ -36,7 +36,7 @@ enum zoektPrefixes {
reposet = "reposet:", reposet = "reposet:",
} }
const transformZoektQuery = async (query: string, orgId: number): Promise<string | ServiceError> => { const transformZoektQuery = async (query: string, orgId: number, prisma: PrismaClient): Promise<string | ServiceError> => {
const prevQueryParts = query.split(" "); const prevQueryParts = query.split(" ");
const newQueryParts = []; const newQueryParts = [];
@ -127,10 +127,9 @@ const getFileWebUrl = (template: string, branch: string, fileName: string): stri
return encodeURI(url + optionalQueryParams); return encodeURI(url + optionalQueryParams);
} }
export const search = async ({ query, matches, contextLines, whole }: SearchRequest, domain: string, apiKey: string | undefined = undefined) => sew(() => export const search = async ({ query, matches, contextLines, whole }: SearchRequest) => sew(() =>
withAuth((userId, _apiKeyHash) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const transformedQuery = await transformZoektQuery(query, org.id, prisma);
const transformedQuery = await transformZoektQuery(query, org.id);
if (isServiceError(transformedQuery)) { if (isServiceError(transformedQuery)) {
return transformedQuery; return transformedQuery;
} }
@ -203,13 +202,6 @@ export const search = async ({ query, matches, contextLines, whole }: SearchRequ
in: Array.from(repoIdentifiers).filter((id) => typeof id === "number"), in: Array.from(repoIdentifiers).filter((id) => typeof id === "number"),
}, },
orgId: org.id, orgId: org.id,
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId,
}
}
} : {})
} }
})).forEach(repo => repos.set(repo.id, repo)); })).forEach(repo => repos.set(repo.id, repo));
@ -219,13 +211,6 @@ export const search = async ({ query, matches, contextLines, whole }: SearchRequ
in: Array.from(repoIdentifiers).filter((id) => typeof id === "string"), in: Array.from(repoIdentifiers).filter((id) => typeof id === "string"),
}, },
orgId: org.id, orgId: org.id,
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
permittedUsers: {
some: {
userId,
}
}
} : {})
} }
})).forEach(repo => repos.set(repo.name, repo)); })).forEach(repo => repos.set(repo.name, repo));
@ -357,5 +342,4 @@ export const search = async ({ query, matches, contextLines, whole }: SearchRequ
}); });
return parser.parseAsync(searchBody); return parser.parseAsync(searchBody);
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true, apiKey ? { apiKey, domain } : undefined) }));
);

View file

@ -1,7 +1,48 @@
import 'server-only'; import 'server-only';
import { PrismaClient } from "@sourcebot/db"; import { env } from "@/env.mjs";
import { Prisma, PrismaClient } from "@sourcebot/db";
// @see: https://authjs.dev/getting-started/adapters/prisma // @see: https://authjs.dev/getting-started/adapters/prisma
const globalForPrisma = globalThis as unknown as { prisma: PrismaClient } const globalForPrisma = globalThis as unknown as { prisma: PrismaClient }
// @NOTE: In almost all cases, the userScopedPrismaClientExtension should be used
// (since actions & queries are scoped to a particular user). There are some exceptions
// (e.g., in initialize.ts).
//
// @todo: we can mark this as `__unsafePrisma` in the future once we've migrated
// all of the actions & queries to use the userScopedPrismaClientExtension to avoid
// accidental misuse.
export const prisma = globalForPrisma.prisma || new PrismaClient() export const prisma = globalForPrisma.prisma || new PrismaClient()
if (process.env.NODE_ENV !== "production") globalForPrisma.prisma = prisma if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma
/**
* Creates a prisma client extension that scopes queries to striclty information
* a given user should be able to access.
*/
export const userScopedPrismaClientExtension = (userId?: string) => {
return Prisma.defineExtension(
(prisma) => {
return prisma.$extends({
query: {
...(env.EXPERIMENT_PERMISSION_SYNC_ENABLED === 'true' ? {
repo: {
$allOperations({ args, query }) {
if ('where' in args) {
args.where = {
...args.where,
permittedUsers: {
some: {
userId,
}
}
}
}
return query(args);
}
}
} : {})
}
})
})
}

View file

@ -1,6 +1,6 @@
import { prisma } from "@/prisma"; import { prisma as __unsafePrisma, userScopedPrismaClientExtension } from "@/prisma";
import { hashSecret } from "@sourcebot/crypto"; import { hashSecret } from "@sourcebot/crypto";
import { ApiKey, Org, OrgRole, User } from "@sourcebot/db"; import { ApiKey, Org, OrgRole, PrismaClient, User } from "@sourcebot/db";
import { headers } from "next/headers"; import { headers } from "next/headers";
import { auth } from "./auth"; import { auth } from "./auth";
import { notAuthenticated, notFound, ServiceError } from "./lib/serviceError"; import { notAuthenticated, notFound, ServiceError } from "./lib/serviceError";
@ -14,12 +14,14 @@ interface OptionalAuthContext {
user?: User; user?: User;
org: Org; org: Org;
role: OrgRole; role: OrgRole;
prisma: PrismaClient;
} }
interface RequiredAuthContext { interface RequiredAuthContext {
user: User; user: User;
org: Org; org: Org;
role: Omit<OrgRole, 'GUEST'>; role: Omit<OrgRole, 'GUEST'>;
prisma: PrismaClient;
} }
export const withAuthV2 = async <T>(fn: (params: RequiredAuthContext) => Promise<T>) => { export const withAuthV2 = async <T>(fn: (params: RequiredAuthContext) => Promise<T>) => {
@ -29,13 +31,13 @@ export const withAuthV2 = async <T>(fn: (params: RequiredAuthContext) => Promise
return authContext; return authContext;
} }
const { user, org, role } = authContext; const { user, org, role, prisma } = authContext;
if (!user || role === OrgRole.GUEST) { if (!user || role === OrgRole.GUEST) {
return notAuthenticated(); return notAuthenticated();
} }
return fn({ user, org, role }); return fn({ user, org, role, prisma });
}; };
export const withOptionalAuthV2 = async <T>(fn: (params: OptionalAuthContext) => Promise<T>) => { export const withOptionalAuthV2 = async <T>(fn: (params: OptionalAuthContext) => Promise<T>) => {
@ -44,7 +46,7 @@ export const withOptionalAuthV2 = async <T>(fn: (params: OptionalAuthContext) =>
return authContext; return authContext;
} }
const { user, org, role } = authContext; const { user, org, role, prisma } = authContext;
const hasAnonymousAccessEntitlement = hasEntitlement("anonymous-access"); const hasAnonymousAccessEntitlement = hasEntitlement("anonymous-access");
const orgMetadata = getOrgMetadata(org); const orgMetadata = getOrgMetadata(org);
@ -61,13 +63,13 @@ export const withOptionalAuthV2 = async <T>(fn: (params: OptionalAuthContext) =>
return notAuthenticated(); return notAuthenticated();
} }
return fn({ user, org, role }); return fn({ user, org, role, prisma });
}; };
export const getAuthContext = async (): Promise<OptionalAuthContext | ServiceError> => { export const getAuthContext = async (): Promise<OptionalAuthContext | ServiceError> => {
const user = await getAuthenticatedUser(); const user = await getAuthenticatedUser();
const org = await prisma.org.findUnique({ const org = await __unsafePrisma.org.findUnique({
where: { where: {
id: SINGLE_TENANT_ORG_ID, id: SINGLE_TENANT_ORG_ID,
} }
@ -77,7 +79,7 @@ export const getAuthContext = async (): Promise<OptionalAuthContext | ServiceErr
return notFound("Organization not found"); return notFound("Organization not found");
} }
const membership = user ? await prisma.userToOrg.findUnique({ const membership = user ? await __unsafePrisma.userToOrg.findUnique({
where: { where: {
orgId_userId: { orgId_userId: {
orgId: org.id, orgId: org.id,
@ -86,10 +88,13 @@ export const getAuthContext = async (): Promise<OptionalAuthContext | ServiceErr
}, },
}) : null; }) : null;
const prisma = __unsafePrisma.$extends(userScopedPrismaClientExtension(user?.id)) as PrismaClient;
return { return {
user: user ?? undefined, user: user ?? undefined,
org, org,
role: membership?.role ?? OrgRole.GUEST, role: membership?.role ?? OrgRole.GUEST,
prisma,
}; };
}; };
@ -98,7 +103,7 @@ export const getAuthenticatedUser = async () => {
const session = await auth(); const session = await auth();
if (session) { if (session) {
const userId = session.user.id; const userId = session.user.id;
const user = await prisma.user.findUnique({ const user = await __unsafePrisma.user.findUnique({
where: { where: {
id: userId, id: userId,
} }
@ -116,7 +121,7 @@ export const getAuthenticatedUser = async () => {
} }
// Attempt to find the user associated with this api key. // Attempt to find the user associated with this api key.
const user = await prisma.user.findUnique({ const user = await __unsafePrisma.user.findUnique({
where: { where: {
id: apiKey.createdById, id: apiKey.createdById,
}, },
@ -127,7 +132,7 @@ export const getAuthenticatedUser = async () => {
} }
// Update the last used at timestamp for this api key. // Update the last used at timestamp for this api key.
await prisma.apiKey.update({ await __unsafePrisma.apiKey.update({
where: { where: {
hash: apiKey.hash, hash: apiKey.hash,
}, },
@ -152,7 +157,7 @@ const getVerifiedApiObject = async (apiKeyString: string): Promise<ApiKey | unde
} }
const hash = hashSecret(parts[1]); const hash = hashSecret(parts[1]);
const apiKey = await prisma.apiKey.findUnique({ const apiKey = await __unsafePrisma.apiKey.findUnique({
where: { where: {
hash, hash,
}, },