Skip to content

Commit

Permalink
fix: modified DirectoryDataComponent to support user defined custom f…
Browse files Browse the repository at this point in the history
…ile types (#4017)

test_data_components.py: Updated 'test_directory_component_build_with_multithreading' to expect an extra argument
  • Loading branch information
EDLLT authored Oct 7, 2024
1 parent b482862 commit 79a1257
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 6 additions & 4 deletions src/backend/base/langflow/components/data/Directory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langflow.base.data.utils import parallel_load_data, parse_text_file_to_data, retrieve_file_paths
from langflow.base.data.utils import TEXT_FILE_TYPES, parallel_load_data, parse_text_file_to_data, retrieve_file_paths
from langflow.custom import Component
from langflow.io import BoolInput, IntInput, MessageTextInput
from langflow.schema import Data
Expand All @@ -20,7 +20,7 @@ class DirectoryComponent(Component):
MessageTextInput(
name="types",
display_name="Types",
info="File types to load. Leave empty to load all types.",
info="File types to load. Leave empty to load all default supported types.",
is_list=True,
),
IntInput(
Expand Down Expand Up @@ -68,7 +68,9 @@ class DirectoryComponent(Component):

def load_directory(self) -> list[Data]:
path = self.path
types = self.types or [] # self.types is already a list due to is_list=True
types = (
self.types if self.types and self.types != [""] else TEXT_FILE_TYPES
) # self.types is already a list due to is_list=True
depth = self.depth
max_concurrency = self.max_concurrency
load_hidden = self.load_hidden
Expand All @@ -77,7 +79,7 @@ def load_directory(self) -> list[Data]:
use_multithreading = self.use_multithreading

resolved_path = self.resolve_path(path)
file_paths = retrieve_file_paths(resolved_path, load_hidden, recursive, depth)
file_paths = retrieve_file_paths(resolved_path, load_hidden, recursive, depth, types)

if types:
file_paths = [fp for fp in file_paths if any(fp.endswith(ext) for ext in types)]
Expand Down
4 changes: 2 additions & 2 deletions src/backend/tests/unit/test_data_components.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, ANY

import httpx
import pytest
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_directory_component_build_with_multithreading(

# Assert
mock_resolve_path.assert_called_once_with(path)
mock_retrieve_file_paths.assert_called_once_with(path, load_hidden, recursive, depth)
mock_retrieve_file_paths.assert_called_once_with(path, load_hidden, recursive, depth, ANY)
mock_parallel_load_data.assert_called_once_with(
mock_retrieve_file_paths.return_value, silent_errors, max_concurrency
)
Expand Down

0 comments on commit 79a1257

Please sign in to comment.