diff --git a/README.md b/README.md index 883912b..7503a10 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,8 @@ This is a test and benchmark suite for MPI. The intention is to not rely on an entire library, but only incremental steps. -Lemons are both tests that fail, MPI functions that don't do the correct thing and -bad performance implementations. +Lemons are both tests that fail, MPI functions that don't do the correct thing +and bad performance implementations. # Usage @@ -17,7 +17,8 @@ python -m lemonspotter [path_to_database] ``` # Requirements -To run Lemonspotter only python 3.8.0 is needed. To contribute to development, please install all packages listed in the `requirements.txt` file. +To run Lemonspotter only python 3.8.0 is needed. To contribute to development, +please install all packages listed in the `requirements.txt` file. ## Arguments @@ -26,5 +27,3 @@ To run Lemonspotter only python 3.8.0 is needed. To contribute to development, p #### Print Lemonspotter Version ```-v, --version``` - - diff --git a/lemonspotter/__main__.py b/lemonspotter/__main__.py index 5d07911..019b966 100644 --- a/lemonspotter/__main__.py +++ b/lemonspotter/__main__.py @@ -109,6 +109,15 @@ def set_logging_level(log_level: str): logging.basicConfig(level=numeric_level) +def check_version() -> None: + """ + Check that current version of Python runtime is at least 3.8. + """ + + if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 8): + raise RuntimeError('Python 3.8 is required to run LemonSpotter.') + + def main() -> None: """ This function is the workflow of LemonSpotter. diff --git a/lemonspotter/core/argument.py b/lemonspotter/core/argument.py new file mode 100644 index 0000000..fbabe8b --- /dev/null +++ b/lemonspotter/core/argument.py @@ -0,0 +1,52 @@ +""" +This module contains the class definition of an Argument. +""" + +from typing import Sequence, Optional + +from lemonspotter.core.variable import Variable + + +class Argument: + """ + This class represents parameters of Function objects translated from the specification. + """ + + # TODO Arguments need to also support code fragements? + # some dependencies will require executing additional code. + + def __init__(self, + variable: Variable, + dependencies: Optional[Sequence[Variable]] = None): + self._variable = variable + self._dependencies = dependencies + + def __str__(self) -> str: + return self._variable.name + + def __repr__(self) -> str: + return self._variable.name + + @property + def variable(self) -> Variable: + """ + """ + + return self._variable + + @property + def has_dependencies(self) -> bool: + """ + """ + + return self._dependencies is not None + + @property + def dependencies(self) -> Sequence[Variable]: + """ + """ + + if self._dependencies is None: + raise RuntimeError('No dependencies for this argument.') + + return self._dependencies diff --git a/lemonspotter/core/function.py b/lemonspotter/core/function.py index 3a51a8e..0bec0e6 100644 --- a/lemonspotter/core/function.py +++ b/lemonspotter/core/function.py @@ -2,12 +2,12 @@ This module defines the function class which respresents functions from the specification. """ -from typing import Mapping, Any, AbstractSet, Sequence, Iterable -from functools import lru_cache +from typing import Mapping, Any, AbstractSet, Sequence, Iterable, Optional from lemonspotter.core.database import Database from lemonspotter.core.type import Type from lemonspotter.core.parameter import Parameter +from lemonspotter.core.parameter import Direction class Function: @@ -22,6 +22,12 @@ def __init__(self, json: Mapping[str, Any]) -> None: self._json: Mapping[str, Any] = json self.properties: Mapping[str, Any] = {} + # cache + self._parameters: Optional[Sequence[Parameter]] = None + self._in_parameters: Optional[Sequence[Parameter]] = None + self._inout_parameters: Optional[Sequence[Parameter]] = None + self._out_parameters: Optional[Sequence[Parameter]] = None + def __repr__(self) -> str: """ Defines informal string behavior for function type @@ -41,7 +47,7 @@ def name(self) -> str: """This property provides access to the Function name.""" if 'name' not in self._json: - raise Exception('Function name is not in JSON.') + raise RuntimeError('Function name is not in JSON.') return self._json['name'] @@ -49,70 +55,126 @@ def name(self) -> str: def has_parameters(self) -> bool: """""" - return self._json.get('parameters', False) + return 'parameters' in self._json + + @property + def has_in_parameters(self) -> bool: + """""" + + return len(self.in_parameters) > 0 + + @property + def has_inout_parameters(self) -> bool: + """""" + + return len(self.inout_parameters) > 0 + + @property + def has_out_parameters(self) -> bool: + """""" + + return len(self.out_parameters) > 0 @property # type: ignore - @lru_cache() def parameters(self) -> Sequence[Parameter]: """This property provides access to the parameter list of this Function object.""" if 'parameters' not in self._json: - raise Exception('Parameters are not in JSON.') + raise RuntimeError('Parameters are not in JSON.') + + if self._parameters is None: + self._parameters = tuple(Parameter(parameter) + for parameter in self._json['parameters']) + + return self._parameters + + @property + def in_parameters(self) -> Sequence[Parameter]: + """ + Find all parameters which are with IN parameters. + """ + + if self._in_parameters is None: + self._in_parameters = tuple(parameter + for parameter in self.parameters + if parameter.direction is Direction.IN) + + return self._in_parameters + + @property + def inout_parameters(self) -> Sequence[Parameter]: + """ + Find all parameters which are with INOUT parameters. + """ + + if self._inout_parameters is None: + self._inout_parameters = tuple(parameter + for parameter in self.parameters + if parameter.direction is Direction.INOUT) + + return self._inout_parameters + + @property + def out_parameters(self) -> Sequence[Parameter]: + """ + Find all parameters which are either OUT parameters. + """ + + if self._out_parameters is None: + self._out_parameters = tuple(parameter + for parameter in self.parameters + if parameter.direction is Direction.OUT) - return tuple(Parameter(parameter) for parameter in self._json['parameters']) + return self._out_parameters @property def return_type(self) -> Type: """This property provides the Type object of the return of this Function.""" if 'return' not in self._json: - raise Exception('Return is not in JSON.') + raise RuntimeError('Return is not in JSON.') return Database().get_type(self._json['return']) @property # type: ignore - @lru_cache() def needs_any(self) -> AbstractSet['Function']: """This property provides access to the any set of needed Function objects.""" if 'needs_any' not in self._json: - raise Exception('Needs any is not in JSON.') + raise RuntimeError('Needs any is not in JSON.') subset = filter(lambda name: Database().has_function(name), self._json['needs_any']) return set(Database().get_function(func_name) for func_name in subset) @property # type: ignore - @lru_cache() def needs_all(self) -> AbstractSet['Function']: """This property provides access to the all set of needed Function objects.""" if 'needs_all' not in self._json: - raise Exception('Needs all is not in JSON.') + raise RuntimeError('Needs all is not in JSON.') subset = filter(lambda name: Database().has_function(name), self._json['needs_all']) return set(Database().get_function(func_name) for func_name in subset) @property # type: ignore - @lru_cache() def leads_any(self) -> AbstractSet['Function']: """This property provides access to the any set of lead Function objects.""" if 'leads_any' not in self._json: - raise Exception('Leads any is not in JSON.') + raise RuntimeError('Leads any is not in JSON.') subset = filter(lambda name: Database().has_function(name), self._json['leads_any']) return set(Database().get_function(func_name) for func_name in subset) @property # type: ignore - @lru_cache() def leads_all(self) -> AbstractSet['Function']: """This property provides access to the all set of lead the Function objects.""" if 'leads_all' not in self._json: - raise Exception('Leads all is not in JSON.') + raise RuntimeError('Leads all is not in JSON.') subset = filter(lambda name: Database().has_function(name), self._json['leads_all']) diff --git a/lemonspotter/core/parameter.py b/lemonspotter/core/parameter.py index 6b967b7..96cb528 100644 --- a/lemonspotter/core/parameter.py +++ b/lemonspotter/core/parameter.py @@ -10,6 +10,9 @@ class Direction(Enum): + """ + """ + IN = 'in' OUT = 'out' INOUT = 'inout' @@ -28,7 +31,7 @@ def name(self) -> str: """This property provides the name of the Parameter.""" if 'name' not in self._json: - raise Exception('Name is not in JSON.') + raise RuntimeError('Name is not in JSON.') return self._json['name'] @@ -37,7 +40,7 @@ def type(self) -> Type: """This property provides the Type object of the abstract type of this Parameter.""" if 'abstract_type' not in self._json: - raise Exception('Abstract type is not in JSON.') + raise RuntimeError('Abstract type is not in JSON.') return Database().get_type(self._json['abstract_type']) @@ -46,7 +49,7 @@ def direction(self) -> Direction: """This property provides the direction of the Parameter.""" if 'direction' not in self._json: - raise Exception('Direction is not in JSON.') + raise RuntimeError('Direction is not in JSON.') return Direction(self._json['direction']) diff --git a/lemonspotter/core/partition.py b/lemonspotter/core/partition.py index f6d3135..b04218e 100644 --- a/lemonspotter/core/partition.py +++ b/lemonspotter/core/partition.py @@ -50,6 +50,7 @@ def symbol(cls, operand: 'Operand') -> str: return symbols[operand] +# TODO should be multiple subclasses of Partition class PartitionType(Enum): NUMERIC = 'numeric' LITERAL = 'literal' diff --git a/lemonspotter/core/sample.py b/lemonspotter/core/sample.py index 97e3205..c071a32 100644 --- a/lemonspotter/core/sample.py +++ b/lemonspotter/core/sample.py @@ -1,10 +1,14 @@ """ +Defines the FunctionSample class. """ -from typing import Optional, Sequence, Callable, Iterable +from typing import Optional, Callable, Mapping, MutableMapping import logging +from itertools import chain from lemonspotter.core.function import Function +from lemonspotter.core.parameter import Parameter +from lemonspotter.core.argument import Argument from lemonspotter.core.parameter import Direction from lemonspotter.core.variable import Variable from lemonspotter.core.statement import (ConditionStatement, @@ -16,68 +20,56 @@ class FunctionSample: - """""" + """ + Container for a function and arguments. + """ def __init__(self, function: Function, valid: bool, - variables: Iterable[Variable] = set(), - arguments: Sequence[Variable] = [], + arguments: Optional[MutableMapping[str, Argument]] = None, evaluator: Optional[Callable[[], bool]] = None, ) -> None: self._function = function self._valid = valid - self._variables = variables - self._arguments = arguments + self._arguments: MutableMapping[str, Argument] = arguments if arguments else {} self._evaluator = evaluator - self._return_variable: Variable = Variable(function.return_type) + # generate out variables, return, and out parameters + self._return_variable = Variable(function.return_type, + 'return_' + function.name) + self._generate_out_arguments() @property def function(self) -> Function: - """""" + """ + Retrieve the Function object which this FunctionSample is a sample of. + """ return self._function @property def return_variable(self) -> Variable: - """""" + """ + Retrieve the Variable instance which is the return variable of this + FunctionSample instance. + """ return self._return_variable @property - def arguments(self) -> Sequence[Variable]: - """""" + def arguments(self) -> Mapping[str, Argument]: + """ + Gets the arguments from this FunctionSample. + """ return self._arguments - @arguments.setter - def arguments(self, arguments: Sequence[Variable]) -> None: - """""" - - if arguments is None: - raise RuntimeError('Arguments given to Sample are None.') - - self._arguments = arguments - - @property - def variables(self) -> Iterable[Variable]: - """""" - - return self._variables - - @variables.setter - def variables(self, variables: Iterable[Variable]) -> None: - """""" - - if variables is None: - raise RuntimeError('Variables given to Sample is None.') - - self._variables = variables - @property def evaluator(self) -> Callable[[], bool]: - """""" + """ + Gets the evaluator function for this FunctionSample. + """ if self._evaluator is None: raise RuntimeError('Evaluator is None. Needs to be assigned.') @@ -86,7 +78,9 @@ def evaluator(self) -> Callable[[], bool]: @evaluator.setter def evaluator(self, evaluator: Callable[[], bool]) -> None: - """""" + """ + Sets the evaluator function for this FunctionSample. + """ if evaluator is None: raise RuntimeError('Evaluator given to Sample is None.') @@ -94,46 +88,56 @@ def evaluator(self, evaluator: Callable[[], bool]) -> None: self._evaluator = evaluator def generate_source(self, source: BlockStatement, comment: str = None) -> None: - """""" - - # assign predefined arguments and check for collisions - def check_argument(arg): - if arg.predefined: - predef = source.get_variable(arg.value) - - if predef is None: - raise RuntimeError('Predefined variable not present in source.') + """ + Expresses the FunctionSample as a FunctionStatement and adds it to the given + Source. + """ - if predef.type is arg.type: - return predef + # add arguments to source + logging.debug('Arguments: %s', self.arguments) + logging.debug('Source variables %s', source._variables) + for parameter in self.function.parameters: + # fetch argument + argument = self.arguments[parameter.name] + + # add dependent variables to source + # TODO + # same procedure as below? + + # add argument variable to source + if argument.variable.predefined: + # argument.variable.value has to match a source variable + # get variable from source + source_variable = source.get_variable(argument.variable.name) + if source_variable is None: + raise RuntimeError('Predefined argument %s not found in source %s.', + argument.variable.name, + source._variables) + + if source_variable is argument.variable.type: + # variable found is identical to required + pass else: - # try referencing - if predef.type.referencable: - # create variable - var = Variable(predef.type.reference(), - f'{predef.name}_ref', - f'&{predef.name}') - - return var + if source_variable.type.referencable: + transition_variable = Variable(source_variable.type.reference(), + f'{source_variable.name}_ref', + f'&{source_variable.name}') + source.add_at_start(DeclarationAssignmentStatement(transition_variable)) + # this variable needs to be referenced in the function statement! - elif source.get_variable(arg.name) is not None: - raise RuntimeError('Name collision found between argument and variable.') + self._arguments[parameter.name] = Argument(transition_variable) - else: - return arg + else: + raise NotImplementedError('variable of predefined found, but not same ' + 'type.') - self.arguments = [check_argument(argument) for argument in self.arguments] - - # add arguments to source - for variable in self.arguments: - existing = source.get_variable(variable.name) - if not existing: - if variable.value: - source.add_at_start(DeclarationAssignmentStatement(variable)) + elif not source.get_variable(argument.variable.name): + if argument.variable.value: + source.add_at_start(DeclarationAssignmentStatement(argument.variable)) else: - source.add_at_start(DeclarationStatement(variable)) + source.add_at_start(DeclarationStatement(argument.variable)) # add function call to source source.add_at_start(self._generate_statement(source, comment)) @@ -141,9 +145,9 @@ def check_argument(arg): # add outputs source.add_at_start(FunctionStatement.generate_print(self._return_variable)) - for parameter, argument in zip(self.function.parameters, self.arguments): # type: ignore - if parameter.direction is Direction.OUT or parameter.direction is Direction.INOUT: - source.add_at_start(FunctionStatement.generate_print(argument)) + for parameter in chain(self.function.out_parameters, self.function.inout_parameters): + source.add_at_start(FunctionStatement.generate_print( + self.arguments[parameter.name].variable)) # add check statements to call source.add_at_start(self._generate_return_check()) @@ -178,42 +182,60 @@ def _generate_statement(self, Generates a compilable expression of the function with the given arguments. """ - self.return_variable.name = f'return_{self.function.name}' - if source.get_variable(self.return_variable.name) is not None: # todo rename output return name, we have control over this above - raise NotImplementedError('Test if the variable already exists.') + raise NotImplementedError('Return variable already exists!.') + # TODO need to check out arguments as well - statement = (f'{self.function.return_type.language_type} {self.return_variable.name}' - f' = {self.function.name}(') - - # add arguments - logging.debug('arguments %s', str(self.arguments)) - logging.debug('parameters %s', str(self.function.parameters)) + statement = [] + statement.append((f'{self.function.return_type.language_type} ' + f'{self.return_variable.name} ' + f'= {self.function.name}(')) # fill arguments - if self.arguments is not None: - pairs = zip(self.arguments, self.function.parameters) # type: ignore - for idx, (argument, parameter) in enumerate(pairs): - mod = '' -# pointer_diff = argument.pointer_level - parameter.pointer_level -# if pointer_diff < 0: -# # addressof & -# mod += '&' -# -# elif pointer_diff > 0: -# # dereference * -# raise NotImplementedError - - statement += (mod + argument.name) + if self.function.has_parameters: + for idx, parameter in enumerate(self.function.parameters): + # get argument for parameter + argument = self.arguments[parameter.name] + + statement.append(argument.variable.name) # if not last argument then add comma - if (idx + 1) != len(self._arguments): - statement += ', ' + if (idx + 1) != len(self.function.parameters): + statement.append(', ') + + statement.append(');') - statement += ');' + logging.debug('function statement %s', ''.join(statement)) return FunctionStatement(self._function.name, - statement, + ''.join(statement), {self.return_variable.name: self.return_variable}, comment) + + def _generate_out_arguments(self) -> None: + """ + Generates appropriate out variables for the out parameters. + """ + + logging.debug('generating out arguments for %s', self.function.name) + logging.debug('out parameters %s', self.function.out_parameters) + + for parameter in self.function.out_parameters: + logging.debug('generating out argument for parameter %s', parameter.name) + + # generate out variable and its dependents + value = None + + # TODO this doesn't really capture it, some types are dereferencable, + # but need to be passed in as is + if parameter.type.abstract_type == 'STRING': + value = f'malloc({parameter.length} * sizeof(char))' + + elif parameter.type.dereferencable: + value = f'malloc(sizeof({parameter.type.dereference().language_type}))' + + self._arguments[parameter.name] = Argument(Variable( + parameter.type, + parameter.name + '_out', + value)) diff --git a/lemonspotter/core/sampler.py b/lemonspotter/core/sampler.py index 6beda07..0683245 100644 --- a/lemonspotter/core/sampler.py +++ b/lemonspotter/core/sampler.py @@ -3,7 +3,9 @@ from abc import ABC, abstractmethod from typing import Iterable +import logging +from lemonspotter.core.database import Database from lemonspotter.core.function import Function from lemonspotter.core.sample import FunctionSample @@ -14,4 +16,27 @@ class Sampler(ABC): @abstractmethod def generate_samples(self, function: Function) -> Iterable[FunctionSample]: - return [] + """ + Generate FunctionSample objects from a given function according to an + individual Samplers specification. + """ + + return set() + + @classmethod + def _generate_empty_sample(cls, function: Function) -> Iterable[FunctionSample]: + """ + Generate a FunctionSample for functions which don't have any parameters. + """ + + logging.debug('%s has no arguments.', function.name) + + sample = FunctionSample(function, True) + + def evaluator(sample=sample) -> bool: + return (sample.return_variable.value == + Database().get_constant('MPI_SUCCESS').value) + + sample.evaluator = evaluator + + return {sample} diff --git a/lemonspotter/core/statement.py b/lemonspotter/core/statement.py index f18482e..55da2ab 100644 --- a/lemonspotter/core/statement.py +++ b/lemonspotter/core/statement.py @@ -23,6 +23,10 @@ def __init__(self, variables: Dict[str, Variable] = None, comment: str = None) - self._statement: Optional[str] = None self._comment: str = comment.strip() if comment else '' + def __str__(self) -> str: + assert self._statement is not None + return self._statement + @property def comment(self) -> str: return self._comment @@ -236,7 +240,7 @@ def __init__(self, expression: str, comment: str = None) -> None: super().__init__(comment=comment) if not expression: - raise Exception('Expression passed to ReturnStatement is empty.') + raise RuntimeError('Expression passed to ReturnStatement is empty.') self._statement = f'return {expression};' diff --git a/lemonspotter/core/test.py b/lemonspotter/core/test.py index 43df3cc..9ba99b2 100644 --- a/lemonspotter/core/test.py +++ b/lemonspotter/core/test.py @@ -134,7 +134,7 @@ def source(self) -> Source: """This property provides the Source of the Test.""" if self._source is None: - raise Exception('Source is None.') + raise RuntimeError('Source is None.') return self._source diff --git a/lemonspotter/core/type.py b/lemonspotter/core/type.py index 9887114..5312f61 100644 --- a/lemonspotter/core/type.py +++ b/lemonspotter/core/type.py @@ -3,7 +3,6 @@ """ import logging -from functools import lru_cache from typing import TYPE_CHECKING, Mapping, Any, Iterable from lemonspotter.core.database import Database @@ -75,7 +74,6 @@ def constants(self) -> Iterable['Constant']: return Database().get_constants(self.abstract_type) @property # type: ignore - @lru_cache() def partitions(self) -> Iterable[Partition]: """""" @@ -102,7 +100,7 @@ def reference(self) -> 'Type': """""" if 'reference' not in self._json: - raise Exception(f'Type {self._json["name"]} cannot be referenced.') + raise RuntimeError(f'Type {self._json["name"]} cannot be referenced.') return Database().get_type(self._json['reference']) @@ -116,6 +114,6 @@ def dereference(self) -> 'Type': """""" if 'dereference' not in self._json: - raise Exception(f'Type {self._json["name"]} cannot be dereferenced.') + raise RuntimeError(f'Type {self._json["name"]} cannot be dereferenced.') return Database().get_type(self._json['dereference']) diff --git a/lemonspotter/core/variable.py b/lemonspotter/core/variable.py index b080528..eaf3a98 100644 --- a/lemonspotter/core/variable.py +++ b/lemonspotter/core/variable.py @@ -40,7 +40,7 @@ def name(self) -> str: """This property provides the name of the Variable.""" if self._name is None: - raise Exception('Name is None.') + raise RuntimeError('Name is None.') return self._name @@ -49,7 +49,7 @@ def name(self, name: str) -> None: """""" if name is None: - raise Exception('Assigning None to name of Variable.') + raise RuntimeError('Assigning None to name of Variable.') self._name = name @@ -64,7 +64,7 @@ def value(self, value: str) -> None: """This allows setting the value of the Variable.""" if value is None: - raise Exception('Assigning None to value of Variable.') + raise RuntimeError('Assigning None to value of Variable.') self._value = value diff --git a/lemonspotter/samplers/declare.py b/lemonspotter/samplers/declare.py index af950ed..6c038f3 100644 --- a/lemonspotter/samplers/declare.py +++ b/lemonspotter/samplers/declare.py @@ -3,13 +3,16 @@ """ import logging -from typing import Iterable +from typing import Sequence, Iterable +from itertools import chain from lemonspotter.core.parameter import Direction from lemonspotter.core.sampler import Sampler from lemonspotter.core.variable import Variable from lemonspotter.core.function import Function from lemonspotter.core.sample import FunctionSample +from lemonspotter.core.parameter import Parameter +from lemonspotter.core.argument import Argument class DeclarationSampler(Sampler): @@ -23,6 +26,8 @@ def __str__(self) -> str: def generate_samples(self, function: Function) -> Iterable[FunctionSample]: """ + Generate a FunctionSample which correctly declares everything such + that it is able to be used to build a source fragment. """ logging.debug('DeclarationSampler used for %s', function.name) @@ -32,22 +37,22 @@ def evaluator() -> bool: 'code, not runnable.') # generate valid but empty arguments - arguments = [] - variables = set() + arguments = {} - for parameter in function.parameters: # type: ignore - if parameter.direction == Direction.OUT and parameter.type.dereferencable: - mem_alloc = f'malloc(sizeof({parameter.type.dereference().language_type}))' + for parameter in chain(function.in_parameters, + function.inout_parameters): # type: ignore + arguments[parameter.name] = self._generate_argument(parameter) - variable = Variable(parameter.type, f'arg_{parameter.name}', mem_alloc) - variables.add(variable) - else: - variable = Variable(parameter.type, f'arg_{parameter.name}') - variables.add(variable) + sample = FunctionSample(function, True, arguments, evaluator) - logging.debug('declaring variable argument: %s', variable.name) - arguments.append(variable) + return set([sample]) - sample = FunctionSample(function, True, variables, arguments, evaluator) + @classmethod + def _generate_argument(cls, parameter: Parameter) -> Argument: + """ + Generate all variables required for this parameter. + """ - return set([sample]) + assert parameter.direction is not Direction.OUT + + return Argument(Variable(parameter.type, f'arg_{parameter.name}')) diff --git a/lemonspotter/samplers/valid.py b/lemonspotter/samplers/valid.py index 1b03fe7..ab476fd 100644 --- a/lemonspotter/samplers/valid.py +++ b/lemonspotter/samplers/valid.py @@ -3,18 +3,27 @@ """ import logging -from typing import Iterable -import itertools +from typing import Iterable, Mapping +from itertools import chain, product from lemonspotter.core.sampler import Sampler from lemonspotter.core.database import Database from lemonspotter.core.variable import Variable from lemonspotter.core.function import Function +from lemonspotter.core.argument import Argument from lemonspotter.core.sample import FunctionSample from lemonspotter.core.parameter import Parameter, Direction from lemonspotter.core.partition import PartitionType +def cartesian_dict_product(inp): + """ + https://stackoverflow.com/questions/5228158/cartesian-product-of-a-dictionary-of-lists + """ + + return (dict(zip(inp.keys(), values)) for values in product(*inp.values())) + + class ValidSampler(Sampler): """ """ @@ -28,66 +37,51 @@ def generate_samples(self, function: Function) -> Iterable[FunctionSample]: logging.debug('generating samples of parameters for %s', function.name) - if not function.parameters: - logging.debug('%s has no arguments.', function.name) - - sample = FunctionSample(function, True, {}, []) - - def evaluator(sample=sample) -> bool: - return (sample.return_variable.value == - Database().get_constant('MPI_SUCCESS').value) - - sample.evaluator = evaluator - - return {sample} - - argument_lists = [] + if not (function.has_in_parameters or function.has_inout_parameters): + return self._generate_empty_sample(function) - for parameter in function.parameters: # type: ignore - argument_lists.append(self.generate_sample(parameter)) + arguments = {} + for parameter in chain(function.in_parameters, + function.inout_parameters): # type: ignore + arguments[parameter.name] = self._generate_arguments(parameter) - if not argument_lists: - raise Exception('No arguments generated from a function with parameters.') + if not arguments: + raise RuntimeError('No arguments generated from a function with parameters.') - logging.debug('pre cartesian product: %s', str(argument_lists)) + logging.debug('pre cartesian product: %s', arguments) # cartesian product of all arguments - combined = set(itertools.product(*argument_lists)) - logging.debug('prefiltering argument lists: %s', str(combined)) + combined = cartesian_dict_product(arguments) + logging.debug('prefiltering argument lists: %s', combined) # respect filters of Function - def argument_filter(argument_list: Iterable) -> bool: + def argument_filter(arguments: Mapping[str, Argument]) -> bool: for sieve in function.filters: # any sieve needs to be True # go through parameters/argument mapping, needs to match all requirements - for parameter, argument in zip(function.parameters, argument_list): # type: ignore - if parameter.direction is Direction.OUT: - # note we don't write out arguments in function filters - continue - - elif sieve[parameter.name]['value'] == 'any': + for parameter_name, argument in arguments.items(): # type: ignore + if sieve[parameter_name]['value'] == 'any': continue - elif sieve[parameter.name]['value'] != argument.value: + if sieve[parameter_name]['value'] != argument.variable.value: break else: # sieve applies return True - else: - logging.debug('%s has no sieve allowed.', argument_list) - return False + logging.debug('%s has no sieve allowed.', arguments) + return False filtered = filter(argument_filter, combined) # convert to FunctionSample samples = set() - for argument_list in filtered: - sample = FunctionSample(function, True, set(argument_list), argument_list) + for argument in filtered: + sample = FunctionSample(function, True, argument) # function without parameters - # NOTE sample=sample is done to avoid late binding closure behaviour! + # note sample=sample is done to avoid late binding closure behaviour! def evaluator(sample=sample) -> bool: # todo use valid error lookup rule logging.debug('evaluator for function %s', function.name) @@ -98,62 +92,48 @@ def evaluator(sample=sample) -> bool: Database().get_constant('MPI_SUCCESS').value) sample.evaluator = evaluator - samples.add(sample) return samples - def generate_sample(self, parameter: Parameter) -> Iterable[Variable]: + def _generate_arguments(self, parameter: Parameter) -> Iterable[Argument]: """""" - type_samples = [] + assert parameter.direction is not Direction.OUT - if parameter.direction == Direction.OUT: - if parameter.type.abstract_type == 'STRING': - mem_alloc = f'malloc({parameter.length} * sizeof(char))' - var = Variable(parameter.type, parameter.name + '_out', mem_alloc) - type_samples.append(var) - - elif parameter.type.dereferencable: - mem_alloc = f'malloc(sizeof({parameter.type.dereference().language_type}))' - var = Variable(parameter.type, parameter.name + '_out', mem_alloc) - type_samples.append(var) - - else: - # generate out variable - var = Variable(parameter.type, parameter.name + '_out') - type_samples.append(var) - - return type_samples + arguments = set() # TODO partition should return str for value # it is in charge of interpreting PartitionType, not here + # this should be in an object oriented pattern, not an large if block for partition in parameter.type.partitions: # type: ignore if partition.type is PartitionType.LITERAL: name = f'{parameter.name}_arg_{partition.value}' - var = Variable(parameter.type, name, partition.value) - type_samples.append(var) + arguments.add(Argument(Variable(parameter.type, name, partition.value))) elif partition.type is PartitionType.NUMERIC: name = f'{parameter.name}_arg_{partition.value}' - var = Variable(parameter.type, name, partition.value) - type_samples.append(var) + arguments.add(Argument(Variable(parameter.type, name, partition.value))) elif partition.type is PartitionType.PREDEFINED: - name = f'{parameter.name}_arg_{partition.value}' - var = Variable(parameter.type, name, partition.value, predefined=True) + name = f'{partition.value}' - type_samples.append(var) + arguments.add(Argument(Variable(parameter.type, + name, + partition.value, + predefined=True))) elif partition.type is PartitionType.CONSTANT: for constant in parameter.type.constants: name = f'{parameter.name}_arg_{constant.name}' - type_samples.append(constant.generate_variable(name)) + + arguments.add(Argument(constant.generate_variable(name))) else: - logging.error(('Trying to generate variable from unknown' - ' partition type in ValidSampler.') + str(partition.type)) + logging.error(('Trying to generate variable from unknown partition' + ' type %s in ValidSampler.'), + partition.type) - return type_samples + return arguments