Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions graphrag/query/indexer_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,18 @@ def read_indexer_reports(

if not dynamic_community_selection:
# perform community level roll up
nodes_df.loc[:, "community"] = nodes_df["community"].fillna(-1)
nodes_df.loc[:, "community"] = nodes_df["community"].astype(int)

nodes_df = nodes_df.groupby(["title"]).agg({"community": "max"}).reset_index()
filtered_community_df = nodes_df["community"].drop_duplicates()

reports_df = reports_df.merge(
filtered_community_df, on="community", how="inner"
)
nodes_df["community"] = nodes_df["community"].fillna(-1).astype(int)
# Get max community for each title efficiently
max_community_df = nodes_df.drop_duplicates(subset=["title"], keep="last")[
["title", "community"]
]
# Only keep unique communities present in reports_df
filtered_community = max_community_df["community"].unique()
reports_df = reports_df[reports_df["community"].isin(filtered_community)]

if config and (
content_embedding_col not in reports_df.columns
or reports_df.loc[:, content_embedding_col].isna().any()
or reports_df[content_embedding_col].isna().any()
):
# TODO: Find a way to retrieve the right embedding model id.
embedding_model_settings = config.get_language_model_config(
Expand All @@ -115,9 +114,17 @@ def read_indexer_reports(
model_type=embedding_model_settings.type,
config=embedding_model_settings,
)
reports_df = embed_community_reports(
reports_df, embedder, embedding_col=content_embedding_col
)
# Only embed missing embeddings for optimization
if content_embedding_col not in reports_df.columns:
reports_df[content_embedding_col] = reports_df["full_content"].apply(
embedder.embed
)
elif reports_df[content_embedding_col].isna().any():
missing_idx = reports_df[content_embedding_col].isna()
# Only embed missing rows
reports_df.loc[missing_idx, content_embedding_col] = reports_df.loc[
missing_idx, "full_content"
].apply(embedder.embed)

return read_community_reports(
df=reports_df,
Expand Down
51 changes: 37 additions & 14 deletions graphrag/query/input/loaders/dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,48 @@ def read_community_reports(
) -> list[CommunityReport]:
"""Read community reports from a dataframe using pre-converted records."""
records = _prepare_records(df)
# Use a local variable for attributes_cols (performance: attribute lookup reduced)
get = dict.get
CommunityReport_ = CommunityReport # Localize for speed in tight loop
to_optional_float_ = to_optional_float
to_optional_list_ = to_optional_list
to_optional_str_ = to_optional_str
to_str_ = to_str
# Minor: reduce attribute lookups by making local
if attributes_cols:
return [
CommunityReport_(
id=to_str_(row, id_col),
short_id=to_optional_str_(row, short_id_col)
if short_id_col
else str(row["Index"]),
title=to_str_(row, title_col),
community_id=to_str_(row, community_col),
summary=to_str_(row, summary_col),
full_content=to_str_(row, content_col),
rank=to_optional_float_(row, rank_col),
full_content_embedding=to_optional_list_(
row, content_embedding_col, item_type=float
),
attributes={col: get(row, col) for col in attributes_cols},
)
for row in records
]
return [
CommunityReport(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col)
CommunityReport_(
id=to_str_(row, id_col),
short_id=to_optional_str_(row, short_id_col)
if short_id_col
else str(row["Index"]),
title=to_str(row, title_col),
community_id=to_str(row, community_col),
summary=to_str(row, summary_col),
full_content=to_str(row, content_col),
rank=to_optional_float(row, rank_col),
full_content_embedding=to_optional_list(
title=to_str_(row, title_col),
community_id=to_str_(row, community_col),
summary=to_str_(row, summary_col),
full_content=to_str_(row, content_col),
rank=to_optional_float_(row, rank_col),
full_content_embedding=to_optional_list_(
row, content_embedding_col, item_type=float
),
attributes=(
{col: row.get(col) for col in attributes_cols}
if attributes_cols
else None
),
attributes=None,
)
for row in records
]
Expand Down