mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-12-12 02:45:18 +00:00
feat: Qdrant support (#2022)
Signed-off-by: Anush008 <anushshetty90@gmail.com>
This commit is contained in:
parent
9ad8e921b5
commit
6dabc7b1ae
5 changed files with 238 additions and 1 deletions
|
|
@ -27,6 +27,7 @@ Choose from the following Vector Databases:
|
|||
|
||||
1. LanceDB
|
||||
2. Pinecone
|
||||
3. Qdrant
|
||||
|
||||
#### Pinecone Configuration
|
||||
|
||||
|
|
@ -40,6 +41,25 @@ environment = "..."
|
|||
|
||||
These parameters can be obtained by registering to [Pinecone](https://app.pinecone.io/?sessionType=signup/).
|
||||
|
||||
#### Qdrant Configuration
|
||||
|
||||
To use Qdrant with the `similar issue` tool, add these credentials to `.secrets.toml` (or set as environment variables):
|
||||
|
||||
```
|
||||
[qdrant]
|
||||
url = "https://YOUR-QDRANT-URL" # e.g., https://xxxxxxxx-xxxxxxxx.eu-central-1-0.aws.cloud.qdrant.io
|
||||
api_key = "..."
|
||||
```
|
||||
|
||||
Then select Qdrant in `configuration.toml`:
|
||||
|
||||
```
|
||||
[pr_similar_issue]
|
||||
vectordb = "qdrant"
|
||||
```
|
||||
|
||||
You can get a free managed Qdrant instance from [Qdrant Cloud](https://cloud.qdrant.io/).
|
||||
|
||||
## How to use
|
||||
|
||||
- To invoke the 'similar issue' tool from **CLI**, run:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,11 @@ key = "" # Acquire through https://platform.openai.com
|
|||
api_key = "..."
|
||||
environment = "gcp-starter"
|
||||
|
||||
[qdrant]
|
||||
# For Qdrant Cloud or self-hosted Qdrant
|
||||
url = "" # e.g., https://xxxxxxxx-xxxxxxxx.eu-central-1-0.aws.cloud.qdrant.io
|
||||
api_key = ""
|
||||
|
||||
[anthropic]
|
||||
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/
|
||||
|
||||
|
|
|
|||
|
|
@ -345,7 +345,7 @@ service_callback = []
|
|||
skip_comments = false
|
||||
force_update_dataset = false
|
||||
max_issues_to_scan = 500
|
||||
vectordb = "pinecone"
|
||||
vectordb = "pinecone" # options: "pinecone", "lancedb", "qdrant"
|
||||
|
||||
[pr_find_similar_component]
|
||||
class_name = ""
|
||||
|
|
@ -363,6 +363,11 @@ number_of_results = 5
|
|||
[lancedb]
|
||||
uri = "./lancedb"
|
||||
|
||||
[qdrant]
|
||||
# fill and place credentials in .secrets.toml
|
||||
# url = "https://YOUR-QDRANT-URL"
|
||||
# api_key = "..."
|
||||
|
||||
[best_practices]
|
||||
content = ""
|
||||
organization_name = ""
|
||||
|
|
|
|||
|
|
@ -174,6 +174,87 @@ class PRSimilarIssue:
|
|||
else:
|
||||
get_logger().info('No new issues to update')
|
||||
|
||||
elif get_settings().pr_similar_issue.vectordb == "qdrant":
|
||||
try:
|
||||
import qdrant_client
|
||||
from qdrant_client.models import (Distance, FieldCondition,
|
||||
Filter, MatchValue,
|
||||
PointStruct, VectorParams)
|
||||
except Exception:
|
||||
raise Exception("Please install qdrant-client to use qdrant as vectordb")
|
||||
|
||||
api_key = None
|
||||
url = None
|
||||
try:
|
||||
api_key = get_settings().qdrant.api_key
|
||||
url = get_settings().qdrant.url
|
||||
except Exception:
|
||||
if not self.cli_mode:
|
||||
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
|
||||
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
|
||||
issue_main.create_comment("Please set qdrant url and api key in secrets file")
|
||||
raise Exception("Please set qdrant url and api key in secrets file")
|
||||
|
||||
self.qdrant = qdrant_client.QdrantClient(url=url, api_key=api_key)
|
||||
|
||||
run_from_scratch = False
|
||||
ingest = True
|
||||
|
||||
if not self.qdrant.collection_exists(collection_name=self.index_name):
|
||||
run_from_scratch = True
|
||||
ingest = False
|
||||
self.qdrant.create_collection(
|
||||
collection_name=self.index_name,
|
||||
vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
|
||||
)
|
||||
else:
|
||||
if get_settings().pr_similar_issue.force_update_dataset:
|
||||
ingest = True
|
||||
else:
|
||||
response = self.qdrant.count(
|
||||
collection_name=self.index_name,
|
||||
count_filter=Filter(must=[
|
||||
FieldCondition(key="metadata.repo", match=MatchValue(value=repo_name_for_index)),
|
||||
FieldCondition(key="id", match=MatchValue(value=f"example_issue_{repo_name_for_index}")),
|
||||
]),
|
||||
)
|
||||
ingest = True if response.count == 0 else False
|
||||
|
||||
if run_from_scratch or ingest:
|
||||
get_logger().info('Indexing the entire repo...')
|
||||
get_logger().info('Getting issues...')
|
||||
issues = list(repo_obj.get_issues(state='all'))
|
||||
get_logger().info('Done')
|
||||
self._update_qdrant_with_issues(issues, repo_name_for_index, ingest=ingest)
|
||||
else:
|
||||
issues_to_update = []
|
||||
issues_paginated_list = repo_obj.get_issues(state='all')
|
||||
counter = 1
|
||||
for issue in issues_paginated_list:
|
||||
if issue.pull_request:
|
||||
continue
|
||||
issue_str, comments, number = self._process_issue(issue)
|
||||
issue_key = f"issue_{number}"
|
||||
point_id = issue_key + "." + "issue"
|
||||
response = self.qdrant.count(
|
||||
collection_name=self.index_name,
|
||||
count_filter=Filter(must=[
|
||||
FieldCondition(key="id", match=MatchValue(value=point_id)),
|
||||
FieldCondition(key="metadata.repo", match=MatchValue(value=repo_name_for_index)),
|
||||
]),
|
||||
)
|
||||
if response.count == 0:
|
||||
counter += 1
|
||||
issues_to_update.append(issue)
|
||||
else:
|
||||
break
|
||||
|
||||
if issues_to_update:
|
||||
get_logger().info(f'Updating index with {counter} new issues...')
|
||||
self._update_qdrant_with_issues(issues_to_update, repo_name_for_index, ingest=True)
|
||||
else:
|
||||
get_logger().info('No new issues to update')
|
||||
|
||||
|
||||
async def run(self):
|
||||
get_logger().info('Getting issue...')
|
||||
|
|
@ -246,6 +327,36 @@ class PRSimilarIssue:
|
|||
score_list.append(str("{:.2f}".format(1-r['_distance'])))
|
||||
get_logger().info('Done')
|
||||
|
||||
elif get_settings().pr_similar_issue.vectordb == "qdrant":
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
res = self.qdrant.search(
|
||||
collection_name=self.index_name,
|
||||
query_vector=embeds[0],
|
||||
limit=5,
|
||||
query_filter=Filter(must=[FieldCondition(key="metadata.repo", match=MatchValue(value=self.repo_name_for_index))]),
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
for r in res:
|
||||
rid = r.payload.get("id", "")
|
||||
if 'example_issue_' in rid:
|
||||
continue
|
||||
try:
|
||||
issue_number = int(rid.split('.')[0].split('_')[-1])
|
||||
except Exception:
|
||||
get_logger().debug(f"Failed to parse issue number from {rid}")
|
||||
continue
|
||||
if original_issue_number == issue_number:
|
||||
continue
|
||||
if issue_number not in relevant_issues_number_list:
|
||||
relevant_issues_number_list.append(issue_number)
|
||||
if 'comment' in rid:
|
||||
relevant_comment_number_list.append(int(rid.split('.')[1].split('_')[-1]))
|
||||
else:
|
||||
relevant_comment_number_list.append(-1)
|
||||
score_list.append(str("{:.2f}".format(r.score)))
|
||||
get_logger().info('Done')
|
||||
|
||||
get_logger().info('Publishing response...')
|
||||
similar_issues_str = "### Similar Issues\n___\n\n"
|
||||
|
||||
|
|
@ -458,6 +569,101 @@ class PRSimilarIssue:
|
|||
get_logger().info('Done')
|
||||
|
||||
|
||||
def _update_qdrant_with_issues(self, issues_list, repo_name_for_index, ingest=False):
|
||||
try:
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
from qdrant_client.models import PointStruct
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
get_logger().info('Processing issues...')
|
||||
corpus = Corpus()
|
||||
example_issue_record = Record(
|
||||
id=f"example_issue_{repo_name_for_index}",
|
||||
text="example_issue",
|
||||
metadata=Metadata(repo=repo_name_for_index)
|
||||
)
|
||||
corpus.append(example_issue_record)
|
||||
|
||||
counter = 0
|
||||
for issue in issues_list:
|
||||
if issue.pull_request:
|
||||
continue
|
||||
|
||||
counter += 1
|
||||
if counter % 100 == 0:
|
||||
get_logger().info(f"Scanned {counter} issues")
|
||||
if counter >= self.max_issues_to_scan:
|
||||
get_logger().info(f"Scanned {self.max_issues_to_scan} issues, stopping")
|
||||
break
|
||||
|
||||
issue_str, comments, number = self._process_issue(issue)
|
||||
issue_key = f"issue_{number}"
|
||||
username = issue.user.login
|
||||
created_at = str(issue.created_at)
|
||||
if len(issue_str) < 8000 or \
|
||||
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL):
|
||||
issue_record = Record(
|
||||
id=issue_key + "." + "issue",
|
||||
text=issue_str,
|
||||
metadata=Metadata(repo=repo_name_for_index,
|
||||
username=username,
|
||||
created_at=created_at,
|
||||
level=IssueLevel.ISSUE)
|
||||
)
|
||||
corpus.append(issue_record)
|
||||
if comments:
|
||||
for j, comment in enumerate(comments):
|
||||
comment_body = comment.body
|
||||
num_words_comment = len(comment_body.split())
|
||||
if num_words_comment < 10 or not isinstance(comment_body, str):
|
||||
continue
|
||||
|
||||
if len(comment_body) < 8000 or \
|
||||
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
|
||||
comment_record = Record(
|
||||
id=issue_key + ".comment_" + str(j + 1),
|
||||
text=comment_body,
|
||||
metadata=Metadata(repo=repo_name_for_index,
|
||||
username=username,
|
||||
created_at=created_at,
|
||||
level=IssueLevel.COMMENT)
|
||||
)
|
||||
corpus.append(comment_record)
|
||||
|
||||
df = pd.DataFrame(corpus.dict()["documents"])
|
||||
get_logger().info('Done')
|
||||
|
||||
get_logger().info('Embedding...')
|
||||
openai.api_key = get_settings().openai.key
|
||||
list_to_encode = list(df["text"].values)
|
||||
try:
|
||||
res = openai.Embedding.create(input=list_to_encode, engine=MODEL)
|
||||
embeds = [record['embedding'] for record in res['data']]
|
||||
except Exception:
|
||||
embeds = []
|
||||
get_logger().error('Failed to embed entire list, embedding one by one...')
|
||||
for i, text in enumerate(list_to_encode):
|
||||
try:
|
||||
res = openai.Embedding.create(input=[text], engine=MODEL)
|
||||
embeds.append(res['data'][0]['embedding'])
|
||||
except Exception:
|
||||
embeds.append([0] * 1536)
|
||||
df["vector"] = embeds
|
||||
get_logger().info('Done')
|
||||
|
||||
get_logger().info('Upserting into Qdrant...')
|
||||
points = []
|
||||
for row in df.to_dict(orient="records"):
|
||||
points.append(
|
||||
PointStruct(id=uuid.uuid5(uuid.NAMESPACE_DNS, row["id"]).hex, vector=row["vector"], payload={"id": row["id"], "text": row["text"], "metadata": row["metadata"]})
|
||||
)
|
||||
self.qdrant.upsert(collection_name=self.index_name, points=points)
|
||||
get_logger().info('Done')
|
||||
|
||||
|
||||
class IssueLevel(str, Enum):
|
||||
ISSUE = "issue"
|
||||
COMMENT = "comment"
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ giteapy==1.0.8
|
|||
# pinecone-client
|
||||
# pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main
|
||||
# lancedb==0.5.1
|
||||
# qdrant-client==1.15.1
|
||||
# uncomment this to support language LangChainOpenAIHandler
|
||||
# langchain==0.2.0
|
||||
# langchain-core==0.2.28
|
||||
|
|
|
|||
Loading…
Reference in a new issue