用于复杂numpy数组的Json编码器和解码器

我正在尝试对复杂的numpy数组进行JSON编码,并且我从astropy找到了一个实用程序(http://astropy.readthedocs.org/en/latest/_modules/astropy/utils/misc.html#JsonCustomEncoder)目的:

import numpy as np

class JsonCustomEncoder(json.JSONEncoder):

""" <cropped for brevity> """

def default(self, obj):

if isinstance(obj, (np.ndarray, np.number)):

return obj.tolist()

elif isinstance(obj, (complex, np.complex)):

return [obj.real, obj.imag]

elif isinstance(obj, set):

return list(obj)

elif isinstance(obj, bytes): # pragma: py3

return obj.decode()

return json.JSONEncoder.default(self, obj)

这对于复杂的numpy数组非常适用:

test = {'some_key':np.array([1+1j,2+5j, 3-4j])}

作为倾销的收益:

encoded = json.dumps(test, cls=JsonCustomEncoder)

print encoded

>>> {"some key": [[1.0, 1.0], [2.0, 5.0], [3.0, -4.0]]}

问题是,我无法自动将其读回到复杂的数组中。例如:

json.loads(encoded)

>>> {"some_key": [[1.0, 1.0], [2.0, 5.0], [3.0, -4.0]]}

你们可以帮我弄清楚覆盖加载/解码的方法,以便推断出它必须是一个复杂的数组吗?IE而不是2元素项的列表,它应该只是将它们放回复杂的数组中。JsonCustomDecoder没有default()覆盖的方法,而编码文档对我来说术语太多。

回答:

这是我根据hpaulj的回答以及他对此线程的回答改编而成的最终解决方案:https

://stackoverflow.com/a/24375113/901925

这将对嵌套在任何数据类型的字典中任意深度的数组进行编码/解码。

import base64

import json

import numpy as np

class NumpyEncoder(json.JSONEncoder):

def default(self, obj):

"""

if input object is a ndarray it will be converted into a dict holding dtype, shape and the data base64 encoded

"""

if isinstance(obj, np.ndarray):

data_b64 = base64.b64encode(obj.data)

return dict(__ndarray__=data_b64,

dtype=str(obj.dtype),

shape=obj.shape)

# Let the base class default method raise the TypeError

return json.JSONEncoder(self, obj)

def json_numpy_obj_hook(dct):

"""

Decodes a previously encoded numpy ndarray

with proper shape and dtype

:param dct: (dict) json encoded ndarray

:return: (ndarray) if input was an encoded ndarray

"""

if isinstance(dct, dict) and '__ndarray__' in dct:

data = base64.b64decode(dct['__ndarray__'])

return np.frombuffer(data, dct['dtype']).reshape(dct['shape'])

return dct

# Overload dump/load to default use this behavior.

def dumps(*args, **kwargs):

kwargs.setdefault('cls', NumpyEncoder)

return json.dumps(*args, **kwargs)

def loads(*args, **kwargs):

kwargs.setdefault('object_hook', json_numpy_obj_hook)

return json.loads(*args, **kwargs)

def dump(*args, **kwargs):

kwargs.setdefault('cls', NumpyEncoder)

return json.dump(*args, **kwargs)

def load(*args, **kwargs):

kwargs.setdefault('object_hook', json_numpy_obj_hook)

return json.load(*args, **kwargs)

if __name__ == '__main__':

data = np.arange(3, dtype=np.complex)

one_level = {'level1': data, 'foo':'bar'}

two_level = {'level2': one_level}

dumped = dumps(two_level)

result = loads(dumped)

print '\noriginal data', data

print '\nnested dict of dict complex array', two_level

print '\ndecoded nested data', result

产生输出:

original data [ 0.+0.j  1.+0.j  2.+0.j]

nested dict of dict complex array {'level2': {'level1': array([ 0.+0.j, 1.+0.j, 2.+0.j]), 'foo': 'bar'}}

decoded nested data {u'level2': {u'level1': array([ 0.+0.j, 1.+0.j, 2.+0.j]), u'foo': u'bar'}}

以上是 用于复杂numpy数组的Json编码器和解码器 的全部内容, 来源链接: utcz.com/qa/409938.html

回到顶部