编写python ORM框架


ORM全称“Object Relational Mapping”,即对象-关系映射,就是把关系数据库的一行映射为一个对象,也就是一个类对应一个表,这样,写代码更简单,不用直接操作SQL语句。

要编写一个ORM框架,所有的类都只能动态定义,因为只有使用者才能根据表的结构定义出对应的类来。在Python中类的动态定义可以由元类来实现。

元类

type

type函数不仅可以查看一个类型或变量的类型,type也能动态的创建类。type可以接受一个类的描述作为参数,然后返回一个类。

type可以像这样工作:

type(类名, 父类的元组(针对继承的情况,可以为空),包含属性的字典(名称和值))

比如下面的代码:

1
2
3
class Hello(object):
def hello(self, name='world'):
print('Hello, %s', % name)

可以手动像这样创建:

1
2
3
4
5
6
7
8
9
10
11
12
>>> def fn(self, name='world'):
... print('Hello, %s' % name)
...
>>> Hello = type('Hello', (object,), dict(hello=fn))
>>> h = Hello()
>>> h.hello()
Hello, world
>>> print(type(Hello))
<class 'type'>
>>> print(type(h))
<class '__main__.Hello'>
>>>

可以看到使用type函数创建的类和直接写class是完全一样的,因为type就是Python在背后用来创建所有类的元类。Python解释器在遇到class定义时,仅仅是扫描一下class定义的语法,然后调用type()函数创建出class。

metaclass

除了使用type()动态创建类以外,要控制类的创建行为,还可以使用metaclassmetaclass允许你创建类或者修改类。换句话说,你可以把类看成是metaclass创建出来的“实例”。

下面的代码使用元类将模块里所有的类的属性都修改为大写形式:

1
2
3
4
5
6

class UpperAttrMetaClass(type):

def __new__(cls, name, bases, attrs):
uppercase_attrs = ((name.upper(), value) for name, value in attrs.items() if not name.startswith('__'))
return type.__new__(cls, name, bases, dict(uppercase_attrs))

有了UpperAttrMetaClass,在定义类的时候还要指示使用UpperAttrMetaClass来定制类,传入关键字参数metaclass

1
2
class Foo(object, metaclass=UpperAttrMetaClass):
bar = 'bip'

当传入关键字参数metaclass时,它指示Python解释器在创建Foo类时,要通过ListMetaclass.__new__()来创建,在此,将类的属性都修改为大写形式然后,返回修改后的定义。

__new__()方法接收到的参数依次是:

  1. 当前准备创建的类的对象;

  2. 类的名字;

  3. 类继承的父类集合;

  4. 类的属性集合。

1
2
3
4
5
6
7
>>> print(hasattr(Foo, 'bar'))
False
>>> print(hasattr(Foo, 'BAR'))
True
>>> f = Foo()
>>> print(f.BAR)
bip

ORM

设计ORM需要从上层调用者角度来设计,如:定义一个User类来操作对应的数据库表User,我们期待他写出这样的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class User(Model):
"""
定义类的属性到表列字段的映射:
"""
__table__ = 'users'

id = StringField(primary_key=True, default=next_id, ddl='varchar(50)')
email = StringField(ddl='varchar(50)')
passwd = StringField(ddl='varchar(50)')
name = StringField(ddl='varchar(50)')
created_at = FloatField(default=time.time)


# 创建实例:
user = User(email='test@test.com', passwd='123456', name='test')

# 存入数据库:
user.save()

User类中的__table__idname是类的属性,不是实例的属性。所以,在类级别上定义的属性用来描述User对象和表的映射关系,而实例属性必须通过__init__()方法去初始化,所以两者互不干扰.

定义Field类

Field类,它负责保存数据库表的字段名和字段类型:

1
2
3
4
5
6
7
8
9
10
11
12
class Field(object):
"""
数据库表的字段名、字段类型、是不是主键、默认值
"""
def __init__(self, name, column_type, primary_key, default):
self.name = name
self.column_type = column_type
self.primary_key = primary_key
self.default = default

def __str__(self):
return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)

StringField、FloatField

映射varcharStringField,映射rearFloatField

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class StringField(Field):
"""
映射varchar
"""
def __init__(self, name=None, ddl='varchar(100)', primary_key=False, default=None):
super(StringField, self).__init__(name, ddl, primary_key, default)


class FloatField(Field):
"""
映射rear
"""
def __init__(self, name=None, primary_key=None, default=0.0):
super(FloatField, self).__init__(name, 'real', primary_key, default)

