一文彻底学会Python的Dataclass
Dataclass是什么?
一句话,dataclass
是一个代码生成器code generator
, 一个用来生成python class的代码生成器。如果你用过Python的collection模块namedtuple,就会比较容易理解代码生成器的概念。
class Student:
def __init__(self, name, age):
self.name = name
self.age = age
s1 = Student('Jack', 30)
s2 = Student('Tim', 20)
print(s1.name)
print(s2.age)
# namedtuple可以简化一个class的创建
from collections import namedtuple
Student = namedtuple('Student',['name', 'age'])
s3 = Student('Paul', 55)
s4 = Student('Nico', 40)
print(s3.name)
print(s4.age)
Dataclass基础
如名字dataclasses
所示,dataclasses 是用于数据结构的类(在某些方面类似于namedtuple,但dataclasses 提供了更多的可能,因为它们归根结底还是常规的 Python 类,而不是元组)
让我们看看如何使用 dataclasses 生成标准的 Python 类,看看它可以帮我们简化多少工作量。
作为一个例子,我们先创建一个标准的position
类,这个类有两个实例属性,x
和y
,都是float
类型。
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self.x = x
self.y = y
创建一个实例
>>> p = Position()
>>> p
<__main__.Position object at 0x10459ae40>
>>> print(p)
<__main__.Position object at 0x10459ae40>
>>>
那为了让它打印出来的内容更好看一些,我们一般可以给这个Class添加__str__
和__repr__
方法
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self.x = x
self.y = y
def __str__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __repr__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
>>> p = Position()
>>> p
Position(x=0, y=0)
>>> print(p)
Position(x=0, y=0)
拿上面这个Python类,如果用dataclass实现,就会简单很多,我们看一下。
from dataclasses import dataclass
@dataclass
class Position:
x: float = 0
y: float = 0
>>> p = Position()
>>> p
Position(x=0, y=0)
>>> print(p)
Position(x=0, y=0)
就这么简单几行代码,实现了前面的一个传统的Python类。
那为什么说dataclass是一个代码生成器
code generator
呢?因为它实际上是通过一个装饰器@dataclass
对我们原始定义的类进行了一个修改,然后返回了这个修改后的同一个类,有点拗口,我们证明一下
首先,我们定一个没有装饰器的传统python类
class Position:
x: float = 0
y: float = 0
通过id可以查看这个类在内存中的位置
>>> id(Position)
409177792
然后我们把这个Position类作为参数传递给dataclass, 返回一个新类 NewPosition ,然后查看它内存的位置
>>> NewPosition = dataclass(Position)
>>> id(NewPosition)
409177792
在内存中的位置完全一样,这说明dataclass装饰器并没有创建一个新类,而是在原先类的基础上进行修改。
__eq__
比较
dataclass不光帮我们默认实现了__str__
和__repr__
方法,而且连__eq__
默认也实现了,而且比较的原则是已实例的属性值为依据,而不是传统class默认的以实例的ID为依据
@dataclass
class Position:
x: float = 0
y: float = 0
>>> p1 = Position(x=1, y=-2)
>>> p2 = Position(x=1, y=-2)
>>> p1 == p2
True
传统class要实现这一的比较,需要我们自己实现__eq__
方法,如下:
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self.x = x
self.y = y
def __str__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __repr__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
默认比较是比较两个实例的ID,所以肯定是不同的
>>> p1 = Position(x=1, y=-2)
>>> p2 = Position(x=1, y=-2)
>>> p1 == p2
False
添加__eq__
方法后,才实现dataclass的效果
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self.x = x
self.y = y
def __str__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __repr__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) == (other.x, other.y)
return NotImplemented
>>> p1 = Position(x=1, y=-2)
>>> p2 = Position(x=1, y=-2)
>>> p1 == p2
True
p1 == p2
和p1 is p2
不是一回事,请大家接着看。
__hash__
方法
如果我们手动添加__eq__
方法,那么__hash__
方法我们也需要重写,为啥呢?
what is __hash__
首先我们先解释了__hash__
方法是干啥的。每一个python object都有一个属性 __hash__
>>> dir(object)
['__class__',
'__delattr__',
'__dir__',
'__doc__',
'__eq__',
'__format__',
'__ge__',
'__getattribute__',
'__getstate__',
'__gt__',
'__hash__',
'__init__',
'__init_subclass__',
'__le__',
'__lt__',
'__ne__',
'__new__',
'__reduce__',
'__reduce_ex__',
'__repr__',
'__setattr__',
'__sizeof__',
'__str__',
'__subclasshook__']
但并不代表每一个python object都是可哈希的。比如一个list,就不是可哈希的,因为list是可变数据类型,它的__hash__
属性是None
>>> a = [1, 2]
>>> print(a.__hash__)
None
>>> hash(a)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'list'
如果一个python object不可哈希,它也就不能作为python dict的key,或者set的元素。
>>> a = [1, 2]
>>> b = {a: 1}
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'list'
Python默认class是可哈希的
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self.x = x
self.y = y
>>> p1 = Position(1,1)
>>> p2 = Position(1,1)
>>> p1 == p2
False
>>>> p1 is p2
False
>>> hash(p1), hash(p2) # 对各自对象所在的内容地址进行hash
(17592085467597, 17592085467665)
>>> a = {p1:p2} # 可哈希,所以就可以作为dict的key
为啥class添加了__eq__
就不可哈希了呢?
一句话总结,就是出现了矛盾,当我们添加了我们的__eq__
方法后,
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self.x = x
self.y = y
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) == (other.x, other.y)
return NotImplemented
p1是等于p2了,但是p1和p2的内存地址还是不同的,导致两个对象如果按照内存地址算hash,hash就不同了。
>>> p1 == p2
True
>>> p1 is p2
False
为避免出现这个矛盾,所以干脆就让这个class的__hash___
是None
,也就是不可哈希了
>>> hash(p1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'Position'
所以回到前面,当我们重写了__eq__
方法后,__hash___
方法也需要重写。特别是,我们如果需要我们的class是可哈希的话。简单重写如下:
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self.x = x
self.y = y
def __str__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __repr__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) == (other.x, other.y)
return NotImplemented
def __hash__(self):
return hash((self.x, self.y))
这时,它们就是可哈希的,而且我们用x和y作为一个元祖值去计算哈希,所以当两个对象的x和y相同时,它们的哈希值也一样。
>>> p1 = Position(1,1)
>>> p2 = Position(1,1)
>>> p1 == p2
True
>>> p1 is p2
False
>>> hash(p1), hash(p2)
(8389048192121911274, 8389048192121911274)
当然也可以作为dict的key出现,因为p1和p2的哈希一样,所以就是重复的key
>>> p1 = Position(1,1)
>>> p2 = Position(1,1)
>>>
>>> d = {p1: 1, p2: 2}
>>> d
{Position(x=1, y=1): 2}
还有一个潜在问题
一般来说,可哈希的对象,一般是immutable的。但是我们这个position对象,却可以改变它的x和y值
比如我们可以改变p1的x值,如此又导致了各种矛盾问题。
>>> p1.x
1
>>> p1.x = 2
>>> d
{Position(x=2, y=1): 2}
>>> p1 == p2
False
>>> hash(p1), hash(p2)
(6794810172467074373, 8389048192121911274)
>>>
你还可以把p1的值再改回去,等等会出现各种奇怪现象,那为了尽量避免这个现象,我们需要引入property
来保护一下我们的x和y
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self._x = x
self._y = y
@property
def x(self):
return self._x
@property
def y(self):
return self._y
def __str__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __repr__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) == (other.x, other.y)
return NotImplemented
def __hash__(self):
return hash((self.x, self.y))
>>> p1 = Position(1,1)
>>> p2 = Position(1,1)
>>> p1 == p2
True
>>> d = {p1: 1}
>>> p1.x
1
>>> p1.x = 2
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: property 'x' of 'Position' object has no setter
OK,上面我们花了这么大的篇幅,来实现一个Position
类,而且很容易出错,或者忘记,但是如果用dataclass
,就可以非常简单了,只需要下面几行代码就可以实现我们上面一个包含7个def的class
用dataclass四行代码搞定
from dataclasses import dataclass
@dataclass(frozen=True)
class Position:
x: float = 0
y: float = 0
测试一下,所有之前需要__init__
, property
, __eq__
, __str__
, __repr__
, __hash__
等 ,全部省略,dataclass都帮我们做了。
>>> p1 = Position(1, 1)
>>> p2 = Position(1, 1)
>>> p1 == p2
True
>>> hash(p1), hash(p2)
(8389048192121911274, 8389048192121911274)
>>>
>>> d = {p1: 1}
>>> p1.x
1
>>> p1.x = 2
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<string>", line 4, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'x'
>>>
比较和排序
想要dataclass支持排序是非常简单的,只需要加上order=True
from dataclasses import dataclass
@dataclass(frozen=True, order=True)
class Position:
x: float = 0
y: float = 0
>>> p1 = Position(1, 2)
>>> p2 = Position(1, 3)
>>> p1 == p2, p1 > p2, p1 < p2, p1>=p2, p1<=p2
(False, False, True, False, True)
>>> sorted([p1, p2])
[Position(x=1, y=2), Position(x=1, y=3)]
dataclass默认的比较是根据tuple (x, y)的大小比较,实现简单(dataclass的比较函数我们是可以重置的,也就是根据具体需求去重新自己实现,但是这里我们就不演示了
),但是即使如此,传统class要实现这个功能还是需要很多额外代码的,这里我们直接把实现贴出来(这里用了total_ordering简化了步骤,否则我们要实现的内置函数会更多,大于,大于等于,小于,小于等于)。
from functools import total_ordering
@total_ordering
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self._x = x
self._y = y
@property
def x(self):
return self._x
@property
def y(self):
return self._y
def __str__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __repr__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) == (other.x, other.y)
return NotImplemented
def __hash__(self):
return hash((self.x, self.y))
def __lt__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) < (other.x, other.y)
return NotImplemented
测试
>>> p1 = Position(1, 2)
>>> p2 = Position(1, 3)
>>> p1 == p2, p1 > p2, p1 < p2, p1>=p2, p1<=p2
(False, False, True, False, True)
>>> sorted([p1, p2])
[Position(x=1, y=2), Position(x=1, y=3)]
>>>
序列化成字典或者元组
dataclass的实例可以非常方便的序列化成字典或者元组。
from dataclasses import dataclass, asdict, astuple
@dataclass(frozen=True, order=True)
class Position:
x: float = 0
y: float = 0
转化成字典或者元组
>>> p = Position(1, 2)
>>> asdict(p)
{'x': 1, 'y': 2}
>>> astuple(p)
(1, 2)
>>>
而传统的class要实现这个功能则需要我们自己写代码,比如
from functools import total_ordering
@total_ordering
class Position:
def __init__(self, x: float = 0, y: float = 0,):
self._x = x
self._y = y
@property
def x(self):
return self._x
@property
def y(self):
return self._y
def __str__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __repr__(self):
return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) == (other.x, other.y)
return NotImplemented
def __hash__(self):
return hash((self.x, self.y))
def __lt__(self, other):
if self.__class__ == other.__class__:
return (self.x, self.y) < (other.x, other.y)
return NotImplemented
def asdict(self):
return {
'x': self.x,
'y': self.y
}
def astuple(self):
return (self.x, self.y)
测试
>>> p = Position(1, 2)
>>> p
Position(x=1, y=2)
>>> p.asdict()
{'x': 1, 'y': 2}
>>> p.astuple()
(1, 2)
只允许关键字参数
在dataclass里通过添加参数kw_only=True
,可以限制实例初始化的时候只允许关键字参数(keyword only argument
),而不能使用位置参数(positional argument
).
from dataclasses import dataclass, asdict, astuple
@dataclass(frozen=True, order=True,kw_only=True)
class Position:
x: float = 0
y: float = 0
测试
>>> p = Position(1, 2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: Position.__init__() takes 1 positional argument but 3 were given
>>> p = Position(x=1, y=2)
>>> p
Position(x=1, y=2)
>>>
当然,也可以让部分参数是keyword only argument (比如,只让y是keyword only argument)
from dataclasses import dataclass, KW_ONLY
@dataclass(frozen=True, order=True)
class Position:
x: float = 0
_: KW_ONLY
y: float = 0
测试
>>> p = Position(1, 2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: Position.__init__() takes from 1 to 2 positional arguments but 3 were given
>>> p = Position(1, y=2)
>>> p
Position(x=1, y=2)
>>>
Discussion