Source code for asdf.tags.core.ndarray

# Licensed under a 3-clause BSD style license - see LICENSE.rst
# -*- coding: utf-8 -*-

import sys
import weakref

import numpy as np
from numpy import ma

from jsonschema import ValidationError

from ...types import AsdfType
from ... import schema
from ... import util
from ... import yamlutil


_datatype_names = {
    'int8'       : 'i1',
    'int16'      : 'i2',
    'int32'      : 'i4',
    'int64'      : 'i8',
    'uint8'      : 'u1',
    'uint16'     : 'u2',
    'uint32'     : 'u4',
    'uint64'     : 'u8',
    'float32'    : 'f4',
    'float64'    : 'f8',
    'complex64'  : 'c8',
    'complex128' : 'c16',
    'bool8'      : 'b1'
}


_string_datatype_names = {
    'ascii' : 'S',
    'ucs4'  : 'U'
}


def asdf_byteorder_to_numpy_byteorder(byteorder):
    if byteorder == 'big':
        return '>'
    elif byteorder == 'little':
        return '<'
    raise ValueError("Invalid ASDF byteorder '{0}'".format(byteorder))


def asdf_datatype_to_numpy_dtype(datatype, byteorder=None):
    if byteorder is None:
        byteorder = sys.byteorder
    if isinstance(datatype, str) and datatype in _datatype_names:
        datatype = _datatype_names[datatype]
        byteorder = asdf_byteorder_to_numpy_byteorder(byteorder)
        return np.dtype(str(byteorder + datatype))
    elif (isinstance(datatype, list) and
          len(datatype) == 2 and
          isinstance(datatype[0], str) and
          isinstance(datatype[1], int) and
          datatype[0] in _string_datatype_names):
        length = datatype[1]
        byteorder = asdf_byteorder_to_numpy_byteorder(byteorder)
        datatype = str(byteorder) + str(_string_datatype_names[datatype[0]]) + str(length)
        return np.dtype(datatype)
    elif isinstance(datatype, dict):
        if 'datatype' not in datatype:
            raise ValueError("Field entry has no datatype: '{0}'".format(datatype))
        name = datatype.get('name', '')
        byteorder = datatype.get('byteorder', byteorder)
        shape = datatype.get('shape')
        datatype = asdf_datatype_to_numpy_dtype(datatype['datatype'], byteorder)
        if shape is None:
            return (str(name), datatype)
        else:
            return (str(name), datatype, tuple(shape))
    elif isinstance(datatype, list):
        datatype_list = []
        for i, subdatatype in enumerate(datatype):
            np_dtype = asdf_datatype_to_numpy_dtype(subdatatype, byteorder)
            if isinstance(np_dtype, tuple):
                datatype_list.append(np_dtype)
            elif isinstance(np_dtype, np.dtype):
                datatype_list.append((str(''), np_dtype))
            else:
                raise RuntimeError("Error parsing asdf datatype")
        return np.dtype(datatype_list)
    raise ValueError("Unknown datatype {0}".format(datatype))


def numpy_byteorder_to_asdf_byteorder(byteorder):
    if byteorder == '=':
        return sys.byteorder
    elif byteorder == '<':
        return 'little'
    else:
        return 'big'


def numpy_dtype_to_asdf_datatype(dtype, include_byteorder=True):
    dtype = np.dtype(dtype)
    if dtype.names is not None:
        fields = []
        for name in dtype.names:
            field = dtype.fields[name][0]
            d = {}
            d['name'] = name
            field_dtype, byteorder = numpy_dtype_to_asdf_datatype(field)
            d['datatype'] = field_dtype
            if include_byteorder:
                d['byteorder'] = byteorder
            if field.shape:
                d['shape'] = list(field.shape)
            fields.append(d)
        return fields, numpy_byteorder_to_asdf_byteorder(dtype.byteorder)

    elif dtype.subdtype is not None:
        return numpy_dtype_to_asdf_datatype(dtype.subdtype[0])

    elif dtype.name in _datatype_names:
        return dtype.name, numpy_byteorder_to_asdf_byteorder(dtype.byteorder)

    elif dtype.name == 'bool':
        return 'bool8', numpy_byteorder_to_asdf_byteorder(dtype.byteorder)

    elif dtype.name.startswith('string') or dtype.name.startswith('bytes'):
        return ['ascii', dtype.itemsize], 'big'

    elif dtype.name.startswith('unicode') or dtype.name.startswith('str'):
        return (['ucs4', int(dtype.itemsize / 4)],
                numpy_byteorder_to_asdf_byteorder(dtype.byteorder))

    raise ValueError("Unknown dtype {0}".format(dtype))