Field的基础上,进一步定义各种类型的Field,比如IntegerField等等。

ModelMetaclass

ModelMetaclass用于保存子类如User的映射信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def create_args_string(num):
args = []
for n in range(num):
args.append('?')
return ', '.join(args)


class ModelMetaClass(type):
"""
映射信息
"""
def __new__(cls, name, bases, attrs):
if name == 'Model':
return type.__new__(cls, name, bases, attrs)

table_name = attrs.get('__table__', None) or name
logging.info('found model: %s (table: %s)' % (name, table_name))

mappings = dict()
fields = []
primary_key = None

for k, v in attrs.items():
if isinstance(v, Field):
logging.info(' found mapping: %s ==> %s' % (k, v))
mappings[k] = v

if v.primary_key:
# 找到主键
if primary_key:
raise Exception('Duplicate primary key for field: %s' % k)
primary_key = k
else:
fields.append(k)

if not primary_key:
raise Exception('Primary key not found.')

for k in mappings.keys():
attrs.pop(k)

escaped_fields = list(map(lambda field: '`%s`' % field, fields))
attrs['__mappings__'] = mappings # 保存属性和列的映射关系
attrs['__table__'] = table_name # 表名
attrs['__primary_key__'] = primary_key # 主键属性名
attrs['__fields__'] = fields # 除主键外的属性名

# 对应数据库的select操作
attrs['__select__'] = 'select `%s`, %s from `%s`' % (
primary_key,
', '.join(escaped_fields),
table_name)

# 对应数据库的insert操作
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (
table_name,
', '.join(escaped_fields),
primary_key,
create_args_string(len(escaped_fields) + 1))

# 对应数据库的update操作
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (
table_name,
','.join(map(lambda field: '`%s`=?' % (mappings.get(field).name or field), fields)),
primary_key)

# 对应数据库的delete操作
attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (table_name, primary_key)
return type.__new__(cls, name, bases, attrs)

ModelMetaclass中,一共做了几件事情:

  1. 排除掉对Model类的修改

  2. 把表名保存到__table__中,如果当前类没有定义__table__属性,则使用类名作为表名

  3. 在当前类(比如User)中查找定义的类的所有属性,如果找到一个Field属性,就把它保存到一个__mappings__dict中,防止主键属性重复;

  4. 从类属性中删除该Field属性,否则,容易造成运行时错误(实例的属性会遮盖类的同名属性)

  5. __select____insert____update__等,对应于数据表的SQL操作

定义Model

当用户定义一个class User(Model)时,Python解释器首先在当前类User的定义中查找metaclass,如果没有找到,就继续在父类Model中查找metaclass,找到了,就使用Model中定义的metaclassModelMetaclass来创建User类,也就是说,metaclass可以隐式地继承到子类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Model(dict, metaclass=ModelMetaClass):

def __init__(self, **kw):
super(Model, self).__init__(**kw)

def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" % key)

def __setattr__(self, key, value):
self[key] = value

def get_value(self, key):
return getattr(self, key, None)

def get_value_or_default(self, key):
"""
获得属性值,如果属性值不存在,则判断是否有默认的属性值
"""
value = getattr(self, key, None)
if value is None:
field = self.__mappings__[key]
if field.default is not None:
value = field.default() if callable(field.default) else field.default
setattr(self, key, value)
return value

async def save(self):
args = list(map(self.get_value_or_default, self.__fields__))
args.append(self.get_value_or_default(self.__primary_key__))
rows = await execute(self.__insert__, args)
if rows != 1:
logging.warning('failed to insert record: affected rows: %s' % rows)

Modeldict继承,所以具备所有dict的功能,同时又实现了特殊方法__getattr__()__setattr__(),因此又可以像引用普通字段那样写:

>>> user['id']
123
>>> user.id
123

Model类定义了一个save的实例方法,可以把一个User实例存入数据库,save方法使用的异步IO来操作数据库:

1
2
user = User(email='test@test.com', passwd='123456', name='test')
await user.save()

get_value_or_default函数获得对应属性的值,如果对应属性的值为None,则取__mappings__里的默认值,如果default属性是callable,则取出调用后的值。如User类的id属性为default=next_id

创建连接池

上面的save方法使用的是异步IO来操作数据库,在Python中,aiomysqlMySQL数据库提供了异步IO的驱动。

