Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Add conversions from GVS to networkx #26906

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/community/extended_testing_deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ motor>=3.3.1,<4
msal>=1.25.0,<2
mwparserfromhell>=0.6.4,<0.7
mwxml>=0.3.3,<0.4
networkx>=3.2.1,<4
newspaper3k>=0.2.8,<0.3
numexpr>=2.8.6,<3
nvidia-riva-client>=2.14.0,<3
Expand Down
84 changes: 84 additions & 0 deletions libs/community/langchain_community/graph_vectorstores/networkx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Utilities for using Graph Vector Stores with networkx."""

import typing

from langchain_core.documents import Document

from langchain_community.graph_vectorstores.links import get_links

if typing.TYPE_CHECKING:
import networkx as nx


def documents_to_networkx(
documents: typing.Iterable[Document],
*,
tag_nodes: bool = True,
) -> "nx.DiGraph":
"""Return the networkx directed graph corresponding to the documents.

Args:
documents: The documents to convenrt to networkx.
tag_nodes: If `True`, each tag will be rendered as a node, with edges
to/from the corresponding documents. If `False`, edges will be
between documents, with a label corresponding to the tag(s)
connecting them.
"""
import networkx as nx

graph = nx.DiGraph()

tag_ids: typing.Dict[typing.Tuple[str, str], str] = {}
tag_labels: typing.Dict[str, str] = {}
documents_by_incoming: typing.Dict[str, typing.Set[str]] = {}

# First pass:
# - Register tag IDs for each unique (kind, tag).
# - If rendering tag nodes, add them to the graph.
# - If not rendering tag nodes, create a dictionary of documents by incoming tags.
for document in documents:
if document.id is None:
raise ValueError(f"Illegal graph document without ID: {document}")

for link in get_links(document):
tag_key = (link.kind, link.tag)
tag_id = tag_ids.get(tag_key)
if tag_id is None:
tag_id = f"tag_{len(tag_ids)}"
tag_ids[tag_key] = tag_id

if tag_nodes:
graph.add_node(tag_id, label=f"{link.kind}:{link.tag}")

if not tag_nodes and (link.direction == "in" or link.direction == "bidir"):
tag_labels[tag_id] = f"{link.kind}:{link.tag}"
documents_by_incoming.setdefault(tag_id, set()).add(document.id)

# Second pass:
# - Render document nodes
# - If rendering tag nodes, render edges to/from documents and tag nodes.
# - If not rendering tag nodes, render edges to/from documents based on tags.
for document in documents:
graph.add_node(document.id, text=document.page_content)

targets: typing.Dict[str, typing.List[str]] = {}
for link in get_links(document):
tag_id = tag_ids[(link.kind, link.tag)]
if tag_nodes:
if link.direction == "in" or link.direction == "bidir":
graph.add_edge(tag_id, document.id)
if link.direction == "out" or link.direction == "bidir":
graph.add_edge(document.id, tag_id)
else:
if link.direction == "out" or link.direction == "bidir":
label = tag_labels[tag_id]
for target in documents_by_incoming[tag_id]:
if target != document.id:
targets.setdefault(target, []).append(label)

# Avoid a multigraph by collecting the list of labels for each edge.
if not tag_nodes:
for target, labels in targets.items():
graph.add_edge(document.id, target, label=str(labels))

return graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest
from langchain_core.documents import Document

from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_community.graph_vectorstores.networkx import documents_to_networkx


@pytest.mark.requires("networkx")
def test_documents_to_networkx() -> None:
import networkx as nx

doc1 = Document(
id="a",
page_content="some content",
metadata={
METADATA_LINKS_KEY: [
Link.incoming("href", "a"),
Link.bidir("kw", "foo"),
]
},
)
doc2 = Document(
id="b",
page_content="<some\n more content>",
metadata={
METADATA_LINKS_KEY: [
Link.incoming("href", "b"),
Link.outgoing("href", "a"),
Link.bidir("kw", "foo"),
Link.bidir("kw", "bar"),
]
},
)

graph_with_tags = documents_to_networkx([doc1, doc2], tag_nodes=True)
link_data = nx.node_link_data(graph_with_tags)
assert link_data["directed"]
assert not link_data["multigraph"]

link_data["nodes"].sort(key=lambda n: n["id"])
assert link_data["nodes"] == [
{"id": "a", "text": "some content"},
{"id": "b", "text": "<some\n more content>"},
{"id": "tag_0", "label": "href:a"},
{"id": "tag_1", "label": "kw:foo"},
{"id": "tag_2", "label": "href:b"},
{"id": "tag_3", "label": "kw:bar"},
]
link_data["links"].sort(key=lambda n: (n["source"], n["target"]))
assert link_data["links"] == [
{"source": "a", "target": "tag_1"},
{"source": "b", "target": "tag_0"},
{"source": "b", "target": "tag_1"},
{"source": "b", "target": "tag_3"},
{"source": "tag_0", "target": "a"},
{"source": "tag_1", "target": "a"},
{"source": "tag_1", "target": "b"},
{"source": "tag_2", "target": "b"},
{"source": "tag_3", "target": "b"},
]

graph_without_tags = documents_to_networkx([doc1, doc2], tag_nodes=False)
link_data = nx.node_link_data(graph_without_tags)
assert link_data["directed"]
assert not link_data["multigraph"]

link_data["nodes"].sort(key=lambda n: n["id"])
assert link_data["nodes"] == [
{"id": "a", "text": "some content"},
{"id": "b", "text": "<some\n more content>"},
]

link_data["links"].sort(key=lambda n: (n["source"], n["target"]))
assert link_data["links"] == [
{"source": "a", "target": "b", "label": "['kw:foo']"},
{"source": "b", "target": "a", "label": "['href:a', 'kw:foo']"},
]
Loading