diff --git a/graphrag/index/operations/layout_graph/zero.py b/graphrag/index/operations/layout_graph/zero.py index 934df0030f..004abbedd1 100644 --- a/graphrag/index/operations/layout_graph/zero.py +++ b/graphrag/index/operations/layout_graph/zero.py @@ -67,30 +67,37 @@ def get_zero_positions( three_d: bool | None = False, ) -> list[NodePosition]: """Project embedding vectors down to 2D/3D using UMAP.""" - embedding_position_data: list[NodePosition] = [] - for index, node_name in enumerate(node_labels): - node_category = 1 if node_categories is None else node_categories[index] - node_size = 1 if node_sizes is None else node_sizes[index] - - if not three_d: - embedding_position_data.append( - NodePosition( - label=str(node_name), - x=0, - y=0, - cluster=str(int(node_category)), - size=int(node_size), - ) - ) - else: - embedding_position_data.append( - NodePosition( - label=str(node_name), - x=0, - y=0, - z=0, - cluster=str(int(node_category)), - size=int(node_size), - ) + n = len(node_labels) + + if node_categories is None: + category_values = ["1"] * n + else: + category_values = [str(int(cat)) for cat in node_categories] + + if node_sizes is None: + size_values = [1] * n + else: + size_values = [int(sz) for sz in node_sizes] + + if not three_d: + return [ + NodePosition( + label=str(node_labels[i]), + x=0, + y=0, + cluster=category_values[i], + size=size_values[i], ) - return embedding_position_data + for i in range(n) + ] + return [ + NodePosition( + label=str(node_labels[i]), + x=0, + y=0, + z=0, + cluster=category_values[i], + size=size_values[i], + ) + for i in range(n) + ]