wip refresh oauth tokens

This commit is contained in:
msukkari 2025-11-03 18:27:13 -08:00
parent 6cc9d0b267
commit e6498531aa
10 changed files with 143 additions and 9 deletions

View file

@ -24,7 +24,7 @@ import { LogoutEscapeHatch } from "@/app/components/logoutEscapeHatch";
import { GitHubStarToast } from "./components/githubStarToast"; import { GitHubStarToast } from "./components/githubStarToast";
import { UpgradeToast } from "./components/upgradeToast"; import { UpgradeToast } from "./components/upgradeToast";
import { getIntegrationProviderStates } from "@/ee/features/permissionSyncing/actions"; import { getIntegrationProviderStates } from "@/ee/features/permissionSyncing/actions";
import { LinkAccounts } from "@/ee/features/permissionSyncing/linkAccounts"; import { LinkAccounts } from "@/ee/features/permissionSyncing/components/linkAccounts";
interface LayoutProps { interface LayoutProps {
children: React.ReactNode, children: React.ReactNode,

View file

@ -1,6 +1,6 @@
import { hasEntitlement } from "@sourcebot/shared"; import { hasEntitlement } from "@sourcebot/shared";
import { notFound } from "@/lib/serviceError"; import { notFound } from "@/lib/serviceError";
import { LinkedAccountsSettings } from "@/ee/features/permissionSyncing/linkedAccountsSettings"; import { LinkedAccountsSettings } from "@/ee/features/permissionSyncing/components/linkedAccountsSettings";
export default async function PermissionSyncingPage() { export default async function PermissionSyncingPage() {
const hasPermissionSyncingEntitlement = await hasEntitlement("permission-syncing"); const hasPermissionSyncingEntitlement = await hasEntitlement("permission-syncing");

View file

@ -18,6 +18,7 @@ import { hasEntitlement } from '@sourcebot/shared';
import { onCreateUser } from '@/lib/authUtils'; import { onCreateUser } from '@/lib/authUtils';
import { getAuditService } from '@/ee/features/audit/factory'; import { getAuditService } from '@/ee/features/audit/factory';
import { SINGLE_TENANT_ORG_ID } from './lib/constants'; import { SINGLE_TENANT_ORG_ID } from './lib/constants';
import { refreshOAuthToken } from '@/ee/features/permissionSyncing/actions';
const auditService = getAuditService(); const auditService = getAuditService();
const eeIdentityProviders = hasEntitlement("sso") ? await getEEIdentityProviders() : []; const eeIdentityProviders = hasEntitlement("sso") ? await getEEIdentityProviders() : [];
@ -40,7 +41,12 @@ declare module 'next-auth' {
declare module 'next-auth/jwt' { declare module 'next-auth/jwt' {
interface JWT { interface JWT {
userId: string userId: string;
accessToken?: string;
refreshToken?: string;
expiresAt?: number;
provider?: string;
error?: string;
} }
} }
@ -179,13 +185,51 @@ export const { handlers, signIn, signOut, auth } = NextAuth({
} }
}, },
callbacks: { callbacks: {
async jwt({ token, user: _user }) { async jwt({ token, user: _user, account }) {
const user = _user as User | undefined; const user = _user as User | undefined;
// @note: `user` will be available on signUp or signIn triggers. // @note: `user` will be available on signUp or signIn triggers.
// Cache the userId in the JWT for later use. // Cache the userId in the JWT for later use.
if (user) { if (user) {
token.userId = user.id; token.userId = user.id;
} }
if (account) {
token.accessToken = account.access_token;
token.refreshToken = account.refresh_token;
token.expiresAt = account.expires_at;
token.provider = account.provider;
}
if (hasEntitlement('permission-syncing') &&
token.provider &&
['github', 'gitlab'].includes(token.provider) &&
token.expiresAt &&
token.refreshToken) {
const now = Math.floor(Date.now() / 1000);
const bufferTimeS = 5 * 60;
if (now >= (token.expiresAt - bufferTimeS)) {
try {
const refreshedTokens = await refreshOAuthToken(
token.provider,
token.refreshToken,
token.userId
);
if (refreshedTokens) {
token.accessToken = refreshedTokens.accessToken;
token.refreshToken = refreshedTokens.refreshToken ?? token.refreshToken;
token.expiresAt = refreshedTokens.expiresAt;
} else {
token.error = 'RefreshTokenError';
}
} catch (error) {
console.error('Error refreshing token:', error);
token.error = 'RefreshTokenError';
}
}
}
return token; return token;
}, },
async session({ session, token }) { async session({ session, token }) {

View file

@ -8,7 +8,9 @@ import { env } from "@/env.mjs";
import { OrgRole } from "@sourcebot/db"; import { OrgRole } from "@sourcebot/db";
import { cookies } from "next/headers"; import { cookies } from "next/headers";
import { OPTIONAL_PROVIDERS_LINK_SKIPPED_COOKIE_NAME } from "@/lib/constants"; import { OPTIONAL_PROVIDERS_LINK_SKIPPED_COOKIE_NAME } from "@/lib/constants";
import { getTokenFromConfig } from '@sourcebot/crypto';
import { IntegrationIdentityProviderState } from "@/ee/features/permissionSyncing/types"; import { IntegrationIdentityProviderState } from "@/ee/features/permissionSyncing/types";
import { GitHubIdentityProviderConfig, GitLabIdentityProviderConfig } from "@sourcebot/schemas/v3/index.type";
const logger = createLogger('web-ee-permission-syncing-actions'); const logger = createLogger('web-ee-permission-syncing-actions');
@ -95,4 +97,92 @@ export const skipOptionalProvidersLink = async () => sew(async () => {
maxAge: 365 * 24 * 60 * 60, // 1 year in seconds maxAge: 365 * 24 * 60 * 60, // 1 year in seconds
}); });
return true; return true;
}); });
export const refreshOAuthToken = async (
provider: string,
refreshToken: string,
userId: string
): Promise<{ accessToken: string; refreshToken: string | null; expiresAt: number } | null> => {
try {
// Load config and find the provider configuration
const config = await loadConfig(env.CONFIG_PATH);
const identityProviders = config?.identityProviders ?? [];
const providerConfig = identityProviders.find(
idp => idp.provider === provider
) as GitHubIdentityProviderConfig | GitLabIdentityProviderConfig;
if (!providerConfig || !('clientId' in providerConfig) || !('clientSecret' in providerConfig)) {
logger.error(`Provider config not found or invalid for: ${provider}`);
return null;
}
// Get client credentials from config
const clientId = await getTokenFromConfig(providerConfig.clientId);
const clientSecret = await getTokenFromConfig(providerConfig.clientSecret);
const baseUrl = 'baseUrl' in providerConfig && providerConfig.baseUrl
? await getTokenFromConfig(providerConfig.baseUrl)
: undefined;
let url: string;
if (baseUrl) {
url = provider === 'github'
? `${baseUrl}/login/oauth/access_token`
: `${baseUrl}/oauth/token`;
} else if (provider === 'github') {
url = 'https://github.com/login/oauth/access_token';
} else if (provider === 'gitlab') {
url = 'https://gitlab.com/oauth/token';
} else {
logger.error(`Unsupported provider for token refresh: ${provider}`);
return null;
}
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json',
},
body: new URLSearchParams({
client_id: clientId,
client_secret: clientSecret,
grant_type: 'refresh_token',
refresh_token: refreshToken,
}),
});
if (!response.ok) {
const errorText = await response.text();
logger.error(`Failed to refresh ${provider} token: ${response.status} ${errorText}`);
return null;
}
const data = await response.json();
const result = {
accessToken: data.access_token,
refreshToken: data.refresh_token ?? null,
expiresAt: data.expires_in ? Math.floor(Date.now() / 1000) + data.expires_in : 0,
};
const { prisma } = await import('@/prisma');
await prisma.account.updateMany({
where: {
userId: userId,
provider: provider,
},
data: {
access_token: result.accessToken,
refresh_token: result.refreshToken,
expires_at: result.expiresAt,
},
});
return result;
} catch (error) {
logger.error(`Error refreshing ${provider} token:`, error);
return null;
}
};