def inline_data_asarray(inline, dtype=None):
    # np.asarray doesn't handle structured arrays unless the innermost
    # elements are tuples.  To do that, we drill down the first
    # element of each level until we find a single item that
    # successfully converts to a scalar of the expected structured
    # dtype.  Then we go through and convert everything at that level
    # to a tuple.  This probably breaks for nested structured dtypes,
    # but it's probably good enough for now.  It also won't work with
    # object dtypes, but ASDF explicitly excludes those, so we're ok
    # there.
    if dtype is not None and dtype.fields is not None:
        def find_innermost_match(l, depth=0):
            if not isinstance(l, list) or not len(l):
                raise ValueError(
                    "data can not be converted to structured array")
            try:
                np.asarray(tuple(l), dtype=dtype)
            except ValueError:
                return find_innermost_match(l[0], depth + 1)
            else:
                return depth
        depth = find_innermost_match(inline)

        def convert_to_tuples(l, data_depth, depth=0):
            if data_depth == depth:
                return tuple(l)
            else:
                return [convert_to_tuples(x, data_depth, depth+1) for x in l]
        inline = convert_to_tuples(inline, depth)

        return np.asarray(inline, dtype=dtype)
    else:
        def handle_mask(inline):
            if isinstance(inline, list):
                if None in inline:
                    inline_array = np.asarray(inline)
                    nones = np.equal(inline_array, None)
                    return np.ma.array(np.where(nones, 0, inline),
                                       mask=nones)
                else:
                    return [handle_mask(x) for x in inline]
            return inline
        inline = handle_mask(inline)

        inline = np.ma.asarray(inline, dtype=dtype)
        if not ma.is_masked(inline):
            return inline.data
        else:
            return inline


def numpy_array_to_list(array):
    def tolist(x):
        if isinstance(x, (np.ndarray, NDArrayType)):
            if x.dtype.char == 'S':
                x = x.astype('U').tolist()
            else:
                x = x.tolist()

        if isinstance(x, (list, tuple)):
            return [tolist(y) for y in x]
        else:
            return x

    def ascii_to_unicode(x):
        # Convert byte string arrays to unicode string arrays, since YAML
        # doesn't handle the former.
        if isinstance(x, list):
            return [ascii_to_unicode(y) for y in x]
        elif isinstance(x, bytes):
            return x.decode('ascii')
        else:
            return x

    result = ascii_to_unicode(tolist(array))

    schema.validate_large_literals(result)

    return result


class NDArrayType(AsdfType):
    name = 'core/ndarray'
    version = '1.0.0'
    types = [np.ndarray, ma.MaskedArray]

    def __init__(self, source, shape, dtype, offset, strides,
                 order, mask, asdffile):
        self._asdffile = asdffile
        self._source = source
        self._block = None
        self._block_data_weakref = None
        self._array = None
        self._mask = mask

        if isinstance(source, list):
            self._array = inline_data_asarray(source, dtype)
            self._array = self._apply_mask(self._array, self._mask)
            self._block = asdffile.blocks.add_inline(self._array)
            if shape is not None:
                if ((shape[0] == '*' and
                     self._array.shape[1:] != tuple(shape[1:])) or
                    (self._array.shape != tuple(shape))):
                    raise ValueError(
                        "inline data doesn't match the given shape")

        self._shape = shape
        self._dtype = dtype
        self._offset = offset
        self._strides = strides
        self._order = order
        if not asdffile.blocks.lazy_load:
            self._make_array()

    def _make_array(self):
        # If the ASDF file has been updated in-place, then there's
        # a chance that the block's original data object has been
        # closed and replaced.  We need to check here and re-generate
        # the array if necessary, otherwise we risk segfaults when
        # memory mapping.
        if self._block_data_weakref is not None:
            block_data = self._block_data_weakref()
            if block_data is not self.block.data:
                self._array = None
                self._block_data_weakref = None

        if self._array is None:
            block = self.block
            shape = self.get_actual_shape(
                self._shape, self._strides, self._dtype, len(block))
            self._array = np.ndarray(
                shape, self._dtype, block.data,
                self._offset, self._strides, self._order)
            self._block_data_weakref = weakref.ref(block.data)
            self._array = self._apply_mask(self._array, self._mask)
            if block.readonly:
                self._array.setflags(write=False)
        return self._array

    def _apply_mask(self, array, mask):
        if isinstance(mask, (np.ndarray, NDArrayType)):
            # Use "mask.view()" here so the underlying possibly
            # memmapped mask array is freed properly when the masked
            # array goes away.
            array = ma.array(array, mask=mask.view())
            # assert util.get_array_base(array.mask) is util.get_array_base(mask)
            return array
        elif np.isscalar(mask):
            if np.isnan(mask):
                return ma.array(array, mask=np.isnan(array))
            else:
                return ma.masked_values(array, mask)
        return array

    def __array__(self):
        return self._make_array()

    def __repr__(self):
        # repr alone should not force loading of the data
        if self._array is None:
            return "<{0} (unloaded) shape: {1} dtype: {2}>".format(
                'array' if self._mask is None else 'masked array',
                self._shape, self._dtype)
        return repr(self._array)

    def __str__(self):
        # str alone should not force loading of the data
        if self._array is None:
            return "<{0} (unloaded) shape: {1} dtype: {2}>".format(
                'array' if self._mask is None else 'masked array',
                self._shape, self._dtype)
        return str(self._array)

    def get_actual_shape(self, shape, strides, dtype, block_size):
        """
        Get the actual shape of an array, by computing it against the
        block_size if it contains a ``*``.
        """
        num_stars = shape.count('*')
        if num_stars == 0:
            return shape
        elif num_stars == 1:
            if shape[0] != '*':
                raise ValueError("'*' may only be in first entry of shape")
            if strides is not None:
                stride = strides[0]
            else:
                stride = np.product(shape[1:]) * dtype.itemsize
            missing = int(block_size / stride)
            return [missing] + shape[1:]
        raise ValueError("Invalid shape '{0}'".format(shape))

    @property
    def block(self):
        if self._block is None:
            self._block = self._asdffile.blocks.get_block(self._source)
        return self._block

    @property
    def shape(self):
        if self._shape is None:
            return self.__array__().shape
        if '*' in self._shape:
            return tuple(self.get_actual_shape(
                self._shape, self._strides, self._dtype, len(self.block)))
        return tuple(self._shape)

    @property
    def dtype(self):
        if self._array is None:
            return self._dtype
        else:
            return self._array.dtype

    def __len__(self):
        if self._array is None:
            return self._shape[0]
        else:
            return len(self._array)

    def __getattr__(self, attr):
        # We need to ignore __array_struct__, or unicode arrays end up
        # getting "double casted" and upsized.  This also reduces the
        # number of array creations in the general case.
        if attr == '__array_struct__':
            raise AttributeError()
        return getattr(self._make_array(), attr)

    def __setitem__(self, *args):
        # This workaround appears to be necessary in order to avoid a segfault
        # in the case that array assignment causes an exception. The segfault
        # originates from the call to __repr__ inside the traceback report.
        try:
            self._make_array().__setitem__(*args)
        except Exception as e:
            self._array = None
            self._block_data_weakref = None
            raise e from None

    @classmethod
    def from_tree(cls, node, ctx):
        if isinstance(node, list):
            return cls(node, None, None, None, None, None, None, ctx)

        elif isinstance(node, dict):
            source = node.get('source')
            data = node.get('data')
            if source and data:
                raise ValueError(
                    "Both source and data may not be provided "
                    "at the same time")
            if data:
                source = data
            shape = node.get('shape', None)
            if data is not None:
                byteorder = sys.byteorder
            else:
                byteorder = node['byteorder']
            if 'datatype' in node:
                dtype = asdf_datatype_to_numpy_dtype(
                    node['datatype'], byteorder)
            else:
                dtype = None
            offset = node.get('offset', 0)
            strides = node.get('strides', None)
            mask = node.get('mask', None)

            return cls(source, shape, dtype, offset, strides, 'A', mask, ctx)

        raise TypeError("Invalid ndarray description.")

    @classmethod
    def reserve_blocks(cls, data, ctx):
        # Find all of the used data buffers so we can add or rearrange
        # them if necessary
        if isinstance(data, np.ndarray):
            yield ctx.blocks.find_or_create_block_for_array(data, ctx)
        elif isinstance(data, NDArrayType):
            yield data.block

    @classmethod
    def to_tree(cls, data, ctx):
        base = util.get_array_base(data)
        shape = data.shape
        dtype = data.dtype
        offset = data.ctypes.data - base.ctypes.data
        strides = None

        if not data.flags.c_contiguous:
            # We do not want to encode strides for broadcasted arrays
            if not all(data.strides):
                data = np.ascontiguousarray(data)
            else:
                strides = data.strides

        block = ctx.blocks.find_or_create_block_for_array(data, ctx)

        result = {}

        result['shape'] = list(shape)
        if block.array_storage == 'streamed':
            result['shape'][0] = '*'

        dtype, byteorder = numpy_dtype_to_asdf_datatype(
            dtype, include_byteorder=(block.array_storage != 'inline'))

        byteorder = block.override_byteorder(byteorder)

        if block.array_storage == 'inline':
            listdata = numpy_array_to_list(data)
            result['data'] = yamlutil.custom_tree_to_tagged_tree(
                listdata, ctx)
            result['datatype'] = dtype
        else:
            result['shape'] = list(shape)
            if block.array_storage == 'streamed':
                result['shape'][0] = '*'

            result['source'] = ctx.blocks.get_source(block)
            result['datatype'] = dtype
            result['byteorder'] = byteorder

            if offset > 0:
                result['offset'] = offset

            if strides is not None:
                result['strides'] = list(strides)

        if isinstance(data, ma.MaskedArray):
            if np.any(data.mask):
                if block.array_storage == 'inline':
                    ctx.blocks.set_array_storage(ctx.blocks[data.mask], 'inline')
                result['mask'] = yamlutil.custom_tree_to_tagged_tree(
                    data.mask, ctx)

        return result

    @classmethod
    def _assert_equality(cls, old, new, func):
        if old.dtype.fields:
            if not new.dtype.fields:
                assert False, "arrays not equal"
            for a, b in zip(old, new):
                cls._assert_equality(a, b, func)
        else:
            old = old.__array__()
            new = new.__array__()
            if old.dtype.char in 'SU':
                if old.dtype.char == 'S':
                    old = old.astype('U')
                if new.dtype.char == 'S':
                    new = new.astype('U')
                old = old.tolist()
                new = new.tolist()
                assert old == new
            else:
                func(old, new)

    @classmethod
    def assert_equal(cls, old, new):
        from numpy.testing import assert_array_equal

        cls._assert_equality(old, new, assert_array_equal)

    @classmethod
    def assert_allclose(cls, old, new):
        from numpy.testing import assert_allclose, assert_array_equal

        if (old.dtype.kind in 'iu' and
            new.dtype.kind in 'iu'):
            cls._assert_equality(old, new, assert_array_equal)
        else:
            cls._assert_equality(old, new, assert_allclose)

    @classmethod
    def copy_to_new_asdf(cls, node, asdffile):
        if isinstance(node, NDArrayType):
            array = node._make_array()
            asdffile.blocks.set_array_storage(asdffile.blocks[array],
                                              node.block.array_storage)
            return node._make_array()
        return node


