diff --git a/packages/web/src/app/api/(server)/search/route.ts b/packages/web/src/app/api/(server)/search/route.ts index 027096f0..b78a2991 100644 --- a/packages/web/src/app/api/(server)/search/route.ts +++ b/packages/web/src/app/api/(server)/search/route.ts @@ -14,8 +14,18 @@ export const POST = async (request: NextRequest) => { schemaValidationError(parsed.error) ); } + + const { + query, + ...options + } = parsed.data; - const response = await search(parsed.data); + const response = await search({ + queryType: 'string', + query, + options, + }); + if (isServiceError(response)) { return serviceErrorResponse(response); } diff --git a/packages/web/src/app/api/(server)/stream_search/route.ts b/packages/web/src/app/api/(server)/stream_search/route.ts index 2e09b078..409079a7 100644 --- a/packages/web/src/app/api/(server)/stream_search/route.ts +++ b/packages/web/src/app/api/(server)/stream_search/route.ts @@ -14,7 +14,17 @@ export const POST = async (request: NextRequest) => { return serviceErrorResponse(schemaValidationError(parsed.error)); } - const stream = await streamSearch(parsed.data); + const { + query, + ...options + } = parsed.data; + + const stream = await streamSearch({ + queryType: 'string', + query, + options, + }); + if (isServiceError(stream)) { return serviceErrorResponse(stream); } diff --git a/packages/web/src/features/chat/tools.ts b/packages/web/src/features/chat/tools.ts index 94e33e73..60ed8714 100644 --- a/packages/web/src/features/chat/tools.ts +++ b/packages/web/src/features/chat/tools.ts @@ -178,12 +178,15 @@ Multiple expressions can be or'd together with or, negated with -, or grouped wi }); const response = await search({ + queryType: 'string', query, - matches: limit ?? 100, - contextLines: 3, - whole: false, - isCaseSensitivityEnabled: true, - isRegexEnabled: true, + options: { + matches: limit ?? 100, + contextLines: 3, + whole: false, + isCaseSensitivityEnabled: true, + isRegexEnabled: true, + } }); if (isServiceError(response)) { @@ -219,11 +222,11 @@ export const searchReposTool = tool({ }), execute: async ({ query, limit }) => { const reposResponse = await getRepos(); - + if (isServiceError(reposResponse)) { return reposResponse; } - + // Configure Fuse.js for fuzzy searching const fuse = new Fuse(reposResponse, { keys: [ @@ -234,7 +237,7 @@ export const searchReposTool = tool({ includeScore: true, minMatchCharLength: 1, }); - + const searchResults = fuse.search(query, { limit: limit ?? 10 }); searchResults.sort((a, b) => (a.score ?? 0) - (b.score ?? 0)); @@ -253,11 +256,11 @@ export const listAllReposTool = tool({ inputSchema: z.object({}), execute: async () => { const reposResponse = await getRepos(); - + if (isServiceError(reposResponse)) { return reposResponse; } - + return reposResponse.map((repo) => repo.repoName); } }); diff --git a/packages/web/src/features/codeNav/api.ts b/packages/web/src/features/codeNav/api.ts index ace85293..ab2cd7a9 100644 --- a/packages/web/src/features/codeNav/api.ts +++ b/packages/web/src/features/codeNav/api.ts @@ -7,6 +7,7 @@ import { isServiceError } from "@/lib/utils"; import { withOptionalAuthV2 } from "@/withAuthV2"; import { SearchResponse } from "../search/types"; import { FindRelatedSymbolsRequest, FindRelatedSymbolsResponse } from "./types"; +import { QueryIR } from '../search/ir'; // The maximum number of matches to return from the search API. const MAX_REFERENCE_COUNT = 1000; @@ -19,14 +20,37 @@ export const findSearchBasedSymbolReferences = async (props: FindRelatedSymbolsR revisionName = "HEAD", } = props; - const query = `\\b${symbolName}\\b rev:${revisionName} ${getExpandedLanguageFilter(language)}`; + const languageFilter = getExpandedLanguageFilter(language); + + const query: QueryIR = { + and: { + children: [ + { + regexp: { + regexp: `\\b${symbolName}\\b`, + case_sensitive: true, + file_name: false, + content: true, + } + }, + { + branch: { + pattern: revisionName, + exact: true, + } + }, + languageFilter, + ] + } + } const searchResult = await search({ + queryType: 'ir', query, - matches: MAX_REFERENCE_COUNT, - contextLines: 0, - isCaseSensitivityEnabled: true, - isRegexEnabled: true, + options: { + matches: MAX_REFERENCE_COUNT, + contextLines: 0, + } }); if (isServiceError(searchResult)) { @@ -39,27 +63,54 @@ export const findSearchBasedSymbolReferences = async (props: FindRelatedSymbolsR export const findSearchBasedSymbolDefinitions = async (props: FindRelatedSymbolsRequest): Promise => sew(() => withOptionalAuthV2(async () => { - const { - symbolName, - language, - revisionName = "HEAD", - } = props; + const { + symbolName, + language, + revisionName = "HEAD", + } = props; - const query = `sym:\\b${symbolName}\\b rev:${revisionName} ${getExpandedLanguageFilter(language)}`; + const languageFilter = getExpandedLanguageFilter(language); - const searchResult = await search({ - query, + const query: QueryIR = { + and: { + children: [ + { + symbol: { + expr: { + regexp: { + regexp: `\\b${symbolName}\\b`, + case_sensitive: true, + file_name: false, + content: true, + } + }, + } + }, + { + branch: { + pattern: revisionName, + exact: true, + } + }, + languageFilter, + ] + } + } + + const searchResult = await search({ + queryType: 'ir', + query, + options: { matches: MAX_REFERENCE_COUNT, contextLines: 0, - isCaseSensitivityEnabled: true, - isRegexEnabled: true, - }); - - if (isServiceError(searchResult)) { - return searchResult; } + }); - return parseRelatedSymbolsSearchResponse(searchResult); + if (isServiceError(searchResult)) { + return searchResult; + } + + return parseRelatedSymbolsSearchResponse(searchResult); })); const parseRelatedSymbolsSearchResponse = (searchResult: SearchResponse): FindRelatedSymbolsResponse => { @@ -89,14 +140,43 @@ const parseRelatedSymbolsSearchResponse = (searchResult: SearchResponse): FindRe } // Expands the language filter to include all variants of the language. -const getExpandedLanguageFilter = (language: string) => { +const getExpandedLanguageFilter = (language: string): QueryIR => { switch (language) { case "TypeScript": case "JavaScript": case "JSX": case "TSX": - return `(lang:TypeScript or lang:JavaScript or lang:JSX or lang:TSX)` + return { + or: { + children: [ + { + language: { + language: "TypeScript", + } + }, + { + language: { + language: "JavaScript", + } + }, + { + language: { + language: "JSX", + } + }, + { + language: { + language: "TSX", + } + }, + ] + }, + } default: - return `lang:${language}` + return { + language: { + language: language, + }, + } } } \ No newline at end of file diff --git a/packages/web/src/features/search/fileSourceApi.ts b/packages/web/src/features/search/fileSourceApi.ts index 1cb887ec..2c9b153f 100644 --- a/packages/web/src/features/search/fileSourceApi.ts +++ b/packages/web/src/features/search/fileSourceApi.ts @@ -5,23 +5,46 @@ import { isServiceError } from "../../lib/utils"; import { search } from "./searchApi"; import { sew } from "@/actions"; import { withOptionalAuthV2 } from "@/withAuthV2"; +import { QueryIR } from './ir'; // @todo (bkellam) #574 : We should really be using `git show :` to fetch file contents here. // This will allow us to support permalinks to files at a specific revision that may not be indexed // by zoekt. export const getFileSource = async ({ fileName, repository, branch }: FileSourceRequest): Promise => sew(() => withOptionalAuthV2(async () => { - let query = `file:${fileName} repo:^${repository}$`; - if (branch) { - query = query.concat(` rev:${branch}`); + const query: QueryIR = { + and: { + children: [ + { + repo: { + regexp: `^${repository}$`, + }, + }, + { + regexp: { + regexp: fileName, + case_sensitive: true, + file_name: true, + content: false + }, + }, + ...(branch ? [{ + branch: { + pattern: branch, + exact: true, + }, + }]: []) + ] + } } const searchResponse = await search({ + queryType: 'ir', query, - matches: 1, - whole: true, - isCaseSensitivityEnabled: true, - isRegexEnabled: true, + options: { + matches: 1, + whole: true, + } }); if (isServiceError(searchResponse)) { diff --git a/packages/web/src/features/search/ir.ts b/packages/web/src/features/search/ir.ts new file mode 100644 index 00000000..75c3b27e --- /dev/null +++ b/packages/web/src/features/search/ir.ts @@ -0,0 +1,185 @@ +import { Q as QueryIR } from '@/proto/zoekt/webserver/v1/Q'; + +export type { + QueryIR, +} + +/** + * Visitor pattern for traversing a QueryIR tree. + * Return false from any method to stop traversal early. + */ +export type QueryVisitor = { + onRawConfig?: (query: QueryIR) => boolean | void; + onRegexp?: (query: QueryIR) => boolean | void; + onSymbol?: (query: QueryIR) => boolean | void; + onLanguage?: (query: QueryIR) => boolean | void; + onConst?: (query: QueryIR) => boolean | void; + onRepo?: (query: QueryIR) => boolean | void; + onRepoRegexp?: (query: QueryIR) => boolean | void; + onBranchesRepos?: (query: QueryIR) => boolean | void; + onRepoIds?: (query: QueryIR) => boolean | void; + onRepoSet?: (query: QueryIR) => boolean | void; + onFileNameSet?: (query: QueryIR) => boolean | void; + onType?: (query: QueryIR) => boolean | void; + onSubstring?: (query: QueryIR) => boolean | void; + onAnd?: (query: QueryIR) => boolean | void; + onOr?: (query: QueryIR) => boolean | void; + onNot?: (query: QueryIR) => boolean | void; + onBranch?: (query: QueryIR) => boolean | void; + onBoost?: (query: QueryIR) => boolean | void; +}; + +/** + * Traverses a QueryIR tree using the visitor pattern. + * @param query The query to traverse + * @param visitor An object with optional callback methods for each query type + * @returns false if traversal was stopped early, true otherwise + */ +export function traverseQueryIR( + query: QueryIR, + visitor: QueryVisitor +): boolean { + if (!query.query) { + return true; + } + + // Call the appropriate visitor method + let shouldContinue: boolean | void = true; + + switch (query.query) { + case 'raw_config': + shouldContinue = visitor.onRawConfig?.(query); + break; + case 'regexp': + shouldContinue = visitor.onRegexp?.(query); + if (shouldContinue !== false && query.regexp) { + // Symbol expressions contain nested queries + if (query.regexp) { + shouldContinue = true; + } + } + break; + case 'symbol': + shouldContinue = visitor.onSymbol?.(query); + if (shouldContinue !== false && query.symbol?.expr) { + shouldContinue = traverseQueryIR(query.symbol.expr, visitor); + } + break; + case 'language': + shouldContinue = visitor.onLanguage?.(query); + break; + case 'const': + shouldContinue = visitor.onConst?.(query); + break; + case 'repo': + shouldContinue = visitor.onRepo?.(query); + break; + case 'repo_regexp': + shouldContinue = visitor.onRepoRegexp?.(query); + break; + case 'branches_repos': + shouldContinue = visitor.onBranchesRepos?.(query); + break; + case 'repo_ids': + shouldContinue = visitor.onRepoIds?.(query); + break; + case 'repo_set': + shouldContinue = visitor.onRepoSet?.(query); + break; + case 'file_name_set': + shouldContinue = visitor.onFileNameSet?.(query); + break; + case 'type': + shouldContinue = visitor.onType?.(query); + break; + case 'substring': + shouldContinue = visitor.onSubstring?.(query); + break; + case 'and': + shouldContinue = visitor.onAnd?.(query); + if (shouldContinue !== false && query.and?.children) { + for (const child of query.and.children) { + if (!traverseQueryIR(child, visitor)) { + return false; + } + } + } + break; + case 'or': + shouldContinue = visitor.onOr?.(query); + if (shouldContinue !== false && query.or?.children) { + for (const child of query.or.children) { + if (!traverseQueryIR(child, visitor)) { + return false; + } + } + } + break; + case 'not': + shouldContinue = visitor.onNot?.(query); + if (shouldContinue !== false && query.not?.child) { + shouldContinue = traverseQueryIR(query.not.child, visitor); + } + break; + case 'branch': + shouldContinue = visitor.onBranch?.(query); + break; + case 'boost': + shouldContinue = visitor.onBoost?.(query); + if (shouldContinue !== false && query.boost?.child) { + shouldContinue = traverseQueryIR(query.boost.child, visitor); + } + break; + } + + return shouldContinue !== false; +} + +/** + * Finds a node in the query tree that matches the predicate. + * @param query The query to search + * @param predicate A function that returns true if the node matches + * @returns The first matching query node, or undefined if none found + */ +export function findInQueryIR( + query: QueryIR, + predicate: (query: QueryIR) => boolean +): QueryIR | undefined { + let found: QueryIR | undefined; + + traverseQueryIR(query, { + onRawConfig: (q) => { if (predicate(q)) { found = q; return false; } }, + onRegexp: (q) => { if (predicate(q)) { found = q; return false; } }, + onSymbol: (q) => { if (predicate(q)) { found = q; return false; } }, + onLanguage: (q) => { if (predicate(q)) { found = q; return false; } }, + onConst: (q) => { if (predicate(q)) { found = q; return false; } }, + onRepo: (q) => { if (predicate(q)) { found = q; return false; } }, + onRepoRegexp: (q) => { if (predicate(q)) { found = q; return false; } }, + onBranchesRepos: (q) => { if (predicate(q)) { found = q; return false; } }, + onRepoIds: (q) => { if (predicate(q)) { found = q; return false; } }, + onRepoSet: (q) => { if (predicate(q)) { found = q; return false; } }, + onFileNameSet: (q) => { if (predicate(q)) { found = q; return false; } }, + onType: (q) => { if (predicate(q)) { found = q; return false; } }, + onSubstring: (q) => { if (predicate(q)) { found = q; return false; } }, + onAnd: (q) => { if (predicate(q)) { found = q; return false; } }, + onOr: (q) => { if (predicate(q)) { found = q; return false; } }, + onNot: (q) => { if (predicate(q)) { found = q; return false; } }, + onBranch: (q) => { if (predicate(q)) { found = q; return false; } }, + onBoost: (q) => { if (predicate(q)) { found = q; return false; } }, + }); + + return found; +} + +/** + * Checks if any node in the query tree matches the predicate. + * @param query The query to search + * @param predicate A function that returns true if the node matches + * @returns true if any node matches, false otherwise + */ +export function someInQueryIR( + query: QueryIR, + predicate: (query: QueryIR) => boolean +): boolean { + return findInQueryIR(query, predicate) !== undefined; +} diff --git a/packages/web/src/features/search/query.ts b/packages/web/src/features/search/parser.ts similarity index 85% rename from packages/web/src/features/search/query.ts rename to packages/web/src/features/search/parser.ts index 3505e451..09c88d07 100644 --- a/packages/web/src/features/search/query.ts +++ b/packages/web/src/features/search/parser.ts @@ -1,4 +1,4 @@ -import { Q as ZoektGrpcQuery } from '@/proto/zoekt/webserver/v1/Q'; +import { QueryIR } from './ir'; import { AndExpr, ArchivedExpr, @@ -21,13 +21,15 @@ import { Tree, VisibilityExpr, } from '@sourcebot/query-language'; -import { parser as _lezerQueryParser } from '@sourcebot/query-language'; +import { parser as _parser } from '@sourcebot/query-language'; +import { PrismaClient } from '@sourcebot/db'; +import { SINGLE_TENANT_ORG_ID } from '@/lib/constants'; -const lezerQueryParser = _lezerQueryParser.configure({ +// Configure the parser to throw errors when encountering invalid syntax. +const parser = _parser.configure({ strict: true, }); - type ArchivedValue = 'yes' | 'no' | 'only'; type VisibilityValue = 'public' | 'private' | 'any'; type ForkValue = 'yes' | 'no' | 'only'; @@ -44,11 +46,56 @@ const isForkValue = (value: string): value is ForkValue => { return value === 'yes' || value === 'no' || value === 'only'; } -export const parseQueryIntoLezerTree = (query: string): Tree => { - return lezerQueryParser.parse(query); +/** + * Given a query string, parses it into the query intermediate representation. + */ +export const parseQuerySyntaxIntoIR = async ({ + query, + options, + prisma, +}: { + query: string, + options: { + isCaseSensitivityEnabled?: boolean; + isRegexEnabled?: boolean; + }, + prisma: PrismaClient, +}): Promise => { + // First parse the query into a Lezer tree. + const tree = parser.parse(query); + + // Then transform the tree into the intermediate representation. + return transformTreeToIR({ + tree, + input: query, + isCaseSensitivityEnabled: options.isCaseSensitivityEnabled ?? false, + isRegexEnabled: options.isRegexEnabled ?? false, + onExpandSearchContext: async (contextName: string) => { + const context = await prisma.searchContext.findUnique({ + where: { + name_orgId: { + name: contextName, + orgId: SINGLE_TENANT_ORG_ID, + } + }, + include: { + repos: true, + } + }); + + if (!context) { + throw new Error(`Search context "${contextName}" not found`); + } + + return context.repos.map((repo) => repo.name); + }, + }); } -export const transformLezerTreeToZoektGrpcQuery = async ({ +/** + * Given a Lezer tree, transforms it into the query intermediate representation. + */ +const transformTreeToIR = async ({ tree, input, isCaseSensitivityEnabled, @@ -60,8 +107,8 @@ export const transformLezerTreeToZoektGrpcQuery = async ({ isCaseSensitivityEnabled: boolean; isRegexEnabled: boolean; onExpandSearchContext: (contextName: string) => Promise; -}): Promise => { - const transformNode = async (node: SyntaxNode): Promise => { +}): Promise => { + const transformNode = async (node: SyntaxNode): Promise => { switch (node.type.id) { case Program: { // Program wraps the actual query - transform its child @@ -140,7 +187,7 @@ export const transformLezerTreeToZoektGrpcQuery = async ({ } } - const transformPrefixExpr = async (node: SyntaxNode): Promise => { + const transformPrefixExpr = async (node: SyntaxNode): Promise => { // Find which specific prefix type this is const prefixNode = node.firstChild; if (!prefixNode) { diff --git a/packages/web/src/features/search/searchApi.ts b/packages/web/src/features/search/searchApi.ts index e0b48955..01f0fb5d 100644 --- a/packages/web/src/features/search/searchApi.ts +++ b/packages/web/src/features/search/searchApi.ts @@ -1,566 +1,74 @@ import { sew } from "@/actions"; -import { SINGLE_TENANT_ORG_ID } from '@/lib/constants'; -import type { ProtoGrpcType } from '@/proto/webserver'; -import { FileMatch__Output as ZoektGrpcFileMatch } from "@/proto/zoekt/webserver/v1/FileMatch"; -import { Range__Output as ZoektGrpcRange } from "@/proto/zoekt/webserver/v1/Range"; -import type { SearchRequest as ZoektGrpcSearchRequest } from '@/proto/zoekt/webserver/v1/SearchRequest'; -import { SearchResponse__Output as ZoektGrpcSearchResponse } from "@/proto/zoekt/webserver/v1/SearchResponse"; -import { StreamSearchRequest as ZoektGrpcStreamSearchRequest } from "@/proto/zoekt/webserver/v1/StreamSearchRequest"; -import { StreamSearchResponse__Output as ZoektGrpcStreamSearchResponse } from "@/proto/zoekt/webserver/v1/StreamSearchResponse"; -import { WebserverServiceClient } from '@/proto/zoekt/webserver/v1/WebserverService'; -import { withOptionalAuthV2 } from "@/withAuthV2"; -import * as grpc from '@grpc/grpc-js'; -import * as protoLoader from '@grpc/proto-loader'; -import * as Sentry from '@sentry/nextjs'; -import { PrismaClient, Repo, UserWithAccounts } from "@sourcebot/db"; -import { createLogger, env, hasEntitlement } from "@sourcebot/shared"; -import path from 'path'; -import { parseQueryIntoLezerTree, transformLezerTreeToZoektGrpcQuery } from './query'; -import { RepositoryInfo, SearchRequest, SearchResponse, SearchResultFile, SearchStats, SourceRange, StreamedSearchResponse } from "./types"; -import { FlushReason as ZoektFlushReason } from "@/proto/zoekt/webserver/v1/FlushReason"; -import { RevisionExpr } from "@sourcebot/query-language"; -import { getCodeHostBrowseFileAtBranchUrl } from "@/lib/utils"; import { getRepoPermissionFilterForUser } from "@/prisma"; +import { withOptionalAuthV2 } from "@/withAuthV2"; +import { PrismaClient, UserWithAccounts } from "@sourcebot/db"; +import { createLogger, env, hasEntitlement } from "@sourcebot/shared"; +import { QueryIR } from './ir'; +import { parseQuerySyntaxIntoIR } from './parser'; +import { SearchOptions } from "./types"; +import { createZoektSearchRequest, zoektSearch, zoektStreamSearch } from './zoektSearcher'; const logger = createLogger("searchApi"); -export const search = (searchRequest: SearchRequest) => sew(() => +type QueryStringSearchRequest = { + queryType: 'string'; + query: string; + options: SearchOptions; +} + +type QueryIRSearchRequest = { + queryType: 'ir'; + query: QueryIR; + // Omit options that are specific to query syntax parsing. + options: Omit; +} + +type SearchRequest = QueryStringSearchRequest | QueryIRSearchRequest; + +export const search = (request: SearchRequest) => sew(() => withOptionalAuthV2(async ({ prisma, user }) => { const repoSearchScope = await getAccessibleRepoNamesForUser({ user, prisma }); - const zoektSearchRequest = await createZoektSearchRequest({ - searchRequest, + // If needed, parse the query syntax into the query intermediate representation. + const query = request.queryType === 'string' ? await parseQuerySyntaxIntoIR({ + query: request.query, + options: request.options, prisma, + }) : request.query; + + const zoektSearchRequest = await createZoektSearchRequest({ + query, + options: request.options, repoSearchScope, }); - logger.debug(`zoektSearchRequest:\n${JSON.stringify(zoektSearchRequest, null, 2)}`); return zoektSearch(zoektSearchRequest, prisma); })); -export const streamSearch = (searchRequest: SearchRequest) => sew(() => +export const streamSearch = (request: SearchRequest) => sew(() => withOptionalAuthV2(async ({ prisma, user }) => { const repoSearchScope = await getAccessibleRepoNamesForUser({ user, prisma }); - const zoektSearchRequest = await createZoektSearchRequest({ - searchRequest, + // If needed, parse the query syntax into the query intermediate representation. + const query = request.queryType === 'string' ? await parseQuerySyntaxIntoIR({ + query: request.query, + options: request.options, prisma, + }) : request.query; + + const zoektSearchRequest = await createZoektSearchRequest({ + query, + options: request.options, repoSearchScope, }); - console.log(`zoektStreamSearchRequest:\n${JSON.stringify(zoektSearchRequest, null, 2)}`); + logger.debug(`zoektStreamSearchRequest:\n${JSON.stringify(zoektSearchRequest, null, 2)}`); return zoektStreamSearch(zoektSearchRequest, prisma); })); - -const zoektSearch = async (searchRequest: ZoektGrpcSearchRequest, prisma: PrismaClient): Promise => { - const client = createGrpcClient(); - const metadata = new grpc.Metadata(); - - return new Promise((resolve, reject) => { - client.Search(searchRequest, metadata, async (error, response) => { - if (error || !response) { - reject(error || new Error('No response received')); - return; - } - - const reposMapCache = await createReposMapForChunk(response, new Map(), prisma); - const { stats, files, repositoryInfo } = await transformZoektSearchResponse(response, reposMapCache); - - resolve({ - stats, - files, - repositoryInfo, - isSearchExhaustive: stats.actualMatchCount <= stats.totalMatchCount, - } satisfies SearchResponse); - }); - }); -} - -const zoektStreamSearch = async (searchRequest: ZoektGrpcSearchRequest, prisma: PrismaClient): Promise => { - const client = createGrpcClient(); - let grpcStream: ReturnType | null = null; - let isStreamActive = true; - let pendingChunks = 0; - let accumulatedStats: SearchStats = { - actualMatchCount: 0, - totalMatchCount: 0, - duration: 0, - fileCount: 0, - filesSkipped: 0, - contentBytesLoaded: 0, - indexBytesLoaded: 0, - crashes: 0, - shardFilesConsidered: 0, - filesConsidered: 0, - filesLoaded: 0, - shardsScanned: 0, - shardsSkipped: 0, - shardsSkippedFilter: 0, - ngramMatches: 0, - ngramLookups: 0, - wait: 0, - matchTreeConstruction: 0, - matchTreeSearch: 0, - regexpsConsidered: 0, - flushReason: ZoektFlushReason.FLUSH_REASON_UNKNOWN_UNSPECIFIED, - }; - - return new ReadableStream({ - async start(controller) { - const tryCloseController = () => { - if (!isStreamActive && pendingChunks === 0) { - const finalResponse: StreamedSearchResponse = { - type: 'final', - accumulatedStats, - isSearchExhaustive: accumulatedStats.totalMatchCount <= accumulatedStats.actualMatchCount, - } - - controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify(finalResponse)}\n\n`)); - controller.enqueue(new TextEncoder().encode('data: [DONE]\n\n')); - controller.close(); - client.close(); - logger.debug('SSE stream closed'); - } - }; - - try { - const metadata = new grpc.Metadata(); - - const streamRequest: ZoektGrpcStreamSearchRequest = { - request: searchRequest, - }; - - grpcStream = client.StreamSearch(streamRequest, metadata); - - // `_reposMapCache` is used to cache repository metadata across all chunks. - // This reduces the number of database queries required to transform file matches. - const _reposMapCache = new Map(); - - // Handle incoming data chunks - grpcStream.on('data', async (chunk: ZoektGrpcStreamSearchResponse) => { - if (!isStreamActive) { - logger.debug('SSE stream closed, skipping chunk'); - return; - } - - // Track that we're processing a chunk - pendingChunks++; - - // grpcStream.on doesn't actually await on our handler, so we need to - // explicitly pause the stream here to prevent the stream from completing - // prior to our asynchronous work being completed. - grpcStream?.pause(); - - try { - if (!chunk.response_chunk) { - logger.warn('No response chunk received'); - return; - } - - const reposMapCache = await createReposMapForChunk(chunk.response_chunk, _reposMapCache, prisma); - const { stats, files, repositoryInfo } = await transformZoektSearchResponse(chunk.response_chunk, reposMapCache); - - accumulatedStats = accumulateStats(accumulatedStats, stats); - - const response: StreamedSearchResponse = { - type: 'chunk', - files, - repositoryInfo, - stats - } - - const sseData = `data: ${JSON.stringify(response)}\n\n`; - controller.enqueue(new TextEncoder().encode(sseData)); - } catch (error) { - console.error('Error encoding chunk:', error); - } finally { - pendingChunks--; - grpcStream?.resume(); - - // @note: we were hitting "Controller is already closed" errors when calling - // `controller.enqueue` above for the last chunk. The reasoning was the event - // handler for 'end' was being invoked prior to the completion of the last chunk, - // resulting in the controller being closed prematurely. The workaround was to - // keep track of the number of pending chunks and only close the controller - // when there are no more chunks to process. We need to explicitly call - // `tryCloseController` since there _seems_ to be no ordering guarantees between - // the 'end' event handler and this callback. - tryCloseController(); - } - }); - - // Handle stream completion - grpcStream.on('end', () => { - if (!isStreamActive) { - return; - } - isStreamActive = false; - tryCloseController(); - }); - - // Handle errors - grpcStream.on('error', (error: grpc.ServiceError) => { - logger.error('gRPC stream error:', error); - Sentry.captureException(error); - - if (!isStreamActive) { - return; - } - isStreamActive = false; - - // Send error as SSE event - const errorData = `data: ${JSON.stringify({ - error: { - code: error.code, - message: error.details || error.message, - } - })}\n\n`; - controller.enqueue(new TextEncoder().encode(errorData)); - - controller.close(); - client.close(); - }); - } catch (error) { - logger.error('Stream initialization error:', error); - - const errorMessage = error instanceof Error ? error.message : 'Unknown error'; - const errorData = `data: ${JSON.stringify({ - error: { message: errorMessage } - })}\n\n`; - controller.enqueue(new TextEncoder().encode(errorData)); - - controller.close(); - client.close(); - } - }, - cancel() { - logger.warn('SSE stream cancelled by client'); - isStreamActive = false; - - // Cancel the gRPC stream to stop receiving data - if (grpcStream) { - grpcStream.cancel(); - } - - client.close(); - } - }); -} - -// Creates a mapping between all repository ids in a given response -// chunk. The mapping allows us to efficiently lookup repository metadata. -const createReposMapForChunk = async (chunk: ZoektGrpcSearchResponse, reposMapCache: Map, prisma: PrismaClient): Promise> => { - const reposMap = new Map(); - await Promise.all(chunk.files.map(async (file) => { - const id = getRepoIdForFile(file); - - const repo = await (async () => { - // If it's in the cache, return the cached value. - if (reposMapCache.has(id)) { - return reposMapCache.get(id); - } - - // Otherwise, query the database for the record. - const repo = typeof id === 'number' ? - await prisma.repo.findUnique({ - where: { - id: id, - }, - }) : - await prisma.repo.findFirst({ - where: { - name: id, - }, - }); - - // If a repository is found, cache it for future lookups. - if (repo) { - reposMapCache.set(id, repo); - } - - return repo; - })(); - - // Only add the repository to the map if it was found. - if (repo) { - reposMap.set(id, repo); - } - })); - - return reposMap; -} - -const transformZoektSearchResponse = async (response: ZoektGrpcSearchResponse, reposMapCache: Map): Promise<{ - stats: SearchStats, - files: SearchResultFile[], - repositoryInfo: RepositoryInfo[], -}> => { - const files = response.files.map((file) => { - const fileNameChunks = file.chunk_matches.filter((chunk) => chunk.file_name); - const repoId = getRepoIdForFile(file); - const repo = reposMapCache.get(repoId); - - // This should never happen. - if (!repo) { - throw new Error(`Repository not found for file: ${file.file_name}`); - } - - // @todo: address "file_name might not be a valid UTF-8 string" warning. - const fileName = file.file_name.toString('utf-8'); - - const convertRange = (range: ZoektGrpcRange): SourceRange => ({ - start: { - byteOffset: range.start?.byte_offset ?? 0, - column: range.start?.column ?? 1, - lineNumber: range.start?.line_number ?? 1, - }, - end: { - byteOffset: range.end?.byte_offset ?? 0, - column: range.end?.column ?? 1, - lineNumber: range.end?.line_number ?? 1, - } - }) - - return { - fileName: { - text: fileName, - matchRanges: fileNameChunks.length === 1 ? fileNameChunks[0].ranges.map(convertRange) : [], - }, - repository: repo.name, - repositoryId: repo.id, - language: file.language, - webUrl: getCodeHostBrowseFileAtBranchUrl({ - webUrl: repo.webUrl, - codeHostType: repo.external_codeHostType, - // If a file has multiple branches, default to the first one. - branchName: file.branches?.[0] ?? 'HEAD', - filePath: fileName, - }), - chunks: file.chunk_matches - .filter((chunk) => !chunk.file_name) // filter out filename chunks. - .map((chunk) => { - return { - content: chunk.content.toString('utf-8'), - matchRanges: chunk.ranges.map(convertRange), - contentStart: chunk.content_start ? { - byteOffset: chunk.content_start.byte_offset, - column: chunk.content_start.column, - lineNumber: chunk.content_start.line_number, - } : { - byteOffset: 0, - column: 1, - lineNumber: 1, - }, - symbols: chunk.symbol_info.map((symbol) => { - return { - symbol: symbol.sym, - kind: symbol.kind, - parent: symbol.parent ? { - symbol: symbol.parent, - kind: symbol.parent_kind, - } : undefined, - } - }) - } - }), - branches: file.branches, - content: file.content ? file.content.toString('utf-8') : undefined, - } - }).filter(file => file !== undefined); - - const actualMatchCount = files.reduce( - (acc, file) => - // Match count is the sum of the number of chunk matches and file name matches. - acc + file.chunks.reduce( - (acc, chunk) => acc + chunk.matchRanges.length, - 0, - ) + file.fileName.matchRanges.length, - 0, - ); - - const stats: SearchStats = { - actualMatchCount, - totalMatchCount: response.stats?.match_count ?? 0, - duration: response.stats?.duration?.nanos ?? 0, - fileCount: response.stats?.file_count ?? 0, - filesSkipped: response.stats?.files_skipped ?? 0, - contentBytesLoaded: response.stats?.content_bytes_loaded ?? 0, - indexBytesLoaded: response.stats?.index_bytes_loaded ?? 0, - crashes: response.stats?.crashes ?? 0, - shardFilesConsidered: response.stats?.shard_files_considered ?? 0, - filesConsidered: response.stats?.files_considered ?? 0, - filesLoaded: response.stats?.files_loaded ?? 0, - shardsScanned: response.stats?.shards_scanned ?? 0, - shardsSkipped: response.stats?.shards_skipped ?? 0, - shardsSkippedFilter: response.stats?.shards_skipped_filter ?? 0, - ngramMatches: response.stats?.ngram_matches ?? 0, - ngramLookups: response.stats?.ngram_lookups ?? 0, - wait: response.stats?.wait?.nanos ?? 0, - matchTreeConstruction: response.stats?.match_tree_construction?.nanos ?? 0, - matchTreeSearch: response.stats?.match_tree_search?.nanos ?? 0, - regexpsConsidered: response.stats?.regexps_considered ?? 0, - flushReason: response.stats?.flush_reason?.toString() ?? ZoektFlushReason.FLUSH_REASON_UNKNOWN_UNSPECIFIED, - } - - return { - files, - repositoryInfo: Array.from(reposMapCache.values()).map((repo) => ({ - id: repo.id, - codeHostType: repo.external_codeHostType, - name: repo.name, - displayName: repo.displayName ?? undefined, - webUrl: repo.webUrl ?? undefined, - })), - stats, - } -} - -// @note (2025-05-12): in zoekt, repositories are identified by the `RepositoryID` field -// which corresponds to the `id` in the Repo table. In order to efficiently fetch repository -// metadata when transforming (potentially thousands) of file matches, we aggregate a unique -// set of repository ids* and map them to their corresponding Repo record. -// -// *Q: Why is `RepositoryID` optional? And why are we falling back to `Repository`? -// A: Prior to this change, the repository id was not plumbed into zoekt, so RepositoryID was -// always undefined. To make this a non-breaking change, we fallback to using the repository's name -// (`Repository`) as the identifier in these cases. This is not guaranteed to be unique, but in -// practice it is since the repository name includes the host and path (e.g., 'github.com/org/repo', -// 'gitea.com/org/repo', etc.). -// -// Note: When a repository is re-indexed (every hour) this ID will be populated. -// @see: https://github.com/sourcebot-dev/zoekt/pull/6 -const getRepoIdForFile = (file: ZoektGrpcFileMatch): string | number => { - return file.repository_id ?? file.repository; -} - -const createZoektSearchRequest = async ({ - searchRequest, - prisma, - repoSearchScope, -}: { - searchRequest: SearchRequest; - prisma: PrismaClient; - // Allows the caller to scope the search to a specific set of repositories. - repoSearchScope?: string[]; -}) => { - const tree = parseQueryIntoLezerTree(searchRequest.query); - const zoektQuery = await transformLezerTreeToZoektGrpcQuery({ - tree, - input: searchRequest.query, - isCaseSensitivityEnabled: searchRequest.isCaseSensitivityEnabled ?? false, - isRegexEnabled: searchRequest.isRegexEnabled ?? false, - onExpandSearchContext: async (contextName: string) => { - const context = await prisma.searchContext.findUnique({ - where: { - name_orgId: { - name: contextName, - orgId: SINGLE_TENANT_ORG_ID, - } - }, - include: { - repos: true, - } - }); - - if (!context) { - throw new Error(`Search context "${contextName}" not found`); - } - - return context.repos.map((repo) => repo.name); - }, - }); - - // Find if there are any `rev:` filters in the query. - let containsRevExpression = false; - tree.iterate({ - enter: (node) => { - if (node.type.id === RevisionExpr) { - containsRevExpression = true; - // false to stop the iteration. - return false; - } - } - }); - - const zoektSearchRequest: ZoektGrpcSearchRequest = { - query: { - and: { - children: [ - zoektQuery, - // If the query does not contain a `rev:` filter, we default to searching `HEAD`. - ...(!containsRevExpression ? [{ - branch: { - pattern: 'HEAD', - exact: true, - } - }] : []), - ...(repoSearchScope ? [{ - repo_set: { - set: repoSearchScope.reduce((acc, repo) => { - acc[repo] = true; - return acc; - }, {} as Record) - } - }] : []), - ] - } - }, - opts: { - chunk_matches: true, - // @note: Zoekt has several different ways to limit a given search. The two that - // we care about are `MaxMatchDisplayCount` and `TotalMaxMatchCount`: - // - `MaxMatchDisplayCount` truncates the number of matches AFTER performing - // a search (specifically, after collating and sorting the results). The number of - // results returned by the API will be less than or equal to this value. - // - // - `TotalMaxMatchCount` truncates the number of matches DURING a search. The results - // returned by the API the API can be less than, equal to, or greater than this value. - // Why greater? Because this value is compared _after_ a given shard has finished - // being processed, the number of matches returned by the last shard may have exceeded - // this value. - // - // Let's define two variables: - // - `actualMatchCount` : The number of matches that are returned by the API. This is - // always less than or equal to `MaxMatchDisplayCount`. - // - `totalMatchCount` : The number of matches that zoekt found before it either - // 1) found all matches or 2) hit the `TotalMaxMatchCount` limit. This number is - // not bounded and can be less than, equal to, or greater than both `TotalMaxMatchCount` - // and `MaxMatchDisplayCount`. - // - // - // Our challenge is to determine whether or not the search returned all possible matches/ - // (it was exaustive) or if it was truncated. By setting the `TotalMaxMatchCount` to - // `MaxMatchDisplayCount + 1`, we can determine which of these occurred by comparing - // `totalMatchCount` to `MaxMatchDisplayCount`. - // - // if (totalMatchCount ≤ actualMatchCount): - // Search is EXHAUSTIVE (found all possible matches) - // Proof: totalMatchCount ≤ MaxMatchDisplayCount < TotalMaxMatchCount - // Therefore Zoekt stopped naturally, not due to limit - // - // if (totalMatchCount > actualMatchCount): - // Search is TRUNCATED (more matches exist) - // Proof: totalMatchCount > MaxMatchDisplayCount + 1 = TotalMaxMatchCount - // Therefore Zoekt hit the limit and stopped searching - // - max_match_display_count: searchRequest.matches, - total_max_match_count: searchRequest.matches + 1, - num_context_lines: searchRequest.contextLines ?? 0, - whole: !!searchRequest.whole, - shard_max_match_count: -1, - max_wall_time: { - seconds: 0, - } - }, - }; - - return zoektSearchRequest; -} - /** * Returns a list of repository names that the user has access to. * If permission syncing is disabled, returns undefined. @@ -581,65 +89,3 @@ const getAccessibleRepoNamesForUser = async ({ user, prisma }: { user?: UserWith }); return accessibleRepos.map(repo => repo.name); } - -const createGrpcClient = (): WebserverServiceClient => { - // Path to proto files - these should match your monorepo structure - const protoBasePath = path.join(process.cwd(), '../../vendor/zoekt/grpc/protos'); - const protoPath = path.join(protoBasePath, 'zoekt/webserver/v1/webserver.proto'); - - const packageDefinition = protoLoader.loadSync(protoPath, { - keepCase: true, - longs: Number, - enums: String, - defaults: true, - oneofs: true, - includeDirs: [protoBasePath], - }); - - const proto = grpc.loadPackageDefinition(packageDefinition) as unknown as ProtoGrpcType; - - // Extract host and port from ZOEKT_WEBSERVER_URL - const zoektUrl = new URL(env.ZOEKT_WEBSERVER_URL); - const grpcAddress = `${zoektUrl.hostname}:${zoektUrl.port}`; - - return new proto.zoekt.webserver.v1.WebserverService( - grpcAddress, - grpc.credentials.createInsecure(), - { - 'grpc.max_receive_message_length': 500 * 1024 * 1024, // 500MB - 'grpc.max_send_message_length': 500 * 1024 * 1024, // 500MB - } - ); -} - - -const accumulateStats = (a: SearchStats, b: SearchStats): SearchStats => { - return { - actualMatchCount: a.actualMatchCount + b.actualMatchCount, - totalMatchCount: a.totalMatchCount + b.totalMatchCount, - duration: a.duration + b.duration, - fileCount: a.fileCount + b.fileCount, - filesSkipped: a.filesSkipped + b.filesSkipped, - contentBytesLoaded: a.contentBytesLoaded + b.contentBytesLoaded, - indexBytesLoaded: a.indexBytesLoaded + b.indexBytesLoaded, - crashes: a.crashes + b.crashes, - shardFilesConsidered: a.shardFilesConsidered + b.shardFilesConsidered, - filesConsidered: a.filesConsidered + b.filesConsidered, - filesLoaded: a.filesLoaded + b.filesLoaded, - shardsScanned: a.shardsScanned + b.shardsScanned, - shardsSkipped: a.shardsSkipped + b.shardsSkipped, - shardsSkippedFilter: a.shardsSkippedFilter + b.shardsSkippedFilter, - ngramMatches: a.ngramMatches + b.ngramMatches, - ngramLookups: a.ngramLookups + b.ngramLookups, - wait: a.wait + b.wait, - matchTreeConstruction: a.matchTreeConstruction + b.matchTreeConstruction, - matchTreeSearch: a.matchTreeSearch + b.matchTreeSearch, - regexpsConsidered: a.regexpsConsidered + b.regexpsConsidered, - // Capture the first non-unknown flush reason. - ...(a.flushReason === ZoektFlushReason.FLUSH_REASON_UNKNOWN_UNSPECIFIED ? { - flushReason: b.flushReason - } : { - flushReason: a.flushReason, - }), - } -} diff --git a/packages/web/src/features/search/types.ts b/packages/web/src/features/search/types.ts index c4b49630..77d91faa 100644 --- a/packages/web/src/features/search/types.ts +++ b/packages/web/src/features/search/types.ts @@ -82,14 +82,19 @@ export const searchFileSchema = z.object({ export type SearchResultFile = z.infer; export type SearchResultChunk = SearchResultFile["chunks"][number]; -export const searchRequestSchema = z.object({ - query: z.string(), // The zoekt query to execute. +export const searchOptionsSchema = z.object({ matches: z.number(), // The number of matches to return. contextLines: z.number().optional(), // The number of context lines to return. whole: z.boolean().optional(), // Whether to return the whole file as part of the response. isRegexEnabled: z.boolean().optional(), // Whether to enable regular expression search. isCaseSensitivityEnabled: z.boolean().optional(), // Whether to enable case sensitivity. }); +export type SearchOptions = z.infer; + +export const searchRequestSchema = z.object({ + query: z.string(), // The zoekt query to execute. + ...searchOptionsSchema.shape, +}); export type SearchRequest = z.infer; export const searchResponseSchema = z.object({ diff --git a/packages/web/src/features/search/zoektSearcher.ts b/packages/web/src/features/search/zoektSearcher.ts new file mode 100644 index 00000000..ec588a73 --- /dev/null +++ b/packages/web/src/features/search/zoektSearcher.ts @@ -0,0 +1,558 @@ +import { getCodeHostBrowseFileAtBranchUrl } from "@/lib/utils"; +import type { ProtoGrpcType } from '@/proto/webserver'; +import { FileMatch__Output as ZoektGrpcFileMatch } from "@/proto/zoekt/webserver/v1/FileMatch"; +import { FlushReason as ZoektGrpcFlushReason } from "@/proto/zoekt/webserver/v1/FlushReason"; +import { Range__Output as ZoektGrpcRange } from "@/proto/zoekt/webserver/v1/Range"; +import type { SearchRequest as ZoektGrpcSearchRequest } from '@/proto/zoekt/webserver/v1/SearchRequest'; +import { SearchResponse__Output as ZoektGrpcSearchResponse } from "@/proto/zoekt/webserver/v1/SearchResponse"; +import { StreamSearchRequest as ZoektGrpcStreamSearchRequest } from "@/proto/zoekt/webserver/v1/StreamSearchRequest"; +import { StreamSearchResponse__Output as ZoektGrpcStreamSearchResponse } from "@/proto/zoekt/webserver/v1/StreamSearchResponse"; +import { WebserverServiceClient } from '@/proto/zoekt/webserver/v1/WebserverService'; +import * as grpc from '@grpc/grpc-js'; +import * as protoLoader from '@grpc/proto-loader'; +import * as Sentry from '@sentry/nextjs'; +import { PrismaClient, Repo } from "@sourcebot/db"; +import { createLogger, env } from "@sourcebot/shared"; +import path from 'path'; +import { QueryIR, someInQueryIR } from './ir'; +import { RepositoryInfo, SearchResponse, SearchResultFile, SearchStats, SourceRange, StreamedSearchResponse } from "./types"; + +const logger = createLogger("zoekt-searcher"); + +/** + * Creates a ZoektGrpcSearchRequest given a query IR. + */ +export const createZoektSearchRequest = async ({ + query, + options, + repoSearchScope, +}: { + query: QueryIR; + options: { + matches: number, + contextLines?: number, + whole?: boolean, + }; + // Allows the caller to scope the search to a specific set of repositories. + repoSearchScope?: string[]; +}) => { + // Find if there are any `rev:` filters in the query. + const containsRevExpression = someInQueryIR(query, (q) => q.query === 'branch'); + + const zoektSearchRequest: ZoektGrpcSearchRequest = { + query: { + and: { + children: [ + query, + // If the query does not contain a `rev:` filter, we default to searching `HEAD`. + ...(!containsRevExpression ? [{ + branch: { + pattern: 'HEAD', + exact: true, + } + }] : []), + ...(repoSearchScope ? [{ + repo_set: { + set: repoSearchScope.reduce((acc, repo) => { + acc[repo] = true; + return acc; + }, {} as Record) + } + }] : []), + ] + } + }, + opts: { + chunk_matches: true, + // @note: Zoekt has several different ways to limit a given search. The two that + // we care about are `MaxMatchDisplayCount` and `TotalMaxMatchCount`: + // - `MaxMatchDisplayCount` truncates the number of matches AFTER performing + // a search (specifically, after collating and sorting the results). The number of + // results returned by the API will be less than or equal to this value. + // + // - `TotalMaxMatchCount` truncates the number of matches DURING a search. The results + // returned by the API the API can be less than, equal to, or greater than this value. + // Why greater? Because this value is compared _after_ a given shard has finished + // being processed, the number of matches returned by the last shard may have exceeded + // this value. + // + // Let's define two variables: + // - `actualMatchCount` : The number of matches that are returned by the API. This is + // always less than or equal to `MaxMatchDisplayCount`. + // - `totalMatchCount` : The number of matches that zoekt found before it either + // 1) found all matches or 2) hit the `TotalMaxMatchCount` limit. This number is + // not bounded and can be less than, equal to, or greater than both `TotalMaxMatchCount` + // and `MaxMatchDisplayCount`. + // + // + // Our challenge is to determine whether or not the search returned all possible matches/ + // (it was exaustive) or if it was truncated. By setting the `TotalMaxMatchCount` to + // `MaxMatchDisplayCount + 1`, we can determine which of these occurred by comparing + // `totalMatchCount` to `MaxMatchDisplayCount`. + // + // if (totalMatchCount ≤ actualMatchCount): + // Search is EXHAUSTIVE (found all possible matches) + // Proof: totalMatchCount ≤ MaxMatchDisplayCount < TotalMaxMatchCount + // Therefore Zoekt stopped naturally, not due to limit + // + // if (totalMatchCount > actualMatchCount): + // Search is TRUNCATED (more matches exist) + // Proof: totalMatchCount > MaxMatchDisplayCount + 1 = TotalMaxMatchCount + // Therefore Zoekt hit the limit and stopped searching + // + max_match_display_count: options.matches, + total_max_match_count: options.matches + 1, + num_context_lines: options.contextLines ?? 0, + whole: !!options.whole, + shard_max_match_count: -1, + max_wall_time: { + seconds: 0, + } + }, + }; + + return zoektSearchRequest; +} + +export const zoektSearch = async (searchRequest: ZoektGrpcSearchRequest, prisma: PrismaClient): Promise => { + const client = createGrpcClient(); + const metadata = new grpc.Metadata(); + + return new Promise((resolve, reject) => { + client.Search(searchRequest, metadata, async (error, response) => { + if (error || !response) { + reject(error || new Error('No response received')); + return; + } + + const reposMapCache = await createReposMapForChunk(response, new Map(), prisma); + const { stats, files, repositoryInfo } = await transformZoektSearchResponse(response, reposMapCache); + + resolve({ + stats, + files, + repositoryInfo, + isSearchExhaustive: stats.actualMatchCount <= stats.totalMatchCount, + } satisfies SearchResponse); + }); + }); +} + +export const zoektStreamSearch = async (searchRequest: ZoektGrpcSearchRequest, prisma: PrismaClient): Promise => { + const client = createGrpcClient(); + let grpcStream: ReturnType | null = null; + let isStreamActive = true; + let pendingChunks = 0; + let accumulatedStats: SearchStats = { + actualMatchCount: 0, + totalMatchCount: 0, + duration: 0, + fileCount: 0, + filesSkipped: 0, + contentBytesLoaded: 0, + indexBytesLoaded: 0, + crashes: 0, + shardFilesConsidered: 0, + filesConsidered: 0, + filesLoaded: 0, + shardsScanned: 0, + shardsSkipped: 0, + shardsSkippedFilter: 0, + ngramMatches: 0, + ngramLookups: 0, + wait: 0, + matchTreeConstruction: 0, + matchTreeSearch: 0, + regexpsConsidered: 0, + flushReason: ZoektGrpcFlushReason.FLUSH_REASON_UNKNOWN_UNSPECIFIED, + }; + + return new ReadableStream({ + async start(controller) { + const tryCloseController = () => { + if (!isStreamActive && pendingChunks === 0) { + const finalResponse: StreamedSearchResponse = { + type: 'final', + accumulatedStats, + isSearchExhaustive: accumulatedStats.totalMatchCount <= accumulatedStats.actualMatchCount, + } + + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify(finalResponse)}\n\n`)); + controller.enqueue(new TextEncoder().encode('data: [DONE]\n\n')); + controller.close(); + client.close(); + logger.debug('SSE stream closed'); + } + }; + + try { + const metadata = new grpc.Metadata(); + + const streamRequest: ZoektGrpcStreamSearchRequest = { + request: searchRequest, + }; + + grpcStream = client.StreamSearch(streamRequest, metadata); + + // `_reposMapCache` is used to cache repository metadata across all chunks. + // This reduces the number of database queries required to transform file matches. + const _reposMapCache = new Map(); + + // Handle incoming data chunks + grpcStream.on('data', async (chunk: ZoektGrpcStreamSearchResponse) => { + if (!isStreamActive) { + logger.debug('SSE stream closed, skipping chunk'); + return; + } + + // Track that we're processing a chunk + pendingChunks++; + + // grpcStream.on doesn't actually await on our handler, so we need to + // explicitly pause the stream here to prevent the stream from completing + // prior to our asynchronous work being completed. + grpcStream?.pause(); + + try { + if (!chunk.response_chunk) { + logger.warn('No response chunk received'); + return; + } + + const reposMapCache = await createReposMapForChunk(chunk.response_chunk, _reposMapCache, prisma); + const { stats, files, repositoryInfo } = await transformZoektSearchResponse(chunk.response_chunk, reposMapCache); + + accumulatedStats = accumulateStats(accumulatedStats, stats); + + const response: StreamedSearchResponse = { + type: 'chunk', + files, + repositoryInfo, + stats + } + + const sseData = `data: ${JSON.stringify(response)}\n\n`; + controller.enqueue(new TextEncoder().encode(sseData)); + } catch (error) { + console.error('Error encoding chunk:', error); + } finally { + pendingChunks--; + grpcStream?.resume(); + + // @note: we were hitting "Controller is already closed" errors when calling + // `controller.enqueue` above for the last chunk. The reasoning was the event + // handler for 'end' was being invoked prior to the completion of the last chunk, + // resulting in the controller being closed prematurely. The workaround was to + // keep track of the number of pending chunks and only close the controller + // when there are no more chunks to process. We need to explicitly call + // `tryCloseController` since there _seems_ to be no ordering guarantees between + // the 'end' event handler and this callback. + tryCloseController(); + } + }); + + // Handle stream completion + grpcStream.on('end', () => { + if (!isStreamActive) { + return; + } + isStreamActive = false; + tryCloseController(); + }); + + // Handle errors + grpcStream.on('error', (error: grpc.ServiceError) => { + logger.error('gRPC stream error:', error); + Sentry.captureException(error); + + if (!isStreamActive) { + return; + } + isStreamActive = false; + + // Send error as SSE event + const errorData = `data: ${JSON.stringify({ + error: { + code: error.code, + message: error.details || error.message, + } + })}\n\n`; + controller.enqueue(new TextEncoder().encode(errorData)); + + controller.close(); + client.close(); + }); + } catch (error) { + logger.error('Stream initialization error:', error); + + const errorMessage = error instanceof Error ? error.message : 'Unknown error'; + const errorData = `data: ${JSON.stringify({ + error: { message: errorMessage } + })}\n\n`; + controller.enqueue(new TextEncoder().encode(errorData)); + + controller.close(); + client.close(); + } + }, + cancel() { + logger.warn('SSE stream cancelled by client'); + isStreamActive = false; + + // Cancel the gRPC stream to stop receiving data + if (grpcStream) { + grpcStream.cancel(); + } + + client.close(); + } + }); +} + +// Creates a mapping between all repository ids in a given response +// chunk. The mapping allows us to efficiently lookup repository metadata. +const createReposMapForChunk = async (chunk: ZoektGrpcSearchResponse, reposMapCache: Map, prisma: PrismaClient): Promise> => { + const reposMap = new Map(); + await Promise.all(chunk.files.map(async (file) => { + const id = getRepoIdForFile(file); + + const repo = await (async () => { + // If it's in the cache, return the cached value. + if (reposMapCache.has(id)) { + return reposMapCache.get(id); + } + + // Otherwise, query the database for the record. + const repo = typeof id === 'number' ? + await prisma.repo.findUnique({ + where: { + id: id, + }, + }) : + await prisma.repo.findFirst({ + where: { + name: id, + }, + }); + + // If a repository is found, cache it for future lookups. + if (repo) { + reposMapCache.set(id, repo); + } + + return repo; + })(); + + // Only add the repository to the map if it was found. + if (repo) { + reposMap.set(id, repo); + } + })); + + return reposMap; +} + +const transformZoektSearchResponse = async (response: ZoektGrpcSearchResponse, reposMapCache: Map): Promise<{ + stats: SearchStats, + files: SearchResultFile[], + repositoryInfo: RepositoryInfo[], +}> => { + const files = response.files.map((file) => { + const fileNameChunks = file.chunk_matches.filter((chunk) => chunk.file_name); + const repoId = getRepoIdForFile(file); + const repo = reposMapCache.get(repoId); + + // This should never happen. + if (!repo) { + throw new Error(`Repository not found for file: ${file.file_name}`); + } + + // @todo: address "file_name might not be a valid UTF-8 string" warning. + const fileName = file.file_name.toString('utf-8'); + + const convertRange = (range: ZoektGrpcRange): SourceRange => ({ + start: { + byteOffset: range.start?.byte_offset ?? 0, + column: range.start?.column ?? 1, + lineNumber: range.start?.line_number ?? 1, + }, + end: { + byteOffset: range.end?.byte_offset ?? 0, + column: range.end?.column ?? 1, + lineNumber: range.end?.line_number ?? 1, + } + }) + + return { + fileName: { + text: fileName, + matchRanges: fileNameChunks.length === 1 ? fileNameChunks[0].ranges.map(convertRange) : [], + }, + repository: repo.name, + repositoryId: repo.id, + language: file.language, + webUrl: getCodeHostBrowseFileAtBranchUrl({ + webUrl: repo.webUrl, + codeHostType: repo.external_codeHostType, + // If a file has multiple branches, default to the first one. + branchName: file.branches?.[0] ?? 'HEAD', + filePath: fileName, + }), + chunks: file.chunk_matches + .filter((chunk) => !chunk.file_name) // filter out filename chunks. + .map((chunk) => { + return { + content: chunk.content.toString('utf-8'), + matchRanges: chunk.ranges.map(convertRange), + contentStart: chunk.content_start ? { + byteOffset: chunk.content_start.byte_offset, + column: chunk.content_start.column, + lineNumber: chunk.content_start.line_number, + } : { + byteOffset: 0, + column: 1, + lineNumber: 1, + }, + symbols: chunk.symbol_info.map((symbol) => { + return { + symbol: symbol.sym, + kind: symbol.kind, + parent: symbol.parent ? { + symbol: symbol.parent, + kind: symbol.parent_kind, + } : undefined, + } + }) + } + }), + branches: file.branches, + content: file.content ? file.content.toString('utf-8') : undefined, + } + }).filter(file => file !== undefined); + + const actualMatchCount = files.reduce( + (acc, file) => + // Match count is the sum of the number of chunk matches and file name matches. + acc + file.chunks.reduce( + (acc, chunk) => acc + chunk.matchRanges.length, + 0, + ) + file.fileName.matchRanges.length, + 0, + ); + + const stats: SearchStats = { + actualMatchCount, + totalMatchCount: response.stats?.match_count ?? 0, + duration: response.stats?.duration?.nanos ?? 0, + fileCount: response.stats?.file_count ?? 0, + filesSkipped: response.stats?.files_skipped ?? 0, + contentBytesLoaded: response.stats?.content_bytes_loaded ?? 0, + indexBytesLoaded: response.stats?.index_bytes_loaded ?? 0, + crashes: response.stats?.crashes ?? 0, + shardFilesConsidered: response.stats?.shard_files_considered ?? 0, + filesConsidered: response.stats?.files_considered ?? 0, + filesLoaded: response.stats?.files_loaded ?? 0, + shardsScanned: response.stats?.shards_scanned ?? 0, + shardsSkipped: response.stats?.shards_skipped ?? 0, + shardsSkippedFilter: response.stats?.shards_skipped_filter ?? 0, + ngramMatches: response.stats?.ngram_matches ?? 0, + ngramLookups: response.stats?.ngram_lookups ?? 0, + wait: response.stats?.wait?.nanos ?? 0, + matchTreeConstruction: response.stats?.match_tree_construction?.nanos ?? 0, + matchTreeSearch: response.stats?.match_tree_search?.nanos ?? 0, + regexpsConsidered: response.stats?.regexps_considered ?? 0, + flushReason: response.stats?.flush_reason?.toString() ?? ZoektGrpcFlushReason.FLUSH_REASON_UNKNOWN_UNSPECIFIED, + } + + return { + files, + repositoryInfo: Array.from(reposMapCache.values()).map((repo) => ({ + id: repo.id, + codeHostType: repo.external_codeHostType, + name: repo.name, + displayName: repo.displayName ?? undefined, + webUrl: repo.webUrl ?? undefined, + })), + stats, + } +} + +// @note (2025-05-12): in zoekt, repositories are identified by the `RepositoryID` field +// which corresponds to the `id` in the Repo table. In order to efficiently fetch repository +// metadata when transforming (potentially thousands) of file matches, we aggregate a unique +// set of repository ids* and map them to their corresponding Repo record. +// +// *Q: Why is `RepositoryID` optional? And why are we falling back to `Repository`? +// A: Prior to this change, the repository id was not plumbed into zoekt, so RepositoryID was +// always undefined. To make this a non-breaking change, we fallback to using the repository's name +// (`Repository`) as the identifier in these cases. This is not guaranteed to be unique, but in +// practice it is since the repository name includes the host and path (e.g., 'github.com/org/repo', +// 'gitea.com/org/repo', etc.). +// +// Note: When a repository is re-indexed (every hour) this ID will be populated. +// @see: https://github.com/sourcebot-dev/zoekt/pull/6 +const getRepoIdForFile = (file: ZoektGrpcFileMatch): string | number => { + return file.repository_id ?? file.repository; +} + +const createGrpcClient = (): WebserverServiceClient => { + // Path to proto files - these should match your monorepo structure + const protoBasePath = path.join(process.cwd(), '../../vendor/zoekt/grpc/protos'); + const protoPath = path.join(protoBasePath, 'zoekt/webserver/v1/webserver.proto'); + + const packageDefinition = protoLoader.loadSync(protoPath, { + keepCase: true, + longs: Number, + enums: String, + defaults: true, + oneofs: true, + includeDirs: [protoBasePath], + }); + + const proto = grpc.loadPackageDefinition(packageDefinition) as unknown as ProtoGrpcType; + + // Extract host and port from ZOEKT_WEBSERVER_URL + const zoektUrl = new URL(env.ZOEKT_WEBSERVER_URL); + const grpcAddress = `${zoektUrl.hostname}:${zoektUrl.port}`; + + return new proto.zoekt.webserver.v1.WebserverService( + grpcAddress, + grpc.credentials.createInsecure(), + { + 'grpc.max_receive_message_length': 500 * 1024 * 1024, // 500MB + 'grpc.max_send_message_length': 500 * 1024 * 1024, // 500MB + } + ); +} + + +const accumulateStats = (a: SearchStats, b: SearchStats): SearchStats => { + return { + actualMatchCount: a.actualMatchCount + b.actualMatchCount, + totalMatchCount: a.totalMatchCount + b.totalMatchCount, + duration: a.duration + b.duration, + fileCount: a.fileCount + b.fileCount, + filesSkipped: a.filesSkipped + b.filesSkipped, + contentBytesLoaded: a.contentBytesLoaded + b.contentBytesLoaded, + indexBytesLoaded: a.indexBytesLoaded + b.indexBytesLoaded, + crashes: a.crashes + b.crashes, + shardFilesConsidered: a.shardFilesConsidered + b.shardFilesConsidered, + filesConsidered: a.filesConsidered + b.filesConsidered, + filesLoaded: a.filesLoaded + b.filesLoaded, + shardsScanned: a.shardsScanned + b.shardsScanned, + shardsSkipped: a.shardsSkipped + b.shardsSkipped, + shardsSkippedFilter: a.shardsSkippedFilter + b.shardsSkippedFilter, + ngramMatches: a.ngramMatches + b.ngramMatches, + ngramLookups: a.ngramLookups + b.ngramLookups, + wait: a.wait + b.wait, + matchTreeConstruction: a.matchTreeConstruction + b.matchTreeConstruction, + matchTreeSearch: a.matchTreeSearch + b.matchTreeSearch, + regexpsConsidered: a.regexpsConsidered + b.regexpsConsidered, + // Capture the first non-unknown flush reason. + ...(a.flushReason === ZoektGrpcFlushReason.FLUSH_REASON_UNKNOWN_UNSPECIFIED ? { + flushReason: b.flushReason + } : { + flushReason: a.flushReason, + }), + } +} \ No newline at end of file