FLAME  devel
 All Classes Functions Variables Typedefs Enumerations Pages
modstate.cpp
1 
2 #include <string>
3 #include <sstream>
4 
5 #include "flame/base.h"
6 #include "pyflame.h"
7 
8 #define NO_IMPORT_ARRAY
9 #define PY_ARRAY_UNIQUE_SYMBOL FLAME_PyArray_API
10 #define NPY_NO_DEPRECATED_API NPY_1_6_API_VERSION
11 #include <numpy/ndarrayobject.h>
12 
13 #if SIZE_MAX==NPY_MAX_UINT32
14 #define NPY_SIZE_T NPY_UINT32
15 #elif SIZE_MAX==NPY_MAX_UINT64
16 #define NPY_SIZE_T NPY_UINT64
17 #else
18 #error logic error with SIZE_MAX
19 #endif
20 
21 #define TRY PyState *state = (PyState*)raw; try
22 
23 namespace {
24 
25 struct PyState {
26  PyObject_HEAD
27  PyObject *dict, *weak; // __dict__ and __weakref__
28  PyObject *attrs; // lookup name to attribute index (for StateBase)
29  StateBase *state;
30 };
31 
32 static
33 int PyState_traverse(PyObject *raw, visitproc visit, void *arg)
34 {
35  PyState *state = (PyState*)raw;
36  Py_VISIT(state->attrs);
37  Py_VISIT(state->dict);
38  return 0;
39 }
40 
41 static
42 int PyState_clear(PyObject *raw)
43 {
44  PyState *state = (PyState*)raw;
45  Py_CLEAR(state->dict);
46  Py_CLEAR(state->attrs);
47  return 0;
48 }
49 
50 static
51 void PyState_free(PyObject *raw)
52 {
53  TRY {
54  std::unique_ptr<StateBase> S(state->state);
55  state->state = NULL;
56 
57  if(state->weak)
58  PyObject_ClearWeakRefs(raw);
59 
60  PyState_clear(raw);
61 
62  Py_TYPE(raw)->tp_free(raw);
63  } CATCH2V(std::exception, RuntimeError)
64 }
65 
66 static
67 PyObject *PyState_getattro(PyObject *raw, PyObject *attr)
68 {
69  TRY {
70  PyObject *idx = PyDict_GetItem(state->attrs, attr);
71  if(!idx) {
72  return PyObject_GenericGetAttr(raw, attr);
73  }
74  int i = PyInt_AsLong(idx);
75 
76 
78 
79  if(!state->state->getArray(i, info))
80  return PyErr_Format(PyExc_RuntimeError, "invalid attribute name (sub-class forgot %d)", i);
81 
82  if(info.ndim==0) { // Scalar
83  switch(info.type) {
84  case StateBase::ArrayInfo::Double:
85  return PyFloat_FromDouble(*(double*)info.ptr);
86  case StateBase::ArrayInfo::Sizet:
87  return PyLong_FromSize_t(*(size_t*)info.ptr);
88  }
89  return PyErr_Format(PyExc_TypeError, "unsupported type code %d", info.type);
90  }
91 
92  int pytype;
93  switch(info.type) {
94  case StateBase::ArrayInfo::Double: pytype = NPY_DOUBLE; break;
95  case StateBase::ArrayInfo::Sizet: pytype = NPY_SIZE_T; break;
96  default:
97  return PyErr_Format(PyExc_TypeError, "unsupported type code %d", info.type);
98  }
99 
100  npy_intp dims[StateBase::ArrayInfo::maxdims];
101  std::copy(info.dim,
102  info.dim+StateBase::ArrayInfo::maxdims,
103  dims);
104 
105  // Alloc new array and copy in
106 
107  PyRef<PyArrayObject> obj(PyArray_SimpleNew(info.ndim, dims, pytype));
108 
109  // pull parts from PyArray into ArrayInfo so we can use ArrayInfo::raw() to access
110  StateBase::ArrayInfo pyinfo;
111  pyinfo.ptr = PyArray_BYTES(obj.py());
112  pyinfo.ndim= PyArray_NDIM(obj.get());
113  std::copy(PyArray_DIMS(obj.get()),
114  PyArray_DIMS(obj.get())+pyinfo.ndim,
115  pyinfo.dim);
116  std::copy(PyArray_STRIDES(obj.get()),
117  PyArray_STRIDES(obj.get())+pyinfo.ndim,
118  pyinfo.stride);
119 
121 
122  for(; !idxiter.done; idxiter.next()) {
123  void *dest = pyinfo.raw(idxiter.index);
124  const void *src = info .raw(idxiter.index);
125 
126  switch(info.type) {
127  case StateBase::ArrayInfo::Double: *(double*)dest = *(double*)src; break;
128  case StateBase::ArrayInfo::Sizet: *(size_t*)dest = *(size_t*)src; break;
129  }
130  }
131 
132  return obj.releasePy();
133  } CATCH()
134 }
135 
136 static
137 int PyState_setattro(PyObject *raw, PyObject *attr, PyObject *val)
138 {
139  TRY {
140  PyObject *idx = PyDict_GetItem(state->attrs, attr);
141  if(!idx)
142  return PyObject_GenericSetAttr(raw, attr, val);
143  int i = PyInt_AsLong(idx);
144 
146 
147  if(!state->state->getArray(i, info)) {
148  PyErr_Format(PyExc_RuntimeError, "invalid attribute name (sub-class forgot %d)", i);
149  return -1;
150  }
151 
152  if(info.ndim==0) {
153  // Scalar (use python primative types)
154 
155  switch(info.type) {
156  case StateBase::ArrayInfo::Double: {
157  double *dest = (double*)info.ptr;
158  if(PyFloat_Check(val))
159  *dest = PyFloat_AsDouble(val);
160  else if(PyLong_Check(val))
161  *dest = PyLong_AsDouble(val);
162  else if(PyInt_Check(val))
163  *dest = PyInt_AsLong(val);
164  else
165  PyErr_Format(PyExc_ValueError, "Can't assign to double field");
166  }
167  break;
168  case StateBase::ArrayInfo::Sizet: {
169  size_t *dest = (size_t*)info.ptr;
170  if(PyFloat_Check(val))
171  *dest = PyFloat_AsDouble(val);
172  else if(PyLong_Check(val))
173  *dest = PyLong_AsUnsignedLongLong(val);
174  else if(PyInt_Check(val))
175  *dest = PyInt_AsLong(val);
176  else
177  PyErr_Format(PyExc_ValueError, "Can't assign to double field");
178  }
179  break;
180  default:
181  PyErr_Format(PyExc_TypeError, "unsupported type code %d", info.type);
182  }
183 
184  return PyErr_Occurred() ? -1 : 0;
185  }
186  // array (use numpy)
187 
188  int pytype;
189  switch(info.type) {
190  case StateBase::ArrayInfo::Double: pytype = NPY_DOUBLE; break;
191  case StateBase::ArrayInfo::Sizet: pytype = NPY_SIZE_T; break;
192  default:
193  PyErr_Format(PyExc_TypeError, "unsupported type code %d", info.type);
194  return -1;
195  }
196 
197  // ValueError: object too deep for desired array
198  // means assignment with wrong cardinality
199  PyRef<PyArrayObject> arr(PyArray_FromObject(val, pytype, info.ndim, info.ndim));
200 
201  if(info.ndim!=(size_t)PyArray_NDIM(arr.py())) {
202  PyErr_Format(PyExc_ValueError, "cardinality don't match");
203  return -1;
204  } else if(!std::equal(info.dim, info.dim+info.ndim,
205  PyArray_DIMS(arr.py()))) {
206  PyErr_Format(PyExc_ValueError, "shape does not match don't match");
207  return -1;
208  }
209 
210  // pull parts from PyArray into ArrayInfo so we can use ArrayInfo::raw() to access
211  StateBase::ArrayInfo pyinfo;
212  pyinfo.ptr = PyArray_BYTES(arr.py());
213  pyinfo.ndim= PyArray_NDIM(arr.get());
214  std::copy(PyArray_DIMS(arr.get()),
215  PyArray_DIMS(arr.get())+pyinfo.ndim,
216  pyinfo.dim);
217  std::copy(PyArray_STRIDES(arr.get()),
218  PyArray_STRIDES(arr.get())+pyinfo.ndim,
219  pyinfo.stride);
220 
222 
223  for(; !idxiter.done; idxiter.next()) {
224  const void *src = pyinfo .raw(idxiter.index);
225  void *dest = info.raw(idxiter.index);
226 
227  switch(info.type) {
228  case StateBase::ArrayInfo::Double: *(double*)dest = *(double*)src; break;
229  case StateBase::ArrayInfo::Sizet: *(size_t*)dest = *(size_t*)src; break;
230  }
231  }
232 
233 
234  if(info.ndim==1) {
235  for(size_t i=0; i<info.dim[0]; i++) {
236  const void *src = PyArray_GETPTR1(arr.py(), i);
237  void *dest = info.raw(&i);
238  switch(info.type) {
239  case StateBase::ArrayInfo::Double: *(double*)dest = *(double*)src; break;
240  case StateBase::ArrayInfo::Sizet: *(size_t*)dest = *(size_t*)src; break;
241  }
242  }
243  } else if(info.ndim==2) {
244  size_t idx[2];
245  for(idx[0]=0; idx[0]<info.dim[0]; idx[0]++) {
246  for(idx[1]=0; idx[1]<info.dim[1]; idx[1]++) {
247  const void *src = PyArray_GETPTR2(arr.py(), idx[0], idx[1]);
248  void *dest = info.raw(idx);
249  switch(info.type) {
250  case StateBase::ArrayInfo::Double: *(double*)dest = *(double*)src; break;
251  case StateBase::ArrayInfo::Sizet: *(size_t*)dest = *(size_t*)src; break;
252  }
253  }
254  }
255  }
256 
257  return 0;
258  } CATCH3(std::exception, RuntimeError, -1)
259 }
260 
261 static
262 PyObject* PyState_str(PyObject *raw)
263 {
264  TRY {
265  std::ostringstream strm;
266  state->state->show(strm, 0);
267  return PyString_FromString(strm.str().c_str());
268  } CATCH()
269 }
270 
271 static
272 PyObject* PyState_iter(PyObject *raw)
273 {
274  TRY {
275  return PyObject_GetIter(state->attrs);
276  }CATCH()
277 }
278 
279 static
280 Py_ssize_t PyState_len(PyObject *raw)
281 {
282  TRY{
283  return PyObject_Length(state->attrs);
284  }CATCH1(-1)
285 }
286 
287 static PySequenceMethods PyState_seq = {
288  &PyState_len
289 };
290 
291 static
292 PyObject* PyState_clone(PyObject *raw, PyObject *unused)
293 {
294  TRY {
295  std::unique_ptr<StateBase> newstate(state->state->clone());
296 
297  PyObject *ret = wrapstate(newstate.get());
298  newstate.release();
299  return ret;
300  } CATCH()
301 }
302 
303 static
304 PyObject* PyState_show(PyObject *raw, PyObject *args, PyObject *kws)
305 {
306  TRY {
307  unsigned long level = 1;
308  const char *names[] = {"level", NULL};
309  if(!PyArg_ParseTupleAndKeywords(args, kws, "|k", (char**)names, &level))
310  return NULL;
311 
312  std::ostringstream strm;
313  state->state->show(strm, level);
314  return PyString_FromString(strm.str().c_str());
315  } CATCH()
316 }
317 
318 static PyMethodDef PyState_methods[] = {
319  {"clone", (PyCFunction)&PyState_clone, METH_NOARGS,
320  "clone()\n\n"
321  "Returns a new State instance which is a copy of this one"
322  },
323  {"show", (PyCFunction)&PyState_show, METH_VARARGS|METH_KEYWORDS,
324  "show(level=1)"
325  },
326  {NULL, NULL, 0, NULL}
327 };
328 
329 static PyTypeObject PyStateType = {
330 #if PY_MAJOR_VERSION >= 3
331  PyVarObject_HEAD_INIT(NULL, 0)
332 #else
333  PyObject_HEAD_INIT(NULL)
334  0,
335 #endif
336  "flame._internal.State",
337  sizeof(PyState),
338 };
339 
340 } // namespace
341 
342 PyObject* wrapstate(StateBase* b)
343 {
344  try {
345 
346  PyRef<PyState> state(PyStateType.tp_alloc(&PyStateType, 0));
347 
348  state->state = b;
349  state->attrs = state->weak = state->dict = 0;
350 
351  state->attrs = PyDict_New();
352  if(!state->attrs)
353  return NULL;
354 
355  for(unsigned i=0; true; i++)
356  {
358 
359  if(!b->getArray(i, info))
360  break;
361 
362  bool skip = info.ndim>3;
363  switch(info.type) {
364  case StateBase::ArrayInfo::Double:
365  case StateBase::ArrayInfo::Sizet:
366  break;
367  default:
368  skip = true;
369  }
370 
371  if(skip) continue;
372 
373  PyRef<> name(PyInt_FromLong(i));
374  if(PyDict_SetItemString(state->attrs, info.name, name.py()))
375  throw std::runtime_error("Failed to insert into Dict");
376 
377  }
378 
379  return state.releasePy();
380  } CATCH()
381 }
382 
383 
384 StateBase* unwrapstate(PyObject* raw)
385 {
386  if(!PyObject_TypeCheck(raw, &PyStateType))
387  throw std::invalid_argument("Argument is not a State");
388  PyState *state = (PyState*)raw;
389  return state->state;
390 }
391 
392 static const char pymdoc[] =
393  "The interface to a sub-class of C++ StateBase.\n"
394  "Can't be constructed from python, see Machine.allocState()\n"
395  "\n"
396  "Provides access to some C++ member variables via the Machine::getArray() interface.\n"
397  ;
398 
399 int registerModState(PyObject *mod)
400 {
401  PyStateType.tp_doc = pymdoc;
402 
403  PyStateType.tp_str = &PyState_str;
404  PyStateType.tp_repr = &PyState_str;
405  PyStateType.tp_dealloc = &PyState_free;
406 
407  PyStateType.tp_iter = &PyState_iter;
408  PyStateType.tp_as_sequence = &PyState_seq;
409 
410  PyStateType.tp_weaklistoffset = offsetof(PyState, weak);
411  PyStateType.tp_traverse = &PyState_traverse;
412  PyStateType.tp_clear = &PyState_clear;
413 
414  PyStateType.tp_dictoffset = offsetof(PyState, dict);
415  PyStateType.tp_getattro = &PyState_getattro;
416  PyStateType.tp_setattro = &PyState_setattro;
417 
418  PyStateType.tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC;
419  PyStateType.tp_methods = PyState_methods;
420 
421  if(PyType_Ready(&PyStateType))
422  return -1;
423 
424  Py_INCREF(&PyStateType);
425  if(PyModule_AddObject(mod, "State", (PyObject*)&PyStateType)) {
426  Py_DECREF(&PyStateType);
427  return -1;
428  }
429 
430  return 0;
431 }
The abstract base class for all simulation state objects.
Definition: base.h:29
Used with StateBase::getArray() to describe a single parameter.
Definition: base.h:51
Helper to step through the indicies of an Nd array.
Definition: util.h:112
size_t dim[maxdims]
Array dimensions in elements.
Definition: base.h:69
size_t stride[maxdims]
Array strides in bytes.
Definition: base.h:71
virtual bool getArray(unsigned index, ArrayInfo &Info)
Introspect named parameter of the derived class.
Definition: base.cpp:37
const char * name
The parameter name.
Definition: base.h:55
unsigned ndim
Definition: base.h:67