def _make_operation(name):
    def __operation__(self, *args):
        return getattr(self._make_array(), name)(*args)
    return __operation__


for op in [
        '__neg__', '__pos__', '__abs__', '__invert__', '__complex__',
        '__int__', '__long__', '__float__', '__oct__', '__hex__',
        '__lt__', '__le__', '__eq__', '__ne__', '__gt__', '__ge__',
        '__cmp__', '__rcmp__', '__add__', '__sub__', '__mul__',
        '__floordiv__', '__mod__', '__divmod__', '__pow__',
        '__lshift__', '__rshift__', '__and__', '__xor__', '__or__',
        '__div__', '__truediv__', '__radd__', '__rsub__', '__rmul__',
        '__rdiv__', '__rtruediv__', '__rfloordiv__', '__rmod__',
        '__rdivmod__', '__rpow__', '__rlshift__', '__rrshift__',
        '__rand__', '__rxor__', '__ror__', '__iadd__', '__isub__',
        '__imul__', '__idiv__', '__itruediv__', '__ifloordiv__',
        '__imod__', '__ipow__', '__ilshift__', '__irshift__',
        '__iand__', '__ixor__', '__ior__', '__getitem__',
        '__delitem__', '__contains__']:
    setattr(NDArrayType, op, _make_operation(op))


