Source code for cvxpy.atoms.affine.einsum

"""
Copyright, the CVXPY authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np

from cvxpy.atoms.affine.binary_operators import multiply
from cvxpy.atoms.affine.reshape import reshape
from cvxpy.atoms.affine.sum import sum as cvxpy_sum
from cvxpy.atoms.affine.transpose import permute_dims
from cvxpy.utilities.einsum_utilities import (
    find_contraction,
    greedy_path,
    optimal_path,
    parse_einsum_input,
)


[docs] def einsum(subscripts, *exprs, optimize="greedy"): """Evaluates the Einstein summation convention on the given expressions. This atom is the CVXPY analog of NumPy's einsum function `numpy.einsum` [1], and it maintains the same syntax and semantics. The einsum operation is evaluated by contracting pairs of expressions by elementwise multiplication and summation. The order in which the contractions are performed affects both the memory usage and the FLOP count required. The optimize parameter determines whether to contract expressions using the optimal or greedy ordering. The cost to compute the optimal path is exponential in the number of distinct subscripts, while the cost to compute the greedy path is cubic in the number of distinct subscripts. We typically expect the greedy search to produce the optimal path for most problems. Examples -------- >>> import cvxpy as cp >>> A = cp.Variable((3, 4)) >>> B = cp.Variable((4, 5)) >>> # Matrix multiplication >>> result = cp.einsum('ij,jk->ik', A, B) >>> >>> # Trace of a matrix >>> C = cp.Variable((3, 3)) >>> trace = cp.einsum('ii->', C) References ---------- [1] https://numpy.org/doc/stable/reference/generated/numpy.einsum.html Parameters ---------- subscripts : str The subscripts for the einsum operation. exprs : Expression The expressions to contract. optimize : {bool, 'greedy', 'optimal'}, optional Whether to contract the expressions using the optimal or greedy ordering. Defaults to "greedy". Returns ------- Expression The contracted expression. """ """Note for maintainers: Here, einsum is implemented using the CVXPY sum, multiply, permute_dims, and reshape atoms. The implementation proceeds as follows: 1. Parse and validate the subscripts and the shapes and count of the expressions. The core logic is: ```python input_subscripts, output_subscript = subscripts.split("->") ``` The rest is validation. 2. Reduce duplicate indices in the expressions. Duplicated indices in a subscript pattern represent indexing along a diagonal of the corresponding dimensions. For example, the subscript pattern 'ii->i' extracts the diagonal of a matrix and 'ii->' takes the trace. For the following steps, it is necessary that every dimension of a tensor is uniquely indexed (within the tensor) or not indexed (elipsis). We reduce duplicated indices by creating a new tensor of reduced dimension by taking the diagonal elements. 3. Contraction. A. If only one input, simply perform an axis sum. B. Otherwise, we iterate over pairs of tensors and contract them. Contracting two tensors involves (i) reshaping and permuting them to compatible shapes where corresponding indices align, (ii) performing elementwise multiplication, and (iii) summing over the contracted dimensions. The order of the contractions is given by the contraction path. 4. Permute the final result to match output subscript order. After all contractions, we permute the final result to match the output subscript order. """ # 1. Initial parsing dummy_operands = [np.empty(expr.shape, dtype=np.dtype([])) for expr in exprs] input_subscripts, output_subscript, _ = parse_einsum_input((subscripts, *dummy_operands)) input_list = input_subscripts.split(",") output_set = set(output_subscript) dimension_dict = _validate_arguments(input_list, exprs) # 2. Reduce duplicate indices operands = [] reduced_inputs = [] input_sets = [] for expr, input in zip(exprs, input_list, strict=True): operand, inputs = _initial_reduction(expr, input, dimension_dict) operands.append(operand) reduced_inputs.append(inputs) input_sets.append(set(inputs)) # 3.A. If only one input, simply perform an axis sum if len(operands) == 1: return _sum_single_operand(operands[0], reduced_inputs[0], output_subscript) path = _get_path(input_sets, output_set, dimension_dict, optimize) # 3.B Contract tensors for contraction_inds in path: # Results of contraction go to the last of the list _, input_sets, to_remove, all_labels = find_contraction( contraction_inds, input_sets, output_set ) new_operand, new_input = _contract_pair( [operands[i] for i in contraction_inds], [reduced_inputs[i] for i in contraction_inds], to_remove, all_labels, dimension_dict, ) operands = [operands[i] for i in range(len(operands)) if i not in contraction_inds] operands.append(new_operand) reduced_inputs = [ reduced_inputs[i] for i in range(len(reduced_inputs)) if i not in contraction_inds ] reduced_inputs.append(new_input) # 4. After all contractions, permute the final result to match output subscript order final_operand = operands[0] final_input = reduced_inputs[0] # Create permutation to match output subscript order if final_input != output_subscript: perm = [final_input.index(x) for x in output_subscript] final_operand = permute_dims(final_operand, axes=perm) return final_operand
def _validate_arguments(input_list, exprs): """Validate the input arguments for einsum operation. This function checks that the number of expressions matches the number of input subscript patterns, and that the dimensions of each expression are consistent with the subscript patterns and with each other. Parameters ---------- input_list : list of str List of subscript patterns for each input expression. exprs : list of Expression List of CVXPY expressions to be contracted. Returns ------- dict Dictionary mapping subscript characters to their corresponding dimensions. Raises ------ ValueError If the number of expressions doesn't match the number of input patterns, if a subscript pattern doesn't match the shape of its corresponding expression, or if the same subscript character has inconsistent dimensions across expressions. """ dimension_dict = {} if len(exprs) != len(input_list): raise ValueError( f"Number of arguments ({len(exprs)}) doesn't match " f"number of input patterns ({len(input_list)})" ) for tnum, term in enumerate(input_list): sh = exprs[tnum].shape if len(sh) != len(term): raise ValueError( "Einstein sum subscript %s does not contain the " "correct number of indices for operand %d." % (input_list[tnum], tnum) ) for cnum, char in enumerate(term): dim = sh[cnum] if char in dimension_dict: if dimension_dict[char] != dim: raise ValueError( "Size of label '%s' for operand %d (%d) " "does not match previous terms (%d)." % (char, tnum, dim, dimension_dict[char]) ) else: dimension_dict[char] = dim return dimension_dict def _get_path(input_sets, output_set, dimension_dict, optimize): """Get the contraction path for the einsum operation. This function determines the order in which pairs of tensors should be contracted to minimize computational cost. The path optimization can use either a greedy or optimal path. Parameters ---------- input_sets : list of set List of sets containing the subscript characters for each input tensor. output_set : set Set of subscript characters that should appear in the final output. dimension_dict : dict Dictionary mapping subscript characters to their corresponding dimensions. optimize : {bool, 'greedy', 'optimal'} Optimization strategy for determining the contraction path. - True or 'optimal': Use optimal path (exponential cost) - False or 'greedy': Use greedy path (cubic cost) Returns ------- list of tuple List of tuples, where each tuple contains the indices of the tensors to contract in each step of the contraction process. Raises ------ ValueError If optimize has an invalid value. """ if optimize in {True, "optimal"}: return optimal_path(input_sets, output_set, dimension_dict, np.iinfo(np.int32).max) elif optimize in {False, "greedy"}: return greedy_path(input_sets, output_set, dimension_dict, np.iinfo(np.int32).max) else: raise ValueError("Invalid value for optimize. Must be True, False, 'optimal', or 'greedy'.") def _initial_reduction(operand, inputs, dimension_dict): """Reduce operands with repeated indices by taking diagonal elements. This function handles the case where a subscript pattern contains repeated indices (e.g., 'ii' or 'ijj'). For repeated indices, the corresponding dimensions are reduced by taking diagonal elements. Parameters ---------- operand : Expression The CVXPY expression to reduce. inputs : str The subscript pattern for this operand. dimension_dict : dict Dictionary mapping subscript characters to their corresponding dimensions. Returns ------- tuple A tuple containing: - Expression: The reduced operand with repeated indices eliminated - str: The updated subscript pattern with repeated indices removed """ # Find repeated indices counts = {} for x in inputs: if x in counts: counts[x] += 1 else: counts[x] = 1 # If there are no repeated indices, return the operator and inputs repeats = [x for x, ct in counts.items() if ct > 1] if len(repeats) == 0: return operand, inputs # For each repeated index, get the diagonal reduction for x in repeats: to_reduce = [i for i, y in enumerate(inputs) if y == x] to_keep = [i for i in range(len(inputs)) if i not in to_reduce] perm = to_reduce + to_keep idxs = np.diag_indices(n=dimension_dict[x], ndim=len(to_reduce)) inputs = x + "".join([inputs[i] for i in to_keep]) operand = permute_dims(operand, axes=perm) operand = operand[idxs] # Return the reduced operand and inputs return operand, inputs def _sum_single_operand(operand, input_subscript, output_subscript): """Sum a single operand along specified axes. This function handles the case where there is only one input operand. It sums the operand along axes that appear in the input subscript but not in the output subscript, and then permutes the remaining dimensions to match the output subscript order. Parameters ---------- operand : Expression The single CVXPY expression to sum. input_subscript : str The subscript pattern for the input operand. output_subscript : str The desired subscript pattern for the output. Returns ------- Expression The summed and permuted expression with shape matching the output subscript. """ if len(output_subscript) < len(input_subscript): idxs = [i for i, x in enumerate(input_subscript) if x not in output_subscript] operand = cvxpy_sum(operand, axis=tuple(idxs), keepdims=False) elif len(output_subscript) == 0: operand = cvxpy_sum(operand, axis=None, keepdims=False) remaining = [x for x in input_subscript if x in output_subscript] perm = [remaining.index(x) for x in output_subscript] operand = permute_dims(operand, axes=perm) return operand def _contract_pair(operands, input_lists, to_remove, all_labels, dimension_dict): """Contract a pair of operands by elementwise multiplication and summation. This function performs the core contraction operation between two tensors. It aligns the operands to compatible shapes, performs elementwise multiplication, and then sums over the to-be-removed dimensions. Parameters ---------- operands : list of Expression List containing exactly two CVXPY expressions to contract. input_lists : list of str List of subscript patterns for each operand. to_remove : set Set of subscript characters that should be summed out (contracted). all_labels : str All subscript characters involved in this contraction. dimension_dict : dict Dictionary mapping subscript characters to their corresponding dimensions. Returns ------- tuple A tuple containing: - Expression: The contracted result - str: The subscript pattern for the contracted result """ # Convert all_labels from set to sorted string for deterministic ordering if isinstance(all_labels, set): all_labels = "".join(sorted(all_labels)) # Align and permute the operands to compatible shapes aligned_operands = [] for operand, input_list in zip(operands, input_lists, strict=True): extra_dims = tuple(1 for x in all_labels if x not in input_list) if len(extra_dims) > 0: shape = operand.shape + extra_dims input_list += "".join([x for x in all_labels if x not in input_list]) perm = [input_list.index(x) for x in all_labels] operand = reshape(operand, shape=shape, order="C") perm = [input_list.index(x) for x in all_labels] if perm != list(range(len(perm))): operand = permute_dims(operand, axes=perm) aligned_operands.append(operand) # Elementwise multiply the operands new_operand = multiply(aligned_operands[0], aligned_operands[1]) # Sum the operands along the axes to be removed if len(to_remove) == len(all_labels): new_operand = cvxpy_sum(new_operand, axis=None, keepdims=False) elif len(to_remove) > 0: axes_to_remove = [i for i, x in enumerate(all_labels) if x in to_remove] # Reduce single-element axis tuple to int if len(axes_to_remove) == 1: axis = axes_to_remove[0] else: axis = tuple(axes_to_remove) new_operand = cvxpy_sum(new_operand, axis=axis, keepdims=False) new_input = "".join([x for x in all_labels if x not in to_remove]) # Return the contracted operand and inputs return new_operand, new_input