diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..2e930a447 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,36 @@ +language: python +sudo: required +dist: trusty +python: + - "2.7" + - "3.3" + - "3.4" + - "3.5" +before_install: + # get a working ffmpeg + - sudo add-apt-repository --yes ppa:mc3man/trusty-media + - sudo apt-get update -qq + - sudo apt-get install -qq ffmpeg + # install numpy etc. via miniconda + - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then + wget http://repo.continuum.io/miniconda/Miniconda-3.8.3-Linux-x86_64.sh -O miniconda.sh; + else + wget http://repo.continuum.io/miniconda/Miniconda3-3.8.3-Linux-x86_64.sh -O miniconda.sh; + fi + - bash miniconda.sh -b -p $HOME/miniconda + - export PATH="$HOME/miniconda/bin:$PATH" + - hash -r + - conda config --set always_yes yes --set changeps1 no + - conda update -q conda + - conda config --add channels pypi + - conda info -a + - deps='pip cython numpy scipy nose pep8' + - conda create -q -n test-environment "python=$TRAVIS_PYTHON_VERSION" $deps + - source activate test-environment +install: + - pip install -e . +before_script: + - pep8 --ignore=E402 madmom tests bin +script: + - nosetests + # TODO: add executable programs diff --git a/README.rst b/README.rst index d847060d3..c8a9ddca5 100644 --- a/README.rst +++ b/README.rst @@ -45,8 +45,8 @@ Installation Prerequisites ------------- -To install the ``madmom`` package, you must have Python version 2.7 and the -following packages installed: +To install the ``madmom`` package, you must have either Python 2.7 or Python +3.3 or newer and the following packages installed: - `numpy `_ - `scipy `_ @@ -54,7 +54,8 @@ following packages installed: - `nose `_ (to run the tests) If you need support for audio files other than ``.wav`` with a sample rate of -44.1kHz and 16 bit depth, you need ``ffmpeg`` (or ``avconv`` on Ubuntu Linux). +44.1kHz and 16 bit depth, you need ``ffmpeg`` (``avconv`` on Ubuntu Linux has +some decoding bugs, so we advise not to use it!). Please refer to the `requirements.txt `_ file for the minimum required versions and make sure that these modules are up to date, otherwise it diff --git a/bin/BeatDetector b/bin/BeatDetector index 59f534133..3de6c7c21 100755 --- a/bin/BeatDetector +++ b/bin/BeatDetector @@ -5,6 +5,8 @@ BeatDetector beat tracking algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -94,7 +96,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/BeatTracker b/bin/BeatTracker index 9f8ba2122..a1d1b046c 100755 --- a/bin/BeatTracker +++ b/bin/BeatTracker @@ -5,6 +5,8 @@ BeatTracker beat tracking algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -92,7 +94,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/CRFBeatDetector b/bin/CRFBeatDetector index 548bcc755..48225ba1a 100755 --- a/bin/CRFBeatDetector +++ b/bin/CRFBeatDetector @@ -5,6 +5,8 @@ CRFBeatDetector beat tracking algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -95,7 +97,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/ComplexFlux b/bin/ComplexFlux index c2dc279ae..ce9ef2cd4 100755 --- a/bin/ComplexFlux +++ b/bin/ComplexFlux @@ -5,6 +5,8 @@ ComplexFlux onset detection algorithm. """ +from __future__ import absolute_import, division, print_function + import argparse from madmom.processors import IOProcessor, io_arguments @@ -72,7 +74,7 @@ def main(): args.online = False # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/DBNBeatTracker b/bin/DBNBeatTracker index 507b7a268..6c23b8010 100755 --- a/bin/DBNBeatTracker +++ b/bin/DBNBeatTracker @@ -5,6 +5,8 @@ DBNBeatTracker beat tracking algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -96,7 +98,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/DownBeatTracker b/bin/DownBeatTracker index e588bf659..8033f0e6d 100755 --- a/bin/DownBeatTracker +++ b/bin/DownBeatTracker @@ -5,6 +5,8 @@ DownBeatTracker (down-)beat tracking algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -92,7 +94,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/LogFiltSpecFlux b/bin/LogFiltSpecFlux index 2dd444def..67af25866 100755 --- a/bin/LogFiltSpecFlux +++ b/bin/LogFiltSpecFlux @@ -5,6 +5,8 @@ LogFiltSpecFlux onset detection algorithm. """ +from __future__ import absolute_import, division, print_function + import argparse from madmom.processors import IOProcessor, io_arguments @@ -71,7 +73,7 @@ def main(): args.online = False # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/MMBeatTracker b/bin/MMBeatTracker index 0fd552efb..7c518fdd7 100755 --- a/bin/MMBeatTracker +++ b/bin/MMBeatTracker @@ -5,6 +5,8 @@ MMBeatTracker multi model beat tracking algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -104,7 +106,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/OnsetDetector b/bin/OnsetDetector index 3fdabcf18..5e3137d7b 100755 --- a/bin/OnsetDetector +++ b/bin/OnsetDetector @@ -5,6 +5,8 @@ OnsetDetector onset detection algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -88,7 +90,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/OnsetDetectorLL b/bin/OnsetDetectorLL index c52e96f26..ed021989e 100755 --- a/bin/OnsetDetectorLL +++ b/bin/OnsetDetectorLL @@ -5,6 +5,8 @@ OnsetDetectorLL online onset detection algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -100,7 +102,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/PianoTranscriptor b/bin/PianoTranscriptor index 68af3139f..5e177a6b9 100755 --- a/bin/PianoTranscriptor +++ b/bin/PianoTranscriptor @@ -5,6 +5,8 @@ PianoTranscriptor (piano) note transcription algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -17,7 +19,7 @@ from madmom.audio.spectrogram import (LogarithmicFilteredSpectrogramProcessor, from madmom.ml.rnn import RNNProcessor, average_predictions from madmom.features import ActivationsProcessor from madmom.features.onsets import PeakPickingProcessor -from madmom.features.notes import (write_midi, write_notes, write_frequencies, +from madmom.features.notes import (write_midi, write_notes, write_mirex_format, note_reshaper) @@ -102,7 +104,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: @@ -136,7 +138,7 @@ def main(): elif args.output_format == 'midi': output = write_midi elif args.output_format == 'mirex': - output = write_frequencies + output = write_mirex_format else: raise ValueError('unknown output format: %s' % args.output_format) out_processor = [peak_picking, output] diff --git a/bin/PickleProcessor b/bin/PickleProcessor index 7d98c5bab..fd2afacde 100755 --- a/bin/PickleProcessor +++ b/bin/PickleProcessor @@ -5,7 +5,9 @@ PickleProcessor. """ -import cPickle +from __future__ import absolute_import, division, print_function + +import pickle import argparse from madmom.processors import io_arguments @@ -28,14 +30,14 @@ def main(): help='pickled processor') parser.add_argument('--version', action='version', version='PickleProcessor v0.1') - sp = io_arguments(parser, output_suffix='.txt', pickle=False) + io_arguments(parser, output_suffix='.txt', pickle=False) # parse arguments args = parser.parse_args() kwargs = vars(args) # create a processor - processor = cPickle.load(kwargs.pop('processor')) + processor = pickle.load(kwargs.pop('processor')) # and call the processing function args.func(processor, **kwargs) diff --git a/bin/SpectralOnsetDetection b/bin/SpectralOnsetDetection index 08d154743..0f7b8c6c4 100755 --- a/bin/SpectralOnsetDetection +++ b/bin/SpectralOnsetDetection @@ -5,6 +5,8 @@ Spectral onset detection script. """ +from __future__ import absolute_import, division, print_function + import argparse from madmom.processors import IOProcessor, io_arguments @@ -84,7 +86,7 @@ def main(): '%s.' % args.onset_method) # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/SuperFlux b/bin/SuperFlux index f46ceb600..70c8e09af 100755 --- a/bin/SuperFlux +++ b/bin/SuperFlux @@ -5,6 +5,8 @@ SuperFlux onset detection algorithm. """ +from __future__ import absolute_import, division, print_function + import argparse from madmom.processors import IOProcessor, io_arguments @@ -71,7 +73,7 @@ def main(): args.online = False # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/SuperFluxNN b/bin/SuperFluxNN index 8893db846..abcc775ef 100755 --- a/bin/SuperFluxNN +++ b/bin/SuperFluxNN @@ -5,6 +5,8 @@ SuperFlux with neural network based peak picking onset detection algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -85,7 +87,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: diff --git a/bin/TempoDetector b/bin/TempoDetector index d041520ee..c3e66ba7e 100755 --- a/bin/TempoDetector +++ b/bin/TempoDetector @@ -5,6 +5,8 @@ TempoDetector tempo estimation algorithm. """ +from __future__ import absolute_import, division, print_function + import glob import argparse @@ -91,7 +93,7 @@ def main(): # print arguments if args.verbose: - print args + print(args) # input processor if args.load: @@ -124,8 +126,8 @@ def main(): from functools import partial writer = partial(write_tempo, mirex=True) elif args.tempo_format in ('raw', 'all'): - # borrow the note writer for outputting multiple values - from madmom.features.notes import write_notes as writer + # borrow the event writer for outputting multiple values + from madmom.utils import write_events as writer else: # normal output writer = write_tempo diff --git a/bin/evaluate b/bin/evaluate index 3a3792768..7752847c7 100755 --- a/bin/evaluate +++ b/bin/evaluate @@ -5,6 +5,8 @@ Evaluation script. """ +from __future__ import absolute_import, division, print_function + import os import sys import argparse @@ -50,7 +52,7 @@ def main(): # print the arguments if args.verbose >= 2: - print args + print(args) if args.quiet: warnings.filterwarnings("ignore") @@ -63,7 +65,7 @@ def main(): ann_files = search_files(args.ann_dir, args.ann_suffix) # quit if no files are found if len(ann_files) == 0: - print "no files to evaluate. exiting." + print("no files to evaluate. exiting.") exit() # list to collect the individual evaluation objects diff --git a/madmom/__init__.py b/madmom/__init__.py index 97b0c320a..0f0e68513 100644 --- a/madmom/__init__.py +++ b/madmom/__init__.py @@ -11,6 +11,8 @@ """ +from __future__ import absolute_import, division, print_function + import os import pkg_resources diff --git a/madmom/audio/__init__.py b/madmom/audio/__init__.py index 27d8f874b..ccd3919da 100644 --- a/madmom/audio/__init__.py +++ b/madmom/audio/__init__.py @@ -6,5 +6,7 @@ """ +from __future__ import absolute_import, division, print_function + # import the submodules from . import signal, ffmpeg, filters, comb_filters, stft, spectrogram diff --git a/madmom/audio/cepstrogram.py b/madmom/audio/cepstrogram.py index 1ce456d5b..146d67f21 100644 --- a/madmom/audio/cepstrogram.py +++ b/madmom/audio/cepstrogram.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains all cepstrogram related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np from scipy.fftpack import dct diff --git a/madmom/audio/chroma.py b/madmom/audio/chroma.py index f46c07c89..0643c5a38 100755 --- a/madmom/audio/chroma.py +++ b/madmom/audio/chroma.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains chroma related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np from madmom.audio.spectrogram import Spectrogram, FilteredSpectrogram diff --git a/madmom/audio/comb_filters.pyx b/madmom/audio/comb_filters.pyx index eb0cd8345..43c3fe7ad 100644 --- a/madmom/audio/comb_filters.pyx +++ b/madmom/audio/comb_filters.pyx @@ -1,10 +1,11 @@ # encoding: utf-8 - """ This file contains comb-filter and comb-filterbank functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np cimport cython diff --git a/madmom/audio/ffmpeg.py b/madmom/audio/ffmpeg.py index 0ddc78866..6208ee514 100644 --- a/madmom/audio/ffmpeg.py +++ b/madmom/audio/ffmpeg.py @@ -1,12 +1,13 @@ # encoding: utf-8 # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains audio handling via ffmpeg functionality. """ +from __future__ import absolute_import, division, print_function + import tempfile import subprocess import os @@ -90,7 +91,7 @@ def decode_to_memory(infile, fmt='f32le', sample_rate=None, num_channels=1, """ # check input file type - if isinstance(infile, file): + if not isinstance(infile, str): raise ValueError("only file names are supported as 'infile', not %s." % infile) # assemble ffmpeg call @@ -179,7 +180,7 @@ def _assemble_ffmpeg_call(infile, output, fmt='f32le', sample_rate=None, raise RuntimeError('avconv has a bug, which results in wrong audio ' 'slices! Decode the audio files to .wav first or ' 'use ffmpeg.') - if isinstance(infile, unicode): + if isinstance(infile, str): infile = infile.encode(sys.getfilesystemencoding()) else: infile = str(infile) @@ -211,7 +212,7 @@ def get_file_info(infile, cmd='ffprobe'): """ # check input file type - if isinstance(infile, file): + if not isinstance(infile, str): raise ValueError("only file names are supported as 'infile', not %s." % infile) # init dictionary @@ -221,9 +222,9 @@ def get_file_info(infile, cmd='ffprobe'): infile]) # parse information for line in output.split(): - if line.startswith('channels='): + if line.startswith(b'channels='): info['num_channels'] = int(line[len('channels='):]) - if line.startswith('sample_rate='): + if line.startswith(b'sample_rate='): info['sample_rate'] = float(line[len('sample_rate='):]) # return the dictionary return info diff --git a/madmom/audio/filters.py b/madmom/audio/filters.py index e1ac815b7..81d70c38f 100644 --- a/madmom/audio/filters.py +++ b/madmom/audio/filters.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains filter and filterbank related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np @@ -462,8 +463,8 @@ def band_bins(cls, bins, overlap=True): # create non-overlapping filters if not overlap: # re-arrange the start and stop positions - start = int(round((center + start) / 2.)) - stop = int(round((center + stop) / 2.)) + start = int(np.floor((center + start) / 2.)) + stop = int(np.ceil((center + stop) / 2.)) # consistently handle too-small filters if stop - start < 2: center = start diff --git a/madmom/audio/hpss.py b/madmom/audio/hpss.py index d5052869d..f16238a7c 100644 --- a/madmom/audio/hpss.py +++ b/madmom/audio/hpss.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains all harmonic/percussive source separation functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np from madmom.processors import Processor diff --git a/madmom/audio/signal.py b/madmom/audio/signal.py index 1e6690deb..aeb8ab297 100644 --- a/madmom/audio/signal.py +++ b/madmom/audio/signal.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains basic signal processing functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np from madmom.processors import Processor @@ -270,7 +271,7 @@ def load_audio_file(filename, sample_rate=None, num_channels=None, start=None, from .ffmpeg import load_ffmpeg_file # determine the name of the file if it is a file handle - if isinstance(filename, file): + if not isinstance(filename, str): # close the file handle if it is open filename.close() # use the file name diff --git a/madmom/audio/spectrogram.py b/madmom/audio/spectrogram.py index c95b106a7..754aa7653 100644 --- a/madmom/audio/spectrogram.py +++ b/madmom/audio/spectrogram.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains spectrogram related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np from madmom.processors import Processor, SequentialProcessor, ParallelProcessor @@ -873,27 +874,28 @@ def __new__(cls, spectrogram, diff_ratio=DIFF_RATIO, spectrogram = Spectrogram(spectrogram, **kwargs) # calculate the number of diff frames to use - if not diff_frames: + if diff_frames is None: # calculate the number of diff_frames on basis of the diff_ratio # get the first sample with a higher magnitude than given ratio window = spectrogram.stft.window - sample = np.argmax(window > diff_ratio * max(window)) + sample = np.argmax(window > float(diff_ratio) * max(window)) diff_samples = len(spectrogram.stft.window) / 2 - sample # convert to frames hop_size = spectrogram.stft.frames.hop_size - diff_frames = int(round(diff_samples / hop_size)) - # always set the minimum to 1 - if diff_frames < 1: - diff_frames = 1 + diff_frames = round(diff_samples / hop_size) + + # use at least 1 frame + diff_frames = max(1, int(diff_frames)) # init matrix diff = np.zeros_like(spectrogram) # apply a maximum filter to diff_spec if needed - if diff_max_bins > 1: + if diff_max_bins is not None and diff_max_bins > 1: from scipy.ndimage.filters import maximum_filter # widen the spectrogram in frequency dimension - diff_spec = maximum_filter(spectrogram, size=[1, diff_max_bins]) + size = [1, int(diff_max_bins)] + diff_spec = maximum_filter(spectrogram, size=size) else: diff_spec = spectrogram # calculate the diff diff --git a/madmom/audio/stft.py b/madmom/audio/stft.py index fee7b25d8..d87cf705b 100644 --- a/madmom/audio/stft.py +++ b/madmom/audio/stft.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains Short-Time Fourier Transform (STFT) related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np import scipy.fftpack as fft diff --git a/madmom/evaluation/__init__.py b/madmom/evaluation/__init__.py index f4e7510df..690c3ce8a 100644 --- a/madmom/evaluation/__init__.py +++ b/madmom/evaluation/__init__.py @@ -2,13 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ Evaluation package. """ -import abc +from __future__ import absolute_import, division, print_function + import numpy as np @@ -142,9 +142,12 @@ def calc_relative_errors(detections, annotations, matches=None): # abstract evaluation base class -class EvaluationABC(object): +class EvaluationMixin(object): """ - Evaluation abstract base class. + Evaluation mixin class. + + This class has a `name` attribute which is used for display purposes and + defaults to 'None'. `METRIC_NAMES` is a list of tuples, containing the attribute's name and the corresponding label, e.g.: @@ -156,10 +159,12 @@ class EvaluationABC(object): ] The attributes defined in `METRIC_NAMES` will be provided as an ordered - dictionary as the `metrics` attribute of the + dictionary as the `metrics` property unless the subclass overwrites the + property. + + `FLOAT_FORMAT` is used to format floats. """ - __metaclass__ = abc.ABCMeta name = None METRIC_NAMES = [] @@ -168,7 +173,6 @@ class EvaluationABC(object): @property def metrics(self): """Metrics as a dictionary.""" - # TODO: use an ordered dict? from collections import OrderedDict metrics = OrderedDict() # metrics = {} @@ -176,10 +180,9 @@ def metrics(self): metrics[metric] = getattr(self, metric) return metrics - @abc.abstractmethod def __len__(self): """Length of the evaluation object.""" - return + raise NotImplementedError('must be implemented by subclass.') def tostring(self, **kwargs): """ @@ -189,8 +192,8 @@ def tostring(self, **kwargs): :return: evaluation metrics formatted as a human readable string Note: This is a fallback method formatting the 'metrics' dictionary in - a human readable way. Classes implementing this abstract base - class should provide a better suitable method. + a human readable way. Classes inheriting from this mixin class + should provide a method better suitable. """ # pylint: disable=unused-argument @@ -200,7 +203,7 @@ class should provide a better suitable method. # evaluation classes -class SimpleEvaluation(EvaluationABC): +class SimpleEvaluation(EvaluationMixin): """ Simple Precision, Recall, F-measure and Accuracy evaluation based on the numbers of true/false positive/negative detections. @@ -679,7 +682,7 @@ def tocsv(eval_objects, metric_names=None, float_format='{:.3f}', **kwargs): if metric_names is None: # get the evaluation metrics from the first evaluation object metric_names = eval_objects[0].METRIC_NAMES - metric_names, metric_labels = zip(*metric_names) + metric_names, metric_labels = list(zip(*metric_names)) # add header lines = ['Name,' + ','.join(metric_labels)] # TODO: use e.metrics dict? @@ -712,7 +715,7 @@ def totex(eval_objects, metric_names=None, float_format='{:.3f}', **kwargs): if metric_names is None: # get the evaluation metrics from the first evaluation object metric_names = eval_objects[0].METRIC_NAMES - metric_names, metric_labels = zip(*metric_names) + metric_names, metric_labels = list(zip(*metric_names)) # add header lines = ['Name & ' + ' & '.join(metric_labels) + '\\\\'] # TODO: use e.metrics dict diff --git a/madmom/evaluation/alignment.py b/madmom/evaluation/alignment.py index 588ed0b8a..ca690895b 100755 --- a/madmom/evaluation/alignment.py +++ b/madmom/evaluation/alignment.py @@ -2,15 +2,17 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains global alignment evaluation functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np -from . import EvaluationABC +from . import EvaluationMixin + # constants for the data format _TIME = 0 @@ -181,7 +183,7 @@ def compute_metrics(event_alignment, ground_truth, window, err_hist_bins): # convert possibly masked values to NaN. A masked value can occur when # computing the mean or stddev of values that are all masked - for k, v in results.iteritems(): + for k, v in results.items(): if v is np.ma.masked_singleton: results[k] = np.NaN @@ -201,7 +203,7 @@ def compute_metrics(event_alignment, ground_truth, window, err_hist_bins): return results -class AlignmentEvaluation(EvaluationABC): +class AlignmentEvaluation(EvaluationMixin): """ Alignment evaluation class for beat-level alignments. Beat-level aligners output beat positions for points in time, rather than computing a time step @@ -349,7 +351,7 @@ def _combine_metrics(eval_objects, piecewise): else: total_weight = sum(len(e) for e in eval_objects) for e in eval_objects: - for name, val in e.metrics.iteritems(): + for name, val in e.metrics.items(): if isinstance(val, np.ndarray) or not np.isnan(val): weight = 1.0 if piecewise else float(len(e)) metrics[name] = \ diff --git a/madmom/evaluation/beats.py b/madmom/evaluation/beats.py index 804f7623c..33e647b70 100755 --- a/madmom/evaluation/beats.py +++ b/madmom/evaluation/beats.py @@ -2,7 +2,6 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This software serves as a Python implementation of the beat evaluation toolkit, which can be downloaded from: @@ -24,6 +23,8 @@ """ +from __future__ import absolute_import, division, print_function + import warnings import numpy as np @@ -246,7 +247,9 @@ def find_longest_continuous_segment(sequence_indices): # lengths of the individual segments segment_lengths = np.diff(boundaries) # return the length and start position of the longest continuous segment - return np.max(segment_lengths), boundaries[np.argmax(segment_lengths)] + length = int(np.max(segment_lengths)) + start_pos = int(boundaries[np.argmax(segment_lengths)]) + return length, start_pos def calc_relative_errors(detections, annotations, matches=None): @@ -331,7 +334,7 @@ def pscore(detections, annotations, tolerance=PSCORE_TOLERANCE): "P-score.") # tolerance must be greater than 0 - if tolerance <= 0: + if float(tolerance) <= 0: raise ValueError("Tolerance must be greater than 0.") # make sure the annotations and detections have a float dtype @@ -372,7 +375,7 @@ def cemgil(detections, annotations, sigma=CEMGIL_SIGMA): return 0. # sigma must be greater than 0 - if sigma <= 0: + if float(sigma) <= 0: raise ValueError("Sigma must be greater than 0.") # make sure the annotations and detections have a float dtype @@ -426,7 +429,7 @@ def goto(detections, annotations, threshold=GOTO_THRESHOLD, sigma=GOTO_SIGMA, "score.") # threshold, sigma and mu must be greater than 0 - if threshold < 0 or sigma < 0 or mu < 0: + if float(threshold) <= 0 or float(sigma) <= 0 or float(mu) <= 0: raise ValueError("Threshold, sigma and mu must be positive.") # make sure the annotations and detections have a float dtype @@ -509,7 +512,7 @@ def cml(detections, annotations, phase_tolerance=CONTINUITY_PHASE_TOLERANCE, "continuity scores, %s given." % detections) # tolerances must be greater than 0 - if tempo_tolerance <= 0 or phase_tolerance <= 0: + if float(tempo_tolerance) <= 0 or float(phase_tolerance) <= 0: raise ValueError("Tempo and phase tolerances must be greater than 0") # make sure the annotations and detections have a float dtype diff --git a/madmom/evaluation/notes.py b/madmom/evaluation/notes.py index 4f75f3ffb..ef4f58497 100755 --- a/madmom/evaluation/notes.py +++ b/madmom/evaluation/notes.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains note evaluation functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np import warnings diff --git a/madmom/evaluation/onsets.py b/madmom/evaluation/onsets.py index 3cfc43257..8fbe1e750 100755 --- a/madmom/evaluation/onsets.py +++ b/madmom/evaluation/onsets.py @@ -2,7 +2,6 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains onset evaluation functionality. @@ -15,18 +14,14 @@ """ -import numpy as np +from __future__ import absolute_import, division, print_function +import numpy as np from . import evaluation_io, Evaluation, SumEvaluation, MeanEvaluation from ..utils import suppress_warnings, combine_events -# default onset evaluation values -WINDOW = 0.025 -COMBINE = 0.03 - - @suppress_warnings def load_onsets(values): """ @@ -57,6 +52,11 @@ def load_onsets(values): return values +# default onset evaluation values +WINDOW = 0.025 +COMBINE = 0.03 + + # onset evaluation function def onset_evaluation(detections, annotations, window=WINDOW): """ @@ -104,6 +104,10 @@ def onset_evaluation(detections, annotations, window=WINDOW): # all annotations are FN return tp, fp, tn, annotations, errors + # window must be greater than 0 + if float(window) <= 0: + raise ValueError('window must be greater than 0') + # sort the detections and annotations det = np.sort(detections) ann = np.sort(annotations) diff --git a/madmom/evaluation/tempo.py b/madmom/evaluation/tempo.py index b513a7771..f8dd2cb49 100755 --- a/madmom/evaluation/tempo.py +++ b/madmom/evaluation/tempo.py @@ -2,16 +2,17 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains tempo evaluation functionality. """ +from __future__ import absolute_import, division, print_function + import warnings import numpy as np -from . import EvaluationABC, MeanEvaluation, evaluation_io +from . import EvaluationMixin, MeanEvaluation, evaluation_io def load_tempo(values, split_value=1., sort=False, norm_strengths=False, @@ -122,7 +123,7 @@ def tempo_evaluation(detections, annotations, tolerance=TOLERANCE): # worst result return 0., False, False # tolerance must be greater than 0 - if not tolerance > 0: + if float(tolerance) <= 0: raise ValueError('tolerance must be greater than 0') # make sure the annotations and detections have a float dtype detections = np.asarray(detections, dtype=np.float) @@ -158,7 +159,7 @@ def tempo_evaluation(detections, annotations, tolerance=TOLERANCE): # basic tempo evaluation -class TempoEvaluation(EvaluationABC): +class TempoEvaluation(EvaluationMixin): """ Tempo evaluation class. diff --git a/madmom/features/__init__.py b/madmom/features/__init__.py index 30539373a..b6fb6b166 100644 --- a/madmom/features/__init__.py +++ b/madmom/features/__init__.py @@ -2,13 +2,14 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This package includes higher level features. Your definition of "higher" may vary, but all "lower" level features can be found the `audio` package. """ +from __future__ import absolute_import, division, print_function + import numpy as np from madmom.processors import Processor @@ -46,12 +47,14 @@ def __init__(self, data, fps=None, sep=None, dtype=np.float32): # the initialisation is done in __new__() and __array_finalize__() def __new__(cls, data, fps=None, sep=None, dtype=np.float32): + import io + # check the type of the given data if isinstance(data, np.ndarray): # cast to Activations obj = np.asarray(data, dtype=dtype).view(cls) obj.fps = fps - elif isinstance(data, (basestring, file)): + elif isinstance(data, (str, io.IOBase)): # read from file or file handle obj = cls.load(data, fps, sep) else: diff --git a/madmom/features/beats.py b/madmom/features/beats.py index 187084b78..f97673da2 100755 --- a/madmom/features/beats.py +++ b/madmom/features/beats.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains all beat tracking related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np from madmom.processors import Processor @@ -481,10 +482,9 @@ def process(self, activations): # since the cython code uses memory views, we need to make sure that # the activations are C-contiguous and of C-type float (np.float32) contiguous_act = np.ascontiguousarray(activations, dtype=np.float32) - results = self.map(_process_crf, - it.izip(it.repeat(contiguous_act), - possible_intervals, - it.repeat(self.interval_sigma))) + results = list(self.map( + _process_crf, zip(it.repeat(contiguous_act), possible_intervals, + it.repeat(self.interval_sigma)))) # normalize their probabilities normalized_seq_probabilities = np.array([r[1] / r[0].shape[0] @@ -846,6 +846,8 @@ def __init__(self, pattern_files, min_bpm=MIN_BPM, max_bpm=MAX_BPM, # pylint: disable=unused-argument # pylint: disable=no-name-in-module + import pickle + from madmom.ml.hmm import HiddenMarkovModel as Hmm from .beats_hmm import (DownBeatTrackingStateSpace as St, DownBeatTrackingTransitionModel as Tm, @@ -862,12 +864,19 @@ def __init__(self, pattern_files, min_bpm=MIN_BPM, max_bpm=MAX_BPM, raise ValueError('`min_bpm`, `max_bpm`, `num_tempo_states` and ' '`transition_lambda` must have the same length ' 'as number of patterns.') + # load the patterns - import cPickle patterns = [] for pattern_file in pattern_files: - with open(pattern_file, 'r') as f: - patterns.append(cPickle.load(f)) + with open(pattern_file, 'rb') as f: + # Python 2 and 3 behave differently + # TODO: use some other format to save the GMMs (.npz, .hdf5) + try: + # Python 3 + patterns.append(pickle.load(f, encoding='latin1')) + except TypeError: + # Python 2 doesn't have/need the encoding + patterns.append(pickle.load(f)) if len(patterns) == 0: raise ValueError('at least one rhythmical pattern must be given.') # extract the GMMs and number of beats @@ -920,7 +929,7 @@ def process(self, activations): if self.downbeats: return beats[beat_numbers == 1] else: - return zip(beats, beat_numbers) + return np.vstack(zip(beats, beat_numbers)) @classmethod def add_arguments(cls, parser, pattern_files=None, min_bpm=MIN_BPM, diff --git a/madmom/features/beats_crf.pyx b/madmom/features/beats_crf.pyx index 97e773301..df204e6b1 100644 --- a/madmom/features/beats_crf.pyx +++ b/madmom/features/beats_crf.pyx @@ -1,5 +1,4 @@ # encoding: utf-8 - """ This file contains the speed crucial Viterbi functionality for the CRFBeatDetector plus some functions computing the distributions and @@ -7,6 +6,8 @@ normalisation factors... """ +from __future__ import absolute_import, division, print_function + import numpy as np cimport numpy as np cimport cython diff --git a/madmom/features/beats_hmm.pyx b/madmom/features/beats_hmm.pyx index d66ff96f8..e0e2397e6 100644 --- a/madmom/features/beats_hmm.pyx +++ b/madmom/features/beats_hmm.pyx @@ -1,5 +1,4 @@ # encoding: utf-8 - """ This file contains HMM state space, transition and observation models used for beat and downbeat tracking. @@ -15,6 +14,8 @@ If you want to change this module and use it interactively, use pyximport. """ +from __future__ import absolute_import, division, print_function + import numpy as np cimport numpy as np cimport cython diff --git a/madmom/features/notes.py b/madmom/features/notes.py index bdca1f76f..b8d6cda76 100755 --- a/madmom/features/notes.py +++ b/madmom/features/notes.py @@ -2,15 +2,16 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains note transcription related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np -from madmom.utils import suppress_warnings, open +from madmom.utils import suppress_warnings @suppress_warnings @@ -22,24 +23,38 @@ def load_notes(filename): :return: numpy array with notes """ - with open(filename, 'rb') as f: - return np.loadtxt(f) + return np.loadtxt(filename) -def write_notes(notes, filename, sep='\t'): +def write_notes(notes, filename, sep='\t', + fmt=list(('%.3f', '%d', '%.3f', '%d'))): """ Write the detected notes to a file. - :param notes: list with notes - :param filename: output file name or file handle - :param sep: separator for the fields [default='\t'] + :param notes: 2D numpy array with notes + :param filename: output file name or file handle + :param sep: separator for the fields [default='\t'] + :param fmt: format of the fields (i.e. columns, see below) + + + Note: The `notes` must be a 2D numpy array with the individual notes as + rows, and the columns defined as: + + 'note_time' 'MIDI_note' ['duration' ['MIDI_velocity']] + + whith the duration and velocity being optional. """ - # write the notes to the output - if filename is not None: - with open(filename, 'wb') as f: - for note in notes: - f.write(sep.join([str(x) for x in note]) + '\n') + from madmom.utils import write_events + # truncate to the number of colums given + if notes.ndim == 1: + fmt = '%f' + elif notes.ndim == 2: + fmt = sep.join(fmt[:notes.shape[1]]) + else: + raise ValueError('unknown format for notes') + # write the notes + write_events(notes, filename, fmt=fmt) # also return them return notes @@ -48,7 +63,7 @@ def write_midi(notes, filename, note_length=0.6, note_velocity=100): """ Write the notes to a MIDI file. - :param notes: detected notes + :param notes: 2D numpy array with notes :param filename: output filename :param note_velocity: default velocity of the notes :param note_length: default length of the notes @@ -67,9 +82,9 @@ def write_midi(notes, filename, note_length=0.6, note_velocity=100): return process_notes(notes, filename) -def write_frequencies(notes, filename, note_length=0.6): +def write_mirex_format(notes, filename, note_length=0.6): """ - Write the frequencies of the notes to file (i.e. MIREX format). + Write the frequencies of the notes to file (in MIREX format). :param notes: detected notes :param filename: output filename @@ -79,18 +94,15 @@ def write_frequencies(notes, filename, note_length=0.6): """ from madmom.audio.filters import midi2hz # MIREX format: onset \t offset \t frequency - with open(filename, 'wb') as f: - for note in notes: - onset, midi_note = note - offset = onset + note_length - frequency = midi2hz(midi_note) - f.write('%.2f\t%.2f\t%.2f\n' % (onset, offset, frequency)) + notes = np.vstack((notes[:, 0], notes[:, 0] + note_length, + midi2hz(notes[:, 1]))).T + write_notes(notes, filename, fmt=list(('%.3f', '%.3f', '%.1f', ))) return notes def note_reshaper(notes): """ - Reshapes the activations produced by a RNN to ave the right shape. + Reshapes the activations produced by a RNN to have the right shape. :param notes: numpy array with note activations :return: reshaped array to represent the 88 MIDI notes diff --git a/madmom/features/onsets.py b/madmom/features/onsets.py index 0820ba9b3..b5377356a 100755 --- a/madmom/features/onsets.py +++ b/madmom/features/onsets.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains onset detection related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np from scipy.ndimage import uniform_filter from scipy.ndimage.filters import maximum_filter @@ -126,6 +127,7 @@ def spectral_flux(spectrogram, diff_frames=None): Spectral Flux. :param spectrogram: Spectrogram instance + :param diff_frames: number of frames to calculate the diff to [int] :return: spectral flux onset detection function "Computer Modeling of Sound for Transformation and Synthesis of Musical @@ -668,11 +670,12 @@ def process(self, activations): np.diff(note_onsets) > self.combine] # zip the onsets with the MIDI note number and add them to # the list of detections - detections.extend(zip(combined_note_onsets, - [note] * len(combined_note_onsets))) + notes = zip(combined_note_onsets, + [note] * len(combined_note_onsets)) + detections.extend(list(notes)) else: # just zip all detected notes - detections = zip(onsets, midi_notes) + detections = list(zip(onsets, midi_notes)) # sort the detections and save as numpy array detections = np.asarray(sorted(detections)) else: diff --git a/madmom/features/tempo.py b/madmom/features/tempo.py index 723575c7e..94b2aea72 100755 --- a/madmom/features/tempo.py +++ b/madmom/features/tempo.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains tempo related functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np from scipy.signal import argrelmax @@ -55,7 +56,7 @@ def interval_histogram_acf(activations, min_tau=1, max_tau=None): if max_tau is None: max_tau = len(activations) - min_tau # test all possible delays - taus = range(min_tau, max_tau + 1) + taus = list(range(min_tau, max_tau + 1)) bins = [] # Note: this is faster than: # corr = np.correlate(activations, activations, mode='full') @@ -170,7 +171,7 @@ def detect_tempo(histogram, fps): strengths = bins[sorted_peaks] strengths /= np.sum(strengths) # return the tempi and their normalized strengths - ret = np.asarray(zip(tempi[sorted_peaks], strengths)) + ret = np.asarray(list(zip(tempi[sorted_peaks], strengths))) # return the tempi return np.atleast_2d(ret) @@ -352,7 +353,6 @@ def write_tempo(tempi, filename, mirex=False): :return: the most dominant tempi and the relative strength """ - from madmom.utils import open # default values t1, t2, strength = 0., 0., 1. # only one tempo was detected @@ -371,8 +371,9 @@ def write_tempo(tempi, filename, mirex=False): # for MIREX, the lower tempo must be given first if mirex and t1 > t2: t1, t2, strength = t2, t1, 1. - strength + # format as a numpy array + out = np.array([t1, t2, strength], ndmin=2) # write to output - with open(filename, 'wb') as f: - f.write("%.2f\t%.2f\t%.2f\n" % (t1, t2, strength)) + np.savetxt(filename, out, fmt='%.2f\t%.2f\t%.2f') # also return the tempi & strength return t1, t2, strength diff --git a/madmom/ml/__init__.py b/madmom/ml/__init__.py index e6c3966d5..0e2ab7403 100644 --- a/madmom/ml/__init__.py +++ b/madmom/ml/__init__.py @@ -1,9 +1,10 @@ # encoding: utf-8 - """ Machine learning package. """ +from __future__ import absolute_import, division, print_function + # import the submodules from . import rnn, hmm, gmm diff --git a/madmom/ml/gmm.py b/madmom/ml/gmm.py index 58fbba73b..4a8ce12cb 100644 --- a/madmom/ml/gmm.py +++ b/madmom/ml/gmm.py @@ -2,7 +2,6 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains functionality needed for fitting and scoring Gaussian Mixture Models (GMMs) (needed e.g. in madmom.features.dbn). @@ -19,6 +18,8 @@ """ +from __future__ import absolute_import, division, print_function + import numpy as np from scipy import linalg diff --git a/madmom/ml/hmm.pyx b/madmom/ml/hmm.pyx index a80c2c684..0d2d0c121 100644 --- a/madmom/ml/hmm.pyx +++ b/madmom/ml/hmm.pyx @@ -1,5 +1,4 @@ # encoding: utf-8 - """ This file contains Hidden Markov Model (HMM) functionality. @@ -11,7 +10,8 @@ If you want to change this module and use it interactively, use pyximport. """ -import abc +from __future__ import absolute_import, division, print_function + import numpy as np cimport numpy as np cimport cython @@ -95,7 +95,8 @@ class TransitionModel(object): if not np.allclose(np.bincount(prev_states, weights=probabilities), 1): raise ValueError('Not a probability distribution.') # convert everything into a sparse CSR matrix - transitions = csr_matrix((probabilities, (states, prev_states))) + transitions = csr_matrix((np.array(probabilities), + (np.array(states), np.array(prev_states)))) # convert to correct types states = transitions.indices.astype(np.uint32) pointers = transitions.indptr.astype(np.uint32) @@ -143,7 +144,6 @@ class ObservationModel(object): different observation probability (log) densities. Type must be np.float. """ - __metaclass__ = abc.ABCMeta def __init__(self, pointers): """ @@ -155,7 +155,6 @@ class ObservationModel(object): self.pointers = pointers - @abc.abstractmethod def log_densities(self, observations): """ Log densities (or probabilities) of the observations for each state. @@ -167,7 +166,7 @@ class ObservationModel(object): observation log probability densities. The type must be np.float. """ - return + raise NotImplementedError('must be implemented by subclass') def densities(self, observations): """ diff --git a/madmom/ml/io.py b/madmom/ml/io.py index 37f1cffe7..2df414b06 100644 --- a/madmom/ml/io.py +++ b/madmom/ml/io.py @@ -2,7 +2,6 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains functionality needed for the conversion of from the universal .h5 format to the .npz format understood by madmom.ml.rnn. @@ -33,6 +32,8 @@ """ +from __future__ import absolute_import, division, print_function + import os import numpy as np @@ -51,16 +52,16 @@ def convert_model(infile, outfile=None, compressed=False): # read in model with h5py.File(infile, 'r') as h5: # model attributes - for attr in h5['model'].attrs.keys(): + for attr in list(h5['model'].attrs.keys()): npz['model_%s' % attr] = h5['model'].attrs[attr] # layers - for l in h5['layer'].keys(): + for l in list(h5['layer'].keys()): layer = h5['layer'][l] # each layer has some attributes - for attr in layer.attrs.keys(): + for attr in list(layer.attrs.keys()): npz['layer_%s_%s' % (l, attr)] = layer.attrs[attr] # and some data sets (i.e. different weights) - for data in layer.keys(): + for data in list(layer.keys()): npz['layer_%s_%s' % (l, data)] = layer[data].value # save the model to .npz format if outfile is None: diff --git a/madmom/ml/rnn.pxd b/madmom/ml/rnn.pxd index 33b18a130..14762225c 100644 --- a/madmom/ml/rnn.pxd +++ b/madmom/ml/rnn.pxd @@ -7,6 +7,9 @@ Note: right now, this file is just an empty augmenting file. However, it increases performance when cython is used to compile the normal rnn.py file. """ + +from __future__ import absolute_import, division, print_function + import cython import numpy as np diff --git a/madmom/ml/rnn.py b/madmom/ml/rnn.py index 4c8911cdb..3655c8bbd 100644 --- a/madmom/ml/rnn.py +++ b/madmom/ml/rnn.py @@ -3,7 +3,6 @@ # pylint: disable=invalid-name # pylint: disable=too-many-arguments # pylint: disable=too-few-public-methods - """ This file contains recurrent neural network (RNN) related functionality. @@ -20,7 +19,8 @@ """ -import abc +from __future__ import absolute_import, division, print_function + import numpy as np from ..processors import Processor, ParallelProcessor @@ -128,26 +128,7 @@ def softmax(x, out=None): # network layer classes -class Layer(object): - """ - Generic network Layer. - - """ - __metaclass__ = abc.ABCMeta - - @abc.abstractmethod - def activate(self, data): - """ - Activate the layer. - - :param data: activate with this data - :return: activations for this data - - """ - return - - -class FeedForwardLayer(Layer): +class FeedForwardLayer(object): """ Feed-forward network layer. @@ -217,7 +198,7 @@ def activate(self, data): out = np.dot(data, self.weights) out += self.bias # loop through each time step - for i in xrange(size): + for i in range(size): # add the weighted previous step if i >= 1: # np.dot(out[i - 1], self.recurrent_weights, out=tmp) @@ -229,7 +210,7 @@ def activate(self, data): return out -class BidirectionalLayer(Layer): +class BidirectionalLayer(object): """ Bidirectional network layer. @@ -339,7 +320,7 @@ def __init__(self, weights, bias, recurrent_weights, peephole_weights, self.peephole_weights = peephole_weights.flatten() -class LSTMLayer(Layer): +class LSTMLayer(object): """ Recurrent network layer with Long Short-Term Memory units. @@ -388,7 +369,7 @@ def activate(self, data): # state (of the previous time step) state_ = np.zeros(self.cell.bias.size, dtype=NN_DTYPE) # process the input data - for i in xrange(size): + for i in range(size): # cache input data data_ = data[i] # input gate: @@ -444,8 +425,8 @@ def load(cls, filename): data = np.load(filename) # determine the number of layers (i.e. all "layer_%d_" occurrences) - num_layers = max([int(re.findall(r'layer_(\d+)_', k)[0]) - for k in data.keys() if k.startswith('layer_')]) + num_layers = max([int(re.findall(r'layer_(\d+)_', k)[0]) for + k in list(data.keys()) if k.startswith('layer_')]) # function for layer creation with the given parameters def create_layer(params): @@ -459,24 +440,26 @@ def create_layer(params): # first check if we need to create a bidirectional layer bwd_layer = None - if '%s_type' % REVERSE in params.keys(): + if '%s_type' % REVERSE in list(params.keys()): # pop the parameters needed for the reverse (backward) layer - bwd_type = str(params.pop('%s_type' % REVERSE)) - bwd_transfer_fn = str(params.pop('%s_transfer_fn' % REVERSE)) + bwd_type = bytes(params.pop('%s_type' % REVERSE)) + bwd_transfer_fn = bytes(params.pop('%s_transfer_fn' % + REVERSE)) bwd_params = dict((k.split('_', 1)[1], params.pop(k)) - for k in params.keys() if + for k in list(params.keys()) if k.startswith('%s_' % REVERSE)) - bwd_params['transfer_fn'] = globals()[bwd_transfer_fn] + bwd_params['transfer_fn'] = globals()[bwd_transfer_fn.decode()] # construct the layer - bwd_layer = globals()['%sLayer' % bwd_type](**bwd_params) + bwd_layer = globals()['%sLayer' % bwd_type.decode()]( + **bwd_params) # pop the parameters needed for the normal (forward) layer - fwd_type = str(params.pop('type')) - fwd_transfer_fn = str(params.pop('transfer_fn')) + fwd_type = bytes(params.pop('type')) + fwd_transfer_fn = bytes(params.pop('transfer_fn')) fwd_params = params - fwd_params['transfer_fn'] = globals()[fwd_transfer_fn] + fwd_params['transfer_fn'] = globals()[fwd_transfer_fn.decode()] # construct the layer - fwd_layer = globals()['%sLayer' % fwd_type](**fwd_params) + fwd_layer = globals()['%sLayer' % fwd_type.decode()](**fwd_params) # return the (bidirectional) layer if bwd_layer is not None: @@ -489,10 +472,10 @@ def create_layer(params): # loop over all layers layers = [] - for i in xrange(num_layers + 1): + for i in range(num_layers + 1): # get all parameters for that layer layer_params = dict((k.split('_', 2)[2], data[k]) - for k in data.keys() if + for k in list(data.keys()) if k.startswith('layer_%d' % i)) # create a layer from these parameters layer = create_layer(layer_params) @@ -542,6 +525,8 @@ def __init__(self, nn_files, num_threads=None, **kwargs): """ # pylint: disable=unused-argument + if len(nn_files) == 0: + raise ValueError('at least one RNN model must be given.') nn_models = [] for nn_file in nn_files: nn_models.append(RecurrentNeuralNetwork.load(nn_file)) diff --git a/madmom/ml/rnnlib.py b/madmom/ml/rnnlib.py index 52da37f32..69818e24f 100644 --- a/madmom/ml/rnnlib.py +++ b/madmom/ml/rnnlib.py @@ -3,7 +3,6 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains all functionality needed for interaction with RNNLIB. @@ -19,6 +18,8 @@ """ +from __future__ import absolute_import, division, print_function + import os.path import re import shutil @@ -694,7 +695,7 @@ def run(self): # TODO: make regression task work as well # until then just output the log with open("%s/log" % tmp_work_path, 'rb') as log: - print log.read() + print(log.read()) raise RuntimeError("Error while RNNLIB processing.") finally: # put a tuple with nc file, nn file and activations @@ -1247,7 +1248,7 @@ def test_save_files(files, out_dir=None, file_set='test', threads=THREADS, # cast the activations to an Activations instance (we only passed # one .nc file, so it's the first activation in the returned list) if verbose: - print act_file + print(act_file) Activations(activations[0], fps=fps).save(act_file) @@ -1315,7 +1316,7 @@ def create_config(files, config, out_dir, num_folds=8, randomize=False, try: folds[fold].append(str(nc_file[0])) except IndexError: - print "can't find .nc file for file: %s" % line + print("can't find .nc file for file: %s" % line) else: # use a standard splits for fold in range(num_folds): @@ -1410,9 +1411,9 @@ def create_nc_files(files, annotations, out_dir, norm=False, att=0, from the `diff_ratio`) :param diff_max_bins: apply a maximum filter with this width (in bins in frequency dimension) before calculating the diff; - (e.g. for the difference spectrogram of the SuperFlux - algorithm 3 `max_bins` are used together with a 24 - band logarithmic filterbank) + (e.g. for the difference spectrogram of the + SuperFlux algorithm 3 `max_bins` are used together + with a 24 band logarithmic filterbank) :param positive_diffs: keep only the positive differences, i.e. set all diff values < 0 to 0 @@ -1471,11 +1472,11 @@ def create_nc_files(files, annotations, out_dir, norm=False, att=0, wav_files = match_file(f, files, annotation, '.flac') # no wav file found if len(wav_files) < 1: - print "can't find audio file for %s" % f + print("can't find audio file for %s" % f) exit() # print file if verbose: - print f + print(f) # create the data for the .nc file from the .wav file nc_data = processor.process(wav_files[0]) diff --git a/madmom/models b/madmom/models index 65b8e88d2..17ad0ba2a 160000 --- a/madmom/models +++ b/madmom/models @@ -1 +1 @@ -Subproject commit 65b8e88d2ffe7d4c955e596f60422458120447b8 +Subproject commit 17ad0ba2a048dffc9ed94156aec8a885ec21f145 diff --git a/madmom/processors.py b/madmom/processors.py index a91045da2..7a64685cd 100644 --- a/madmom/processors.py +++ b/madmom/processors.py @@ -2,7 +2,6 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains all processor related functionality. @@ -12,9 +11,10 @@ """ +from __future__ import absolute_import, division, print_function + import os import sys -import abc import argparse import multiprocessing as mp @@ -26,7 +26,6 @@ class Processor(object): Abstract base class for processing data. """ - __metaclass__ = abc.ABCMeta @classmethod def load(cls, infile): @@ -41,13 +40,13 @@ def load(cls, infile): :return: Processor instance """ - import cPickle + import pickle # close the open file if needed and use its name - if not isinstance(infile, basestring): + if not isinstance(infile, str): infile.close() infile = infile.name # instantiate a new Processor and return it - return cPickle.load(open(infile, 'rb')) + return pickle.load(open(infile, 'rb')) def dump(self, outfile): """ @@ -60,19 +59,18 @@ def dump(self, outfile): :param outfile: output file name or file handle """ - import cPickle + import pickle import warnings warnings.warn('The resulting file is considered a model file, please ' 'see the LICENSE file for details!') # close the open file if needed and use its name - if not isinstance(outfile, basestring): + if not isinstance(outfile, str): outfile.close() outfile = outfile.name # dump the Processor to the given file - cPickle.dump(self, open(outfile, 'wb'), - protocol=cPickle.HIGHEST_PROTOCOL) + pickle.dump(self, open(outfile, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) - @abc.abstractmethod def process(self, data): """ Process the data. @@ -84,7 +82,7 @@ def process(self, data): :return: processed data """ - return data + raise NotImplementedError('must be implemented by subclass.') def __call__(self, *args): """This magic method makes a Processor instance callable.""" @@ -97,7 +95,6 @@ class OutputProcessor(Processor): """ - @abc.abstractmethod def process(self, data, output): """ Processes the data and feeds it to output. @@ -109,8 +106,7 @@ def process(self, data, output): """ # pylint: disable=arguments-differ - # also return the data - return data + raise NotImplementedError('must be implemented by subclass.') # functions for processing file(s) with a Processor @@ -282,7 +278,7 @@ def process(self, data): """ import itertools as it # process data in parallel and return a list with processed data - return self.map(_process, it.izip(self.processors, it.repeat(data))) + return list(self.map(_process, zip(self.processors, it.repeat(data)))) @classmethod def add_arguments(cls, parser, num_threads=NUM_THREADS): @@ -527,6 +523,11 @@ def io_arguments(parser, output_suffix='.txt', pickle=True): :param pickle: add 'pickle' subparser [bool] """ + # default output + try: + output = sys.stdout.buffer + except AttributeError: + output = sys.stdout # add general options parser.add_argument('-v', dest='verbose', action='count', help='increase verbosity level') @@ -538,17 +539,17 @@ def io_arguments(parser, output_suffix='.txt', pickle=True): sp.set_defaults(func=pickle_processor) # Note: requiring '-o' is a simple safety measure to not overwrite # existing audio files after using the processor in 'batch' mode - sp.add_argument('-o', dest='outfile', type=argparse.FileType('w'), - help='file to pickle the processor to') + sp.add_argument('-o', dest='outfile', type=argparse.FileType('wb'), + default=output, help='output file [default: STDOUT]') # single file processing options sp = sub_parsers.add_parser('single', help='single file processing') sp.set_defaults(func=process_single) - sp.add_argument('infile', type=argparse.FileType('r'), + sp.add_argument('infile', type=argparse.FileType('rb'), help='input audio file') # Note: requiring '-o' is a simple safety measure to not overwrite existing # audio files after using the processor in 'batch' mode - sp.add_argument('-o', dest='outfile', type=argparse.FileType('w'), - default=sys.stdout, help='output file [default: STDOUT]') + sp.add_argument('-o', dest='outfile', type=argparse.FileType('wb'), + default=output, help='output file [default: STDOUT]') sp.add_argument('-j', dest='num_threads', type=int, default=mp.cpu_count(), help='number of parallel threads [default=%(default)s]') # batch file processing options diff --git a/madmom/utils/__init__.py b/madmom/utils/__init__.py index affec1016..07604708d 100644 --- a/madmom/utils/__init__.py +++ b/madmom/utils/__init__.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ Utility package. """ +from __future__ import absolute_import, division, print_function + import argparse import contextlib import numpy as np @@ -40,33 +41,6 @@ def decorator_function(*args, **kwargs): return decorator_function -# overwrite the built-in open() to transparently apply some magic file handling -@contextlib.contextmanager -def open(filename, mode='r'): - """ - Context manager which yields an open file or handle with the given mode - and closes it if needed afterwards. - - :param filename: file name or open file handle - :param mode: mode in which to open the file - :return: an open file handle - - """ - import __builtin__ - # check if we need to open the file - if isinstance(filename, basestring): - f = fid = __builtin__.open(filename, mode) - else: - f = filename - fid = None - # TODO: include automatic (un-)zipping here? - # yield an open file handle - yield f - # close the file if needed - if fid: - fid.close() - - # file handling routines def search_files(path, suffix=None): """ @@ -178,25 +152,24 @@ def load_events(filename): ignored (i.e. only the first column is returned). """ - with open(filename, 'rb') as f: - # read in the events, one per line - events = np.loadtxt(f, ndmin=2) - # 1st column is the event's time, the rest is ignored - return events[:, 0] + # read in the events, one per line + events = np.loadtxt(filename, ndmin=2) + # 1st column is the event's time, the rest is ignored + return events[:, 0] -def write_events(events, filename): +def write_events(events, filename, fmt='%.3f'): """ Write a list of events to a text file, one floating point number per line. - :param events: list of events [seconds] - :param filename: output file name or file handle + :param events: events [seconds, list or numpy array] + :param filename: output file name or open file handle + :param fmt: format to be written + :return: return the events """ # write the events to the output - if filename is not None: - with open(filename, 'wb') as f: - f.writelines('%g\n' % e for e in events) + np.savetxt(filename, np.asarray(events), fmt=fmt) # also return them return events @@ -245,9 +218,11 @@ def quantize_events(events, fps, length=None, shift=None): :return: a quantized numpy array """ + # convert to numpy array if needed + events = np.asarray(events, dtype=np.float) # shift all events if needed if shift: - events = np.asarray(events) + shift + events += shift # determine the length for the quantized array if length is None: # set the length to be long enough to cover all events @@ -258,14 +233,11 @@ def quantize_events(events, fps, length=None, shift=None): events = events[:np.searchsorted(events, float(length - 0.5) / fps)] # init array quantized = np.zeros(length) - # set the events - for event in events: - idx = int(round(event * float(fps))) - try: - quantized[idx] = 1 - except IndexError: - # ignore out of range indices - pass + # quantize + events *= fps + # indices to be set in the quantized array + idx = np.unique(np.round(events).astype(np.int)) + quantized[idx] = 1 # return the quantized array return quantized @@ -304,7 +276,7 @@ def __call__(self, parser, namespace, value, option_string=None): try: cur_values.extend([self.list_type(v) for v in value.split(self.sep)]) - except ValueError, e: + except ValueError as e: raise argparse.ArgumentError(self, e) diff --git a/madmom/utils/midi.py b/madmom/utils/midi.py index 74c4f4db4..965459e0d 100644 --- a/madmom/utils/midi.py +++ b/madmom/utils/midi.py @@ -3,14 +3,26 @@ # pylint: disable=invalid-name # pylint: disable=too-many-arguments # pylint: disable=too-few-public-methods - """ This file contains MIDI functionality. Almost all code is taken from Giles Hall's python-midi package: https://github.com/vishnubob/python-midi -The last merged commit is 3053fefe8cd829ff891ac4fe58dc230744fce0e6 +It combines the complete package in a single file, to make it easier to +distribute. Most notable changes are `MIDITrack` and `MIDIFile` classes which +handle all data i/o and provide a interface which allows to read/display all +notes as simple numpy arrays. Also, the EventRegistry is handled differently. + +The last merged commit is 3053fefe. + +Since then the following commits have been added functionality-wise: +- 0964c0b (prevent multiple tick conversions) +- c43bf37 (add pitch and value properties to AfterTouchEvent) +- 40111c6 (add 0x08 MetaEvent: ProgramNameEvent) +- 43de818 (handle unknown MIDI meta events gracefully) + +Additionally, the module has been updated to work with Python3. The MIT License (MIT) Copyright (c) 2013 Giles F. Hall @@ -35,21 +47,22 @@ """ +from __future__ import absolute_import, division, print_function + import math import struct import numpy as np -from ..utils import open # constants OCTAVE_MAX_VALUE = 12 -OCTAVE_VALUES = range(OCTAVE_MAX_VALUE) +OCTAVE_VALUES = list(range(OCTAVE_MAX_VALUE)) NOTE_NAMES = ['C', 'Cs', 'D', 'Ds', 'E', 'F', 'Fs', 'G', 'Gs', 'A', 'As', 'B'] WHITE_KEYS = [0, 2, 4, 5, 7, 9, 11] BLACK_KEYS = [1, 3, 6, 8, 10] NOTE_PER_OCTAVE = len(NOTE_NAMES) -NOTE_VALUES = range(OCTAVE_MAX_VALUE * NOTE_PER_OCTAVE) +NOTE_VALUES = list(range(OCTAVE_MAX_VALUE * NOTE_PER_OCTAVE)) NOTE_NAME_MAP_FLAT = {} NOTE_VALUE_MAP_FLAT = [] NOTE_NAME_MAP_SHARP = {} @@ -58,22 +71,22 @@ for index in range(128): note_idx = index % NOTE_PER_OCTAVE oct_idx = index / OCTAVE_MAX_VALUE - name = NOTE_NAMES[note_idx] - if len(name) == 2: + note_name = NOTE_NAMES[note_idx] + if len(note_name) == 2: # sharp note flat = NOTE_NAMES[note_idx + 1] + 'b' NOTE_NAME_MAP_FLAT['%s_%d' % (flat, oct_idx)] = index - NOTE_NAME_MAP_SHARP['%s_%d' % (name, oct_idx)] = index + NOTE_NAME_MAP_SHARP['%s_%d' % (note_name, oct_idx)] = index NOTE_VALUE_MAP_FLAT.append('%s_%d' % (flat, oct_idx)) - NOTE_VALUE_MAP_SHARP.append('%s_%d' % (name, oct_idx)) - globals()['%s_%d' % (name[0] + 's', oct_idx)] = index + NOTE_VALUE_MAP_SHARP.append('%s_%d' % (note_name, oct_idx)) + globals()['%s_%d' % (note_name[0] + 's', oct_idx)] = index globals()['%s_%d' % (flat, oct_idx)] = index else: - NOTE_NAME_MAP_FLAT['%s_%d' % (name, oct_idx)] = index - NOTE_NAME_MAP_SHARP['%s_%d' % (name, oct_idx)] = index - NOTE_VALUE_MAP_FLAT.append('%s_%d' % (name, oct_idx)) - NOTE_VALUE_MAP_SHARP.append('%s_%d' % (name, oct_idx)) - globals()['%s_%d' % (name, oct_idx)] = index + NOTE_NAME_MAP_FLAT['%s_%d' % (note_name, oct_idx)] = index + NOTE_NAME_MAP_SHARP['%s_%d' % (note_name, oct_idx)] = index + NOTE_VALUE_MAP_FLAT.append('%s_%d' % (note_name, oct_idx)) + NOTE_VALUE_MAP_SHARP.append('%s_%d' % (note_name, oct_idx)) + globals()['%s_%d' % (note_name, oct_idx)] = index BEAT_NAMES = ['whole', 'half', 'quarter', 'eighth', 'sixteenth', 'thirty-second', 'sixty-fourth'] @@ -107,7 +120,7 @@ def read_variable_length(data): next_byte = 1 value = 0 while next_byte: - next_value = ord(data.next()) + next_value = next(data) # is the hi-bit set? if not next_value & 0x80: # no next BYTE @@ -123,70 +136,67 @@ def read_variable_length(data): def write_variable_length(value): """ + Write a variable length variable. :param value: :return: + """ - chr1 = chr(value & 0x7F) + result = bytearray() + result.insert(0, value & 0x7F) value >>= 7 if value: - chr2 = chr((value & 0x7F) | 0x80) + result.insert(0, (value & 0x7F) | 0x80) value >>= 7 if value: - chr3 = chr((value & 0x7F) | 0x80) + result.insert(0, (value & 0x7F) | 0x80) value >>= 7 if value: - chr4 = chr((value & 0x7F) | 0x80) - result = chr4 + chr3 + chr2 + chr1 - else: - result = chr3 + chr2 + chr1 - else: - result = chr2 + chr1 - else: - result = chr1 + result.insert(0, (value & 0x7F) | 0x80) return result -class EventRegistry(type): +class EventRegistry(object): """ - Class for automatically registering usable Events. + Class for registering Events. + + Event classes should be registered manually by calling + EventRegistry.register_event(EventClass) after the class definition. """ Events = {} MetaEvents = {} - def __init__(cls, name, bases, dct): + @classmethod + def register_event(cls, event): """ Registers an event in the registry. - :param name: the name of the event to register - :param bases: the base class(es) - :param dct: dictionary with all the stuff + :param event: the event to register :raise ValueError: for unknown events """ - super(EventRegistry, cls).__init__(name, bases, dct) - # register the event - if cls.register: - # normal events - if any(x in [Event, NoteEvent] for x in bases): - # raise an error if the event class is registered already - if cls.status_msg in EventRegistry.Events: - raise AssertionError("Event %s already registered" % - cls.name) - # register the Event - EventRegistry.Events[cls.status_msg] = cls - # meta events - elif any(x in [MetaEvent, MetaEventWithText] for x in bases): - # raise an error if the meta event class is registered already - if cls.meta_command in EventRegistry.MetaEvents: + # normal events + if any(b in (Event, NoteEvent) for b in event.__bases__): + # raise an error if the event class is registered already + if event.status_msg in cls.Events: + raise AssertionError("Event %s already registered" % + event.name) + # register the Event + cls.Events[event.status_msg] = event + # meta events + elif any(b in (MetaEvent, MetaEventWithText) for b in event.__bases__): + # raise an error if the meta event class is registered already + if event.meta_command is not None: + if event.meta_command in EventRegistry.MetaEvents: raise AssertionError("Event %s already registered" % - cls.name) - # register the MetaEvent - EventRegistry.MetaEvents[cls.meta_command] = cls - else: - # raise an error - raise ValueError("Unknown base class in event type: %s" % name) + event.name) + # register the MetaEvent + cls.MetaEvents[event.meta_command] = event + else: + # raise an error + raise ValueError("Unknown base class in event type: %s" % + event.__bases__) class AbstractEvent(object): @@ -194,12 +204,9 @@ class AbstractEvent(object): Abstract Event. """ - __metaclass__ = EventRegistry - __slots__ = ['tick', 'data'] name = "Generic MIDI Event" length = 0 status_msg = 0x0 - register = False def __init__(self, **kwargs): if isinstance(self.length, int): @@ -212,23 +219,26 @@ def __init__(self, **kwargs): setattr(self, key, kwargs[key]) def __cmp__(self, other): - if self.tick < other.tick: - return -1 - elif self.tick > other.tick: - return 1 - return cmp(self.data, other.data) + raise RuntimeError('add missing comparison operators') + + def __lt__(self, other): + return self.tick < other.tick + + def __gt__(self, other): + return self.tick > other.tick def __str__(self): return "%s: tick: %s data: %s" % (self.__class__.__name__, self.tick, self.data) +# do not register AbstractEvent + class Event(AbstractEvent): """ Event. """ - __slots__ = ['channel'] name = 'Event' def __init__(self, **kwargs): @@ -260,6 +270,8 @@ def is_event(cls, status_msg): """ return cls.status_msg == (status_msg & 0xF0) +# do not register Event + class MetaEvent(AbstractEvent): """ @@ -282,6 +294,24 @@ def is_event(cls, status_msg): """ return cls.status_msg == status_msg +# do not register MetaEvent + + +class MetaEventWithText(MetaEvent): + """ + Meta Event With Text. + + """ + def __init__(self, **kwargs): + super(MetaEventWithText, self).__init__(**kwargs) + if 'text' not in kwargs: + self.text = ''.join(chr(datum) for datum in self.data) + + def __str__(self): + return "%s: %s" % (self.__class__.__name__, self.text) + +# do not register MetaEventWithText + class NoteEvent(Event): """ @@ -289,7 +319,6 @@ class NoteEvent(Event): concrete class. It defines the generalities of NoteOn and NoteOff events. """ - __slots__ = ['pitch', 'velocity'] length = 2 @property @@ -328,45 +357,84 @@ def velocity(self, velocity): """ self.data[1] = velocity +# do not register NoteEvent + class NoteOnEvent(NoteEvent): """ Note On Event. """ - register = True status_msg = 0x90 name = 'Note On' +EventRegistry.register_event(NoteOnEvent) + class NoteOffEvent(NoteEvent): """ Note Off Event. """ - register = True status_msg = 0x80 name = 'Note Off' +EventRegistry.register_event(NoteOffEvent) + class AfterTouchEvent(Event): """ After Touch Event. """ - register = True status_msg = 0xA0 length = 2 name = 'After Touch' + @property + def pitch(self): + """ + Pitch of the after touch event. + + """ + return self.data[0] + + @pitch.setter + def pitch(self, pitch): + """ + Set the pitch of the after touch event. + + :param pitch: pitch of the after touch event. + + """ + self.data[0] = pitch + + @property + def value(self): + """ + Value of the after touch event. + + """ + return self.data[1] + + @value.setter + def value(self, value): + """ + Set the value of the after touch event. + + :param value: value of the after touch event. + + """ + self.data[1] = value + +EventRegistry.register_event(AfterTouchEvent) + class ControlChangeEvent(Event): """ Control Change Event. """ - __slots__ = ['control', 'value'] - register = True status_msg = 0xB0 length = 2 name = 'Control Change' @@ -407,14 +475,14 @@ def value(self, value): """ self.data[1] = value +EventRegistry.register_event(ControlChangeEvent) + class ProgramChangeEvent(Event): """ Program Change Event. """ - __slots__ = ['value'] - register = True status_msg = 0xC0 length = 1 name = 'Program Change' @@ -436,14 +504,14 @@ def value(self, value): """ self.data[0] = value +EventRegistry.register_event(ProgramChangeEvent) + class ChannelAfterTouchEvent(Event): """ Channel After Touch Event. """ - __slots__ = ['value'] - register = True status_msg = 0xD0 length = 1 name = 'Channel After Touch' @@ -465,14 +533,14 @@ def value(self, value): """ self.data[0] = value +EventRegistry.register_event(ChannelAfterTouchEvent) + class PitchWheelEvent(Event): """ Pitch Wheel Event. """ - __slots__ = ['pitch'] - register = True status_msg = 0xE0 length = 2 name = 'Pitch Wheel' @@ -497,13 +565,14 @@ def pitch(self, pitch): self.data[0] = value & 0x7F self.data[1] = (value >> 7) & 0x7F +EventRegistry.register_event(PitchWheelEvent) + class SysExEvent(Event): """ System Exclusive Event. """ - register = True status_msg = 0xF0 length = 'variable' name = 'SysEx' @@ -519,19 +588,7 @@ def is_event(cls, status_msg): """ return cls.status_msg == status_msg - -class MetaEventWithText(MetaEvent): - """ - Meta Event With Text. - - """ - def __init__(self, **kwargs): - super(MetaEventWithText, self).__init__(**kwargs) - if 'text' not in kwargs: - self.text = ''.join(chr(datum) for datum in self.data) - - def __str__(self): - return "%s: %s" % (self.__class__.__name__, self.text) +EventRegistry.register_event(SysExEvent) class SequenceNumberMetaEvent(MetaEvent): @@ -539,97 +596,132 @@ class SequenceNumberMetaEvent(MetaEvent): Sequence Number Meta Event. """ - register = True meta_command = 0x00 length = 2 name = 'Sequence Number' +EventRegistry.register_event(SequenceNumberMetaEvent) + class TextMetaEvent(MetaEventWithText): """ Text Meta Event. """ - register = True meta_command = 0x01 length = 'variable' name = 'Text' +EventRegistry.register_event(TextMetaEvent) + class CopyrightMetaEvent(MetaEventWithText): """ Copyright Meta Event. """ - register = True meta_command = 0x02 length = 'variable' name = 'Copyright Notice' +EventRegistry.register_event(CopyrightMetaEvent) + class TrackNameEvent(MetaEventWithText): """ Track Name Event. """ - register = True meta_command = 0x03 length = 'variable' name = 'Track Name' +EventRegistry.register_event(TrackNameEvent) + class InstrumentNameEvent(MetaEventWithText): """ Instrument Name Event. """ - register = True meta_command = 0x04 length = 'variable' name = 'Instrument Name' +EventRegistry.register_event(InstrumentNameEvent) + class LyricsEvent(MetaEventWithText): """ Lyrics Event. """ - register = True meta_command = 0x05 length = 'variable' name = 'Lyrics' +EventRegistry.register_event(LyricsEvent) + class MarkerEvent(MetaEventWithText): """ Marker Event. """ - register = True meta_command = 0x06 length = 'variable' name = 'Marker' +EventRegistry.register_event(MarkerEvent) + class CuePointEvent(MetaEventWithText): """ Cue Point Event. """ - register = True meta_command = 0x07 length = 'variable' name = 'Cue Point' +EventRegistry.register_event(CuePointEvent) + -class SomethingEvent(MetaEvent): +class ProgramNameEvent(MetaEventWithText): """ - Something Event. + Program Name Event. """ - register = True - meta_command = 0x09 - name = 'Something' + meta_command = 0x08 + length = 'varlen' + name = 'Program Name' + +EventRegistry.register_event(ProgramNameEvent) + + +class UnknownMetaEvent(MetaEvent): + """ + Unknown Meta Event. + + The `meta_command` class variable must be set by the constructor of + inherited classes. + + """ + meta_command = None + name = 'Unknown' + + def __init__(self, **kwargs): + """ + + """ + super(MetaEvent, self).__init__(**kwargs) + self.meta_command = kwargs['meta_command'] + + def copy(self, **kwargs): + kwargs['meta_command'] = self.meta_command + return super(UnknownMetaEvent, self).copy(kwargs) + +EventRegistry.register_event(UnknownMetaEvent) class ChannelPrefixEvent(MetaEvent): @@ -637,49 +729,51 @@ class ChannelPrefixEvent(MetaEvent): Channel Prefix Event. """ - register = True meta_command = 0x20 length = 1 name = 'Channel Prefix' +EventRegistry.register_event(ChannelPrefixEvent) + class PortEvent(MetaEvent): """ Port Event. """ - register = True meta_command = 0x21 name = 'MIDI Port/Cable' +EventRegistry.register_event(PortEvent) + class TrackLoopEvent(MetaEvent): """ Track Loop Event. """ - register = True meta_command = 0x2E name = 'Track Loop' +EventRegistry.register_event(TrackLoopEvent) + class EndOfTrackEvent(MetaEvent): """ End Of Track Event. """ - register = True meta_command = 0x2F name = 'End of Track' +EventRegistry.register_event(EndOfTrackEvent) + class SetTempoEvent(MetaEvent): """ Set Tempo Event. """ - __slots__ = ['microseconds_per_quarter_note'] - register = True meta_command = 0x51 length = 3 name = 'Set Tempo' @@ -691,7 +785,7 @@ def microseconds_per_quarter_note(self): """ assert len(self.data) == 3 - values = [self.data[x] << (16 - (8 * x)) for x in xrange(3)] + values = [self.data[x] << (16 - (8 * x)) for x in range(3)] return sum(values) @microseconds_per_quarter_note.setter @@ -704,24 +798,25 @@ def microseconds_per_quarter_note(self, microseconds): """ self.data = [(microseconds >> (16 - (8 * x)) & 0xFF) for x in range(3)] +EventRegistry.register_event(SetTempoEvent) + class SmpteOffsetEvent(MetaEvent): """ SMPTE Offset Event. """ - register = True meta_command = 0x54 name = 'SMPTE Offset' +EventRegistry.register_event(SmpteOffsetEvent) + class TimeSignatureEvent(MetaEvent): """ Time Signature Event. """ - __slots__ = ['numerator', 'denominator', 'metronome', 'thirty_seconds'] - register = True meta_command = 0x58 length = 4 name = 'Time Signature' @@ -794,14 +889,14 @@ def thirty_seconds(self, thirty_seconds): """ self.data[3] = thirty_seconds +EventRegistry.register_event(TimeSignatureEvent) + class KeySignatureEvent(MetaEvent): """ Key Signature Event. """ - __slots__ = ['alternatives', 'minor'] - register = True meta_command = 0x59 length = 2 name = 'Key Signature' @@ -842,16 +937,19 @@ def minor(self, val): """ self.data[1] = val +EventRegistry.register_event(KeySignatureEvent) + class SequencerSpecificEvent(MetaEvent): """ Sequencer Specific Event. """ - register = True meta_command = 0x7F name = 'Sequencer Specific' +EventRegistry.register_event(SequencerSpecificEvent) + # MIDI Track class MIDITrack(object): @@ -869,31 +967,31 @@ def __init__(self, events=None, relative_timing=True): self.events = [] else: self.events = events - self._relative_timing = relative_timing + self.relative_timing = relative_timing def make_ticks_abs(self): """ Make the track's timing information absolute. """ - if self._relative_timing: + if self.relative_timing: running_tick = 0 for event in self.events: event.tick += running_tick running_tick = event.tick - self._relative_timing = False + self.relative_timing = False def make_ticks_rel(self): """ Make the track's timing information relative. """ - if not self._relative_timing: + if not self.relative_timing: running_tick = 0 for event in self.events: event.tick -= running_tick running_tick += event.tick - self._relative_timing = True + self.relative_timing = True @property def data_stream(self): @@ -906,32 +1004,32 @@ def data_stream(self): # and unset the status message status = None # then encode all events of the track - track_data = '' + track_data = bytearray() for event in self.events: # encode the event data, first the timing information - track_data += write_variable_length(event.tick) + track_data.extend(write_variable_length(event.tick)) # is the event a MetaEvent? if isinstance(event, MetaEvent): - track_data += chr(event.status_msg) - track_data += chr(event.meta_command) - track_data += write_variable_length(len(event.data)) - track_data += ''.join([chr(data) for data in event.data]) + track_data.append(event.status_msg) + track_data.append(event.meta_command) + track_data.extend(write_variable_length(len(event.data))) + track_data.extend(event.data) # is this event a SysEx Event? elif isinstance(event, SysExEvent): - track_data += chr(0xF0) - track_data += ''.join([chr(data) for data in event.data]) - track_data += chr(0xF7) + track_data.append(0xF0) + track_data.extend(event.data) + track_data.append(0xF7) # not a meta or SysEx event, must be a general message elif isinstance(event, Event): if not status or status.status_msg != event.status_msg or \ status.channel != event.channel: status = event - track_data += chr(event.status_msg | event.channel) - track_data += ''.join([chr(data) for data in event.data]) + track_data.append(event.status_msg | event.channel) + track_data.extend(event.data) else: raise ValueError("Unknown MIDI Event: " + str(event)) # prepare the track header - track_header = 'MTrk%s' % struct.pack(">L", len(track_data)) + track_header = b'MTrk%s' % struct.pack(">L", len(track_data)) # return the track header + data return track_header + track_data @@ -949,7 +1047,7 @@ def from_file(cls, midi_stream): status = None # first four bytes are Track header chunk = midi_stream.read(4) - if chunk != 'MTrk': + if chunk != b'MTrk': raise TypeError("Bad track header in MIDI file: %s" % chunk) # next four bytes are track size track_size = struct.unpack(">L", midi_stream.read(4))[0] @@ -960,22 +1058,26 @@ def from_file(cls, midi_stream): # first datum is variable length representing the delta-time tick = read_variable_length(track_data) # next byte is status message - status_msg = ord(track_data.next()) + status_msg = next(track_data) # is the event a MetaEvent? if MetaEvent.is_event(status_msg): - cmd = ord(track_data.next()) + cmd = next(track_data) if cmd not in EventRegistry.MetaEvents: - raise Warning("Unknown Meta MIDI Event: %s" % cmd) - event_cls = EventRegistry.MetaEvents[cmd] + import warnings + warnings.warn("Unknown Meta MIDI Event: %s" % cmd) + event_cls = UnknownMetaEvent + else: + event_cls = EventRegistry.MetaEvents[cmd] data_len = read_variable_length(track_data) - data = [ord(track_data.next()) for _ in range(data_len)] + data = [next(track_data) for _ in range(data_len)] # create an event and append it to the list - events.append(event_cls(tick=tick, data=data)) + events.append(event_cls(tick=tick, data=data, + meta_command=cmd)) # is this event a SysEx Event? elif SysExEvent.is_event(status_msg): data = [] while True: - datum = ord(track_data.next()) + datum = next(track_data) if datum == 0xF7: break data.append(datum) @@ -991,7 +1093,7 @@ def from_file(cls, midi_stream): event_cls = EventRegistry.Events[key] channel = status & 0x0F data.append(status_msg) - data += [ord(track_data.next()) for _ in + data += [next(track_data) for _ in range(event_cls.length - 1)] # create an event and append it to the list events.append(event_cls(tick=tick, channel=channel, @@ -1000,7 +1102,7 @@ def from_file(cls, midi_stream): status = status_msg event_cls = EventRegistry.Events[key] channel = status & 0x0F - data = [ord(track_data.next()) for _ in + data = [next(track_data) for _ in range(event_cls.length)] # create an event and append it to the list events.append(event_cls(tick=tick, channel=channel, @@ -1322,20 +1424,9 @@ def data_stream(self): MIDI data stream representation of the MIDI file. """ - # from StringIO import StringIO - # str_buffer = StringIO() - # # put the MIDI header in the stream - # header_data = struct.pack(">LHHH", 6, self.format, - # len(self.tracks), self.resolution) - # str_buffer.write('MThd%s' % header_data) - # # put each track in the stream - # for track in self.tracks: - # str_buffer.write(track.data_stream) - # # return the string buffer - # return str_buffer.getvalue() # generate a MIDI header - data = 'MThd%s' % struct.pack(">LHHH", 6, self.format, - len(self.tracks), self.resolution) + data = b'MThd%s' % struct.pack(">LHHH", 6, self.format, + len(self.tracks), self.resolution) # append the tracks for track in self.tracks: data += track.data_stream @@ -1349,8 +1440,13 @@ def write(self, midi_file): :param midi_file: the MIDI file name """ - with open(midi_file, 'wb') as midi_file: - midi_file.write(self.data_stream) + # if we get a filename, open the file + if not hasattr(midi_file, 'write'): + midi_file = open(midi_file, 'wb') + # write the MIDI stream + midi_file.write(self.data_stream) + # close the file + midi_file.close() @classmethod def from_file(cls, midi_file): @@ -1368,8 +1464,8 @@ def from_file(cls, midi_file): # read in file header # first four bytes are MIDI header chunk = midi_file.read(4) - if chunk != 'MThd': - raise TypeError("Bad header in MIDI file.") + if chunk != b'MThd': + raise TypeError("Bad header in MIDI file: %s", chunk) # next four bytes are header size # next two bytes specify the format version # next two bytes specify the number of tracks diff --git a/madmom/utils/stats.py b/madmom/utils/stats.py index fc3a6a481..c5cfe408c 100644 --- a/madmom/utils/stats.py +++ b/madmom/utils/stats.py @@ -2,12 +2,13 @@ # pylint: disable=no-member # pylint: disable=invalid-name # pylint: disable=too-many-arguments - """ This file contains some statistical functionality. """ +from __future__ import absolute_import, division, print_function + import numpy as np diff --git a/setup.py b/setup.py index 992b7f1cd..82061c6f3 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,9 @@ # some PyPI metadata classifiers = ['Development Status :: 3 - Alpha', 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', 'Environment :: Console', 'License :: OSI Approved :: BSD License', 'License :: Free for non-commercial use', diff --git a/tests/__init__.py b/tests/__init__.py index 8d23ab6f5..c898770d5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,9 +1,11 @@ # encoding: utf-8 +# pylint: skip-file """ This module contains tests. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import os diff --git a/tests/data/events.txt b/tests/data/events.txt index d484a812c..f896c7957 100644 --- a/tests/data/events.txt +++ b/tests/data/events.txt @@ -1,8 +1,8 @@ -1 -1.02 -1.5 -2 -2.03 -2.05 -2.5 -3 +1.000 +1.020 +1.500 +2.000 +2.030 +2.050 +2.500 +3.000 diff --git a/tests/test_audio_comb_filters.py b/tests/test_audio_comb_filters.py index 1a0276dbc..d195ef0a7 100644 --- a/tests/test_audio_comb_filters.py +++ b/tests/test_audio_comb_filters.py @@ -1,9 +1,11 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.audio.comb_filters module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest diff --git a/tests/test_audio_filters.py b/tests/test_audio_filters.py index 99a0a45de..2b30f9f13 100644 --- a/tests/test_audio_filters.py +++ b/tests/test_audio_filters.py @@ -1,13 +1,15 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.audio.filters module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest import types -import cPickle +import pickle import tempfile from madmom.audio.filters import * @@ -436,9 +438,9 @@ def test_normalization(self): def test_pickling(self): f, filename = tempfile.mkstemp() filt = Filter(np.arange(5)) - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.start, filt_.start)) @@ -508,9 +510,9 @@ def test_values(self): def test_pickling(self): f, filename = tempfile.mkstemp() filt = TriangularFilter(1, 4, 10) - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.start, filt_.start)) self.assertTrue(np.allclose(filt.stop, filt_.stop)) @@ -518,30 +520,19 @@ def test_pickling(self): def test_band_bins_method_too_few_bins(self): with self.assertRaises(ValueError): - result = TriangularFilter.band_bins(np.arange(2)) - result.next() + list(TriangularFilter.band_bins(np.arange(2))) def test_band_bins_method_overlap(self): # test overlapping - result = TriangularFilter.band_bins(self.bins) - self.assertTrue(result.next() == (0, 1, 2)) - self.assertTrue(result.next() == (1, 2, 3)) - self.assertTrue(result.next() == (2, 3, 4)) - self.assertTrue(result.next() == (3, 4, 6)) - self.assertTrue(result.next() == (4, 6, 9)) - with self.assertRaises(StopIteration): - result.next() + result = list(TriangularFilter.band_bins(self.bins)) + self.assertTrue(result == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 6), + (4, 6, 9)]) def test_band_bins_method_non_overlap(self): # test non-overlapping - result = TriangularFilter.band_bins(self.bins, overlap=False) - self.assertTrue(result.next() == (1, 1, 2)) - self.assertTrue(result.next() == (2, 2, 3)) - self.assertTrue(result.next() == (3, 3, 4)) - self.assertTrue(result.next() == (4, 4, 5)) - self.assertTrue(result.next() == (5, 6, 8)) - with self.assertRaises(StopIteration): - result.next() + result = list(TriangularFilter.band_bins(self.bins, overlap=False)) + self.assertTrue(result == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), + (5, 6, 8)]) def test_filters_method_normalized(self): # normalized filters @@ -624,34 +615,25 @@ def test_values(self): def test_pickling(self): f, filename = tempfile.mkstemp() filt = RectangularFilter(5, 10) - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.start, filt_.start)) self.assertTrue(np.allclose(filt.stop, filt_.stop)) def test_band_bins_method_too_few_bins(self): with self.assertRaises(ValueError): - result = RectangularFilter.band_bins(np.arange(1)) - result.next() + list(RectangularFilter.band_bins(np.arange(1))) def test_band_bins_method_overlap(self): - result = RectangularFilter.band_bins(self.bins, overlap=True) with self.assertRaises(NotImplementedError): - # TODO: write test when implemented - result.next() + list(RectangularFilter.band_bins(self.bins, overlap=True)) def test_band_bins_method(self): - result = RectangularFilter.band_bins(self.bins) - self.assertTrue(result.next() == (0, 1)) - self.assertTrue(result.next() == (1, 2)) - self.assertTrue(result.next() == (2, 3)) - self.assertTrue(result.next() == (3, 4)) - self.assertTrue(result.next() == (4, 6)) - self.assertTrue(result.next() == (6, 9)) - with self.assertRaises(StopIteration): - result.next() + result = list(RectangularFilter.band_bins(self.bins)) + self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (3, 4), (4, 6), + (6, 9)]) def test_filters_method_normalized(self): # normalized filters @@ -814,14 +796,14 @@ def test_values_rectangular_and_triangular(self): self.assertTrue(np.allclose(filt.max(), 1)) # all triangular filters correct = np.zeros(100) - correct[1:70] = [1./6, 2./6, 3./6, 4./6, 5./6, 1., 8./9, 7./9, 6./9, - 5./9, 5./9, 6./9, 7./9, 8./9, 1., 0.9, 0.8, 0.7, 0.6, - 0.5, 0.6, 0.7, 0.8, 0.9, 1., 0.96, 0.92, 0.88, 0.84, - 0.8, 0.76, 0.72, 0.68, 0.64, 0.6, 0.56, 0.52, 0.52, - 0.56, 0.6, 0.64, 0.68, 0.72, 0.76, 0.8, 0.84, 0.88, - 0.92, 0.96, 1., 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, - 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15, - 0.1, 0.05] + correct[1:70] = [1. / 6, 2. / 6, 3. / 6, 4. / 6, 5. / 6, 1., 8. / 9, + 7. / 9, 6. / 9, 5. / 9, 5. / 9, 6. / 9, 7. / 9, + 8. / 9, 1., 0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, + 0.9, 1., 0.96, 0.92, 0.88, 0.84, 0.8, 0.76, 0.72, + 0.68, 0.64, 0.6, 0.56, 0.52, 0.52, 0.56, 0.6, 0.64, + 0.68, 0.72, 0.76, 0.8, 0.84, 0.88, 0.92, 0.96, 1., + 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, + 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15, 0.1, 0.05] self.assertTrue(np.allclose(filt[:, 0], correct)) # all rectangular filters are 1 self.assertTrue(np.allclose(filt[:, 1], np.ones(100))) @@ -833,9 +815,9 @@ def test_values_rectangular_and_triangular(self): def test_pickling(self): filt = Filterbank.from_filters(self.triang_filters, np.arange(100)) f, filename = tempfile.mkstemp() - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.bin_frequencies, filt_.bin_frequencies)) @@ -888,9 +870,9 @@ def test_values(self): def test_pickling(self): filt = MelFilterbank(np.arange(1000) * 20, 10) f, filename = tempfile.mkstemp() - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.bin_frequencies, filt_.bin_frequencies)) @@ -928,7 +910,7 @@ def test_default_values(self): [11304.931640, 13285.986328], [12273.925781, 14427.246093], [13329.052734, 15654.638671], [14470.312500, 16968.164062]] self.assertTrue(np.allclose(filt.min(), 0)) - self.assertTrue(np.allclose(filt.max(), 1./3)) + self.assertTrue(np.allclose(filt.max(), 1. / 3)) self.assertTrue(np.allclose(filt.center_frequencies, center)) self.assertTrue(np.allclose(filt.corner_frequencies, corner)) @@ -962,9 +944,9 @@ def test_constant_values(self): def test_pickling(self): filt = BarkFilterbank(FFT_FREQS_1024) f, filename = tempfile.mkstemp() - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.bin_frequencies, filt_.bin_frequencies)) @@ -1075,9 +1057,9 @@ def test_constant_values(self): def test_pickling(self): filt = LogarithmicFilterbank(FFT_FREQS_1024) f, filename = tempfile.mkstemp() - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.bin_frequencies, filt_.bin_frequencies)) @@ -1180,9 +1162,9 @@ def test_types(self): def test_pickling(self): filt = RectangularFilterbank(FFT_FREQS_1024, [100, 1000]) f, filename = tempfile.mkstemp() - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.bin_frequencies, filt_.bin_frequencies)) @@ -1197,15 +1179,13 @@ def test_values(self): self.assertEqual(filt.shape, (100, 3)) self.assertTrue(np.allclose(filt.bin_frequencies, np.arange(0, 2000, 20))) - print filt.crossover_frequencies self.assertTrue(np.allclose(filt.crossover_frequencies, [100, 1000])) - def test_values_unique_filters(self): filt = RectangularFilterbank(np.arange(0, 2000, 20), [100, 101, 1000], unique_filters=False) self.assertTrue(np.allclose(filt.min(), 0)) - self.assertTrue(np.allclose(filt.max(), 1./3)) + self.assertTrue(np.allclose(filt.max(), 1. / 3)) # second band must be 0 self.assertTrue(np.allclose(filt[:, 1], np.zeros(100))) self.assertEqual(filt.shape, (100, 4)) @@ -1258,9 +1238,9 @@ def test_values(self): def test_pickling(self): filt = PitchClassProfileFilterbank(FFT_FREQS_1024) f, filename = tempfile.mkstemp() - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.bin_frequencies, filt_.bin_frequencies)) @@ -1300,9 +1280,9 @@ def test_values(self): def test_pickling(self): filt = HarmonicPitchClassProfileFilterbank(FFT_FREQS_1024) f, filename = tempfile.mkstemp() - cPickle.dump(filt, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - filt_ = cPickle.load(open(filename)) + pickle.dump(filt, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + filt_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(filt, filt_)) self.assertTrue(np.allclose(filt.bin_frequencies, filt_.bin_frequencies)) diff --git a/tests/test_audio_signal.py b/tests/test_audio_signal.py index 54111d80c..eee4f8c68 100644 --- a/tests/test_audio_signal.py +++ b/tests/test_audio_signal.py @@ -1,12 +1,13 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.audio.signal module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest -import __builtin__ from . import DATA_PATH from .test_audio_comb_filters import sig_1d, sig_2d @@ -335,12 +336,12 @@ def test_values(self): result = sound_pressure_level(np.zeros(100)) self.assertTrue(np.allclose(result, -np.finfo(float).max)) # maximum float amplitude, alternating between -1 and 1 - sig = np.cos(np.linspace(0, 2*np.pi*100, 2*100+1)) + sig = np.cos(np.linspace(0, 2 * np.pi * 100, 2 * 100 + 1)) result = sound_pressure_level(sig) self.assertTrue(np.allclose(result, 0.)) # maximum float amplitude, alternating between -1 and 1 - sig = np.cos(np.linspace(0, 2*np.pi*100, 2*100+1)) * \ - np.iinfo(np.int16).max + sig = (np.cos(np.linspace(0, 2 * np.pi * 100, 2 * 100 + 1)) * + np.iinfo(np.int16).max) result = sound_pressure_level(sig.astype(np.int16)) self.assertTrue(np.allclose(result, 0.)) # multi-channel signals @@ -373,7 +374,7 @@ def test_file_handle(self): # test wave loader f = DATA_PATH + '/sample.wav' # open file handle - file_handle = __builtin__.open(f) + file_handle = open(f) signal, sample_rate = load_audio_file(file_handle) self.assertIsInstance(signal, np.ndarray) self.assertTrue(signal.dtype == np.int16) @@ -387,7 +388,7 @@ def test_file_handle(self): # test ffmpeg loader f = DATA_PATH + '/stereo_sample.flac' # open file handle - file_handle = __builtin__.open(f) + file_handle = open(f) signal, sample_rate = load_audio_file(file_handle) self.assertIsInstance(signal, np.ndarray) self.assertTrue(signal.dtype == np.int16) @@ -428,13 +429,15 @@ def test_values(self): def test_start_stop(self): # test wave loader f = DATA_PATH + '/sample.wav' - signal, sample_rate = load_audio_file(f, start=1./44100, stop=5./44100) + signal, sample_rate = load_audio_file(f, start=1. / 44100, + stop=5. / 44100) self.assertTrue(np.allclose(signal, [-2510, -2484, -2678, -2833])) self.assertTrue(len(signal) == 4) self.assertTrue(sample_rate == 44100) # test ffmpeg loader f = DATA_PATH + '/stereo_sample.flac' - signal, sample_rate = load_audio_file(f, start=1./44100, stop=4./44100) + signal, sample_rate = load_audio_file(f, start=1. / 44100, + stop=4. / 44100) self.assertTrue(np.allclose(signal, [[35, 36], [29, 34], [36, 31]])) self.assertTrue(len(signal) == 3) self.assertTrue(sample_rate == 44100) @@ -450,11 +453,11 @@ def test_downmix(self): # test ffmpeg loader f = DATA_PATH + '/stereo_sample.flac' signal, sample_rate = load_audio_file(f, num_channels=1) - # TODO: is it a problem that the results are rounded differently? - self.assertTrue(np.allclose(signal[:5], [36, 36, 32, 34, 34])) - self.assertTrue(len(signal) == 182919) + # results are rounded differently, thus allow atol=1 + self.assertTrue(np.allclose(signal[:5], [35, 35, 31, 33, 33], atol=1)) + # avconv results in a different length of 182909 samples + self.assertTrue(np.allclose(len(signal), 182919, atol=10)) self.assertTrue(sample_rate == 44100) - self.assertTrue(signal.shape == (182919, )) def test_upmix(self): f = DATA_PATH + '/sample.wav' @@ -555,13 +558,13 @@ def test_values_file(self): self.assertTrue(np.allclose(result.length, 2.8)) def test_pickling(self): - import cPickle + import pickle import tempfile result = Signal(DATA_PATH + '/sample.wav') f, filename = tempfile.mkstemp() - cPickle.dump(result, open(filename, 'w'), - protocol=cPickle.HIGHEST_PROTOCOL) - result_ = cPickle.load(open(filename)) + pickle.dump(result, open(filename, 'wb'), + protocol=pickle.HIGHEST_PROTOCOL) + result_ = pickle.load(open(filename, 'rb')) self.assertTrue(np.allclose(result, result_)) self.assertTrue(result.sample_rate == result_.sample_rate) diff --git a/tests/test_audio_spectrogram.py b/tests/test_audio_spectrogram.py index bd6520519..38e05177b 100644 --- a/tests/test_audio_spectrogram.py +++ b/tests/test_audio_spectrogram.py @@ -1,12 +1,14 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.audio.spectrogram module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest -import cPickle +import pickle from . import DATA_PATH from madmom.audio.spectrogram import * @@ -148,8 +150,8 @@ def test_values(self): def test_pickle(self): result = Spectrogram(DATA_PATH + '/sample.wav') - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) def test_methods(self): @@ -209,8 +211,8 @@ def test_pickle(self): from madmom.audio.filters import MelFilterbank result = FilteredSpectrogram(DATA_PATH + '/sample.wav', filterbank=MelFilterbank) - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) # additional attributes self.assertTrue(np.allclose(result.filterbank, dump.filterbank)) @@ -282,8 +284,8 @@ def test_pickle(self): # test with non-default values result = LogarithmicSpectrogram(DATA_PATH + '/sample.wav', mul=2, add=2) - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) self.assertTrue(result.mul == dump.mul) self.assertTrue(result.add == dump.add) @@ -345,8 +347,8 @@ def test_pickle(self): # test with non-default values result = LogarithmicFilteredSpectrogram(DATA_PATH + '/sample.wav', mul=2, add=2) - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) self.assertTrue(np.allclose(result.filterbank, dump.filterbank)) self.assertTrue(result.mul == dump.mul) @@ -416,8 +418,8 @@ def test_pickle(self): result = SpectrogramDifference(DATA_PATH + '/sample.wav', diff_ratio=0.7, diff_frames=3, diff_max_bins=2, positive_diffs=True) - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) self.assertTrue(result.diff_ratio == dump.diff_ratio) self.assertTrue(result.diff_frames == dump.diff_frames) @@ -502,8 +504,8 @@ def test_values(self): def test_pickle(self): # test with non-default values result = MultiBandSpectrogram(DATA_PATH + '/sample.wav', [200, 1000]) - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) self.assertTrue(result.crossover_frequencies == dump.crossover_frequencies) diff --git a/tests/test_audio_stft.py b/tests/test_audio_stft.py index 0aefe0588..2f86ecc6d 100644 --- a/tests/test_audio_stft.py +++ b/tests/test_audio_stft.py @@ -1,12 +1,14 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.audio.stft module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest -import cPickle +import pickle from . import DATA_PATH from madmom.audio.stft import * @@ -66,14 +68,14 @@ def test_value(self): # signal length and FFT size = 12 # fft_freqs: 0, 1/12, 2/12, 3/12, 4/12, 5/12 # [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0] every 4th bin => 3/12 - res = [3.+0.j, 0.+0.j, 0.-0.j, 3+0.j, 0.+0.j, 0.+0.j] + res = [3. + 0.j, 0. + 0.j, 0. - 0.j, 3 + 0.j, 0. + 0.j, 0. + 0.j] self.assertTrue(np.allclose(result[0], res)) # [1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0] every erd bin => 4/12 - res = [4.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 4.+0.j, 0.+0.j] + res = [4. + 0.j, 0. + 0.j, 0. + 0.j, 0. + 0.j, 4. + 0.j, 0. + 0.j] self.assertTrue(np.allclose(result[1], res)) # [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0] every 2nd bin => 6/12 # can't resolve any more - res = [6.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j] + res = [6. + 0.j, 0. + 0.j, 0. + 0.j, 0. + 0.j, 0. + 0.j, 0. + 0.j] self.assertTrue(np.allclose(result[2], res)) def test_circular_shift(self): @@ -81,14 +83,14 @@ def test_circular_shift(self): # signal length and FFT size = 12 # fft_freqs: 0, 1/12, 2/12, 3/12, 4/12, 5/12 # [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0] every 4th bin => 3/12 - res = [3.+0.j, 0.+0.j, 0.+0j, -3.+0.j, 0.+0.j, 0.+0.j] + res = [3. + 0.j, 0. + 0.j, 0. + 0j, -3. + 0.j, 0. + 0.j, 0. + 0.j] self.assertTrue(np.allclose(result[0], res)) # [1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0] every erd bin => 4/12 - res = [4.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 4.+0.j, 0.+0.j] + res = [4. + 0.j, 0. + 0.j, 0. + 0.j, 0. + 0.j, 4. + 0.j, 0. + 0.j] self.assertTrue(np.allclose(result[1], res)) # [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0] every 2nd bin => 6/12 # can't resolve any more - res = [6.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j] + res = [6. + 0.j, 0. + 0.j, 0. + 0.j, 0. + 0.j, 0. + 0.j, 0. + 0.j] self.assertTrue(np.allclose(result[2], res)) @@ -176,8 +178,8 @@ def test_pickle(self): result = ShortTimeFourierTransform(DATA_PATH + '/sample.wav', window=np.hamming, fft_size=4096, circular_shift=True) - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) # additional attributes self.assertTrue(np.allclose(result.window, dump.window)) @@ -258,8 +260,8 @@ def test_values(self): def test_pickle(self): result = Phase(DATA_PATH + '/sample.wav') - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) def test_methods(self): @@ -297,6 +299,6 @@ def test_values(self): def test_pickle(self): result = LocalGroupDelay(DATA_PATH + '/sample.wav') - dump = cPickle.dumps(result, protocol=cPickle.HIGHEST_PROTOCOL) - dump = cPickle.loads(dump) + dump = pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL) + dump = pickle.loads(dump) self.assertTrue(np.allclose(result, dump)) diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 578b6ab9c..bf3bc55e5 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -1,9 +1,11 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.evaluation module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest import math @@ -165,10 +167,10 @@ def test_results(self): # (TP + TN) / (TP + FP + TN + FN) self.assertEqual(e.accuracy, 1) # metric dictionary - self.assertEqual(e.metrics.keys(), ['num_tp', 'num_fp', 'num_tn', - 'num_fn', 'num_annotations', - 'precision', 'recall', 'fmeasure', - 'accuracy']) + self.assertEqual(list(e.metrics.keys()), + ['num_tp', 'num_fp', 'num_tn', 'num_fn', + 'num_annotations', 'precision', 'recall', + 'fmeasure', 'accuracy']) correct = OrderedDict([('num_tp', 0), ('num_fp', 0), ('num_tn', 0), ('num_fn', 0), ('num_annotations', 0), ('precision', 1.0), ('recall', 1.0), @@ -262,10 +264,10 @@ def test_results(self): # acc: (TP + TN) / (TP + FP + TN + FN) self.assertEqual(e.accuracy, 1) # test metric dictionary keys - self.assertEqual(e.metrics.keys(), ['tp', 'fp', 'tn', 'fn', 'num_tp', - 'num_fp', 'num_tn', 'num_fn', - 'num_annotations', 'precision', - 'recall', 'fmeasure', 'accuracy']) + self.assertEqual(list(e.metrics.keys()), + ['tp', 'fp', 'tn', 'fn', 'num_tp', 'num_fp', 'num_tn', + 'num_fn', 'num_annotations', 'precision', 'recall', + 'fmeasure', 'accuracy']) # test with other values e = Evaluation(tp=[1, 2, 3.0], fp=[1.5], fn=[0, 3.1]) tp = np.asarray([1, 2, 3], dtype=np.float) diff --git a/tests/test_evaluation_beats.py b/tests/test_evaluation_beats.py index cf7bf28b1..b37107bb4 100644 --- a/tests/test_evaluation_beats.py +++ b/tests/test_evaluation_beats.py @@ -1,9 +1,11 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.evaluation.beats module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest import math @@ -169,6 +171,8 @@ def test_types(self): length, start = find_longest_continuous_segment([]) self.assertIsInstance(length, int) self.assertIsInstance(start, int) + + def test_errors(self): # events must be correct type with self.assertRaises(IndexError): find_longest_continuous_segment(None) @@ -265,24 +269,26 @@ def test_types(self): self.assertIsInstance(score, float) score = pscore({}, {}, 0.2) self.assertIsInstance(score, float) - with self.assertRaises(TypeError): - pscore(None, ANNOTATIONS, 0.2) - with self.assertRaises(TypeError): - pscore(DETECTIONS, None, 0.2) - # tolerance must be correct type + # tolerance must be convertible to float score = pscore(DETECTIONS, ANNOTATIONS, int(1.2)) self.assertIsInstance(score, float) - with self.assertRaises(ValueError): - pscore(DETECTIONS, ANNOTATIONS, []) - with self.assertRaises(TypeError): - pscore(DETECTIONS, ANNOTATIONS, {}) def test_errors(self): # tolerance must be > 0 with self.assertRaises(ValueError): pscore(DETECTIONS, ANNOTATIONS, 0) - with self.assertRaises(ValueError): + # tolerance must be convertible to float + with self.assertRaises(TypeError): pscore(DETECTIONS, ANNOTATIONS, None) + with self.assertRaises(TypeError): + pscore(DETECTIONS, ANNOTATIONS, []) + with self.assertRaises(TypeError): + pscore(DETECTIONS, ANNOTATIONS, {}) + # detections / annotations must be correct type + with self.assertRaises(TypeError): + pscore(None, ANNOTATIONS, 0.2) + with self.assertRaises(TypeError): + pscore(DETECTIONS, None, 0.2) # score relies on intervals, hence at least 2 annotations must be given with self.assertRaises(BeatIntervalError): pscore(DETECTIONS, [1], 0.2) @@ -312,13 +318,23 @@ def test_types(self): self.assertIsInstance(score, float) score = cemgil({}, {}, 0.04) self.assertIsInstance(score, float) + # sigma must be correct type + score = cemgil(DETECTIONS, ANNOTATIONS, int(1)) + self.assertIsInstance(score, float) + + def test_errors(self): + # sigma must not be None + with self.assertRaises(TypeError): + cemgil(DETECTIONS, ANNOTATIONS, None) + # sigma must be greater than 0 + with self.assertRaises(ValueError): + cemgil(DETECTIONS, ANNOTATIONS, 0) + # detections / annotations must be correct type with self.assertRaises(TypeError): cemgil(None, ANNOTATIONS, 0.04) with self.assertRaises(TypeError): cemgil(DETECTIONS, None, 0.04) - # tolerance must be correct type - score = cemgil(DETECTIONS, ANNOTATIONS, int(1)) - self.assertIsInstance(score, float) + # sigma must be correct type with self.assertRaises(TypeError): cemgil(DETECTIONS, ANNOTATIONS, [0.04]) with self.assertRaises(TypeError): @@ -326,13 +342,6 @@ def test_types(self): with self.assertRaises(TypeError): cemgil(DETECTIONS, ANNOTATIONS, {0.04: 0}) - def test_errors(self): - # sigma must be greater than 0 - with self.assertRaises(ValueError): - cemgil(DETECTIONS, ANNOTATIONS, 0) - with self.assertRaises(ValueError): - cemgil(DETECTIONS, ANNOTATIONS, None) - def test_values(self): # two empty sequences should have a perfect score score = cemgil([], [], 0.04) @@ -359,25 +368,21 @@ def test_types(self): self.assertIsInstance(score, float) score = goto({}, {}, 0.175, 0.2, 0.2) self.assertIsInstance(score, float) - with self.assertRaises(TypeError): - goto(None, ANNOTATIONS, 0.175, 0.2, 0.2) - with self.assertRaises(TypeError): - goto(DETECTIONS, None, 0.175, 0.2, 0.2) # parameters must be correct type - score = goto(DETECTIONS, ANNOTATIONS, int(0.175), 0.2, 0.2) + score = goto(DETECTIONS, ANNOTATIONS, int(1.175), 0.2, 0.2) self.assertIsInstance(score, float) - score = goto(DETECTIONS, ANNOTATIONS, 0.175, int(0.2), 0.2) + score = goto(DETECTIONS, ANNOTATIONS, 0.175, int(1.2), 0.2) self.assertIsInstance(score, float) - score = goto(DETECTIONS, ANNOTATIONS, 0.175, 0.2, int(0.2)) + score = goto(DETECTIONS, ANNOTATIONS, 0.175, 0.2, int(1.2)) self.assertIsInstance(score, float) def test_errors(self): # parameters must not be None - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): goto(DETECTIONS, ANNOTATIONS, None, 0.2, 0.2) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): goto(DETECTIONS, ANNOTATIONS, 0.175, None, 0.2) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): goto(DETECTIONS, ANNOTATIONS, 0.175, 0.2, None) # parameters must be positive with self.assertRaises(ValueError): @@ -386,6 +391,11 @@ def test_errors(self): goto(DETECTIONS, ANNOTATIONS, 0.175, -1, 0.2) with self.assertRaises(ValueError): goto(DETECTIONS, ANNOTATIONS, 0.175, 0.2, -1) + # detections / annotations must be correct type + with self.assertRaises(TypeError): + goto(None, ANNOTATIONS, 0.175, 0.2, 0.2) + with self.assertRaises(TypeError): + goto(DETECTIONS, None, 0.175, 0.2, 0.2) # score relies on intervals, hence at least 2 annotations must be given with self.assertRaises(BeatIntervalError): goto(DETECTIONS, [1], 0.175, 0.2, 0.2) @@ -426,26 +436,31 @@ def test_types(self): cmlc, cmlt = cml({}, {}, 0.175, 0.175) self.assertIsInstance(cmlc, float) self.assertIsInstance(cmlt, float) - with self.assertRaises(TypeError): - cml(None, ANNOTATIONS, 0.175, 0.175) - with self.assertRaises(TypeError): - cml(DETECTIONS, None, 0.175, 0.175) # tolerances must be correct type cmlc, cmlt = cml(DETECTIONS, ANNOTATIONS, int(1), int(1)) self.assertIsInstance(cmlc, float) self.assertIsInstance(cmlt, float) - cmlc, cmlt = cml(DETECTIONS, ANNOTATIONS, [0.175], [0.175]) - self.assertIsInstance(cmlc, float) - self.assertIsInstance(cmlt, float) with self.assertRaises(TypeError): cml(DETECTIONS, ANNOTATIONS, {}, {}) + with self.assertRaises(TypeError): + cml(DETECTIONS, ANNOTATIONS, [0.175], [0.175]) def test_errors(self): + # tolerances must not be None + with self.assertRaises(TypeError): + cml(DETECTIONS, ANNOTATIONS, 0.1, None) + with self.assertRaises(TypeError): + cml(DETECTIONS, ANNOTATIONS, None, 0.1) # tolerances must be greater than 0 with self.assertRaises(ValueError): - cml(DETECTIONS, ANNOTATIONS, 0, None) + cml(DETECTIONS, ANNOTATIONS, 0, 1) with self.assertRaises(ValueError): - cml(DETECTIONS, ANNOTATIONS, None, 0) + cml(DETECTIONS, ANNOTATIONS, 1, 0) + # detections / annotations must be correct type + with self.assertRaises(TypeError): + cml(None, ANNOTATIONS, 0.175, 0.175) + with self.assertRaises(TypeError): + cml(DETECTIONS, None, 0.175, 0.175) # score relies on intervals, hence at least 2 ann/det must be given with self.assertRaises(BeatIntervalError): cml(DETECTIONS, [1.], 0.175, 0.175) @@ -487,10 +502,6 @@ def test_types(self): self.assertIsInstance(cmlt, float) self.assertIsInstance(amlc, float) self.assertIsInstance(amlt, float) - with self.assertRaises(TypeError): - continuity(None, ANNOTATIONS, 0.175, 0.175) - with self.assertRaises(TypeError): - continuity(DETECTIONS, None, 0.175, 0.175) # tolerances must be correct type scores = continuity(DETECTIONS, ANNOTATIONS, int(1), int(1)) cmlc, cmlt, amlc, amlt = scores @@ -498,21 +509,36 @@ def test_types(self): self.assertIsInstance(cmlt, float) self.assertIsInstance(amlc, float) self.assertIsInstance(amlt, float) - scores = continuity(DETECTIONS, ANNOTATIONS, [0.175], [0.175]) - cmlc, cmlt, amlc, amlt = scores - self.assertIsInstance(cmlc, float) - self.assertIsInstance(cmlt, float) - self.assertIsInstance(amlc, float) - self.assertIsInstance(amlt, float) - with self.assertRaises(TypeError): - continuity(DETECTIONS, ANNOTATIONS, {}, {}) def test_errors(self): + # tolerances must not be None + with self.assertRaises(TypeError): + continuity(DETECTIONS, ANNOTATIONS, 0.1, None) + with self.assertRaises(TypeError): + continuity(DETECTIONS, ANNOTATIONS, None, 0.1) # tolerances must be greater than 0 with self.assertRaises(ValueError): - continuity(DETECTIONS, ANNOTATIONS, 0, None) + continuity(DETECTIONS, ANNOTATIONS, 1, 0) with self.assertRaises(ValueError): - continuity(DETECTIONS, ANNOTATIONS, None, 0) + continuity(DETECTIONS, ANNOTATIONS, 0, 1) + # tolerances must be correct type + with self.assertRaises(TypeError): + continuity(DETECTIONS, ANNOTATIONS, [0.175], 1) + with self.assertRaises(TypeError): + continuity(DETECTIONS, ANNOTATIONS, 1, [0.175]) + with self.assertRaises(TypeError): + continuity(DETECTIONS, ANNOTATIONS, None, 1) + with self.assertRaises(TypeError): + continuity(DETECTIONS, ANNOTATIONS, 1, None) + with self.assertRaises(TypeError): + continuity(DETECTIONS, ANNOTATIONS, {}, 1) + with self.assertRaises(TypeError): + continuity(DETECTIONS, ANNOTATIONS, 1, {}) + # detections / annotations must be correct type + with self.assertRaises(TypeError): + continuity(None, ANNOTATIONS, 0.175, 0.175) + with self.assertRaises(TypeError): + continuity(DETECTIONS, None, 0.175, 0.175) # score relies on intervals, hence at least 2 ann/det must be given with self.assertRaises(BeatIntervalError): continuity(DETECTIONS, [1.], 0.175, 0.175) @@ -754,10 +780,6 @@ def test_types(self): ig, histogram = information_gain({}, {}, 40) self.assertIsInstance(ig, float) self.assertIsInstance(histogram, np.ndarray) - with self.assertRaises(TypeError): - information_gain(None, ANNOTATIONS, 40) - with self.assertRaises(TypeError): - information_gain(DETECTIONS, None, 40) # tolerances must be correct type ig, histogram = information_gain(DETECTIONS, ANNOTATIONS, 40) self.assertIsInstance(ig, float) @@ -767,6 +789,19 @@ def test_types(self): self.assertIsInstance(histogram, np.ndarray) def test_errors(self): + # num_bins must not be None + with self.assertRaises(TypeError): + information_gain(DETECTIONS, ANNOTATIONS, None) + # num_bins must be correct type + with self.assertRaises(TypeError): + information_gain(DETECTIONS, ANNOTATIONS, [10]) + with self.assertRaises(TypeError): + information_gain(DETECTIONS, ANNOTATIONS, {10}) + # detections / annotations must be correct type + with self.assertRaises(TypeError): + information_gain(None, ANNOTATIONS, 40) + with self.assertRaises(TypeError): + information_gain(DETECTIONS, None, 40) # score relies on intervals, hence at least 2 annotations must be given with self.assertRaises(BeatIntervalError): information_gain([1.], ANNOTATIONS, 4) @@ -925,7 +960,7 @@ def test_results(self): self.assertTrue(np.allclose(e.error_histogram, error_histogram_)) def test_tostring(self): - print BeatEvaluation([], []) + print(BeatEvaluation([], [])) class TestBeatMeanEvaluationClass(unittest.TestCase): @@ -1022,4 +1057,4 @@ def test_results(self): self.assertEqual(len(e), 2) def test_tostring(self): - print BeatMeanEvaluation([]) + print(BeatMeanEvaluation([])) diff --git a/tests/test_evaluation_notes.py b/tests/test_evaluation_notes.py index b5c1f33e6..365d7832e 100644 --- a/tests/test_evaluation_notes.py +++ b/tests/test_evaluation_notes.py @@ -1,12 +1,13 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.evaluation.notes module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest -import __builtin__ import math from madmom.evaluation.notes import * @@ -40,7 +41,7 @@ def test_load_notes_from_file(self): self.assertIsInstance(annotations, np.ndarray) def test_load_notes_from_file_handle(self): - file_handle = __builtin__.open(DATA_PATH + 'stereo_sample.notes') + file_handle = open(DATA_PATH + 'stereo_sample.notes') annotations = load_notes(file_handle) self.assertIsInstance(annotations, np.ndarray) file_handle.close() @@ -224,7 +225,7 @@ def test_results(self): np.std([0, 0.014, -0.001, 0]))) def test_tostring(self): - print NoteEvaluation([], []) + print(NoteEvaluation([], [])) class TestNoteSumEvaluationClass(unittest.TestCase): @@ -289,7 +290,7 @@ def test_results(self): self.assertEqual(e.std_error, e2.std_error) def test_tostring(self): - print NoteSumEvaluation([]) + print(NoteSumEvaluation([])) class TestNoteMeanEvaluationClass(unittest.TestCase): @@ -364,4 +365,4 @@ def test_results(self): self.assertEqual(e.std_error, e2.std_error) def test_tostring(self): - print NoteMeanEvaluation([]) + print(NoteMeanEvaluation([])) diff --git a/tests/test_evaluation_onsets.py b/tests/test_evaluation_onsets.py index 146db5532..00b944042 100644 --- a/tests/test_evaluation_onsets.py +++ b/tests/test_evaluation_onsets.py @@ -1,9 +1,11 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.evaluation.onsets module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest import math @@ -27,6 +29,23 @@ def test_values(self): # test evaluation function class TestOnsetEvaluationFunction(unittest.TestCase): + def test_errors(self): + # detections / annotations must not be None + with self.assertRaises(TypeError): + onset_evaluation(None, ANNOTATIONS) + with self.assertRaises(TypeError): + onset_evaluation(DETECTIONS, None) + # tolerance must be > 0 + with self.assertRaises(ValueError): + onset_evaluation(DETECTIONS, ANNOTATIONS, 0) + # tolerance must be correct type + with self.assertRaises(TypeError): + onset_evaluation(DETECTIONS, ANNOTATIONS, None) + with self.assertRaises(TypeError): + onset_evaluation(DETECTIONS, ANNOTATIONS, []) + with self.assertRaises(TypeError): + onset_evaluation(DETECTIONS, ANNOTATIONS, {}) + def test_results(self): # default window tp, fp, tn, fn, errors = onset_evaluation(DETECTIONS, ANNOTATIONS) @@ -148,7 +167,7 @@ def test_results(self): self.assertEqual(e.std_error, std) def test_tostring(self): - print OnsetEvaluation([], []) + print(OnsetEvaluation([], [])) class TestOnsetSumEvaluationClass(unittest.TestCase): @@ -218,7 +237,7 @@ def test_results(self): self.assertEqual(e.std_error, e2.std_error) def test_tostring(self): - print OnsetSumEvaluation([]) + print(OnsetSumEvaluation([])) class TestOnsetMeanEvaluationClass(unittest.TestCase): @@ -298,4 +317,4 @@ def test_results(self): np.mean([e_.std_error for e_ in [e2, e3]])) def test_tostring(self): - print OnsetMeanEvaluation([]) + print(OnsetMeanEvaluation([])) diff --git a/tests/test_evaluation_tempo.py b/tests/test_evaluation_tempo.py index 2daa3cba9..48ba7e9e3 100644 --- a/tests/test_evaluation_tempo.py +++ b/tests/test_evaluation_tempo.py @@ -1,12 +1,13 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.evaluation.tempo module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest -import __builtin__ import math from madmom.evaluation.tempo import * @@ -28,7 +29,7 @@ def test_load_tempo_from_file(self): self.assertIsInstance(annotations, np.ndarray) def test_load_tempo_from_file_handle(self): - file_handle = __builtin__.open(DATA_PATH + 'sample.tempo') + file_handle = open(DATA_PATH + 'sample.tempo') annotations = load_tempo(file_handle) self.assertIsInstance(annotations, np.ndarray) file_handle.close() @@ -158,33 +159,28 @@ def test_types(self): self.assertIsInstance(scores, tuple) scores = tempo_evaluation({}, {}) self.assertIsInstance(scores, tuple) - # # we do not support normal non-empty lists - # with self.assertRaises(TypeError): - # tempo_evaluation(DETECTIONS.tolist(), ANNOTATIONS.tolist(), 0.08) - # detections must not be None + # tolerance must be correct type + scores = tempo_evaluation(DETECTIONS, ANNOTATIONS, int(1.2)) + self.assertIsInstance(scores, tuple) + + def test_errors(self): + # detections / annotations must not be None with self.assertRaises(TypeError): tempo_evaluation(None, ANN_TEMPI) - # annotations must not be None with self.assertRaises(TypeError): tempo_evaluation(DETECTIONS, None) - # tolerance must be correct type - scores = tempo_evaluation(DETECTIONS, ANNOTATIONS, int(1.2)) - self.assertIsInstance(scores, tuple) - # various not supported versions + # tolerance must be > 0 with self.assertRaises(ValueError): + tempo_evaluation(DETECTIONS, ANNOTATIONS, 0) + # tolerance must be correct type + with self.assertRaises(TypeError): + tempo_evaluation(DETECTIONS, ANN_TEMPI, None) + with self.assertRaises(TypeError): tempo_evaluation(DETECTIONS, ANN_TEMPI, []) - # with self.assertRaises(ValueError): - # tempo_evaluation(DETECTIONS, ANN_TEMPI) - # TODO: what should happen if we supply a dictionary? - # with self.assertRaises(TypeError): - # tempo_evaluation(DETECTIONS, ANN_TEMPI, ANN_STRENGTHS, {}) + with self.assertRaises(TypeError): + tempo_evaluation(DETECTIONS, ANN_TEMPI, {}) def test_values(self): - # tolerance must be > 0 - with self.assertRaises(ValueError): - tempo_evaluation(DETECTIONS, ANNOTATIONS, 0) - with self.assertRaises(ValueError): - tempo_evaluation(DETECTIONS, ANNOTATIONS, None) # no tempi should return perfect score scores = tempo_evaluation([], []) self.assertEqual(scores, (1, True, True)) @@ -319,7 +315,7 @@ def test_results_no_triple(self): self.assertEqual(e.acc2, False) def test_tostring(self): - print TempoEvaluation([], []) + print(TempoEvaluation([], [])) class TestMeanTempoEvaluationClass(unittest.TestCase): @@ -361,4 +357,4 @@ def test_results(self): self.assertEqual(len(e), 2) def test_tostring(self): - print TempoMeanEvaluation([]) + print(TempoMeanEvaluation([])) diff --git a/tests/test_ml_hmm.py b/tests/test_ml_hmm.py index 5a4ba48ec..fb12992cc 100644 --- a/tests/test_ml_hmm.py +++ b/tests/test_ml_hmm.py @@ -1,9 +1,11 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains tests for the madmom.ml.hmm module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest from madmom.ml.hmm import * @@ -20,44 +22,43 @@ (2, 2, 0.7)] OBS_PROB = np.array([[0.7, 0.15, 0.15], - [0.3, 0.5, 0.2], - [0.2, 0.4, 0.4]]) + [0.3, 0.5, 0.2], + [0.2, 0.4, 0.4]]) OBS_SEQ = np.array([0, 2, 2, 0, 0, 1, 1, 2, 0, 2, 1, 1, 1, 2, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0]) -CORRECT_FWD = np.array( - [[0.6754386, 0.23684211, 0.0877193], - [0.369291, 0.36798608, 0.26272292], - [0.18146746, 0.33625874, 0.4822738], - [0.35097423, 0.37533682, 0.27368895], - [0.51780506, 0.32329768, 0.15889725], - [0.17366244, 0.58209473, 0.24424283], - [0.06699296, 0.58957189, 0.34343515], - [0.05708114, 0.3428725, 0.60004636], - [0.18734426, 0.43567034, 0.3769854], - [0.09699435, 0.31882203, 0.58418362], - [0.03609747, 0.47711943, 0.4867831], - [0.02569311, 0.52002881, 0.45427808], - [0.02452257, 0.53259115, 0.44288628], - [0.03637171, 0.31660931, 0.64701899], - [0.02015006, 0.46444741, 0.51540253], - [0.02118133, 0.51228818, 0.46653049], - [0.16609052, 0.48889238, 0.3450171], - [0.06141349, 0.55365814, 0.38492837], - [0.2327641, 0.47273564, 0.29450026], - [0.42127593, 0.37947727, 0.1992468], - [0.57132392, 0.30444215, 0.12423393], - [0.66310201, 0.25840843, 0.07848956], - [0.23315472, 0.59876843, 0.16807684], - [0.43437318, 0.40024174, 0.16538507], - [0.58171672, 0.30436365, 0.11391962]]) +CORRECT_FWD = np.array([[0.6754386, 0.23684211, 0.0877193], + [0.369291, 0.36798608, 0.26272292], + [0.18146746, 0.33625874, 0.4822738], + [0.35097423, 0.37533682, 0.27368895], + [0.51780506, 0.32329768, 0.15889725], + [0.17366244, 0.58209473, 0.24424283], + [0.06699296, 0.58957189, 0.34343515], + [0.05708114, 0.3428725, 0.60004636], + [0.18734426, 0.43567034, 0.3769854], + [0.09699435, 0.31882203, 0.58418362], + [0.03609747, 0.47711943, 0.4867831], + [0.02569311, 0.52002881, 0.45427808], + [0.02452257, 0.53259115, 0.44288628], + [0.03637171, 0.31660931, 0.64701899], + [0.02015006, 0.46444741, 0.51540253], + [0.02118133, 0.51228818, 0.46653049], + [0.16609052, 0.48889238, 0.3450171], + [0.06141349, 0.55365814, 0.38492837], + [0.2327641, 0.47273564, 0.29450026], + [0.42127593, 0.37947727, 0.1992468], + [0.57132392, 0.30444215, 0.12423393], + [0.66310201, 0.25840843, 0.07848956], + [0.23315472, 0.59876843, 0.16807684], + [0.43437318, 0.40024174, 0.16538507], + [0.58171672, 0.30436365, 0.11391962]]) class TestHmmInference(unittest.TestCase): def setUp(self): - frm, to, prob = zip(*TRANSITIONS) + frm, to, prob = list(zip(*TRANSITIONS)) tm = TransitionModel.from_dense(to, frm, prob) om = DiscreteObservationModel(OBS_PROB) @@ -79,5 +80,6 @@ def test_forward(self): self.assertTrue(np.allclose(fwd, CORRECT_FWD)) def test_forward_generator(self): - fwd = np.vstack(list(self.hmm.forward_generator(OBS_SEQ, block_size=5))) + fwd = np.vstack(list(self.hmm.forward_generator(OBS_SEQ, + block_size=5))) self.assertTrue(np.allclose(fwd, CORRECT_FWD)) diff --git a/tests/test_utils.py b/tests/test_utils.py index f362dd79b..ecdc2e50b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,12 +1,13 @@ # encoding: utf-8 +# pylint: skip-file """ This file contains test functions for the madmom.utils module. """ -# pylint: skip-file + +from __future__ import absolute_import, division, print_function import unittest -import __builtin__ from madmom.utils import * @@ -126,7 +127,7 @@ def test_read_events_from_file(self): self.assertIsInstance(events, np.ndarray) def test_read_events_from_file_handle(self): - file_handle = __builtin__.open(DATA_PATH + 'events.txt') + file_handle = open(DATA_PATH + 'events.txt') events = load_events(file_handle) self.assertIsInstance(events, np.ndarray) file_handle.close() @@ -159,7 +160,7 @@ def test_write_events_to_file(self): self.assertEqual(EVENTS, result) def test_write_events_to_file_handle(self): - file_handle = __builtin__.open(DATA_PATH + 'events.txt', 'w') + file_handle = open(DATA_PATH + 'events.txt', 'wb') result = write_events(EVENTS, file_handle) self.assertEqual(EVENTS, result) file_handle.close() @@ -205,8 +206,7 @@ def test_fps(self): quantized = quantize_events(EVENTS, 10, length=None) idx = np.nonzero(quantized)[0] # tar: [1, 1.02, 1.5, 2.0, 2.03, 2.05, 2.5, 3] - correct = [10, 15, 20, 21, 25, 30] - self.assertTrue(np.allclose(idx, correct)) + self.assertTrue(np.allclose(idx, [10, 15, 20, 25, 30])) # 100 FPS quantized = quantize_events(EVENTS, 100, length=None) idx = np.nonzero(quantized)[0] @@ -219,36 +219,31 @@ def test_length(self): quantized = quantize_events(EVENTS, 100, length=280) idx = np.nonzero(quantized)[0] # targets: [1, 1.02, 1.5, 2.0, 2.03, 2.05, 2.5, 3] - correct = [100, 102, 150, 200, 203, 205, 250] - self.assertTrue(np.allclose(idx, correct)) + self.assertTrue(np.allclose(idx, [100, 102, 150, 200, 203, 205, 250])) def test_rounding(self): # without length quantized = quantize_events([3.95], 10, length=None) idx = np.nonzero(quantized)[0] - correct = [40] - self.assertTrue(np.allclose(idx, correct)) + self.assertTrue(np.allclose(idx, [40])) # with length - quantized = quantize_events([3.95], 10, length=40) + quantized = quantize_events([3.95], 10, length=39) idx = np.nonzero(quantized)[0] - correct = [] - self.assertTrue(np.allclose(idx, correct)) + self.assertTrue(np.allclose(idx, [])) # round down with length quantized = quantize_events([3.9499999], 10, length=40) idx = np.nonzero(quantized)[0] - correct = [39] - self.assertTrue(np.allclose(idx, correct)) + self.assertTrue(np.allclose(idx, [39])) def test_shift(self): # no length quantized = quantize_events(EVENTS, 10, shift=1) idx = np.nonzero(quantized)[0] - correct = [20, 25, 30, 31, 35, 40] - self.assertTrue(np.allclose(idx, correct)) + self.assertTrue(np.allclose(idx, [20, 25, 30, 35, 40])) # limited length quantized = quantize_events(EVENTS, 10, shift=1, length=35) idx = np.nonzero(quantized)[0] - correct = [20, 25, 30, 31] + correct = [20, 25, 30] self.assertTrue(np.allclose(idx, correct))