import numpy as np

def shiftByPlusOne(A, axes):

    """With input array A, for an axis i, the output is
    A[:,...,:,+1,:,...,:], where the +1 is the ith axis (counting from 0)."""

    # Initalise shifted array Q.
    Q = A

    for i in axes:

        Q = np.swapaxes(Q, 0, i)
        Q[1:, :] = Q[0:Q.shape[0]-1, :]
        Q[0, :] = 0
        Q = np.swapaxes(Q, 0, i)

    return Q


def arrConcat(A, m, j):

    """For an integer m, k-fold array A of shape (n_1, ..., n_k),
    this concatenates m copies of A along the jth direction or axis."""

    # Create tuple (1, ..., 1, m, 1, ..., 1), m in jth position, for input
    # into np.tile function.
    reps = np.ones( (1, len(A.shape)) )
    reps[j] = m
    reps = tuple(map(tuple, reps))

    AConcat = np.tile(A, reps)

    return AConcat


def arrRepeatSlice(A, m, j):

    """For an integer m, k-fold array A of shape (n_1, ..., n_k),
    this repeats each slice of A along the jth direction m times."""

    return A