下面的代码创建了一个连接池,由全局变量__pool存储:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
async def create_pool(loop, **kw):
"""
创建数据库连接池
"""
logging.info('create database connection pool...')
global __pool
__pool = await aiomysql.create_pool(
host=kw.get('host', 'localhost'),
port=kw.get('port', 3306),
user=kw['user'],
password=kw['password'],
db=kw['db'],
charset=kw.get('charset', 'utf8'),
autocommit=kw.get('autocommit', True),
maxsize=kw.get('maxsize', 10),
minsize=kw.get('minsize', 1),
loop=loop
)

执行INSERT语句

上面的Model类的save的实例方法,可以把一个实例对象保存到数据库中,由excute函数执行insert语句来完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

def log(sql, args=()):
logging.info('SQL: %s' % sql)


async def execute(sql, args, autocommit=True):
log(sql)
async with __pool.get() as conn:
if not autocommit:
await conn.begin()
try:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(sql.replace('?', '%s'), args)
affected = cur.rowcount
if not autocommit:
await conn.commit()
except BaseException as e:
if not autocommit:
await conn.rollback()
raise
return affected

编写Model

创建一个users表对应的Model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import time
import uuid
from orm import Model, StringField, FloatField


def next_id():
return '%015d%s000' % (int(time.time() * 1000), uuid.uuid4().hex)


class User(Model):
"""
定义类的属性到表列字段的映射:
"""
__table__ = 'users'

id = StringField(primary_key=True, default=next_id, ddl='varchar(50)')
email = StringField(ddl='varchar(50)')
passwd = StringField(ddl='varchar(50)')
name = StringField(ddl='varchar(50)')
created_at = FloatField(default=time.time)

初始化数据库表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
drop database if exists test;

create database test;

use test;

grant select, insert, update, delete on test.* to 'heqingliang'@'%' identified by 'heqingliang';

create table users (
`id` varchar(50) not null,
`email` varchar(50) not null,
`passwd` varchar(50) not null,
`name` varchar(50) not null,
`created_at` real not null,
unique key `idx_email` (`email`),
key `idx_created_at` (`created_at`),
primary key (`id`)
) engine=innodb default charset=utf8;

SQL脚本放到MySQL命令行里执行:

mysql -u root -p < schema.sql

可以在数据库中看到,创建了一个test数据库,并创建类一个users表,users表中字段对应于User类的各个属性:

1
2
3
4
5
6
7
8
9
10
11
12
mysql> USE test;
mysql> DESC users;
+------------+-------------+------+-----+---------+-------+
| Field | Type | Null | Key | Default | Extra |
+------------+-------------+------+-----+---------+-------+
| id | varchar(50) | NO | PRI | NULL | |
| email | varchar(50) | NO | UNI | NULL | |
| passwd | varchar(50) | NO | | NULL | |
| name | varchar(50) | NO | | NULL | |
| created_at | double | NO | MUL | NULL | |
+------------+-------------+------+-----+---------+-------+
5 rows in set (0.00 sec)

编写数据访问代码

接下来,就可以真正开始编写代码操作对象了。比如,对于User对象,我们就可以做如下操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import asyncio
import orm
from models import User


async def test(loop):
await orm.create_pool(user='heqingliang', password='heqingliang', db='test', loop=loop)

users = [
User(email='test1@test.com', passwd='123456', name='test1'),
User(email='test2@test.com', passwd='123456', name='test2'),
User(email='test3@test.com', passwd='123456', name='test3'),
User(email='test4@test.com', passwd='123456', name='test2'),
]

for u in users:
await u.save()

if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(test(loop))
loop.run_forever()

执行成功后,可以到如下log

INFO:root:found model: User (table: users)
INFO:root:  found mapping: passwd ==> <StringField, varchar(50):None>
INFO:root:  found mapping: id ==> <StringField, varchar(50):None>
INFO:root:  found mapping: email ==> <StringField, varchar(50):None>
INFO:root:  found mapping: name ==> <StringField, varchar(50):None>
INFO:root:  found mapping: created_at ==> <FloatField, real:None>
INFO:root:create database connection pool...
INFO:root:SQL: insert into `users` (`passwd`, `email`, `name`, `created_at`, `id`) values (?, ?, ?, ?, ?)
INFO:root:SQL: insert into `users` (`passwd`, `email`, `name`, `created_at`, `id`) values (?, ?, ?, ?, ?)
INFO:root:SQL: insert into `users` (`passwd`, `email`, `name`, `created_at`, `id`) values (?, ?, ?, ?, ?)
INFO:root:SQL: insert into `users` (`passwd`, `email`, `name`, `created_at`, `id`) values (?, ?, ?, ?, ?)

