enforce tenancy on search and repo listing endpoints (#181)

* enforce tenancy on search and repo listing

* remove orgId from request schemas
This commit is contained in:
Michael Sukkarieh 2025-01-28 10:39:59 -08:00 committed by GitHub
parent a88f9e6677
commit 75d4189f25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 207 additions and 211 deletions

View file

@ -7,9 +7,11 @@
"build": "yarn workspaces run build", "build": "yarn workspaces run build",
"test": "yarn workspaces run test", "test": "yarn workspaces run test",
"dev": "cross-env SOURCEBOT_TENANT_MODE=single npm-run-all --print-label dev:start", "dev": "cross-env SOURCEBOT_TENANT_MODE=single npm-run-all --print-label dev:start",
"dev:mt": "cross-env SOURCEBOT_TENANT_MODE=multi npm-run-all --print-label dev:start", "dev:mt": "cross-env SOURCEBOT_TENANT_MODE=multi npm-run-all --print-label dev:start:mt",
"dev:start": "yarn workspace @sourcebot/db prisma:migrate:dev && cross-env npm-run-all --print-label --parallel dev:zoekt dev:backend dev:web", "dev:start": "yarn workspace @sourcebot/db prisma:migrate:dev && cross-env npm-run-all --print-label --parallel dev:zoekt dev:backend dev:web",
"dev:start:mt": "yarn workspace @sourcebot/db prisma:migrate:dev && cross-env npm-run-all --print-label --parallel dev:zoekt:mt dev:backend dev:web",
"dev:zoekt": "export PATH=\"$PWD/bin:$PATH\" && export SRC_TENANT_ENFORCEMENT_MODE=none && zoekt-webserver -index .sourcebot/index -rpc", "dev:zoekt": "export PATH=\"$PWD/bin:$PATH\" && export SRC_TENANT_ENFORCEMENT_MODE=none && zoekt-webserver -index .sourcebot/index -rpc",
"dev:zoekt:mt": "export PATH=\"$PWD/bin:$PATH\" && export SRC_TENANT_ENFORCEMENT_MODE=strict && zoekt-webserver -index .sourcebot/index -rpc",
"dev:backend": "yarn workspace @sourcebot/backend dev:watch", "dev:backend": "yarn workspace @sourcebot/backend dev:watch",
"dev:web": "yarn workspace @sourcebot/web dev" "dev:web": "yarn workspace @sourcebot/web dev"
}, },

View file

