diff --git a/pyttb/ktensor.py b/pyttb/ktensor.py index 2507885..74264e5 100644 --- a/pyttb/ktensor.py +++ b/pyttb/ktensor.py @@ -889,21 +889,24 @@ def full(self) -> ttb.tensor: [63. 85.]] """ + def min_split_dims(dims): """ solve min_{i in range(1,d)} product(dims[:i]) + product(dims[i:]) to minimize the memory footprint of the intermediate matrix """ - sum_of_prods = [np.prod(dims[:i])+np.prod(dims[i:]) - for i in range(1,len(dims))] + sum_of_prods = [ + np.prod(dims[:i]) + np.prod(dims[i:]) for i in range(1, len(dims)) + ] i_min = np.argmin(sum_of_prods) + 1 # note range above starts at 1 return i_min i_split = min_split_dims(self.shape) - data = ( (ttb.khatrirao(*self.factor_matrices[:i_split],reverse=True) * w) @ - ttb.khatrirao(*self.factor_matrices[i_split:],reverse=True).T ) - return pyttb.tensor(data, self.shape, copy=False) + data = ( + ttb.khatrirao(*self.factor_matrices[:i_split], reverse=True) * self.weights + ) @ ttb.khatrirao(*self.factor_matrices[i_split:], reverse=True).T + return ttb.tensor(data, self.shape, copy=False) def innerprod( self, other: Union[ttb.tensor, ttb.sptensor, ktensor, ttb.ttensor]