一文彻底学会Python的Dataclass

一文彻底学会Python的Dataclass

Dataclass是什么?

一句话,dataclass是一个代码生成器code generator, 一个用来生成python class的代码生成器。如果你用过Python的collection模块namedtuple,就会比较容易理解代码生成器的概念。

class Student:

    def __init__(self, name, age):
        self.name = name
        self.age = age


s1 = Student('Jack', 30)
s2 = Student('Tim', 20)

print(s1.name)
print(s2.age)

# namedtuple可以简化一个class的创建

from collections import namedtuple


Student = namedtuple('Student',['name', 'age'])

s3 = Student('Paul', 55)
s4 = Student('Nico', 40)

print(s3.name)
print(s4.age)

Dataclass基础

如名字dataclasses所示,dataclasses 是用于数据结构的类(在某些方面类似于namedtuple,但dataclasses 提供了更多的可能,因为它们归根结底还是常规的 Python 类,而不是元组)

让我们看看如何使用 dataclasses 生成标准的 Python 类,看看它可以帮我们简化多少工作量。

作为一个例子,我们先创建一个标准的position类,这个类有两个实例属性,xy,都是float类型。

class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self.x = x
        self.y = y

创建一个实例

>>> p = Position()
>>> p
<__main__.Position object at 0x10459ae40>
>>> print(p)
<__main__.Position object at 0x10459ae40>
>>>

那为了让它打印出来的内容更好看一些,我们一般可以给这个Class添加__str____repr__方法

class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self.x = x
        self.y = y
    def __str__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
    def __repr__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
>>> p = Position()
>>> p
Position(x=0, y=0)
>>> print(p)
Position(x=0, y=0)

拿上面这个Python类,如果用dataclass实现,就会简单很多,我们看一下。

from dataclasses import dataclass

@dataclass
class Position:
    x: float = 0
    y: float = 0
>>> p = Position()
>>> p
Position(x=0, y=0)
>>> print(p)
Position(x=0, y=0)

就这么简单几行代码,实现了前面的一个传统的Python类。

那为什么说dataclass是一个代码生成器code generator呢?因为它实际上是通过一个装饰器 @dataclass 对我们原始定义的类进行了一个修改,然后返回了这个修改后的同一个类,有点拗口,我们证明一下

首先,我们定一个没有装饰器的传统python类

class Position:
    x: float = 0
    y: float = 0

通过id可以查看这个类在内存中的位置

>>> id(Position)
409177792

然后我们把这个Position类作为参数传递给dataclass, 返回一个新类 NewPosition ,然后查看它内存的位置

>>> NewPosition = dataclass(Position)
>>> id(NewPosition)
409177792

在内存中的位置完全一样,这说明dataclass装饰器并没有创建一个新类,而是在原先类的基础上进行修改。

__eq__ 比较

dataclass不光帮我们默认实现了__str____repr__方法,而且连__eq__ 默认也实现了,而且比较的原则是已实例的属性值为依据,而不是传统class默认的以实例的ID为依据

@dataclass
class Position:
    x: float = 0
    y: float = 0
>>> p1 = Position(x=1, y=-2)
>>> p2 = Position(x=1, y=-2)
>>> p1 == p2
True

传统class要实现这一的比较,需要我们自己实现__eq__方法,如下:

class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self.x = x
        self.y = y
    def __str__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
    def __repr__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"

默认比较是比较两个实例的ID,所以肯定是不同的

>>> p1 = Position(x=1, y=-2)
>>> p2 = Position(x=1, y=-2)
>>> p1 == p2
False

添加__eq__方法后,才实现dataclass的效果

class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self.x = x
        self.y = y
    def __str__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
    def __repr__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
    def __eq__(self, other):
        if self.__class__ == other.__class__:
            return (self.x, self.y) == (other.x, other.y)
        return NotImplemented
>>> p1 = Position(x=1, y=-2)
>>> p2 = Position(x=1, y=-2)
>>> p1 == p2
True

p1 == p2p1 is p2 不是一回事,请大家接着看。

__hash__ 方法

如果我们手动添加__eq__方法,那么__hash__ 方法我们也需要重写,为啥呢?

what is __hash__

首先我们先解释了__hash__ 方法是干啥的。每一个python object都有一个属性 __hash__

>>> dir(object)
['__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__']

但并不代表每一个python object都是可哈希的。比如一个list,就不是可哈希的,因为list是可变数据类型,它的__hash__属性是None

