-
Notifications
You must be signed in to change notification settings - Fork 15k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community: Add conversions from GVS to networkx (#26906)
These allow converting linked documents (such as those used with GraphVectorStore) to networkx for rendering and/or in-memory graph algorithms such as community detection.
- Loading branch information
1 parent
7809b31
commit 29bf89d
Showing
3 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
libs/community/langchain_community/graph_vectorstores/networkx.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
"""Utilities for using Graph Vector Stores with networkx.""" | ||
|
||
import typing | ||
|
||
from langchain_core.documents import Document | ||
|
||
from langchain_community.graph_vectorstores.links import get_links | ||
|
||
if typing.TYPE_CHECKING: | ||
import networkx as nx | ||
|
||
|
||
def documents_to_networkx( | ||
documents: typing.Iterable[Document], | ||
*, | ||
tag_nodes: bool = True, | ||
) -> "nx.DiGraph": | ||
"""Return the networkx directed graph corresponding to the documents. | ||
Args: | ||
documents: The documents to convenrt to networkx. | ||
tag_nodes: If `True`, each tag will be rendered as a node, with edges | ||
to/from the corresponding documents. If `False`, edges will be | ||
between documents, with a label corresponding to the tag(s) | ||
connecting them. | ||
""" | ||
import networkx as nx | ||
|
||
graph = nx.DiGraph() | ||
|
||
tag_ids: typing.Dict[typing.Tuple[str, str], str] = {} | ||
tag_labels: typing.Dict[str, str] = {} | ||
documents_by_incoming: typing.Dict[str, typing.Set[str]] = {} | ||
|
||
# First pass: | ||
# - Register tag IDs for each unique (kind, tag). | ||
# - If rendering tag nodes, add them to the graph. | ||
# - If not rendering tag nodes, create a dictionary of documents by incoming tags. | ||
for document in documents: | ||
if document.id is None: | ||
raise ValueError(f"Illegal graph document without ID: {document}") | ||
|
||
for link in get_links(document): | ||
tag_key = (link.kind, link.tag) | ||
tag_id = tag_ids.get(tag_key) | ||
if tag_id is None: | ||
tag_id = f"tag_{len(tag_ids)}" | ||
tag_ids[tag_key] = tag_id | ||
|
||
if tag_nodes: | ||
graph.add_node(tag_id, label=f"{link.kind}:{link.tag}") | ||
|
||
if not tag_nodes and (link.direction == "in" or link.direction == "bidir"): | ||
tag_labels[tag_id] = f"{link.kind}:{link.tag}" | ||
documents_by_incoming.setdefault(tag_id, set()).add(document.id) | ||
|
||
# Second pass: | ||
# - Render document nodes | ||
# - If rendering tag nodes, render edges to/from documents and tag nodes. | ||
# - If not rendering tag nodes, render edges to/from documents based on tags. | ||
for document in documents: | ||
graph.add_node(document.id, text=document.page_content) | ||
|
||
targets: typing.Dict[str, typing.List[str]] = {} | ||
for link in get_links(document): | ||
tag_id = tag_ids[(link.kind, link.tag)] | ||
if tag_nodes: | ||
if link.direction == "in" or link.direction == "bidir": | ||
graph.add_edge(tag_id, document.id) | ||
if link.direction == "out" or link.direction == "bidir": | ||
graph.add_edge(document.id, tag_id) | ||
else: | ||
if link.direction == "out" or link.direction == "bidir": | ||
label = tag_labels[tag_id] | ||
for target in documents_by_incoming[tag_id]: | ||
if target != document.id: | ||
targets.setdefault(target, []).append(label) | ||
|
||
# Avoid a multigraph by collecting the list of labels for each edge. | ||
if not tag_nodes: | ||
for target, labels in targets.items(): | ||
graph.add_edge(document.id, target, label=str(labels)) | ||
|
||
return graph |
77 changes: 77 additions & 0 deletions
77
libs/community/tests/unit_tests/graph_vectorstores/test_networkx.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import pytest | ||
from langchain_core.documents import Document | ||
|
||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link | ||
from langchain_community.graph_vectorstores.networkx import documents_to_networkx | ||
|
||
|
||
@pytest.mark.requires("networkx") | ||
def test_documents_to_networkx() -> None: | ||
import networkx as nx | ||
|
||
doc1 = Document( | ||
id="a", | ||
page_content="some content", | ||
metadata={ | ||
METADATA_LINKS_KEY: [ | ||
Link.incoming("href", "a"), | ||
Link.bidir("kw", "foo"), | ||
] | ||
}, | ||
) | ||
doc2 = Document( | ||
id="b", | ||
page_content="<some\n more content>", | ||
metadata={ | ||
METADATA_LINKS_KEY: [ | ||
Link.incoming("href", "b"), | ||
Link.outgoing("href", "a"), | ||
Link.bidir("kw", "foo"), | ||
Link.bidir("kw", "bar"), | ||
] | ||
}, | ||
) | ||
|
||
graph_with_tags = documents_to_networkx([doc1, doc2], tag_nodes=True) | ||
link_data = nx.node_link_data(graph_with_tags) | ||
assert link_data["directed"] | ||
assert not link_data["multigraph"] | ||
|
||
link_data["nodes"].sort(key=lambda n: n["id"]) | ||
assert link_data["nodes"] == [ | ||
{"id": "a", "text": "some content"}, | ||
{"id": "b", "text": "<some\n more content>"}, | ||
{"id": "tag_0", "label": "href:a"}, | ||
{"id": "tag_1", "label": "kw:foo"}, | ||
{"id": "tag_2", "label": "href:b"}, | ||
{"id": "tag_3", "label": "kw:bar"}, | ||
] | ||
link_data["links"].sort(key=lambda n: (n["source"], n["target"])) | ||
assert link_data["links"] == [ | ||
{"source": "a", "target": "tag_1"}, | ||
{"source": "b", "target": "tag_0"}, | ||
{"source": "b", "target": "tag_1"}, | ||
{"source": "b", "target": "tag_3"}, | ||
{"source": "tag_0", "target": "a"}, | ||
{"source": "tag_1", "target": "a"}, | ||
{"source": "tag_1", "target": "b"}, | ||
{"source": "tag_2", "target": "b"}, | ||
{"source": "tag_3", "target": "b"}, | ||
] | ||
|
||
graph_without_tags = documents_to_networkx([doc1, doc2], tag_nodes=False) | ||
link_data = nx.node_link_data(graph_without_tags) | ||
assert link_data["directed"] | ||
assert not link_data["multigraph"] | ||
|
||
link_data["nodes"].sort(key=lambda n: n["id"]) | ||
assert link_data["nodes"] == [ | ||
{"id": "a", "text": "some content"}, | ||
{"id": "b", "text": "<some\n more content>"}, | ||
] | ||
|
||
link_data["links"].sort(key=lambda n: (n["source"], n["target"])) | ||
assert link_data["links"] == [ | ||
{"source": "a", "target": "b", "label": "['kw:foo']"}, | ||
{"source": "b", "target": "a", "label": "['href:a', 'kw:foo']"}, | ||
] |