5 #include "flame/base.h"
8 #define NO_IMPORT_ARRAY
9 #define PY_ARRAY_UNIQUE_SYMBOL FLAME_PyArray_API
10 #define NPY_NO_DEPRECATED_API NPY_1_6_API_VERSION
11 #include <numpy/ndarrayobject.h>
13 #if SIZE_MAX==NPY_MAX_UINT32
14 #define NPY_SIZE_T NPY_UINT32
15 #elif SIZE_MAX==NPY_MAX_UINT64
16 #define NPY_SIZE_T NPY_UINT64
18 #error logic error with SIZE_MAX
21 #define TRY PyState *state = (PyState*)raw; try
27 PyObject *dict, *weak;
33 int PyState_traverse(PyObject *raw, visitproc visit,
void *arg)
35 PyState *state = (PyState*)raw;
36 Py_VISIT(state->attrs);
37 Py_VISIT(state->dict);
42 int PyState_clear(PyObject *raw)
44 PyState *state = (PyState*)raw;
45 Py_CLEAR(state->dict);
46 Py_CLEAR(state->attrs);
51 void PyState_free(PyObject *raw)
54 std::unique_ptr<StateBase> S(state->state);
58 PyObject_ClearWeakRefs(raw);
62 Py_TYPE(raw)->tp_free(raw);
63 } CATCH2V(std::exception, RuntimeError)
67 PyObject *PyState_getattro(PyObject *raw, PyObject *attr)
70 PyObject *idx = PyDict_GetItem(state->attrs, attr);
72 return PyObject_GenericGetAttr(raw, attr);
74 int i = PyInt_AsLong(idx);
79 if(!state->state->getArray(i, info))
80 return PyErr_Format(PyExc_RuntimeError,
"invalid attribute name (sub-class forgot %d)", i);
84 case StateBase::ArrayInfo::Double:
85 return PyFloat_FromDouble(*(
double*)info.
ptr);
86 case StateBase::ArrayInfo::Sizet:
87 return PyLong_FromSize_t(*(
size_t*)info.
ptr);
89 return PyErr_Format(PyExc_TypeError,
"unsupported type code %d", info.type);
94 case StateBase::ArrayInfo::Double: pytype = NPY_DOUBLE;
break;
95 case StateBase::ArrayInfo::Sizet: pytype = NPY_SIZE_T;
break;
97 return PyErr_Format(PyExc_TypeError,
"unsupported type code %d", info.type);
100 npy_intp dims[StateBase::ArrayInfo::maxdims];
102 info.
dim+StateBase::ArrayInfo::maxdims,
107 PyRef<PyArrayObject> obj(PyArray_SimpleNew(info.
ndim, dims, pytype));
111 pyinfo.
ptr = PyArray_BYTES(obj.py());
112 pyinfo.
ndim= PyArray_NDIM(obj.get());
113 std::copy(PyArray_DIMS(obj.get()),
114 PyArray_DIMS(obj.get())+pyinfo.
ndim,
116 std::copy(PyArray_STRIDES(obj.get()),
117 PyArray_STRIDES(obj.get())+pyinfo.
ndim,
122 for(; !idxiter.done; idxiter.next()) {
123 void *dest = pyinfo.raw(idxiter.index);
124 const void *src = info .raw(idxiter.index);
127 case StateBase::ArrayInfo::Double: *(
double*)dest = *(
double*)src;
break;
128 case StateBase::ArrayInfo::Sizet: *(
size_t*)dest = *(
size_t*)src;
break;
132 return obj.releasePy();
137 int PyState_setattro(PyObject *raw, PyObject *attr, PyObject *val)
140 PyObject *idx = PyDict_GetItem(state->attrs, attr);
142 return PyObject_GenericSetAttr(raw, attr, val);
143 int i = PyInt_AsLong(idx);
147 if(!state->state->getArray(i, info)) {
148 PyErr_Format(PyExc_RuntimeError,
"invalid attribute name (sub-class forgot %d)", i);
156 case StateBase::ArrayInfo::Double: {
157 double *dest = (
double*)info.
ptr;
158 if(PyFloat_Check(val))
159 *dest = PyFloat_AsDouble(val);
160 else if(PyLong_Check(val))
161 *dest = PyLong_AsDouble(val);
162 else if(PyInt_Check(val))
163 *dest = PyInt_AsLong(val);
165 PyErr_Format(PyExc_ValueError,
"Can't assign to double field");
168 case StateBase::ArrayInfo::Sizet: {
169 size_t *dest = (
size_t*)info.
ptr;
170 if(PyFloat_Check(val))
171 *dest = PyFloat_AsDouble(val);
172 else if(PyLong_Check(val))
173 *dest = PyLong_AsUnsignedLongLong(val);
174 else if(PyInt_Check(val))
175 *dest = PyInt_AsLong(val);
177 PyErr_Format(PyExc_ValueError,
"Can't assign to double field");
181 PyErr_Format(PyExc_TypeError,
"unsupported type code %d", info.type);
184 return PyErr_Occurred() ? -1 : 0;
190 case StateBase::ArrayInfo::Double: pytype = NPY_DOUBLE;
break;
191 case StateBase::ArrayInfo::Sizet: pytype = NPY_SIZE_T;
break;
193 PyErr_Format(PyExc_TypeError,
"unsupported type code %d", info.type);
199 PyRef<PyArrayObject> arr(PyArray_FromObject(val, pytype, info.
ndim, info.
ndim));
201 if(info.
ndim!=(
size_t)PyArray_NDIM(arr.py())) {
202 PyErr_Format(PyExc_ValueError,
"cardinality don't match");
204 }
else if(!std::equal(info.
dim, info.
dim+info.
ndim,
205 PyArray_DIMS(arr.py()))) {
206 PyErr_Format(PyExc_ValueError,
"shape does not match don't match");
212 pyinfo.
ptr = PyArray_BYTES(arr.py());
213 pyinfo.
ndim= PyArray_NDIM(arr.get());
214 std::copy(PyArray_DIMS(arr.get()),
215 PyArray_DIMS(arr.get())+pyinfo.
ndim,
217 std::copy(PyArray_STRIDES(arr.get()),
218 PyArray_STRIDES(arr.get())+pyinfo.
ndim,
223 for(; !idxiter.done; idxiter.next()) {
224 const void *src = pyinfo .raw(idxiter.index);
225 void *dest = info.raw(idxiter.index);
228 case StateBase::ArrayInfo::Double: *(
double*)dest = *(
double*)src;
break;
229 case StateBase::ArrayInfo::Sizet: *(
size_t*)dest = *(
size_t*)src;
break;
235 for(
size_t i=0; i<info.
dim[0]; i++) {
236 const void *src = PyArray_GETPTR1(arr.py(), i);
237 void *dest = info.raw(&i);
239 case StateBase::ArrayInfo::Double: *(
double*)dest = *(
double*)src;
break;
240 case StateBase::ArrayInfo::Sizet: *(
size_t*)dest = *(
size_t*)src;
break;
243 }
else if(info.
ndim==2) {
245 for(idx[0]=0; idx[0]<info.
dim[0]; idx[0]++) {
246 for(idx[1]=0; idx[1]<info.
dim[1]; idx[1]++) {
247 const void *src = PyArray_GETPTR2(arr.py(), idx[0], idx[1]);
248 void *dest = info.raw(idx);
250 case StateBase::ArrayInfo::Double: *(
double*)dest = *(
double*)src;
break;
251 case StateBase::ArrayInfo::Sizet: *(
size_t*)dest = *(
size_t*)src;
break;
258 } CATCH3(std::exception, RuntimeError, -1)
262 PyObject* PyState_str(PyObject *raw)
265 std::ostringstream strm;
266 state->state->show(strm, 0);
267 return PyString_FromString(strm.str().c_str());
272 PyObject* PyState_iter(PyObject *raw)
275 return PyObject_GetIter(state->attrs);
280 Py_ssize_t PyState_len(PyObject *raw)
283 return PyObject_Length(state->attrs);
287 static PySequenceMethods PyState_seq = {
292 PyObject* PyState_clone(PyObject *raw, PyObject *unused)
295 std::unique_ptr<StateBase> newstate(state->state->clone());
297 PyObject *ret = wrapstate(newstate.get());
304 PyObject* PyState_show(PyObject *raw, PyObject *args, PyObject *kws)
307 unsigned long level = 1;
308 const char *names[] = {
"level", NULL};
309 if(!PyArg_ParseTupleAndKeywords(args, kws,
"|k", (
char**)names, &level))
312 std::ostringstream strm;
313 state->state->show(strm, level);
314 return PyString_FromString(strm.str().c_str());
318 static PyMethodDef PyState_methods[] = {
319 {
"clone", (PyCFunction)&PyState_clone, METH_NOARGS,
321 "Returns a new State instance which is a copy of this one"
323 {
"show", (PyCFunction)&PyState_show, METH_VARARGS|METH_KEYWORDS,
326 {NULL, NULL, 0, NULL}
329 static PyTypeObject PyStateType = {
330 #if PY_MAJOR_VERSION >= 3
331 PyVarObject_HEAD_INIT(NULL, 0)
333 PyObject_HEAD_INIT(NULL)
336 "flame._internal.State",
346 PyRef<PyState> state(PyStateType.tp_alloc(&PyStateType, 0));
349 state->attrs = state->weak = state->dict = 0;
351 state->attrs = PyDict_New();
355 for(
unsigned i=0;
true; i++)
362 bool skip = info.
ndim>3;
364 case StateBase::ArrayInfo::Double:
365 case StateBase::ArrayInfo::Sizet:
373 PyRef<> name(PyInt_FromLong(i));
374 if(PyDict_SetItemString(state->attrs, info.
name, name.py()))
375 throw std::runtime_error(
"Failed to insert into Dict");
379 return state.releasePy();
386 if(!PyObject_TypeCheck(raw, &PyStateType))
387 throw std::invalid_argument(
"Argument is not a State");
388 PyState *state = (PyState*)raw;
392 static const char pymdoc[] =
393 "The interface to a sub-class of C++ StateBase.\n"
394 "Can't be constructed from python, see Machine.allocState()\n"
396 "Provides access to some C++ member variables via the Machine::getArray() interface.\n"
399 int registerModState(PyObject *mod)
401 PyStateType.tp_doc = pymdoc;
403 PyStateType.tp_str = &PyState_str;
404 PyStateType.tp_repr = &PyState_str;
405 PyStateType.tp_dealloc = &PyState_free;
407 PyStateType.tp_iter = &PyState_iter;
408 PyStateType.tp_as_sequence = &PyState_seq;
410 PyStateType.tp_weaklistoffset = offsetof(PyState, weak);
411 PyStateType.tp_traverse = &PyState_traverse;
412 PyStateType.tp_clear = &PyState_clear;
414 PyStateType.tp_dictoffset = offsetof(PyState, dict);
415 PyStateType.tp_getattro = &PyState_getattro;
416 PyStateType.tp_setattro = &PyState_setattro;
418 PyStateType.tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC;
419 PyStateType.tp_methods = PyState_methods;
421 if(PyType_Ready(&PyStateType))
424 Py_INCREF(&PyStateType);
425 if(PyModule_AddObject(mod,
"State", (PyObject*)&PyStateType)) {
426 Py_DECREF(&PyStateType);
The abstract base class for all simulation state objects.
Used with StateBase::getArray() to describe a single parameter.
Helper to step through the indicies of an Nd array.
size_t dim[maxdims]
Array dimensions in elements.
size_t stride[maxdims]
Array strides in bytes.
virtual bool getArray(unsigned index, ArrayInfo &Info)
Introspect named parameter of the derived class.
const char * name
The parameter name.