View file

@ -1,8 +1,8 @@
import { getAuthProviderInfo } from "@/lib/utils"; import { getAuthProviderInfo } from "@/lib/utils";
import { Check, X } from "lucide-react"; import { Check, X } from "lucide-react";
import { Card, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Card, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { ProviderIcon } from "./components/providerIcon"; import { ProviderIcon } from "./providerIcon";
import { ProviderInfo } from "./components/providerInfo"; import { ProviderInfo } from "./providerInfo";
import { UnlinkButton } from "./unlinkButton"; import { UnlinkButton } from "./unlinkButton";
import { LinkButton } from "./linkButton"; import { LinkButton } from "./linkButton";
import { IntegrationIdentityProviderState } from "@/ee/features/permissionSyncing/types" import { IntegrationIdentityProviderState } from "@/ee/features/permissionSyncing/types"

View file

@ -3,7 +3,7 @@
import { useState } from "react"; import { useState } from "react";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Unlink, Loader2 } from "lucide-react"; import { Unlink, Loader2 } from "lucide-react";
import { unlinkIntegrationProvider } from "./actions"; import { unlinkIntegrationProvider } from "../actions";
import { isServiceError } from "@/lib/utils"; import { isServiceError } from "@/lib/utils";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useToast } from "@/components/hooks/use-toast"; import { useToast } from "@/components/hooks/use-toast";

View file

@ -189,7 +189,7 @@ export const updateChatName = async ({ chatId, name }: { chatId: string, name: s
export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() => export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }, domain: string) => sew(() =>
withAuth((userId) => withAuth((userId) =>
withOrgMembership(userId, domain, async ({ org }) => { withOrgMembership(userId, domain, async () => {
// From the language model ID, attempt to find the // From the language model ID, attempt to find the
// corresponding config in `config.json`. // corresponding config in `config.json`.
const languageModelConfig = const languageModelConfig =