查看数据库,可以看到数据插入成功:

1
2
3
4
5
6
7
8
9
10
11
12
mysql> SELECT * FROM users;
+----------------------------------------------------+----------------+--------+-------+------------------+
| id | email | passwd | name | created_at |
+----------------------------------------------------+----------------+--------+-------+------------------+
| 001527820029534e0c48803c8cc4c44a278d1491190a069000 | test1@test.com | 123456 | test1 | 1527820029.53404 |
| 0015278200295492bdd1c9ba409444eb35211844501a663000 | test2@test.com | 123456 | test2 | 1527820029.54944 |
| 0015278200295619ab4469834ec4b3290a460d079c91f3b000 | test3@test.com | 123456 | test3 | 1527820029.56111 |
| 0015278200295635ed7e1100f294efb8b6ef8f57dbd7188000 | test4@test.com | 123456 | test2 | 1527820029.56394 |
+----------------------------------------------------+----------------+--------+-------+------------------+
4 rows in set (0.00 sec)

mysql>

SELECT语句

根据主键查找

Model类中定义一个find的类方法,find函数根据传入的主键执行select函数,返回相应的记录。

1
2
3
4
5
6
7
8
9
10
class Model(dict, metaclass=ModelMetaClass):
@classmethod
async def find(cls, pk):
"""
根据主键查找
"""
rs = await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
if len(rs) == 0:
return None
return cls(**rs[0])
select函数

select函数从数据库查找相应的记录:

1
2
3
4
5
6
7
8
9
10
11
12
async def select(sql, args, size=None):
log(sql, args)
global __pool
async with __pool.get() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(sql.replace('?', '%s'), args or ())
if size:
rs = await cur.fetchmany(size)
else:
rs = await cur.fetchall()
logging.info('rows returned: %s' % len(rs))
return rs

执行下面find方法的调用:

1
2
3
4
5
async def test(loop):
await orm.create_pool(user='heqingliang', password='heqingliang', db='test', loop=loop)

user = await User.find('001527820029534e0c48803c8cc4c44a278d1491190a069000')
print(user)

输出如下:

INFO:root:create database connection pool...
INFO:root:SQL: select `id`, `name`, `created_at`, `email`, `passwd` from `users` where `id`=?
INFO:root:rows returned: 1
{'created_at': 1527820029.53404, 'id': '001527820029534e0c48803c8cc4c44a278d1491190a069000', 'email': 'test1@test.com', 'passwd': '123456', 'name': 'test1'}
根据where、order by、limit查找
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Model(dict, metaclass=ModelMetaClass):

@classmethod
async def find_all(cls, where=None, args=None, **kw):
"""
根据指定的条件查找
"""
sql = [cls.__select__]
if where:
sql.append('where')
sql.append(where)

if args is None:
args = []

order_by = kw.get('order_by', None)
if order_by:
sql.append('order by')
sql.append(order_by)

limit = kw.get('limit', None)
if limit is not None:
sql.append('limit')
if isinstance(limit, int):
sql.append('?')
args.append(limit)
elif isinstance(limit, tuple) and len(limit) == 2:
sql.append('?, ?')
args.extend(limit)
else:
raise ValueError('Invalid limit value: %s' % str(limit))

rs = await select(' '.join(sql), args)
return [cls(**r) for r in rs]

默认查找数据表中的所有记录,如下代码:

1
2
3
4
5
async def test(loop):
await orm.create_pool(user='heqingliang', password='heqingliang', db='test', loop=loop)

users = await User.find_all()
print(users)

输出如下结果:

INFO:root:create database connection pool...
INFO:root:SQL: select `id`, `email`, `passwd`, `created_at`, `name` from `users`
INFO:root:rows returned: 4
[{'id': '001527820029534e0c48803c8cc4c44a278d1491190a069000', 'email': 'test1@test.com', 'passwd': '123456', 'created_at': 1527820029.53404, 'name': 'test1'}, {'id': '0015278200295492bdd1c9ba409444eb35211844501a663000', 'email': 'test2@test.com', 'passwd': '123456', 'created_at': 1527820029.54944, 'name': 'test2'}, {'id': '0015278200295619ab4469834ec4b3290a460d079c91f3b000', 'email': 'test3@test.com', 'passwd': '123456', 'created_at': 1527820029.56111, 'name': 'test3'}, {'id': '0015278200295635ed7e1100f294efb8b6ef8f57dbd7188000', 'email': 'test4@test.com', 'passwd': '123456', 'created_at': 1527820029.56394, 'name': 'test2'}]

