This commit is contained in:
bkellam 2025-09-23 21:30:00 -07:00
parent 6abe7a40a5
commit e990edbb10
15 changed files with 630 additions and 841 deletions

View file

@ -4,13 +4,12 @@ import { getAuditService } from "@/ee/features/audit/factory";
import { env } from "@/env.mjs"; import { env } from "@/env.mjs";
import { addUserToOrganization, orgHasAvailability } from "@/lib/authUtils"; import { addUserToOrganization, orgHasAvailability } from "@/lib/authUtils";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { notAuthenticated, notFound, orgNotFound, secretAlreadyExists, ServiceError, unexpectedError } from "@/lib/serviceError"; import { notAuthenticated, notFound, orgNotFound, secretAlreadyExists, ServiceError } from "@/lib/serviceError";
import { CodeHostType, getOrgMetadata, isHttpError, isServiceError } from "@/lib/utils"; import { CodeHostType, getOrgMetadata, isHttpError, isServiceError } from "@/lib/utils";
import { prisma } from "@/prisma"; import { prisma } from "@/prisma";
import { render } from "@react-email/components"; import { render } from "@react-email/components";
import * as Sentry from '@sentry/nextjs'; import { decrypt, encrypt, generateApiKey, getTokenFromConfig } from "@sourcebot/crypto";
import { decrypt, encrypt, generateApiKey, getTokenFromConfig, hashSecret } from "@sourcebot/crypto"; import { ConnectionSyncStatus, OrgRole, Prisma, RepoIndexingStatus, StripeSubscriptionStatus } from "@sourcebot/db";
import { ApiKey, ConnectionSyncStatus, Org, OrgRole, Prisma, RepoIndexingStatus, StripeSubscriptionStatus } from "@sourcebot/db";
import { createLogger } from "@sourcebot/logger"; import { createLogger } from "@sourcebot/logger";
import { azuredevopsSchema } from "@sourcebot/schemas/v3/azuredevops.schema"; import { azuredevopsSchema } from "@sourcebot/schemas/v3/azuredevops.schema";
import { bitbucketSchema } from "@sourcebot/schemas/v3/bitbucket.schema"; import { bitbucketSchema } from "@sourcebot/schemas/v3/bitbucket.schema";
@ -29,7 +28,6 @@ import { StatusCodes } from "http-status-codes";
import { cookies, headers } from "next/headers"; import { cookies, headers } from "next/headers";
import { createTransport } from "nodemailer"; import { createTransport } from "nodemailer";
import { Octokit } from "octokit"; import { Octokit } from "octokit";
import { auth } from "./auth";
import { getConnection } from "./data/connection"; import { getConnection } from "./data/connection";
import { getOrgFromDomain } from "./data/org"; import { getOrgFromDomain } from "./data/org";
import { decrementOrgSeatCount, getSubscriptionForOrg } from "./ee/features/billing/serverUtils"; import { decrementOrgSeatCount, getSubscriptionForOrg } from "./ee/features/billing/serverUtils";
@ -37,9 +35,9 @@ import { IS_BILLING_ENABLED } from "./ee/features/billing/stripe";
import InviteUserEmail from "./emails/inviteUserEmail"; import InviteUserEmail from "./emails/inviteUserEmail";
import JoinRequestApprovedEmail from "./emails/joinRequestApprovedEmail"; import JoinRequestApprovedEmail from "./emails/joinRequestApprovedEmail";
import JoinRequestSubmittedEmail from "./emails/joinRequestSubmittedEmail"; import JoinRequestSubmittedEmail from "./emails/joinRequestSubmittedEmail";
import { AGENTIC_SEARCH_TUTORIAL_DISMISSED_COOKIE_NAME, MOBILE_UNSUPPORTED_SPLASH_SCREEN_DISMISSED_COOKIE_NAME, SEARCH_MODE_COOKIE_NAME, SINGLE_TENANT_ORG_DOMAIN, SOURCEBOT_GUEST_USER_ID, SOURCEBOT_SUPPORT_EMAIL } from "./lib/constants"; import { AGENTIC_SEARCH_TUTORIAL_DISMISSED_COOKIE_NAME, MOBILE_UNSUPPORTED_SPLASH_SCREEN_DISMISSED_COOKIE_NAME, SEARCH_MODE_COOKIE_NAME, SOURCEBOT_SUPPORT_EMAIL } from "./lib/constants";
import { orgNameSchema, repositoryQuerySchema } from "./lib/schemas"; import { orgNameSchema, repositoryQuerySchema } from "./lib/schemas";
import { ApiKeyPayload, TenancyMode } from "./lib/types"; import { sew } from "./sew";
import { withAuthV2, withOptionalAuthV2 } from "./withAuthV2"; import { withAuthV2, withOptionalAuthV2 } from "./withAuthV2";
import { withMinimumOrgRole } from "./withMinimumOrgRole"; import { withMinimumOrgRole } from "./withMinimumOrgRole";
@ -50,142 +48,6 @@ const ajv = new Ajv({
const logger = createLogger('web-actions'); const logger = createLogger('web-actions');
const auditService = getAuditService(); const auditService = getAuditService();
/**
* "Service Error Wrapper".
*
* Captures any thrown exceptions and converts them to a unexpected
* service error. Also logs them with Sentry.
*/
export const sew = async <T>(fn: () => Promise<T>): Promise<T | ServiceError> => {
try {
return await fn();
} catch (e) {
Sentry.captureException(e);
logger.error(e);
if (e instanceof Error) {
return unexpectedError(e.message);
}
return unexpectedError(`An unexpected error occurred. Please try again later.`);
}
}
export const withAuth = async <T>(fn: (userId: string, apiKeyHash: string | undefined) => Promise<T>, allowAnonymousAccess: boolean = false, apiKey: ApiKeyPayload | undefined = undefined) => {
const session = await auth();
if (!session) {
// First we check if public access is enabled and supported. If not, then we check if an api key was provided. If not,
// then this is an invalid unauthed request and we return a 401.
const anonymousAccessEnabled = await getAnonymousAccessStatus(SINGLE_TENANT_ORG_DOMAIN);
if (apiKey) {
const apiKeyOrError = await verifyApiKey(apiKey);
if (isServiceError(apiKeyOrError)) {
logger.error(`Invalid API key: ${JSON.stringify(apiKey)}. Error: ${JSON.stringify(apiKeyOrError)}`);
return notAuthenticated();
}
const user = await prisma.user.findUnique({
where: {
id: apiKeyOrError.apiKey.createdById,
},
});
if (!user) {
logger.error(`No user found for API key: ${apiKey}`);
return notAuthenticated();
}
await prisma.apiKey.update({
where: {
hash: apiKeyOrError.apiKey.hash,
},
data: {
lastUsedAt: new Date(),
},
});
return fn(user.id, apiKeyOrError.apiKey.hash);
} else if (
allowAnonymousAccess &&
!isServiceError(anonymousAccessEnabled) &&
anonymousAccessEnabled
) {
if (!hasEntitlement("anonymous-access")) {
const plan = getPlan();
logger.error(`Anonymous access isn't supported in your current plan: ${plan}. For support, contact ${SOURCEBOT_SUPPORT_EMAIL}.`);
return notAuthenticated();
}
// To support anonymous access a guest user is created in initialize.ts, which we return here
return fn(SOURCEBOT_GUEST_USER_ID, undefined);
}
return notAuthenticated();
}
return fn(session.user.id, undefined);
}
export const withOrgMembership = async <T>(userId: string, domain: string, fn: (params: { userRole: OrgRole, org: Org }) => Promise<T>, minRequiredRole: OrgRole = OrgRole.MEMBER) => {
const org = await prisma.org.findUnique({
where: {
domain,
},
});
if (!org) {
return notFound("Organization not found");
}
const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId,
orgId: org.id,
}
},
});
if (!membership) {
return notFound("User not a member of this organization");
}
const getAuthorizationPrecedence = (role: OrgRole): number => {
switch (role) {
case OrgRole.GUEST:
return 0;
case OrgRole.MEMBER:
return 1;
case OrgRole.OWNER:
return 2;
}
}
if (getAuthorizationPrecedence(membership.role) < getAuthorizationPrecedence(minRequiredRole)) {
return {
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS,
message: "You do not have sufficient permissions to perform this action.",
} satisfies ServiceError;
}
return fn({
org: org,
userRole: membership.role,
});
}
export const withTenancyModeEnforcement = async<T>(mode: TenancyMode, fn: () => Promise<T>) => {
if (env.SOURCEBOT_TENANCY_MODE !== mode) {
return {
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.ACTION_DISALLOWED_IN_TENANCY_MODE,
message: "This action is not allowed in the current tenancy mode.",
} satisfies ServiceError;
}
return fn();
}
////// Actions /////// ////// Actions ///////
export const updateOrgName = async (name: string) => sew(() => export const updateOrgName = async (name: string) => sew(() =>
@ -322,59 +184,6 @@ export const deleteSecret = async (key: string, domain: string): Promise<{ succe
} }
})); }));
export const verifyApiKey = async (apiKeyPayload: ApiKeyPayload): Promise<{ apiKey: ApiKey } | ServiceError> => sew(async () => {
const parts = apiKeyPayload.apiKey.split("-");
if (parts.length !== 2 || parts[0] !== "sourcebot") {
return {
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_API_KEY,
message: "Invalid API key",
} satisfies ServiceError;
}
const hash = hashSecret(parts[1])
const apiKey = await prisma.apiKey.findUnique({
where: {
hash,
},
});
if (!apiKey) {
return {
statusCode: StatusCodes.UNAUTHORIZED,
errorCode: ErrorCode.INVALID_API_KEY,
message: "Invalid API key",
} satisfies ServiceError;
}
const apiKeyTargetOrg = await prisma.org.findUnique({
where: {
domain: apiKeyPayload.domain,
},
});
if (!apiKeyTargetOrg) {
return {
statusCode: StatusCodes.UNAUTHORIZED,
errorCode: ErrorCode.INVALID_API_KEY,
message: `Invalid API key payload. Provided domain ${apiKeyPayload.domain} does not exist.`,
} satisfies ServiceError;
}
if (apiKey.orgId !== apiKeyTargetOrg.id) {
return {
statusCode: StatusCodes.UNAUTHORIZED,
errorCode: ErrorCode.INVALID_API_KEY,
message: `Invalid API key payload. Provided domain ${apiKeyPayload.domain} does not match the API key's org.`,
} satisfies ServiceError;
}
return {
apiKey,
}
});
export const createApiKey = async (name: string, domain: string): Promise<{ key: string } | ServiceError> => sew(() => export const createApiKey = async (name: string, domain: string): Promise<{ key: string } | ServiceError> => sew(() =>
withAuthV2(async ({ user, org, prisma }) => { withAuthV2(async ({ user, org, prisma }) => {
const userId = user.id; const userId = user.id;
@ -1236,10 +1045,9 @@ export const getMe = async () => sew(() =>
})); }));
export const redeemInvite = async (inviteId: string): Promise<{ success: boolean } | ServiceError> => sew(() => export const redeemInvite = async (inviteId: string): Promise<{ success: boolean } | ServiceError> => sew(() =>
withAuth(async () => { withOptionalAuthV2(async ({ user, prisma }) => {
const user = await getMe(); if (!user) {
if (isServiceError(user)) { return notAuthenticated();
return user;
} }
const invite = await prisma.invite.findUnique({ const invite = await prisma.invite.findUnique({
@ -1315,10 +1123,9 @@ export const redeemInvite = async (inviteId: string): Promise<{ success: boolean
})); }));
export const getInviteInfo = async (inviteId: string) => sew(() => export const getInviteInfo = async (inviteId: string) => sew(() =>
withAuth(async () => { withOptionalAuthV2(async ({ user }) => {
const user = await getMe(); if (!user) {
if (isServiceError(user)) { return notAuthenticated();
return user;
} }
const invite = await prisma.invite.findUnique({ const invite = await prisma.invite.findUnique({
@ -1613,107 +1420,92 @@ export const getOrgAccountRequests = async (domain: string) => sew(() =>
})); }));
})); }));
export const createAccountRequest = async (userId: string, domain: string) => sew(async () => { export const createAccountRequest = async () => sew(() =>
const user = await prisma.user.findUnique({ withOptionalAuthV2(async ({ user, org, prisma }) => {
where: { if (!user) {
id: userId, return notAuthenticated();
},
});
if (!user) {
return notFound("User not found");
}
const org = await prisma.org.findUnique({
where: {
domain,
},
});
if (!org) {
return notFound("Organization not found");
}
const existingRequest = await prisma.accountRequest.findUnique({
where: {
requestedById_orgId: {
requestedById: userId,
orgId: org.id,
},
},
});
if (existingRequest) {
logger.warn(`User ${userId} already has an account request for org ${org.id}. Skipping account request creation.`);
return {
success: true,
existingRequest: true,
} }
}
if (!existingRequest) { const existingRequest = await prisma.accountRequest.findUnique({
await prisma.accountRequest.create({ where: {
data: { requestedById_orgId: {
requestedById: userId, requestedById: user.id,
orgId: org.id, orgId: org.id,
},
}, },
}); });
if (env.SMTP_CONNECTION_URL && env.EMAIL_FROM_ADDRESS) { if (existingRequest) {
// TODO: This is needed because we can't fetch the origin from the request headers when this is called logger.warn(`User ${user.id} already has an account request for org ${org.id}. Skipping account request creation.`);
// on user creation (the header isn't set when next-auth calls onCreateUser for some reason) return {
const deploymentUrl = env.AUTH_URL; success: true,
existingRequest: true,
}
}
const owner = await prisma.user.findFirst({ if (!existingRequest) {
where: { await prisma.accountRequest.create({
orgs: { data: {
some: { requestedById: user.id,
orgId: org.id, orgId: org.id,
role: "OWNER",
},
},
}, },
}); });
if (!owner) { if (env.SMTP_CONNECTION_URL && env.EMAIL_FROM_ADDRESS) {
logger.error(`Failed to find owner for org ${org.id} when drafting email for account request from ${userId}`); // TODO: This is needed because we can't fetch the origin from the request headers when this is called
} else { // on user creation (the header isn't set when next-auth calls onCreateUser for some reason)
const html = await render(JoinRequestSubmittedEmail({ const deploymentUrl = env.AUTH_URL;
baseUrl: deploymentUrl,
requestor: {
name: user.name ?? undefined,
email: user.email!,
avatarUrl: user.image ?? undefined,
},
orgName: org.name,
orgDomain: org.domain,
orgImageUrl: org.imageUrl ?? undefined,
}));
const transport = createTransport(env.SMTP_CONNECTION_URL); const owner = await prisma.user.findFirst({
const result = await transport.sendMail({ where: {
to: owner.email!, orgs: {
from: env.EMAIL_FROM_ADDRESS, some: {
subject: `New account request for ${org.name} on Sourcebot`, orgId: org.id,
html, role: "OWNER",
text: `New account request for ${org.name} on Sourcebot by ${user.name ?? user.email}`, },
},
},
}); });
const failed = result.rejected.concat(result.pending).filter(Boolean); if (!owner) {
if (failed.length > 0) { logger.error(`Failed to find owner for org ${org.id} when drafting email for account request from ${user.id}`);
logger.error(`Failed to send account request email to ${owner.email}: ${failed}`); } else {
} const html = await render(JoinRequestSubmittedEmail({
} baseUrl: deploymentUrl,
} else { requestor: {
logger.warn(`SMTP_CONNECTION_URL or EMAIL_FROM_ADDRESS not set. Skipping account request email to owner`); name: user.name ?? undefined,
} email: user.email!,
} avatarUrl: user.image ?? undefined,
},
orgName: org.name,
orgDomain: org.domain,
orgImageUrl: org.imageUrl ?? undefined,
}));
return { const transport = createTransport(env.SMTP_CONNECTION_URL);
success: true, const result = await transport.sendMail({
existingRequest: false, to: owner.email!,
} from: env.EMAIL_FROM_ADDRESS,
}); subject: `New account request for ${org.name} on Sourcebot`,
html,
text: `New account request for ${org.name} on Sourcebot by ${user.name ?? user.email}`,
});
const failed = result.rejected.concat(result.pending).filter(Boolean);
if (failed.length > 0) {
logger.error(`Failed to send account request email to ${owner.email}: ${failed}`);
}
}
} else {
logger.warn(`SMTP_CONNECTION_URL or EMAIL_FROM_ADDRESS not set. Skipping account request email to owner`);
}
}
return {
success: true,
existingRequest: false,
}
}));
export const getMemberApprovalRequired = async (domain: string): Promise<boolean | ServiceError> => sew(async () => { export const getMemberApprovalRequired = async (domain: string): Promise<boolean | ServiceError> => sew(async () => {
const org = await prisma.org.findUnique({ const org = await prisma.org.findUnique({

View file

@ -10,17 +10,16 @@ import { useRouter } from "next/navigation"
interface SubmitButtonProps { interface SubmitButtonProps {
domain: string domain: string
userId: string
} }
export function SubmitAccountRequestButton({ domain, userId }: SubmitButtonProps) { export function SubmitAccountRequestButton({ domain }: SubmitButtonProps) {
const { toast } = useToast() const { toast } = useToast()
const router = useRouter() const router = useRouter()
const [isSubmitting, setIsSubmitting] = useState(false) const [isSubmitting, setIsSubmitting] = useState(false)
const handleSubmit = async () => { const handleSubmit = async () => {
setIsSubmitting(true) setIsSubmitting(true)
const result = await createAccountRequest(userId, domain) const result = await createAccountRequest();
if (!isServiceError(result)) { if (!isServiceError(result)) {
if (result.existingRequest) { if (result.existingRequest) {
toast({ toast({

View file

@ -1,6 +1,5 @@
import { LogoutEscapeHatch } from "@/app/components/logoutEscapeHatch" import { LogoutEscapeHatch } from "@/app/components/logoutEscapeHatch"
import { SourcebotLogo } from "@/app/components/sourcebotLogo" import { SourcebotLogo } from "@/app/components/sourcebotLogo"
import { auth } from "@/auth"
import { SubmitAccountRequestButton } from "./submitAccountRequestButton" import { SubmitAccountRequestButton } from "./submitAccountRequestButton"
interface SubmitJoinRequestProps { interface SubmitJoinRequestProps {
@ -8,13 +7,6 @@ interface SubmitJoinRequestProps {
} }
export const SubmitJoinRequest = async ({ domain }: SubmitJoinRequestProps) => { export const SubmitJoinRequest = async ({ domain }: SubmitJoinRequestProps) => {
const session = await auth()
const userId = session?.user?.id
if (!userId) {
return null
}
return ( return (
<div className="min-h-screen bg-[var(--background)] flex items-center justify-center p-6"> <div className="min-h-screen bg-[var(--background)] flex items-center justify-center p-6">
<LogoutEscapeHatch className="absolute top-0 right-0 p-6" /> <LogoutEscapeHatch className="absolute top-0 right-0 p-6" />
@ -45,7 +37,7 @@ export const SubmitJoinRequest = async ({ domain }: SubmitJoinRequestProps) => {
<div className="space-y-4"> <div className="space-y-4">
<div className="flex justify-center"> <div className="flex justify-center">
<SubmitAccountRequestButton domain={domain} userId={userId} /> <SubmitAccountRequestButton domain={domain} />
</div> </div>
</div> </div>
</div> </div>

View file

@ -1,4 +1,5 @@
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/sew";
import { withOptionalAuthV2 } from "@/withAuthV2";
import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions"; import { _getConfiguredLanguageModelsFull, _getAISDKLanguageModelAndOptions, updateChatMessages } from "@/features/chat/actions";
import { createAgentStream } from "@/features/chat/agent"; import { createAgentStream } from "@/features/chat/agent";
import { additionalChatRequestParamsSchema, SBChatMessage, SearchScope } from "@/features/chat/types"; import { additionalChatRequestParamsSchema, SBChatMessage, SearchScope } from "@/features/chat/types";
@ -9,7 +10,6 @@ import { isServiceError } from "@/lib/utils";
import { prisma } from "@/prisma"; import { prisma } from "@/prisma";
import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider"; import { LanguageModelV2 as AISDKLanguageModelV2 } from "@ai-sdk/provider";
import * as Sentry from "@sentry/nextjs"; import * as Sentry from "@sentry/nextjs";
import { OrgRole } from "@sourcebot/db";
import { createLogger } from "@sourcebot/logger"; import { createLogger } from "@sourcebot/logger";
import { import {
createUIMessageStream, createUIMessageStream,
@ -52,56 +52,54 @@ export async function POST(req: Request) {
const { messages, id, selectedSearchScopes, languageModelId } = parsed.data; const { messages, id, selectedSearchScopes, languageModelId } = parsed.data;
const response = await sew(() => const response = await sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { // Validate that the chat exists and is not readonly.
// Validate that the chat exists and is not readonly. const chat = await prisma.chat.findUnique({
const chat = await prisma.chat.findUnique({ where: {
where: {
orgId: org.id,
id,
},
});
if (!chat) {
return notFound();
}
if (chat.isReadonly) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: "Chat is readonly and cannot be edited.",
});
}
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => model.model === languageModelId);
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModelId} is not configured.`,
});
}
const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
return createMessageStreamResponse({
messages,
id,
selectedSearchScopes,
model,
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
modelProviderOptions: providerOptions,
domain,
orgId: org.id, orgId: org.id,
id,
},
});
if (!chat) {
return notFound();
}
if (chat.isReadonly) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: "Chat is readonly and cannot be edited.",
}); });
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true }
)
// From the language model ID, attempt to find the
// corresponding config in `config.json`.
const languageModelConfig =
(await _getConfiguredLanguageModelsFull())
.find((model) => model.model === languageModelId);
if (!languageModelConfig) {
return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModelId} is not configured.`,
});
}
const { model, providerOptions } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
return createMessageStreamResponse({
messages,
id,
selectedSearchScopes,
model,
modelName: languageModelConfig.displayName ?? languageModelConfig.model,
modelProviderOptions: providerOptions,
domain,
orgId: org.id,
});
})
) )
if (isServiceError(response)) { if (isServiceError(response)) {

View file

@ -1,16 +1,20 @@
"use server"; "use server";
import { withAuth } from "@/actions";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { orgNotFound, ServiceError } from "@/lib/serviceError"; import { notAuthenticated, orgNotFound, ServiceError } from "@/lib/serviceError";
import { sew } from "@/actions"; import { sew } from "@/sew";
import { addUserToOrganization } from "@/lib/authUtils"; import { addUserToOrganization } from "@/lib/authUtils";
import { prisma } from "@/prisma"; import { prisma } from "@/prisma";
import { StatusCodes } from "http-status-codes"; import { StatusCodes } from "http-status-codes";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { withOptionalAuthV2 } from "@/withAuthV2";
export const joinOrganization = async (orgId: number, inviteLinkId?: string) => sew(async () => export const joinOrganization = async (orgId: number, inviteLinkId?: string) => sew(async () =>
withAuth(async (userId) => { withOptionalAuthV2(async ({ user }) => {
if (!user) {
return notAuthenticated();
}
const org = await prisma.org.findUnique({ const org = await prisma.org.findUnique({
where: { where: {
id: orgId, id: orgId,
@ -40,7 +44,7 @@ export const joinOrganization = async (orgId: number, inviteLinkId?: string) =>
} }
} }
const addUserToOrgRes = await addUserToOrganization(userId, org.id); const addUserToOrgRes = await addUserToOrganization(user.id, org.id);
if (isServiceError(addUserToOrgRes)) { if (isServiceError(addUserToOrgRes)) {
return addUserToOrgRes; return addUserToOrgRes;
} }

View file

@ -1,103 +1,101 @@
'use server'; 'use server';
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/sew";
import { OrgRole } from "@sourcebot/db"; import { withAuthV2 } from "@/withAuthV2";
import { prisma } from "@/prisma";
import { ServiceError } from "@/lib/serviceError"; import { ServiceError } from "@/lib/serviceError";
import { AnalyticsResponse } from "./types"; import { AnalyticsResponse } from "./types";
import { hasEntitlement } from "@sourcebot/shared"; import { hasEntitlement } from "@sourcebot/shared";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { StatusCodes } from "http-status-codes"; import { StatusCodes } from "http-status-codes";
export const getAnalytics = async (domain: string, apiKey: string | undefined = undefined): Promise<AnalyticsResponse | ServiceError> => sew(() => export const getAnalytics = async (domain: string, _apiKey: string | undefined = undefined): Promise<AnalyticsResponse | ServiceError> => sew(() =>
withAuth((userId, _apiKeyHash) => withAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { if (!hasEntitlement("analytics")) {
if (!hasEntitlement("analytics")) { return {
return { statusCode: StatusCodes.FORBIDDEN,
statusCode: StatusCodes.FORBIDDEN, errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS,
errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, message: "Analytics is not available in your current plan",
message: "Analytics is not available in your current plan", } satisfies ServiceError;
} satisfies ServiceError; }
}
const rows = await prisma.$queryRaw<AnalyticsResponse>` const rows = await prisma.$queryRaw<AnalyticsResponse>`
WITH core AS ( WITH core AS (
SELECT SELECT
date_trunc('day', "timestamp") AS day, date_trunc('day', "timestamp") AS day,
date_trunc('week', "timestamp") AS week, date_trunc('week', "timestamp") AS week,
date_trunc('month', "timestamp") AS month, date_trunc('month', "timestamp") AS month,
action, action,
"actorId" "actorId"
FROM "Audit" FROM "Audit"
WHERE "orgId" = ${org.id} WHERE "orgId" = ${org.id}
AND action IN ( AND action IN (
'user.performed_code_search', 'user.performed_code_search',
'user.performed_find_references', 'user.performed_find_references',
'user.performed_goto_definition' 'user.performed_goto_definition'
) )
), ),
periods AS ( periods AS (
SELECT unnest(array['day', 'week', 'month']) AS period SELECT unnest(array['day', 'week', 'month']) AS period
), ),
buckets AS ( buckets AS (
SELECT SELECT
generate_series( generate_series(
date_trunc('day', (SELECT MIN("timestamp") FROM "Audit" WHERE "orgId" = ${org.id})), date_trunc('day', (SELECT MIN("timestamp") FROM "Audit" WHERE "orgId" = ${org.id})),
date_trunc('day', CURRENT_DATE), date_trunc('day', CURRENT_DATE),
interval '1 day' interval '1 day'
) AS bucket, ) AS bucket,
'day' AS period 'day' AS period
UNION ALL UNION ALL
SELECT SELECT
generate_series( generate_series(
date_trunc('week', (SELECT MIN("timestamp") FROM "Audit" WHERE "orgId" = ${org.id})), date_trunc('week', (SELECT MIN("timestamp") FROM "Audit" WHERE "orgId" = ${org.id})),
date_trunc('week', CURRENT_DATE), date_trunc('week', CURRENT_DATE),
interval '1 week' interval '1 week'
), ),
'week' 'week'
UNION ALL UNION ALL
SELECT SELECT
generate_series( generate_series(
date_trunc('month', (SELECT MIN("timestamp") FROM "Audit" WHERE "orgId" = ${org.id})), date_trunc('month', (SELECT MIN("timestamp") FROM "Audit" WHERE "orgId" = ${org.id})),
date_trunc('month', CURRENT_DATE), date_trunc('month', CURRENT_DATE),
interval '1 month' interval '1 month'
), ),
'month' 'month'
), ),
aggregated AS (
SELECT
b.period,
CASE b.period
WHEN 'day' THEN c.day
WHEN 'week' THEN c.week
ELSE c.month
END AS bucket,
COUNT(*) FILTER (WHERE c.action = 'user.performed_code_search') AS code_searches,
COUNT(*) FILTER (WHERE c.action IN ('user.performed_find_references', 'user.performed_goto_definition')) AS navigations,
COUNT(DISTINCT c."actorId") AS active_users
FROM core c
JOIN LATERAL (
SELECT unnest(array['day', 'week', 'month']) AS period
) b ON true
GROUP BY b.period, bucket
)
aggregated AS (
SELECT SELECT
b.period, b.period,
b.bucket, CASE b.period
COALESCE(a.code_searches, 0)::int AS code_searches, WHEN 'day' THEN c.day
COALESCE(a.navigations, 0)::int AS navigations, WHEN 'week' THEN c.week
COALESCE(a.active_users, 0)::int AS active_users ELSE c.month
FROM buckets b END AS bucket,
LEFT JOIN aggregated a COUNT(*) FILTER (WHERE c.action = 'user.performed_code_search') AS code_searches,
ON a.period = b.period AND a.bucket = b.bucket COUNT(*) FILTER (WHERE c.action IN ('user.performed_find_references', 'user.performed_goto_definition')) AS navigations,
ORDER BY b.period, b.bucket; COUNT(DISTINCT c."actorId") AS active_users
`; FROM core c
JOIN LATERAL (
SELECT unnest(array['day', 'week', 'month']) AS period
) b ON true
GROUP BY b.period, bucket
)
SELECT
b.period,
b.bucket,
COALESCE(a.code_searches, 0)::int AS code_searches,
COALESCE(a.navigations, 0)::int AS navigations,
COALESCE(a.active_users, 0)::int AS active_users
FROM buckets b
LEFT JOIN aggregated a
ON a.period = b.period AND a.bucket = b.bucket
ORDER BY b.period, b.bucket;
`;
return rows; return rows;
}, /* minRequiredRole = */ OrgRole.MEMBER), /* allowAnonymousAccess = */ true, apiKey ? { apiKey, domain } : undefined) })
); );

View file

@ -1,9 +1,10 @@
"use server"; "use server";
import { prisma } from "@/prisma";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
import { StatusCodes } from "http-status-codes"; import { StatusCodes } from "http-status-codes";
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/sew";
import { withAuthV2 } from "@/withAuthV2";
import { withMinimumOrgRole } from "@/withMinimumOrgRole";
import { OrgRole } from "@sourcebot/db"; import { OrgRole } from "@sourcebot/db";
import { createLogger } from "@sourcebot/logger"; import { createLogger } from "@sourcebot/logger";
import { ServiceError } from "@/lib/serviceError"; import { ServiceError } from "@/lib/serviceError";
@ -13,16 +14,15 @@ import { AuditEvent } from "./types";
const auditService = getAuditService(); const auditService = getAuditService();
const logger = createLogger('audit-utils'); const logger = createLogger('audit-utils');
export const createAuditAction = async (event: Omit<AuditEvent, 'sourcebotVersion' | 'orgId' | 'actor' | 'target'>, domain: string) => sew(async () => export const createAuditAction = async (event: Omit<AuditEvent, 'sourcebotVersion' | 'orgId' | 'actor' | 'target'>, _domain: string) => sew(async () =>
withAuth((userId) => withAuthV2(async ({ user, org }) => {
withOrgMembership(userId, domain, async ({ org }) => { await auditService.createAudit({ ...event, orgId: org.id, actor: { id: user.id, type: "user" }, target: { id: org.id.toString(), type: "org" } })
await auditService.createAudit({ ...event, orgId: org.id, actor: { id: userId, type: "user" }, target: { id: org.id.toString(), type: "org" } }) })
}, /* minRequiredRole = */ OrgRole.MEMBER), /* allowAnonymousAccess = */ true)
); );
export const fetchAuditRecords = async (domain: string, apiKey: string | undefined = undefined) => sew(() => export const fetchAuditRecords = async (domain: string, _apiKey: string | undefined = undefined) => sew(() =>
withAuth((userId) => withAuthV2(async ({ user, org, prisma, role }) =>
withOrgMembership(userId, domain, async ({ org }) => { withMinimumOrgRole(role, OrgRole.OWNER, async () => {
try { try {
const auditRecords = await prisma.audit.findMany({ const auditRecords = await prisma.audit.findMany({
where: { where: {
@ -36,7 +36,7 @@ export const fetchAuditRecords = async (domain: string, apiKey: string | undefin
await auditService.createAudit({ await auditService.createAudit({
action: "audit.fetch", action: "audit.fetch",
actor: { actor: {
id: userId, id: user.id,
type: "user" type: "user"
}, },
target: { target: {
@ -55,5 +55,6 @@ export const fetchAuditRecords = async (domain: string, apiKey: string | undefin
message: "Failed to fetch audit logs", message: "Failed to fetch audit logs",
} satisfies ServiceError; } satisfies ServiceError;
} }
}, /* minRequiredRole = */ OrgRole.OWNER), /* allowAnonymousAccess = */ true, apiKey ? { apiKey, domain } : undefined) })
)
); );

View file

@ -1,9 +1,9 @@
'use server'; 'use server';
import { getMe, sew, withAuth } from "@/actions"; import { sew } from "@/sew";
import { ServiceError, stripeClientNotInitialized, notFound } from "@/lib/serviceError"; import { ServiceError, stripeClientNotInitialized, notFound } from "@/lib/serviceError";
import { withOrgMembership } from "@/actions"; import { withAuthV2 } from "@/withAuthV2";
import { prisma } from "@/prisma"; import { withMinimumOrgRole } from "@/withMinimumOrgRole";
import { OrgRole } from "@sourcebot/db"; import { OrgRole } from "@sourcebot/db";
import { stripeClient } from "./stripe"; import { stripeClient } from "./stripe";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
@ -17,13 +17,8 @@ import { createLogger } from "@sourcebot/logger";
const logger = createLogger('billing-actions'); const logger = createLogger('billing-actions');
export const createOnboardingSubscription = async (domain: string) => sew(() => export const createOnboardingSubscription = async (domain: string) => sew(() =>
withAuth(async (userId) => withAuthV2(async ({ user, org, prisma, role }) =>
withOrgMembership(userId, domain, async ({ org }) => { withMinimumOrgRole(role, OrgRole.OWNER, async () => {
const user = await getMe();
if (isServiceError(user)) {
return user;
}
if (!stripeClient) { if (!stripeClient) {
return stripeClientNotInitialized(); return stripeClientNotInitialized();
} }
@ -108,68 +103,65 @@ export const createOnboardingSubscription = async (domain: string) => sew(() =>
message: "Failed to create subscription", message: "Failed to create subscription",
} satisfies ServiceError; } satisfies ServiceError;
} }
}, /* minRequiredRole = */ OrgRole.OWNER) })));
));
export const createStripeCheckoutSession = async (domain: string) => sew(() => export const createStripeCheckoutSession = async (domain: string) => sew(() =>
withAuth((userId) => withAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { if (!org.stripeCustomerId) {
if (!org.stripeCustomerId) { return notFound();
return notFound(); }
}
if (!stripeClient) { if (!stripeClient) {
return stripeClientNotInitialized(); return stripeClientNotInitialized();
} }
const orgMembers = await prisma.userToOrg.findMany({ const orgMembers = await prisma.userToOrg.findMany({
where: { where: {
orgId: org.id, orgId: org.id,
}, },
select: { select: {
userId: true, userId: true,
}
});
const numOrgMembers = orgMembers.length;
const origin = (await headers()).get('origin')!;
const prices = await stripeClient.prices.list({
product: env.STRIPE_PRODUCT_ID,
expand: ['data.product'],
});
const stripeSession = await stripeClient.checkout.sessions.create({
customer: org.stripeCustomerId as string,
payment_method_types: ['card'],
line_items: [
{
price: prices.data[0].id,
quantity: numOrgMembers
} }
}); ],
const numOrgMembers = orgMembers.length; mode: 'subscription',
payment_method_collection: 'always',
const origin = (await headers()).get('origin')!; success_url: `${origin}/${domain}/settings/billing`,
const prices = await stripeClient.prices.list({ cancel_url: `${origin}/${domain}`,
product: env.STRIPE_PRODUCT_ID, });
expand: ['data.product'],
});
const stripeSession = await stripeClient.checkout.sessions.create({
customer: org.stripeCustomerId as string,
payment_method_types: ['card'],
line_items: [
{
price: prices.data[0].id,
quantity: numOrgMembers
}
],
mode: 'subscription',
payment_method_collection: 'always',
success_url: `${origin}/${domain}/settings/billing`,
cancel_url: `${origin}/${domain}`,
});
if (!stripeSession.url) {
return {
statusCode: StatusCodes.INTERNAL_SERVER_ERROR,
errorCode: ErrorCode.STRIPE_CHECKOUT_ERROR,
message: "Failed to create checkout session",
} satisfies ServiceError;
}
if (!stripeSession.url) {
return { return {
url: stripeSession.url, statusCode: StatusCodes.INTERNAL_SERVER_ERROR,
} errorCode: ErrorCode.STRIPE_CHECKOUT_ERROR,
}) message: "Failed to create checkout session",
)); } satisfies ServiceError;
}
return {
url: stripeSession.url,
}
}));
export const getCustomerPortalSessionLink = async (domain: string): Promise<string | ServiceError> => sew(() => export const getCustomerPortalSessionLink = async (domain: string): Promise<string | ServiceError> => sew(() =>
withAuth((userId) => withAuthV2(async ({ org, role }) =>
withOrgMembership(userId, domain, async ({ org }) => { withMinimumOrgRole(role, OrgRole.OWNER, async () => {
if (!org.stripeCustomerId) { if (!org.stripeCustomerId) {
return notFound(); return notFound();
} }
@ -185,31 +177,28 @@ export const getCustomerPortalSessionLink = async (domain: string): Promise<stri
}); });
return portalSession.url; return portalSession.url;
}, /* minRequiredRole = */ OrgRole.OWNER) })));
));
export const getSubscriptionBillingEmail = async (domain: string): Promise<string | ServiceError> => sew(() => export const getSubscriptionBillingEmail = async (_domain: string): Promise<string | ServiceError> => sew(() =>
withAuth(async (userId) => withAuthV2(async ({ org }) => {
withOrgMembership(userId, domain, async ({ org }) => { if (!org.stripeCustomerId) {
if (!org.stripeCustomerId) { return notFound();
return notFound(); }
}
if (!stripeClient) { if (!stripeClient) {
return stripeClientNotInitialized(); return stripeClientNotInitialized();
} }
const customer = await stripeClient.customers.retrieve(org.stripeCustomerId); const customer = await stripeClient.customers.retrieve(org.stripeCustomerId);
if (!('email' in customer) || customer.deleted) { if (!('email' in customer) || customer.deleted) {
return notFound(); return notFound();
} }
return customer.email!; return customer.email!;
}) }));
));
export const changeSubscriptionBillingEmail = async (domain: string, newEmail: string): Promise<{ success: boolean } | ServiceError> => sew(() => export const changeSubscriptionBillingEmail = async (domain: string, newEmail: string): Promise<{ success: boolean } | ServiceError> => sew(() =>
withAuth((userId) => withAuthV2(async ({ org, role }) =>
withOrgMembership(userId, domain, async ({ org }) => { withMinimumOrgRole(role, OrgRole.OWNER, async () => {
if (!org.stripeCustomerId) { if (!org.stripeCustomerId) {
return notFound(); return notFound();
} }
@ -225,24 +214,21 @@ export const changeSubscriptionBillingEmail = async (domain: string, newEmail: s
return { return {
success: true, success: true,
} }
}, /* minRequiredRole = */ OrgRole.OWNER) })));
));
export const getSubscriptionInfo = async (domain: string) => sew(() => export const getSubscriptionInfo = async (_domain: string) => sew(() =>
withAuth(async (userId) => withAuthV2(async ({ org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const subscription = await getSubscriptionForOrg(org.id, prisma);
const subscription = await getSubscriptionForOrg(org.id, prisma);
if (isServiceError(subscription)) { if (isServiceError(subscription)) {
return subscription; return subscription;
} }
return { return {
status: subscription.status, status: subscription.status,
plan: "Team", plan: "Team",
seats: subscription.items.data[0].quantity!, seats: subscription.items.data[0].quantity!,
perSeatPrice: subscription.items.data[0].price.unit_amount! / 100, perSeatPrice: subscription.items.data[0].price.unit_amount! / 100,
nextBillingDate: subscription.current_period_end!, nextBillingDate: subscription.current_period_end!,
} }
}) }));
));

View file

@ -1,6 +1,7 @@
'use server'; 'use server';
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/sew";
import { withAuthV2, withOptionalAuthV2 } from "@/withAuthV2";
import { env } from "@/env.mjs"; import { env } from "@/env.mjs";
import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants"; import { SOURCEBOT_GUEST_USER_ID } from "@/lib/constants";
import { ErrorCode } from "@/lib/errorCodes"; import { ErrorCode } from "@/lib/errorCodes";
@ -32,181 +33,177 @@ import path from 'path';
import { LanguageModelInfo, SBChatMessage } from "./types"; import { LanguageModelInfo, SBChatMessage } from "./types";
export const createChat = async (domain: string) => sew(() => export const createChat = async (domain: string) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ user, org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const userId = user?.id ?? SOURCEBOT_GUEST_USER_ID;
const isGuestUser = userId === SOURCEBOT_GUEST_USER_ID;
const isGuestUser = userId === SOURCEBOT_GUEST_USER_ID; const chat = await prisma.chat.create({
data: {
orgId: org.id,
messages: [] as unknown as Prisma.InputJsonValue,
createdById: userId,
visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE,
},
});
const chat = await prisma.chat.create({ return {
data: { id: chat.id,
orgId: org.id, }
messages: [] as unknown as Prisma.InputJsonValue, })
createdById: userId,
visibility: isGuestUser ? ChatVisibility.PUBLIC : ChatVisibility.PRIVATE,
},
});
return {
id: chat.id,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
); );
export const getChatInfo = async ({ chatId }: { chatId: string }, domain: string) => sew(() => export const getChatInfo = async ({ chatId }: { chatId: string }, domain: string) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ user, org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const userId = user?.id ?? SOURCEBOT_GUEST_USER_ID;
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
orgId: org.id, orgId: org.id,
}, },
}); });
if (!chat) { if (!chat) {
return notFound(); return notFound();
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) {
return notFound(); return notFound();
} }
return { return {
messages: chat.messages as unknown as SBChatMessage[], messages: chat.messages as unknown as SBChatMessage[],
visibility: chat.visibility, visibility: chat.visibility,
name: chat.name, name: chat.name,
isReadonly: chat.isReadonly, isReadonly: chat.isReadonly,
}; };
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); );
export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }, domain: string) => sew(() => export const updateChatMessages = async ({ chatId, messages }: { chatId: string, messages: SBChatMessage[] }, domain: string) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ user, org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const userId = user?.id ?? SOURCEBOT_GUEST_USER_ID;
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
orgId: org.id, orgId: org.id,
}, },
}); });
if (!chat) { if (!chat) {
return notFound(); return notFound();
}
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) {
return notFound();
}
if (chat.isReadonly) {
return chatIsReadonly();
}
await prisma.chat.update({
where: {
id: chatId,
},
data: {
messages: messages as unknown as Prisma.InputJsonValue,
},
});
if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) {
const chatDir = path.join(env.DATA_CACHE_DIR, 'chats');
if (!fs.existsSync(chatDir)) {
fs.mkdirSync(chatDir, { recursive: true });
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { const chatFile = path.join(chatDir, `${chatId}.json`);
return notFound(); fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2));
} }
if (chat.isReadonly) { return {
return chatIsReadonly(); success: true,
} }
})
await prisma.chat.update({
where: {
id: chatId,
},
data: {
messages: messages as unknown as Prisma.InputJsonValue,
},
});
if (env.DEBUG_WRITE_CHAT_MESSAGES_TO_FILE) {
const chatDir = path.join(env.DATA_CACHE_DIR, 'chats');
if (!fs.existsSync(chatDir)) {
fs.mkdirSync(chatDir, { recursive: true });
}
const chatFile = path.join(chatDir, `${chatId}.json`);
fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2));
}
return {
success: true,
}
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
); );
export const getUserChatHistory = async (domain: string) => sew(() => export const getUserChatHistory = async (domain: string) => sew(() =>
withAuth((userId) => withAuthV2(async ({ user, org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const chats = await prisma.chat.findMany({
const chats = await prisma.chat.findMany({ where: {
where: { orgId: org.id,
orgId: org.id, createdById: user.id,
createdById: userId, },
}, orderBy: {
orderBy: { updatedAt: 'desc',
updatedAt: 'desc', },
}, });
});
return chats.map((chat) => ({ return chats.map((chat) => ({
id: chat.id, id: chat.id,
createdAt: chat.createdAt, createdAt: chat.createdAt,
name: chat.name, name: chat.name,
visibility: chat.visibility, visibility: chat.visibility,
})) }))
}) })
)
); );
export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }, domain: string) => sew(() => export const updateChatName = async ({ chatId, name }: { chatId: string, name: string }, domain: string) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ user, org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const userId = user?.id ?? SOURCEBOT_GUEST_USER_ID;
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
orgId: org.id, orgId: org.id,
}, },
}); });
if (!chat) { if (!chat) {
return notFound(); return notFound();
} }
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) {
return notFound(); return notFound();
} }
if (chat.isReadonly) { if (chat.isReadonly) {
return chatIsReadonly(); return chatIsReadonly();
} }
await prisma.chat.update({ await prisma.chat.update({
where: { where: {
id: chatId, id: chatId,
orgId: org.id, orgId: org.id,
}, },
data: { data: {
name, name,
}, },
}); });
return { return {
success: true, success: true,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true) })
); );
export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() => export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ org }) => {
withOrgMembership(userId, domain, async ({ org }) => { // From the language model ID, attempt to find the
// From the language model ID, attempt to find the // corresponding config in `config.json`.
// corresponding config in `config.json`. const languageModelConfig =
const languageModelConfig = (await _getConfiguredLanguageModelsFull())
(await _getConfiguredLanguageModelsFull()) .find((model) => model.model === languageModelId);
.find((model) => model.model === languageModelId);
if (!languageModelConfig) { if (!languageModelConfig) {
return serviceErrorResponse({ return serviceErrorResponse({
statusCode: StatusCodes.BAD_REQUEST, statusCode: StatusCodes.BAD_REQUEST,
errorCode: ErrorCode.INVALID_REQUEST_BODY, errorCode: ErrorCode.INVALID_REQUEST_BODY,
message: `Language model ${languageModelId} is not configured.`, message: `Language model ${languageModelId} is not configured.`,
}); });
} }
const { model } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id); const { model } = await _getAISDKLanguageModelAndOptions(languageModelConfig, org.id);
const prompt = `Convert this question into a short topic title (max 50 characters). const prompt = `Convert this question into a short topic title (max 50 characters).
Rules: Rules:
- Do NOT include question words (what, where, how, why, when, which) - Do NOT include question words (what, where, how, why, when, which)
@ -222,63 +219,61 @@ Examples:
User question: ${message}`; User question: ${message}`;
const result = await generateText({ const result = await generateText({
model, model,
prompt, prompt,
}); });
await updateChatName({ await updateChatName({
chatId, chatId,
name: result.text, name: result.text,
}, domain); }, domain);
return { return {
success: true, success: true,
} }
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true })
)
); );
export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() => export const deleteChat = async ({ chatId }: { chatId: string }, domain: string) => sew(() =>
withAuth((userId) => withAuthV2(async ({ user, org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const userId = user.id;
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
orgId: org.id, orgId: org.id,
}, },
}); });
if (!chat) { if (!chat) {
return notFound(); return notFound();
} }
// Public chats cannot be deleted.
if (chat.visibility === ChatVisibility.PUBLIC) {
return {
statusCode: StatusCodes.FORBIDDEN,
errorCode: ErrorCode.UNEXPECTED_ERROR,
message: 'You are not allowed to delete this chat.',
} satisfies ServiceError;
}
// Only the creator of a chat can delete it.
if (chat.createdById !== userId) {
return notFound();
}
await prisma.chat.delete({
where: {
id: chatId,
orgId: org.id,
},
});
// Public chats cannot be deleted.
if (chat.visibility === ChatVisibility.PUBLIC) {
return { return {
success: true, statusCode: StatusCodes.FORBIDDEN,
} errorCode: ErrorCode.UNEXPECTED_ERROR,
}) message: 'You are not allowed to delete this chat.',
) } satisfies ServiceError;
}
// Only the creator of a chat can delete it.
if (chat.createdById !== userId) {
return notFound();
}
await prisma.chat.delete({
where: {
id: chatId,
orgId: org.id,
},
});
return {
success: true,
}
})
); );
export const submitFeedback = async ({ export const submitFeedback = async ({
@ -290,54 +285,54 @@ export const submitFeedback = async ({
messageId: string, messageId: string,
feedbackType: 'like' | 'dislike' feedbackType: 'like' | 'dislike'
}, domain: string) => sew(() => }, domain: string) => sew(() =>
withAuth((userId) => withOptionalAuthV2(async ({ user, org, prisma }) => {
withOrgMembership(userId, domain, async ({ org }) => { const userId = user?.id ?? SOURCEBOT_GUEST_USER_ID;
const chat = await prisma.chat.findUnique({ const chat = await prisma.chat.findUnique({
where: { where: {
id: chatId, id: chatId,
orgId: org.id, orgId: org.id,
}, },
}); });
if (!chat) { if (!chat) {
return notFound(); return notFound();
}
// When a chat is private, only the creator can submit feedback.
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) {
return notFound();
}
const messages = chat.messages as unknown as SBChatMessage[];
const updatedMessages = messages.map(message => {
if (message.id === messageId && message.role === 'assistant') {
return {
...message,
metadata: {
...message.metadata,
feedback: [
...(message.metadata?.feedback ?? []),
{
type: feedbackType,
timestamp: new Date().toISOString(),
userId: userId,
}
]
}
} satisfies SBChatMessage;
} }
return message;
});
// When a chat is private, only the creator can submit feedback. await prisma.chat.update({
if (chat.visibility === ChatVisibility.PRIVATE && chat.createdById !== userId) { where: { id: chatId },
return notFound(); data: {
} messages: updatedMessages as unknown as Prisma.InputJsonValue,
},
});
const messages = chat.messages as unknown as SBChatMessage[]; return { success: true };
const updatedMessages = messages.map(message => { })
if (message.id === messageId && message.role === 'assistant') {
return {
...message,
metadata: {
...message.metadata,
feedback: [
...(message.metadata?.feedback ?? []),
{
type: feedbackType,
timestamp: new Date().toISOString(),
userId: userId,
}
]
}
} satisfies SBChatMessage;
}
return message;
});
await prisma.chat.update({
where: { id: chatId },
data: {
messages: updatedMessages as unknown as Prisma.InputJsonValue,
},
});
return { success: true };
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowSingleTenantUnauthedAccess = */ true)
); );
/** /**

View file

@ -1,13 +1,13 @@
'use server'; 'use server';
import { sew, withAuth, withOrgMembership } from "@/actions"; import { sew } from "@/sew";
import { withOptionalAuthV2 } from "@/withAuthV2";
import { searchResponseSchema } from "@/features/search/schemas"; import { searchResponseSchema } from "@/features/search/schemas";
import { search } from "@/features/search/searchApi"; import { search } from "@/features/search/searchApi";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { FindRelatedSymbolsResponse } from "./types"; import { FindRelatedSymbolsResponse } from "./types";
import { ServiceError } from "@/lib/serviceError"; import { ServiceError } from "@/lib/serviceError";
import { SearchResponse } from "../search/types"; import { SearchResponse } from "../search/types";
import { OrgRole } from "@sourcebot/db";
// The maximum number of matches to return from the search API. // The maximum number of matches to return from the search API.
const MAX_REFERENCE_COUNT = 1000; const MAX_REFERENCE_COUNT = 1000;
@ -20,28 +20,27 @@ export const findSearchBasedSymbolReferences = async (
}, },
domain: string, domain: string,
): Promise<FindRelatedSymbolsResponse | ServiceError> => sew(() => ): Promise<FindRelatedSymbolsResponse | ServiceError> => sew(() =>
withAuth((session) => withOptionalAuthV2(async () => {
withOrgMembership(session, domain, async () => { const {
const { symbolName,
symbolName, language,
language, revisionName = "HEAD",
revisionName = "HEAD", } = props;
} = props;
const query = `\\b${symbolName}\\b rev:${revisionName} ${getExpandedLanguageFilter(language)} case:yes`; const query = `\\b${symbolName}\\b rev:${revisionName} ${getExpandedLanguageFilter(language)} case:yes`;
const searchResult = await search({ const searchResult = await search({
query, query,
matches: MAX_REFERENCE_COUNT, matches: MAX_REFERENCE_COUNT,
contextLines: 0, contextLines: 0,
}); });
if (isServiceError(searchResult)) { if (isServiceError(searchResult)) {
return searchResult; return searchResult;
} }
return parseRelatedSymbolsSearchResponse(searchResult); return parseRelatedSymbolsSearchResponse(searchResult);
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true) })
); );
@ -53,28 +52,27 @@ export const findSearchBasedSymbolDefinitions = async (
}, },
domain: string, domain: string,
): Promise<FindRelatedSymbolsResponse | ServiceError> => sew(() => ): Promise<FindRelatedSymbolsResponse | ServiceError> => sew(() =>
withAuth((session) => withOptionalAuthV2(async () => {
withOrgMembership(session, domain, async () => { const {
const { symbolName,
symbolName, language,
language, revisionName = "HEAD",
revisionName = "HEAD", } = props;
} = props;
const query = `sym:\\b${symbolName}\\b rev:${revisionName} ${getExpandedLanguageFilter(language)}`; const query = `sym:\\b${symbolName}\\b rev:${revisionName} ${getExpandedLanguageFilter(language)}`;
const searchResult = await search({ const searchResult = await search({
query, query,
matches: MAX_REFERENCE_COUNT, matches: MAX_REFERENCE_COUNT,
contextLines: 0, contextLines: 0,
}); });
if (isServiceError(searchResult)) { if (isServiceError(searchResult)) {
return searchResult; return searchResult;
} }
return parseRelatedSymbolsSearchResponse(searchResult); return parseRelatedSymbolsSearchResponse(searchResult);
}, /* minRequiredRole = */ OrgRole.GUEST), /* allowAnonymousAccess = */ true) })
); );
const parseRelatedSymbolsSearchResponse = (searchResult: SearchResponse) => { const parseRelatedSymbolsSearchResponse = (searchResult: SearchResponse) => {

View file

@ -1,6 +1,6 @@
'use server'; 'use server';
import { sew } from '@/actions'; import { sew } from "@/sew";
import { env } from '@/env.mjs'; import { env } from '@/env.mjs';
import { notFound, unexpectedError } from '@/lib/serviceError'; import { notFound, unexpectedError } from '@/lib/serviceError';
import { withOptionalAuthV2 } from '@/withAuthV2'; import { withOptionalAuthV2 } from '@/withAuthV2';

View file

@ -5,7 +5,7 @@ import { fileNotFound, ServiceError, unexpectedError } from "../../lib/serviceEr
import { FileSourceRequest, FileSourceResponse } from "./types"; import { FileSourceRequest, FileSourceResponse } from "./types";
import { isServiceError } from "../../lib/utils"; import { isServiceError } from "../../lib/utils";
import { search } from "./searchApi"; import { search } from "./searchApi";
import { sew } from "@/actions"; import { sew } from "@/sew";
import { withOptionalAuthV2 } from "@/withAuthV2"; import { withOptionalAuthV2 } from "@/withAuthV2";
// @todo (bkellam) : We should really be using `git show <hash>:<path>` to fetch file contents here. // @todo (bkellam) : We should really be using `git show <hash>:<path>` to fetch file contents here.
// This will allow us to support permalinks to files at a specific revision that may not be indexed // This will allow us to support permalinks to files at a specific revision that may not be indexed

View file

@ -9,7 +9,7 @@ import { StatusCodes } from "http-status-codes";
import { zoektSearchResponseSchema } from "./zoektSchema"; import { zoektSearchResponseSchema } from "./zoektSchema";
import { SearchRequest, SearchResponse, SourceRange } from "./types"; import { SearchRequest, SearchResponse, SourceRange } from "./types";
import { PrismaClient, Repo } from "@sourcebot/db"; import { PrismaClient, Repo } from "@sourcebot/db";
import { sew } from "@/actions"; import { sew } from "@/sew";
import { base64Decode } from "@sourcebot/shared"; import { base64Decode } from "@sourcebot/shared";
import { withOptionalAuthV2 } from "@/withAuthV2"; import { withOptionalAuthV2 } from "@/withAuthV2";

28
packages/web/src/sew.ts Normal file
View file

@ -0,0 +1,28 @@
'use server';
import * as Sentry from "@sentry/nextjs";
import { ServiceError, unexpectedError } from "./lib/serviceError";
import { createLogger } from "@sourcebot/logger";
const logger = createLogger('service-error-wrapper');
/**
* "Service Error Wrapper".
*
* Captures any thrown exceptions and converts them to a unexpected
* service error. Also logs them with Sentry.
*/
export const sew = async <T>(fn: () => Promise<T>): Promise<T | ServiceError> => {
try {
return await fn();
} catch (e) {
Sentry.captureException(e);
logger.error(e);
if (e instanceof Error) {
return unexpectedError(e.message);
}
return unexpectedError(`An unexpected error occurred. Please try again later.`);
}
};

View file

@ -5,8 +5,6 @@ import { headers } from "next/headers";
import { auth } from "./auth"; import { auth } from "./auth";
import { notAuthenticated, notFound, ServiceError } from "./lib/serviceError"; import { notAuthenticated, notFound, ServiceError } from "./lib/serviceError";
import { SINGLE_TENANT_ORG_ID } from "./lib/constants"; import { SINGLE_TENANT_ORG_ID } from "./lib/constants";
import { StatusCodes } from "http-status-codes";
import { ErrorCode } from "./lib/errorCodes";
import { getOrgMetadata, isServiceError } from "./lib/utils"; import { getOrgMetadata, isServiceError } from "./lib/utils";
import { hasEntitlement } from "@sourcebot/shared"; import { hasEntitlement } from "@sourcebot/shared";