>>> a = [1, 2]
>>> print(a.__hash__)
None
>>> hash(a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'list'

如果一个python object不可哈希,它也就不能作为python dict的key,或者set的元素。

>>> a = [1, 2]
>>> b = {a: 1}
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'list'

Python默认class是可哈希的

class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self.x = x
        self.y = y
>>> p1 = Position(1,1)
>>> p2 = Position(1,1)
>>> p1 == p2
False
>>>> p1 is p2
False
>>> hash(p1), hash(p2)  # 对各自对象所在的内容地址进行hash
(17592085467597, 17592085467665)
>>> a = {p1:p2}  # 可哈希,所以就可以作为dict的key

为啥class添加了__eq__就不可哈希了呢?

一句话总结,就是出现了矛盾,当我们添加了我们的__eq__方法后,

class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self.x = x
        self.y = y
    def __eq__(self, other):
        if self.__class__ == other.__class__:
            return (self.x, self.y) == (other.x, other.y)
        return NotImplemented

p1是等于p2了,但是p1和p2的内存地址还是不同的,导致两个对象如果按照内存地址算hash,hash就不同了。

>>> p1 == p2
True
>>> p1 is p2
False

为避免出现这个矛盾,所以干脆就让这个class的__hash___None,也就是不可哈希了

>>> hash(p1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'Position'

所以回到前面,当我们重写了__eq__方法后,__hash___方法也需要重写。特别是,我们如果需要我们的class是可哈希的话。简单重写如下:

class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self.x = x
        self.y = y
    def __str__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
    def __repr__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
        
    def __eq__(self, other):
        if self.__class__ == other.__class__:
            return (self.x, self.y) == (other.x, other.y)
        return NotImplemented
    def __hash__(self):
        return hash((self.x, self.y))

这时,它们就是可哈希的,而且我们用x和y作为一个元祖值去计算哈希,所以当两个对象的x和y相同时,它们的哈希值也一样。

>>> p1 = Position(1,1)
>>> p2 = Position(1,1)
>>> p1 == p2
True
>>> p1 is p2
False
>>> hash(p1), hash(p2)
(8389048192121911274, 8389048192121911274)

当然也可以作为dict的key出现,因为p1和p2的哈希一样,所以就是重复的key

>>> p1 = Position(1,1)
>>> p2 = Position(1,1)
>>>
>>> d = {p1: 1, p2: 2}
>>> d
{Position(x=1, y=1): 2}

还有一个潜在问题

一般来说,可哈希的对象,一般是immutable的。但是我们这个position对象,却可以改变它的x和y值

比如我们可以改变p1的x值,如此又导致了各种矛盾问题。

>>> p1.x
1
>>> p1.x = 2
>>> d
{Position(x=2, y=1): 2}
>>> p1 == p2
False
>>> hash(p1), hash(p2)
(6794810172467074373, 8389048192121911274)
>>>

你还可以把p1的值再改回去,等等会出现各种奇怪现象,那为了尽量避免这个现象,我们需要引入property来保护一下我们的x和y

class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self._x = x
        self._y = y
    @property
    def x(self):
        return self._x
    @property
    def y(self):
        return self._y
    def __str__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
    def __repr__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
        
    def __eq__(self, other):
        if self.__class__ == other.__class__:
            return (self.x, self.y) == (other.x, other.y)
        return NotImplemented
    def __hash__(self):
        return hash((self.x, self.y))
>>> p1 = Position(1,1)
>>> p2 = Position(1,1)
>>> p1 == p2
True
>>> d = {p1: 1}
>>> p1.x
1
>>> p1.x = 2
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: property 'x' of 'Position' object has no setter

OK,上面我们花了这么大的篇幅,来实现一个Position类,而且很容易出错,或者忘记,但是如果用dataclass,就可以非常简单了,只需要下面几行代码就可以实现我们上面一个包含7个def的class

用dataclass四行代码搞定

from dataclasses import dataclass

@dataclass(frozen=True)
class Position:
    x: float = 0
    y: float = 0

测试一下,所有之前需要__init__, property, __eq__, __str__, __repr__, __hash__等 ,全部省略,dataclass都帮我们做了。

>>> p1 = Position(1, 1)
>>> p2 = Position(1, 1)
>>> p1 == p2
True
>>> hash(p1), hash(p2)
(8389048192121911274, 8389048192121911274)
>>>
>>> d = {p1: 1}
>>> p1.x
1
>>> p1.x = 2
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 4, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'x'
>>>

比较和排序

想要dataclass支持排序是非常简单的,只需要加上order=True

from dataclasses import dataclass

@dataclass(frozen=True, order=True)
class Position:
    x: float = 0
    y: float = 0
>>> p1 = Position(1, 2)
>>> p2 = Position(1, 3)
>>> p1 == p2, p1 > p2, p1 < p2, p1>=p2, p1<=p2
(False, False, True, False, True)
>>> sorted([p1, p2])
[Position(x=1, y=2), Position(x=1, y=3)]

dataclass默认的比较是根据tuple (x, y)的大小比较,实现简单(dataclass的比较函数我们是可以重置的,也就是根据具体需求去重新自己实现,但是这里我们就不演示了),但是即使如此,传统class要实现这个功能还是需要很多额外代码的,这里我们直接把实现贴出来(这里用了total_ordering简化了步骤,否则我们要实现的内置函数会更多,大于,大于等于,小于,小于等于)。

from functools import total_ordering

@total_ordering
class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self._x = x
        self._y = y
    @property
    def x(self):
        return self._x
    @property
    def y(self):
        return self._y
    def __str__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
    def __repr__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
        
    def __eq__(self, other):
        if self.__class__ == other.__class__:
            return (self.x, self.y) == (other.x, other.y)
        return NotImplemented
    def __hash__(self):
        return hash((self.x, self.y))
    def __lt__(self, other):
        if self.__class__ == other.__class__:
            return (self.x, self.y) < (other.x, other.y)
        return NotImplemented

测试

>>> p1 = Position(1, 2)
>>> p2 = Position(1, 3)
>>> p1 == p2, p1 > p2, p1 < p2, p1>=p2, p1<=p2
(False, False, True, False, True)
>>> sorted([p1, p2])
[Position(x=1, y=2), Position(x=1, y=3)]
>>>

序列化成字典或者元组

dataclass的实例可以非常方便的序列化成字典或者元组。

from dataclasses import dataclass, asdict, astuple

@dataclass(frozen=True, order=True)
class Position:
    x: float = 0
    y: float = 0

转化成字典或者元组

>>> p = Position(1, 2)
>>> asdict(p)
{'x': 1, 'y': 2}
>>> astuple(p)
(1, 2)
>>>

而传统的class要实现这个功能则需要我们自己写代码,比如

from functools import total_ordering

@total_ordering
class Position:
    def __init__(self, x: float = 0, y: float = 0,):
        self._x = x
        self._y = y
    @property
    def x(self):
        return self._x
    @property
    def y(self):
        return self._y
    def __str__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
    def __repr__(self):
        return f"{self.__class__.__qualname__}(x={self.x}, y={self.y})"
        
    def __eq__(self, other):
        if self.__class__ == other.__class__:
            return (self.x, self.y) == (other.x, other.y)
        return NotImplemented
    def __hash__(self):
        return hash((self.x, self.y))
    def __lt__(self, other):
        if self.__class__ == other.__class__:
            return (self.x, self.y) < (other.x, other.y)
        return NotImplemented
    def asdict(self):
        return {
            'x': self.x,
            'y': self.y
        }
    def astuple(self):
        return (self.x, self.y)

测试

>>> p = Position(1, 2)
>>> p
Position(x=1, y=2)
>>> p.asdict()
{'x': 1, 'y': 2}
>>> p.astuple()
(1, 2)

只允许关键字参数

在dataclass里通过添加参数kw_only=True,可以限制实例初始化的时候只允许关键字参数(keyword only argument),而不能使用位置参数(positional argument).

from dataclasses import dataclass, asdict, astuple

@dataclass(frozen=True, order=True,kw_only=True)
class Position:
    x: float = 0
    y: float = 0

测试

>>> p = Position(1, 2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Position.__init__() takes 1 positional argument but 3 were given
>>> p = Position(x=1, y=2)
>>> p
Position(x=1, y=2)
>>>

当然,也可以让部分参数是keyword only argument (比如,只让y是keyword only argument)

from dataclasses import dataclass, KW_ONLY

@dataclass(frozen=True, order=True)
class Position:
    x: float = 0
    _: KW_ONLY
    y: float = 0

测试

>>> p = Position(1, 2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Position.__init__() takes from 1 to 2 positional arguments but 3 were given
>>> p = Position(1, y=2)
>>> p
Position(x=1, y=2)
>>>

Discussion