@ -1,12 +1,12 @@
'use server'; 'use server';
import Ajv from "ajv"; import Ajv from "ajv";
import { getUser } from "./data/user"; import { auth, getCurrentUserOrg } from "./auth";
import { auth } from "./auth"; import { notAuthenticated, notFound, ServiceError, unexpectedError } from "@/lib/serviceError";
import { notAuthenticated, notFound, ServiceError, unexpectedError } from "./lib/serviceError";
import { prisma } from "@/prisma"; import { prisma } from "@/prisma";
import { StatusCodes } from "http-status-codes"; import { StatusCodes } from "http-status-codes";
import { ErrorCode } from "./lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { isServiceError } from "@/lib/utils";
import { githubSchema } from "@sourcebot/schemas/v3/github.schema"; import { githubSchema } from "@sourcebot/schemas/v3/github.schema";
import { encrypt } from "@sourcebot/crypto" import { encrypt } from "@sourcebot/crypto"
@ -15,31 +15,9 @@ const ajv = new Ajv({
}); });
export const createSecret = async (key: string, value: string): Promise<{ success: boolean } | ServiceError> => { export const createSecret = async (key: string, value: string): Promise<{ success: boolean } | ServiceError> => {
const session = await auth(); const orgId = await getCurrentUserOrg();
if (!session) { if (isServiceError(orgId)) {
return notAuthenticated(); return orgId;
}
const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}
// @todo: refactor this into a shared function
const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
} }
try { try {
@ -62,30 +40,9 @@ export const createSecret = async (key: string, value: string): Promise<{ succes
} }
export const getSecrets = async (): Promise<{ createdAt: Date; key: string; }[] | ServiceError> => { export const getSecrets = async (): Promise<{ createdAt: Date; key: string; }[] | ServiceError> => {
const session = await auth(); const orgId = await getCurrentUserOrg();
if (!session) { if (isServiceError(orgId)) {
return notAuthenticated(); return orgId;
}
const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}
const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
} }
const secrets = await prisma.secret.findMany({ const secrets = await prisma.secret.findMany({
@ -105,30 +62,9 @@ export const getSecrets = async (): Promise<{ createdAt: Date; key: string; }[]
} }
export const deleteSecret = async (key: string): Promise<{ success: boolean } | ServiceError> => { export const deleteSecret = async (key: string): Promise<{ success: boolean } | ServiceError> => {
const session = await auth(); const orgId = await getCurrentUserOrg();
if (!session) { if (isServiceError(orgId)) {
return notAuthenticated(); return orgId;
}
const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}
const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
} }
await prisma.secret.delete({ await prisma.secret.delete({
@ -206,31 +142,9 @@ export const switchActiveOrg = async (orgId: number): Promise<{ id: number } | S
} }
export const createConnection = async (config: string): Promise<{ id: number } | ServiceError> => { export const createConnection = async (config: string): Promise<{ id: number } | ServiceError> => {
const session = await auth(); const orgId = await getCurrentUserOrg();
if (!session) { if (isServiceError(orgId)) {
return notAuthenticated(); return orgId;
}
const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}
// @todo: refactor this into a shared function
const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
} }
let parsedConfig; let parsedConfig;

View file

@ -1,8 +1,15 @@
'use server'; 'use server';
import { listRepositories } from "@/lib/server/searchService"; import { listRepositories } from "@/lib/server/searchService";
import { getCurrentUserOrg } from "../../../../auth";
import { isServiceError } from "@/lib/utils";
export const GET = async () => { export const GET = async () => {
const response = await listRepositories(); const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}
const response = await listRepositories(orgId);
return Response.json(response); return Response.json(response);
} }

View file

@ -5,19 +5,17 @@ import { searchRequestSchema } from "@/lib/schemas";
import { schemaValidationError, serviceErrorResponse } from "@/lib/serviceError"; import { schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { NextRequest } from "next/server"; import { NextRequest } from "next/server";
import { getCurrentUserOrg } from "../../../../auth";
export const POST = async (request: NextRequest) => { export const POST = async (request: NextRequest) => {
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}
console.log(`Searching for org ${orgId}`);
const body = await request.json(); const body = await request.json();
const tenantId = request.headers.get("X-Tenant-ID"); const parsed = await searchRequestSchema.safeParseAsync(body);
console.log(`Search request received. Tenant ID: ${tenantId}`);
const parsed = await searchRequestSchema.safeParseAsync({
...body,
...(tenantId ? {
tenantId: parseInt(tenantId)
} : {}),
});
if (!parsed.success) { if (!parsed.success) {
return serviceErrorResponse( return serviceErrorResponse(
schemaValidationError(parsed.error) schemaValidationError(parsed.error)
@ -25,7 +23,7 @@ export const POST = async (request: NextRequest) => {
} }
const response = await search(parsed.data); const response = await search(parsed.data, orgId);
if (isServiceError(response)) { if (isServiceError(response)) {
return serviceErrorResponse(response); return serviceErrorResponse(response);
} }

View file

@ -5,8 +5,14 @@ import { getFileSource } from "@/lib/server/searchService";
import { schemaValidationError, serviceErrorResponse } from "@/lib/serviceError"; import { schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { NextRequest } from "next/server"; import { NextRequest } from "next/server";
import { getCurrentUserOrg } from "@/auth";
export const POST = async (request: NextRequest) => { export const POST = async (request: NextRequest) => {
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}
const body = await request.json(); const body = await request.json();
const parsed = await fileSourceRequestSchema.safeParseAsync(body); const parsed = await fileSourceRequestSchema.safeParseAsync(body);
if (!parsed.success) { if (!parsed.success) {
@ -15,7 +21,7 @@ export const POST = async (request: NextRequest) => {
); );
} }
const response = await getFileSource(parsed.data); const response = await getFileSource(parsed.data, orgId);
if (isServiceError(response)) { if (isServiceError(response)) {
return serviceErrorResponse(response); return serviceErrorResponse(response);
} }

View file

@ -7,6 +7,7 @@ import { CodePreview } from "./codePreview";
import { PageNotFound } from "@/app/components/pageNotFound"; import { PageNotFound } from "@/app/components/pageNotFound";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { LuFileX2, LuBookX } from "react-icons/lu"; import { LuFileX2, LuBookX } from "react-icons/lu";
import { getCurrentUserOrg } from "@/auth";
interface BrowsePageProps { interface BrowsePageProps {
params: { params: {
@ -44,9 +45,18 @@ export default async function BrowsePage({
} }
})(); })();
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return (
<>
Error: {orgId.message}
</>
)
}
// @todo (bkellam) : We should probably have a endpoint to fetch repository metadata // @todo (bkellam) : We should probably have a endpoint to fetch repository metadata
// given it's name or id. // given it's name or id.
const reposResponse = await listRepositories(); const reposResponse = await listRepositories(orgId);
if (isServiceError(reposResponse)) { if (isServiceError(reposResponse)) {
// @todo : proper error handling // @todo : proper error handling
return ( return (
@ -98,6 +108,7 @@ export default async function BrowsePage({
path={path} path={path}
repoName={repoName} repoName={repoName}
revisionName={revisionName ?? 'HEAD'} revisionName={revisionName ?? 'HEAD'}
orgId={orgId}
/> />
)} )}
</div> </div>
@ -108,19 +119,21 @@ interface CodePreviewWrapper {
path: string, path: string,
repoName: string, repoName: string,
revisionName: string, revisionName: string,
orgId: number,
} }
const CodePreviewWrapper = async ({ const CodePreviewWrapper = async ({
path, path,
repoName, repoName,
revisionName, revisionName,
orgId,
}: CodePreviewWrapper) => { }: CodePreviewWrapper) => {
// @todo: this will depend on `pathType`. // @todo: this will depend on `pathType`.
const fileSourceResponse = await getFileSource({ const fileSourceResponse = await getFileSource({
fileName: path, fileName: path,
repository: repoName, repository: repoName,
branch: revisionName, branch: revisionName,
}); }, orgId);
if (isServiceError(fileSourceResponse)) { if (isServiceError(fileSourceResponse)) {
if (fileSourceResponse.errorCode === ErrorCode.FILE_NOT_FOUND) { if (fileSourceResponse.errorCode === ErrorCode.FILE_NOT_FOUND) {

View file

@ -11,14 +11,22 @@ import { Separator } from "@/components/ui/separator";
import { SymbolIcon } from "@radix-ui/react-icons"; import { SymbolIcon } from "@radix-ui/react-icons";
import { UpgradeToast } from "./components/upgradeToast"; import { UpgradeToast } from "./components/upgradeToast";
import Link from "next/link"; import Link from "next/link";
import { getCurrentUserOrg } from "../auth"
export default async function Home() { export default async function Home() {
const orgId = await getCurrentUserOrg();
return ( return (
<div className="flex flex-col items-center overflow-hidden min-h-screen"> <div className="flex flex-col items-center overflow-hidden min-h-screen">
<NavigationMenu /> <NavigationMenu />
<UpgradeToast /> <UpgradeToast />
{isServiceError(orgId) ? (
<div className="mt-8 text-red-500">
You are not authenticated. Please log in to continue.
</div>
) : (
<div className="flex flex-col justify-center items-center mt-8 mb-8 md:mt-18 w-full px-5"> <div className="flex flex-col justify-center items-center mt-8 mb-8 md:mt-18 w-full px-5">
<div className="max-h-44 w-auto"> <div className="max-h-44 w-auto">
<Image <Image
@ -40,7 +48,7 @@ export default async function Home() {
/> />
<div className="mt-8"> <div className="mt-8">
<Suspense fallback={<div>...</div>}> <Suspense fallback={<div>...</div>}>
<RepositoryList /> <RepositoryList orgId={orgId}/>
</Suspense> </Suspense>
</div> </div>
<div className="flex flex-col items-center w-fit gap-6"> <div className="flex flex-col items-center w-fit gap-6">
@ -98,6 +106,7 @@ export default async function Home() {
</div> </div>
</div> </div>
</div> </div>
)}
<footer className="w-full mt-auto py-4 flex flex-row justify-center items-center gap-4"> <footer className="w-full mt-auto py-4 flex flex-row justify-center items-center gap-4">
<Link href="https://sourcebot.dev" className="text-gray-400 text-sm hover:underline">About</Link> <Link href="https://sourcebot.dev" className="text-gray-400 text-sm hover:underline">About</Link>
@ -110,8 +119,8 @@ export default async function Home() {
) )
} }
const RepositoryList = async () => { const RepositoryList = async ({ orgId }: { orgId: number}) => {
const _repos = await listRepositories(); const _repos = await listRepositories(orgId);
if (isServiceError(_repos)) { if (isServiceError(_repos)) {
return null; return null;

View file

@ -1,14 +1,25 @@
import { Suspense } from "react"; import { Suspense } from "react";
import { NavigationMenu } from "../components/navigationMenu"; import { NavigationMenu } from "../components/navigationMenu";
import { RepositoryTable } from "./repositoryTable"; import { RepositoryTable } from "./repositoryTable";
import { getCurrentUserOrg } from "@/auth";
import { isServiceError } from "@/lib/utils";
export default async function ReposPage() {
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return (
<>
Error: {orgId.message}
</>
)
}
export default function ReposPage() {
return ( return (
<div className="h-screen flex flex-col items-center"> <div className="h-screen flex flex-col items-center">
<NavigationMenu /> <NavigationMenu />
<Suspense fallback={<div>Loading...</div>}> <Suspense fallback={<div>Loading...</div>}>
<div className="max-w-[90%]"> <div className="max-w-[90%]">
<RepositoryTable /> <RepositoryTable orgId={ orgId }/>
</div> </div>
</Suspense> </Suspense>
</div> </div>

View file

@ -3,8 +3,8 @@ import { columns, RepositoryColumnInfo } from "./columns";
import { listRepositories } from "@/lib/server/searchService"; import { listRepositories } from "@/lib/server/searchService";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
export const RepositoryTable = async () => { export const RepositoryTable = async ({ orgId }: { orgId: number }) => {
const _repos = await listRepositories(); const _repos = await listRepositories(orgId);
if (isServiceError(_repos)) { if (isServiceError(_repos)) {
return <div>Error fetching repositories</div>; return <div>Error fetching repositories</div>;

View file

@ -1,6 +1,6 @@
import { NavigationMenu } from "../components/navigationMenu"; import { NavigationMenu } from "../components/navigationMenu";
import { SecretsTable } from "./secretsTable"; import { SecretsTable } from "./secretsTable";
import { getSecrets, createSecret } from "../../actions" import { getSecrets } from "../../actions"
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
export interface SecretsTableProps { export interface SecretsTableProps {

View file

@ -6,6 +6,8 @@ import { prisma } from "@/prisma";
import type { Provider } from "next-auth/providers" import type { Provider } from "next-auth/providers"
import { AUTH_GITHUB_CLIENT_ID, AUTH_GITHUB_CLIENT_SECRET, AUTH_SECRET } from "./lib/environment"; import { AUTH_GITHUB_CLIENT_ID, AUTH_GITHUB_CLIENT_SECRET, AUTH_SECRET } from "./lib/environment";
import { User } from '@sourcebot/db'; import { User } from '@sourcebot/db';
import { notAuthenticated, notFound, unexpectedError } from "@/lib/serviceError";
import { getUser } from "./data/user";
declare module 'next-auth' { declare module 'next-auth' {
interface Session { interface Session {
@ -116,3 +118,33 @@ export const { handlers, signIn, signOut, auth } = NextAuth({
signIn: "/login" signIn: "/login"
} }
}); });
export const getCurrentUserOrg = async () => {
const session = await auth();
if (!session) {
return notAuthenticated();
}
const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}
const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
}
return orgId;
}

View file

@ -4,7 +4,6 @@ export const searchRequestSchema = z.object({
query: z.string(), query: z.string(),
maxMatchDisplayCount: z.number(), maxMatchDisplayCount: z.number(),
whole: z.boolean().optional(), whole: z.boolean().optional(),
tenantId: z.number().optional(),
}); });

View file

@ -34,7 +34,7 @@ const aliasPrefixMappings: Record<string, zoektPrefixes> = {
"revision:": zoektPrefixes.branch, "revision:": zoektPrefixes.branch,
} }
export const search = async ({ query, maxMatchDisplayCount, whole, tenantId }: SearchRequest): Promise<SearchResponse | ServiceError> => { export const search = async ({ query, maxMatchDisplayCount, whole}: SearchRequest, orgId: number): Promise<SearchResponse | ServiceError> => {
// Replace any alias prefixes with their corresponding zoekt prefixes. // Replace any alias prefixes with their corresponding zoekt prefixes.
for (const [prefix, zoektPrefix] of Object.entries(aliasPrefixMappings)) { for (const [prefix, zoektPrefix] of Object.entries(aliasPrefixMappings)) {
query = query.replaceAll(prefix, zoektPrefix); query = query.replaceAll(prefix, zoektPrefix);
@ -54,11 +54,9 @@ export const search = async ({ query, maxMatchDisplayCount, whole, tenantId }: S
}); });
let header: Record<string, string> = {}; let header: Record<string, string> = {};
if (tenantId) {
header = { header = {
"X-Tenant-ID": tenantId.toString() "X-Tenant-ID": orgId.toString()
}; };
}
const searchResponse = await zoektFetch({ const searchResponse = await zoektFetch({
path: "/api/search", path: "/api/search",
@ -92,7 +90,7 @@ export const search = async ({ query, maxMatchDisplayCount, whole, tenantId }: S
// @todo (bkellam) : We should really be using `git show <hash>:<path>` to fetch file contents here. // @todo (bkellam) : We should really be using `git show <hash>:<path>` to fetch file contents here.
// This will allow us to support permalinks to files at a specific revision that may not be indexed // This will allow us to support permalinks to files at a specific revision that may not be indexed
// by zoekt. // by zoekt.
export const getFileSource = async ({ fileName, repository, branch }: FileSourceRequest): Promise<FileSourceResponse | ServiceError> => { export const getFileSource = async ({ fileName, repository, branch }: FileSourceRequest, orgId: number): Promise<FileSourceResponse | ServiceError> => {
const escapedFileName = escapeStringRegexp(fileName); const escapedFileName = escapeStringRegexp(fileName);
const escapedRepository = escapeStringRegexp(repository); const escapedRepository = escapeStringRegexp(repository);
@ -105,7 +103,7 @@ export const getFileSource = async ({ fileName, repository, branch }: FileSource
query, query,
maxMatchDisplayCount: 1, maxMatchDisplayCount: 1,
whole: true, whole: true,
}); }, orgId);
if (isServiceError(searchResponse)) { if (isServiceError(searchResponse)) {
return searchResponse; return searchResponse;
@ -126,15 +124,22 @@ export const getFileSource = async ({ fileName, repository, branch }: FileSource
} }
} }
export const listRepositories = async (): Promise<ListRepositoriesResponse | ServiceError> => { export const listRepositories = async (orgId: number): Promise<ListRepositoriesResponse | ServiceError> => {
const body = JSON.stringify({ const body = JSON.stringify({
opts: { opts: {
Field: 0, Field: 0,
} }
}); });
let header: Record<string, string> = {};
header = {
"X-Tenant-ID": orgId.toString()
};
const listResponse = await zoektFetch({ const listResponse = await zoektFetch({
path: "/api/list", path: "/api/list",
body, body,
header,
method: "POST", method: "POST",
cache: "no-store", cache: "no-store",
}); });