diff --git a/nptdms/tdms.py b/nptdms/tdms.py index 70bc07b..72a8d27 100644 --- a/nptdms/tdms.py +++ b/nptdms/tdms.py @@ -214,6 +214,13 @@ def __len__(self): """ return len(self._groups) + def __contains__(self, group_name): + """ Check if TDMS file contains groupp + + :rtype: Boolean + """ + return group_name in self._groups + def __iter__(self): """ Returns an iterator over the names of groups in this file """ @@ -355,6 +362,13 @@ def __init__(self, path, properties, channels): def __repr__(self): return "" % self.path + def __contains__(self, channel_name): + """ Check if group contains channel + + :rtype: Boolean + """ + return channel_name in self._channels + @property def path(self): """ Path to the TDMS object for this group diff --git a/nptdms/test/test_tdms_file.py b/nptdms/test/test_tdms_file.py index 42238f4..956a430 100644 --- a/nptdms/test/test_tdms_file.py +++ b/nptdms/test/test_tdms_file.py @@ -899,6 +899,28 @@ def test_object_repr(): assert repr(channel) == "" +def test_group_in_tdms_object(): + """Test in for TDMS object + """ + test_file = GeneratedFile() + test_file.add_segment(*basic_segment()) + tdms_data = test_file.load() + + assert 'Group' in tdms_data + assert 'group' not in tdms_data + + +def test_channel_in_group_object(): + """Test in for group + """ + test_file = GeneratedFile() + test_file.add_segment(*basic_segment()) + tdms_data = test_file.load() + + assert 'Channel1' in tdms_data['Group'] + assert 'channel1' not in tdms_data['Group'] + + def test_data_read_from_bytes_io(): """Test reading data"""