36 #ifndef VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
37 #define VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
40 #include "array_vector.hxx"
41 #include "python_utility.hxx"
42 #include "axistags.hxx"
49 python_ptr getArrayTypeObject()
51 python_ptr arraytype((PyObject*)&PyArray_Type);
52 python_ptr vigra(PyImport_ImportModule(
"vigra"));
55 return pythonGetAttr(vigra,
"standardArrayType", arraytype);
59 std::string defaultOrder(std::string defaultValue =
"C")
61 python_ptr arraytype = getArrayTypeObject();
62 return pythonGetAttr(arraytype,
"defaultOrder", defaultValue);
66 python_ptr defaultAxistags(
int ndim, std::string order =
"")
69 order = defaultOrder();
70 python_ptr arraytype = getArrayTypeObject();
71 python_ptr func(PyString_FromString(
"defaultAxistags"), python_ptr::keep_count);
72 python_ptr d(PyInt_FromLong(ndim), python_ptr::keep_count);
73 python_ptr o(PyString_FromString(order.c_str()), python_ptr::keep_count);
74 python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), o.get(), NULL),
75 python_ptr::keep_count);
83 python_ptr emptyAxistags(
int ndim)
85 python_ptr arraytype = getArrayTypeObject();
86 python_ptr func(PyString_FromString(
"_empty_axistags"), python_ptr::keep_count);
87 python_ptr d(PyInt_FromLong(ndim), python_ptr::keep_count);
88 python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), NULL),
89 python_ptr::keep_count);
98 getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
99 python_ptr
object,
const char * name,
100 AxisInfo::AxisType type,
bool ignoreErrors)
102 python_ptr func(PyString_FromString(name), python_ptr::keep_count);
103 python_ptr t(PyInt_FromLong((
long)type), python_ptr::keep_count);
104 python_ptr permutation(PyObject_CallMethodObjArgs(
object, func.get(), t.get(), NULL),
105 python_ptr::keep_count);
106 if(!permutation && ignoreErrors)
111 pythonToCppException(permutation);
113 if(!PySequence_Check(permutation))
117 std::string message = std::string(name) +
"() did not return a sequence.";
118 PyErr_SetString(PyExc_ValueError, message.c_str());
119 pythonToCppException(
false);
122 ArrayVector<npy_intp> res(PySequence_Length(permutation));
123 for(
int k=0; k<(int)res.size(); ++k)
125 python_ptr i(PySequence_GetItem(permutation, k), python_ptr::keep_count);
130 std::string message = std::string(name) +
"() did not return a sequence of int.";
131 PyErr_SetString(PyExc_ValueError, message.c_str());
132 pythonToCppException(
false);
134 res[k] = PyInt_AsLong(i);
141 getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
142 python_ptr
object,
const char * name,
bool ignoreErrors)
144 getAxisPermutationImpl(permute,
object, name, AxisInfo::AllAxes, ignoreErrors);
163 typedef PyObject * pointer;
167 PyAxisTags(python_ptr tags = python_ptr(),
bool createCopy =
false)
172 if(!PySequence_Check(tags))
174 PyErr_SetString(PyExc_TypeError,
175 "PyAxisTags(tags): tags argument must have type 'AxisTags'.");
176 pythonToCppException(
false);
178 else if(PySequence_Length(tags) == 0)
185 python_ptr func(PyString_FromString(
"__copy__"), python_ptr::keep_count);
186 axistags = python_ptr(PyObject_CallMethodObjArgs(tags, func.get(), NULL),
187 python_ptr::keep_count);
195 PyAxisTags(PyAxisTags
const & other,
bool createCopy =
false)
201 python_ptr func(PyString_FromString(
"__copy__"), python_ptr::keep_count);
202 axistags = python_ptr(PyObject_CallMethodObjArgs(other.axistags, func.get(), NULL),
203 python_ptr::keep_count);
207 axistags = other.axistags;
211 PyAxisTags(
int ndim, std::string
const & order =
"")
214 axistags = detail::defaultAxistags(ndim, order);
216 axistags = detail::emptyAxistags(ndim);
222 ? PySequence_Length(axistags)
226 long channelIndex(
long defaultVal)
const
228 return pythonGetAttr(axistags,
"channelIndex", defaultVal);
231 long channelIndex()
const
233 return channelIndex(size());
236 bool hasChannelAxis()
const
238 return channelIndex() != size();
241 long innerNonchannelIndex(
long defaultVal)
const
243 return pythonGetAttr(axistags,
"innerNonchannelIndex", defaultVal);
246 long innerNonchannelIndex()
const
248 return innerNonchannelIndex(size());
251 void setChannelDescription(std::string
const & description)
255 python_ptr d(PyString_FromString(description.c_str()), python_ptr::keep_count);
256 python_ptr func(PyString_FromString(
"setChannelDescription"), python_ptr::keep_count);
257 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), d.get(), NULL),
258 python_ptr::keep_count);
259 pythonToCppException(res);
262 double resolution(
long index)
266 python_ptr func(PyString_FromString(
"resolution"), python_ptr::keep_count);
267 python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
268 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), NULL),
269 python_ptr::keep_count);
270 pythonToCppException(res);
271 if(!PyFloat_Check(res))
273 PyErr_SetString(PyExc_TypeError,
"AxisTags.resolution() did not return float.");
274 pythonToCppException(
false);
276 return PyFloat_AsDouble(res);
279 void setResolution(
long index,
double resolution)
283 python_ptr func(PyString_FromString(
"setResolution"), python_ptr::keep_count);
284 python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
285 python_ptr r(PyFloat_FromDouble(resolution), python_ptr::keep_count);
286 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), r.get(), NULL),
287 python_ptr::keep_count);
288 pythonToCppException(res);
291 void scaleResolution(
long index,
double factor)
295 python_ptr func(PyString_FromString(
"scaleResolution"), python_ptr::keep_count);
296 python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
297 python_ptr f(PyFloat_FromDouble(factor), python_ptr::keep_count);
298 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), f.get(), NULL),
299 python_ptr::keep_count);
300 pythonToCppException(res);
303 void toFrequencyDomain(
long index,
int size,
int sign = 1)
307 python_ptr func(
sign == 1
308 ? PyString_FromString(
"toFrequencyDomain")
309 : PyString_FromString(
"fromFrequencyDomain"),
310 python_ptr::keep_count);
311 python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
312 python_ptr s(PyInt_FromLong(size), python_ptr::keep_count);
313 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), s.get(), NULL),
314 python_ptr::keep_count);
315 pythonToCppException(res);
318 void fromFrequencyDomain(
long index,
int size)
320 toFrequencyDomain(index, size, -1);
323 ArrayVector<npy_intp>
324 permutationToNormalOrder(
bool ignoreErrors =
false)
const
326 ArrayVector<npy_intp> permute;
327 detail::getAxisPermutationImpl(permute, axistags,
"permutationToNormalOrder", ignoreErrors);
331 ArrayVector<npy_intp>
332 permutationToNormalOrder(AxisInfo::AxisType types,
bool ignoreErrors =
false)
const
334 ArrayVector<npy_intp> permute;
335 detail::getAxisPermutationImpl(permute, axistags,
336 "permutationToNormalOrder", types, ignoreErrors);
340 ArrayVector<npy_intp>
341 permutationFromNormalOrder(
bool ignoreErrors =
false)
const
343 ArrayVector<npy_intp> permute;
344 detail::getAxisPermutationImpl(permute, axistags,
345 "permutationFromNormalOrder", ignoreErrors);
349 ArrayVector<npy_intp>
350 permutationFromNormalOrder(AxisInfo::AxisType types,
bool ignoreErrors =
false)
const
352 ArrayVector<npy_intp> permute;
353 detail::getAxisPermutationImpl(permute, axistags,
354 "permutationFromNormalOrder", types, ignoreErrors);
358 void dropChannelAxis()
362 python_ptr func(PyString_FromString(
"dropChannelAxis"),
363 python_ptr::keep_count);
364 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL),
365 python_ptr::keep_count);
366 pythonToCppException(res);
369 void insertChannelAxis()
373 python_ptr func(PyString_FromString(
"insertChannelAxis"),
374 python_ptr::keep_count);
375 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL),
376 python_ptr::keep_count);
377 pythonToCppException(res);
382 return axistags.get();
385 bool operator!()
const
400 enum ChannelAxis { first, last, none };
402 ArrayVector<npy_intp> shape, original_shape;
404 ChannelAxis channelAxis;
405 std::string channelDescription;
413 template <
class U,
int N>
414 TaggedShape(TinyVector<U, N>
const & sh, PyAxisTags tags)
415 : shape(sh.begin(), sh.end()),
416 original_shape(sh.begin(), sh.end()),
422 TaggedShape(ArrayVector<T>
const & sh, PyAxisTags tags)
423 : shape(sh.begin(), sh.end()),
424 original_shape(sh.begin(), sh.end()),
429 template <
class U,
int N>
430 explicit TaggedShape(TinyVector<U, N>
const & sh)
431 : shape(sh.begin(), sh.end()),
432 original_shape(sh.begin(), sh.end()),
437 explicit TaggedShape(ArrayVector<T>
const & sh)
438 : shape(sh.begin(), sh.end()),
439 original_shape(sh.begin(), sh.end()),
443 template <
class U,
int N>
444 TaggedShape & resize(TinyVector<U, N>
const & sh)
446 int start = channelAxis == first
449 stop = channelAxis == last
453 vigra_precondition(N == stop - start || size() == 0,
454 "TaggedShape.resize(): size mismatch.");
459 for(
int k=0; k<N; ++k)
460 shape[k+start] = sh[k];
467 return resize(TinyVector<MultiArrayIndex, 1>(v1));
472 return resize(TinyVector<MultiArrayIndex, 2>(v1, v2));
477 return resize(TinyVector<MultiArrayIndex, 3>(v1, v2, v3));
483 return resize(TinyVector<MultiArrayIndex, 4>(v1, v2, v3, v4));
486 npy_intp & operator[](
int i)
491 npy_intp operator[](
int i)
const
496 unsigned int size()
const
503 int start = channelAxis == first
506 stop = channelAxis == last
509 for(
int k=start; k<stop; ++k)
522 int start = channelAxis == first
525 stop = channelAxis == last
528 for(
int k=start; k<stop; ++k)
534 void rotateToNormalOrder()
536 if(axistags && channelAxis == last)
538 int ndim = (int)size();
540 npy_intp channelCount = shape[ndim-1];
541 for(
int k=ndim-1; k>0; --k)
542 shape[k] = shape[k-1];
543 shape[0] = channelCount;
545 channelCount = original_shape[ndim-1];
546 for(
int k=ndim-1; k>0; --k)
547 original_shape[k] = original_shape[k-1];
548 original_shape[0] = channelCount;
554 TaggedShape & setChannelDescription(std::string
const & description)
558 channelDescription = description;
562 TaggedShape & setChannelIndexLast()
570 template <
class U,
int N>
571 TaggedShape & transposeShape(TinyVector<U, N>
const & p)
573 int ntags = axistags.size();
574 ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
576 int tstart = (axistags.channelIndex(ntags) < ntags)
579 int sstart = (channelAxis == first)
582 int ndim = ntags - tstart;
584 vigra_precondition(N == ndim,
585 "TaggedShape.transposeShape(): size mismatch.");
587 PyAxisTags newAxistags(axistags.axistags);
588 for(
int k=0; k<ndim; ++k)
590 original_shape[k+sstart] = shape[p[k]+sstart];
591 newAxistags.setResolution(permute[k+tstart], axistags.resolution(permute[p[k]+tstart]));
593 shape = original_shape;
594 axistags = newAxistags;
599 TaggedShape & toFrequencyDomain(
int sign = 1)
601 int ntags = axistags.size();
603 ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
605 int tstart = (axistags.channelIndex(ntags) < ntags)
608 int sstart = (channelAxis == first)
611 int send = (channelAxis == last)
614 int size = send - sstart;
616 for(
int k=0; k<size; ++k)
618 axistags.toFrequencyDomain(permute[k+tstart], shape[k+sstart],
sign);
624 TaggedShape & fromFrequencyDomain()
626 return toFrequencyDomain(-1);
629 bool compatible(TaggedShape
const & other)
const
631 if(channelCount() != other.channelCount())
634 int start = channelAxis == first
637 stop = channelAxis == last
640 int ostart = other.channelAxis == first
643 ostop = other.channelAxis == last
644 ? (int)other.size()-1
647 int len = stop - start;
648 if(len != ostop - ostart)
651 for(
int k=0; k<len; ++k)
652 if(shape[k+start] != other.shape[k+ostart])
657 TaggedShape & setChannelCount(
int count)
668 shape.erase(shape.begin());
669 original_shape.erase(original_shape.begin());
676 shape[size()-1] = count;
681 original_shape.pop_back();
688 shape.push_back(count);
689 original_shape.push_back(count);
697 int channelCount()
const
704 return shape[size()-1];
712 void scaleAxisResolution(TaggedShape & tagged_shape)
714 if(tagged_shape.size() != tagged_shape.original_shape.size())
717 int ntags = tagged_shape.axistags.size();
719 ArrayVector<npy_intp> permute = tagged_shape.axistags.permutationToNormalOrder();
721 int tstart = (tagged_shape.axistags.channelIndex(ntags) < ntags)
724 int sstart = (tagged_shape.channelAxis == TaggedShape::first)
727 int size = (int)tagged_shape.size() - sstart;
729 for(
int k=0; k<size; ++k)
732 if(tagged_shape.shape[sk] == tagged_shape.original_shape[sk])
734 double factor = (tagged_shape.original_shape[sk] - 1.0) / (tagged_shape.shape[sk] - 1.0);
735 tagged_shape.axistags.scaleResolution(permute[k+tstart], factor);
740 void unifyTaggedShapeSize(TaggedShape & tagged_shape)
742 PyAxisTags axistags = tagged_shape.axistags;
743 ArrayVector<npy_intp> & shape = tagged_shape.shape;
745 int ndim = (int)shape.size();
746 int ntags = axistags.size();
748 long channelIndex = axistags.channelIndex();
750 if(tagged_shape.channelAxis == TaggedShape::none)
753 if(channelIndex == ntags)
757 vigra_precondition(ndim == ntags,
758 "constructArray(): size mismatch between shape and axistags.");
768 axistags.dropChannelAxis();
772 vigra_precondition(ndim == ntags,
773 "constructArray(): size mismatch between shape and axistags.");
780 if(channelIndex == ntags)
784 vigra_precondition(ndim == ntags+1,
785 "constructArray(): size mismatch between shape and axistags.");
791 shape.erase(shape.begin());
798 axistags.insertChannelAxis();
805 vigra_precondition(ndim == ntags,
806 "constructArray(): size mismatch between shape and axistags.");
812 ArrayVector<npy_intp> finalizeTaggedShape(TaggedShape & tagged_shape)
814 if(tagged_shape.axistags)
816 tagged_shape.rotateToNormalOrder();
820 scaleAxisResolution(tagged_shape);
824 unifyTaggedShapeSize(tagged_shape);
826 if(tagged_shape.channelDescription !=
"")
827 tagged_shape.axistags.setChannelDescription(tagged_shape.channelDescription);
829 return tagged_shape.shape;
834 #endif // VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX