Skip to content

Commit

Permalink
community: Add conversions from GVS to networkx (#26906)
Browse files Browse the repository at this point in the history
These allow converting linked documents (such as those used with
GraphVectorStore) to networkx for rendering and/or in-memory graph
algorithms such as community detection.
  • Loading branch information
bjchambers authored Sep 27, 2024
1 parent 7809b31 commit 29bf89d
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 0 deletions.
1 change: 1 addition & 0 deletions libs/community/extended_testing_deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ motor>=3.3.1,<4
msal>=1.25.0,<2
mwparserfromhell>=0.6.4,<0.7
mwxml>=0.3.3,<0.4
networkx>=3.2.1,<4
newspaper3k>=0.2.8,<0.3
numexpr>=2.8.6,<3
nvidia-riva-client>=2.14.0,<3
Expand Down
84 changes: 84 additions & 0 deletions libs/community/langchain_community/graph_vectorstores/networkx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Utilities for using Graph Vector Stores with networkx."""

import typing

from langchain_core.documents import Document

from langchain_community.graph_vectorstores.links import get_links

if typing.TYPE_CHECKING:
import networkx as nx


def documents_to_networkx(
documents: typing.Iterable[Document],
*,
tag_nodes: bool = True,
) -> "nx.DiGraph":
"""Return the networkx directed graph corresponding to the documents.
Args:
documents: The documents to convenrt to networkx.
tag_nodes: If `True`, each tag will be rendered as a node, with edges
to/from the corresponding documents. If `False`, edges will be
between documents, with a label corresponding to the tag(s)
connecting them.
"""
import networkx as nx

graph = nx.DiGraph()

tag_ids: typing.Dict[typing.Tuple[str, str], str] = {}
tag_labels: typing.Dict[str, str] = {}
documents_by_incoming: typing.Dict[str, typing.Set[str]] = {}

# First pass:
# - Register tag IDs for each unique (kind, tag).
# - If rendering tag nodes, add them to the graph.
# - If not rendering tag nodes, create a dictionary of documents by incoming tags.
for document in documents:
if document.id is None:
raise ValueError(f"Illegal graph document without ID: {document}")

for link in get_links(document):
tag_key = (link.kind, link.tag)
tag_id = tag_ids.get(tag_key)
if tag_id is None:
tag_id = f"tag_{len(tag_ids)}"
tag_ids[tag_key] = tag_id

if tag_nodes:
graph.add_node(tag_id, label=f"{link.kind}:{link.tag}")

if not tag_nodes and (link.direction == "in" or link.direction == "bidir"):
tag_labels[tag_id] = f"{link.kind}:{link.tag}"
documents_by_incoming.setdefault(tag_id, set()).add(document.id)

# Second pass:
# - Render document nodes
# - If rendering tag nodes, render edges to/from documents and tag nodes.
# - If not rendering tag nodes, render edges to/from documents based on tags.
for document in documents:
graph.add_node(document.id, text=document.page_content)

targets: typing.Dict[str, typing.List[str]] = {}
for link in get_links(document):
tag_id = tag_ids[(link.kind, link.tag)]
if tag_nodes:
if link.direction == "in" or link.direction == "bidir":
graph.add_edge(tag_id, document.id)
if link.direction == "out" or link.direction == "bidir":
graph.add_edge(document.id, tag_id)
else:
if link.direction == "out" or link.direction == "bidir":
label = tag_labels[tag_id]
for target in documents_by_incoming[tag_id]:
if target != document.id:
targets.setdefault(target, []).append(label)

# Avoid a multigraph by collecting the list of labels for each edge.
if not tag_nodes:
for target, labels in targets.items():
graph.add_edge(document.id, target, label=str(labels))

return graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest
from langchain_core.documents import Document

from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_community.graph_vectorstores.networkx import documents_to_networkx


@pytest.mark.requires("networkx")
def test_documents_to_networkx() -> None:
import networkx as nx

doc1 = Document(
id="a",
page_content="some content",
metadata={
METADATA_LINKS_KEY: [
Link.incoming("href", "a"),
Link.bidir("kw", "foo"),
]
},
)
doc2 = Document(
id="b",
page_content="<some\n more content>",
metadata={
METADATA_LINKS_KEY: [
Link.incoming("href", "b"),
Link.outgoing("href", "a"),
Link.bidir("kw", "foo"),
Link.bidir("kw", "bar"),
]
},
)

graph_with_tags = documents_to_networkx([doc1, doc2], tag_nodes=True)
link_data = nx.node_link_data(graph_with_tags)
assert link_data["directed"]
assert not link_data["multigraph"]

link_data["nodes"].sort(key=lambda n: n["id"])
assert link_data["nodes"] == [
{"id": "a", "text": "some content"},
{"id": "b", "text": "<some\n more content>"},
{"id": "tag_0", "label": "href:a"},
{"id": "tag_1", "label": "kw:foo"},
{"id": "tag_2", "label": "href:b"},
{"id": "tag_3", "label": "kw:bar"},
]
link_data["links"].sort(key=lambda n: (n["source"], n["target"]))
assert link_data["links"] == [
{"source": "a", "target": "tag_1"},
{"source": "b", "target": "tag_0"},
{"source": "b", "target": "tag_1"},
{"source": "b", "target": "tag_3"},
{"source": "tag_0", "target": "a"},
{"source": "tag_1", "target": "a"},
{"source": "tag_1", "target": "b"},
{"source": "tag_2", "target": "b"},
{"source": "tag_3", "target": "b"},
]

graph_without_tags = documents_to_networkx([doc1, doc2], tag_nodes=False)
link_data = nx.node_link_data(graph_without_tags)
assert link_data["directed"]
assert not link_data["multigraph"]

link_data["nodes"].sort(key=lambda n: n["id"])
assert link_data["nodes"] == [
{"id": "a", "text": "some content"},
{"id": "b", "text": "<some\n more content>"},
]

link_data["links"].sort(key=lambda n: (n["source"], n["target"]))
assert link_data["links"] == [
{"source": "a", "target": "b", "label": "['kw:foo']"},
{"source": "b", "target": "a", "label": "['href:a', 'kw:foo']"},
]

0 comments on commit 29bf89d

Please sign in to comment.