def _get_ndim(instance):
    if isinstance(instance, list):
        array = inline_data_asarray(instance)
        return array.ndim
    elif isinstance(instance, dict):
        if 'shape' in instance:
            return len(instance['shape'])
        elif 'data' in instance:
            array = inline_data_asarray(instance['data'])
            return array.ndim
    elif isinstance(instance, (np.ndarray, NDArrayType)):
        return len(instance.shape)


def validate_ndim(validator, ndim, instance, schema):
    in_ndim = _get_ndim(instance)

    if in_ndim != ndim:
        yield ValidationError(
            "Wrong number of dimensions: Expected {0}, got {1}".format(
                ndim, in_ndim), instance=repr(instance))


def validate_max_ndim(validator, max_ndim, instance, schema):
    in_ndim = _get_ndim(instance)

    if in_ndim > max_ndim:
        yield ValidationError(
            "Wrong number of dimensions: Expected max of {0}, got {1}".format(
                max_ndim, in_ndim), instance=repr(instance))


def validate_datatype(validator, datatype, instance, schema):
    if isinstance(instance, list):
        array = inline_data_asarray(instance)
        in_datatype, _ = numpy_dtype_to_asdf_datatype(array.dtype)
    elif isinstance(instance, dict):
        if 'datatype' in instance:
            in_datatype = instance['datatype']
        elif 'data' in instance:
            array = inline_data_asarray(instance['data'])
            in_datatype, _ = numpy_dtype_to_asdf_datatype(array.dtype)
        else:
            raise ValidationError("Not an array")
    elif isinstance(instance, (np.ndarray, NDArrayType)):
        in_datatype, _ = numpy_dtype_to_asdf_datatype(instance.dtype)
    else:
        raise ValidationError("Not an array")

    if datatype == in_datatype:
        return

    if schema.get('exact_datatype', False):
        yield ValidationError(
            "Expected datatype '{0}', got '{1}'".format(
                datatype, in_datatype))

    np_datatype = asdf_datatype_to_numpy_dtype(datatype)
    np_in_datatype = asdf_datatype_to_numpy_dtype(in_datatype)

    if not np_datatype.fields:
        if np_in_datatype.fields:
            yield ValidationError(
                "Expected scalar datatype '{0}', got '{1}'".format(
                    datatype, in_datatype))

        if not np.can_cast(np_in_datatype, np_datatype, 'safe'):
            yield ValidationError(
                "Can not safely cast from '{0}' to '{1}' ".format(
                    in_datatype, datatype))

    else:
        if not np_in_datatype.fields:
            yield ValidationError(
                "Expected structured datatype '{0}', got '{1}'".format(
                    datatype, in_datatype))

        if len(np_in_datatype.fields) != len(np_datatype.fields):
            yield ValidationError(
                "Mismatch in number of columns: "
                "Expected {0}, got {1}".format(
                    len(datatype), len(in_datatype)))

        for i in range(len(np_datatype.fields)):
            in_type = np_in_datatype[i]
            out_type = np_datatype[i]
            if not np.can_cast(in_type, out_type, 'safe'):
                yield ValidationError(
                    "Can not safely cast to expected datatype: "
                    "Expected {0}, got {1}".format(
                        numpy_dtype_to_asdf_datatype(out_type)[0],
                        numpy_dtype_to_asdf_datatype(in_type)[0]))


NDArrayType.validators = {
    'ndim': validate_ndim,
    'max_ndim': validate_max_ndim,
    'datatype': validate_datatype
}