根据where条件查找,如下代码:

1
2
3
4
5
async def test(loop):
await orm.create_pool(user='heqingliang', password='heqingliang', db='test', loop=loop)

users = await User.find_all('name=?', ['test2'])
print(users)

输出如下结果:

INFO:root:create database connection pool...
INFO:root:SQL: select `id`, `created_at`, `email`, `passwd`, `name` from `users` where name=?
INFO:root:rows returned: 2
[{'created_at': 1527820029.54944, 'name': 'test2', 'email': 'test2@test.com', 'passwd': '123456', 'id': '0015278200295492bdd1c9ba409444eb35211844501a663000'}, {'created_at': 1527820029.56394, 'name': 'test2', 'email': 'test4@test.com', 'passwd': '123456', 'id': '0015278200295635ed7e1100f294efb8b6ef8f57dbd7188000'}]

使用limit,如下代码:

1
2
3
4
5
async def test(loop):
await orm.create_pool(user='heqingliang', password='heqingliang', db='test', loop=loop)

users = await User.find_all(limit=(1, 2))
print(users)

输出如下结果:

INFO:root:create database connection pool...
INFO:root:SQL: select `id`, `passwd`, `email`, `created_at`, `name` from `users` limit ?, ?
INFO:root:rows returned: 2
[{'passwd': '123456', 'email': 'test2@test.com', 'name': 'test2', 'created_at': 1527820029.54944, 'id': '0015278200295492bdd1c9ba409444eb35211844501a663000'}, {'passwd': '123456', 'email': 'test3@test.com', 'name': 'test3', 'created_at': 1527820029.56111, 'id': '0015278200295619ab4469834ec4b3290a460d079c91f3b000'}]

使用where、limit查找,如下代码:

1
2
3
4
5
async def test(loop):
await orm.create_pool(user='heqingliang', password='heqingliang', db='test', loop=loop)

users = await User.find_all('name=?', ['test2'], limit=1)
print(users)

输出如下结果:

INFO:root:create database connection pool...
INFO:root:SQL: select `id`, `passwd`, `created_at`, `name`, `email` from `users` where name=? limit ?
INFO:root:rows returned: 1
[{'name': 'test2', 'email': 'test2@test.com', 'id': '0015278200295492bdd1c9ba409444eb35211844501a663000', 'passwd': '123456', 'created_at': 1527820029.54944}]

使用order by,如下代码:

1
2
3
4
5
async def test(loop):
await orm.create_pool(user='heqingliang', password='heqingliang', db='test', loop=loop)

users = await User.find_all(order_by='name')
print(users)

输出如下结果:

INFO:root:create database connection pool...
INFO:root:SQL: select `id`, `created_at`, `name`, `email`, `passwd` from `users` order by name
INFO:root:rows returned: 4
[{'passwd': '123456', 'created_at': 1527820029.53404, 'id': '001527820029534e0c48803c8cc4c44a278d1491190a069000', 'name': 'test1', 'email': 'test1@test.com'}, {'passwd': '123456', 'created_at': 1527820029.54944, 'id': '0015278200295492bdd1c9ba409444eb35211844501a663000', 'name': 'test2', 'email': 'test2@test.com'}, {'passwd': '123456', 'created_at': 1527820029.56394, 'id': '0015278200295635ed7e1100f294efb8b6ef8f57dbd7188000', 'name': 'test2', 'email': 'test4@test.com'}, {'passwd': '123456', 'created_at': 1527820029.56111, 'id': '0015278200295619ab4469834ec4b3290a460d079c91f3b000', 'name': 'test3', 'email': 'test3@test.com'}]

其他SQL语句

上面实现了数据库的insertselect操作,还可以自行定义其他SQL的执行操作,如uptatedelete,代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Model(dict, metaclass=ModelMetaClass):

async def update(self):
args = list(map(self.getValue, self.__fields__))
args.append(self.getValue(self.__primary_key__))
rows = await execute(self.__update__, args)
if rows != 1:
logging.warn('failed to update by primary key: affected rows: %s' % rows)

async def remove(self):
args = [self.getValue(self.__primary_key__)]
rows = await execute(self.__delete__, args)
if rows != 1:
logging.warn('failed to remove by primary key: affected rows: %s' % rows)