# -*- coding: utf-8 -*-
import math
import datetime
import numpy as np
import pandas as pd
import sqlalchemy
'''
engine: SQLAlchemy Engine
buffer_size: 缓存条目数,当缓存满时自动flush
update_on_duplicate: 当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。
'''
def create_upsert_handler(engine, buffer_size=5000, update_on_duplicate=True):
if engine.dialect.name.lower().find("mysql") != -1:
return MySQLUpsertHandler(engine, buffer_size, update_on_duplicate)
elif engine.dialect.name.lower().find("postgresql") != -1:
return PSQLUpsertHandler(engine, buffer_size, update_on_duplicate)
else:
print(f"没有为{engine.dialect.name}实现特殊的Upsert,使用默认版本,请确认可以正常工作,建议特化一个专门版本")
return UpsertHandlerBase(engine, buffer_size, update_on_duplicate)
def is_duplicate_key(e):
for T in UpsertHandlerBase.__subclasses__():
if T.is_duplicate_key(e):
return True
return UpsertHandlerBase.is_duplicate_key(e)
'''
class UpsertHandler:
# 传入的engine类型应该和使用的UpsertHandler支持的数据库类型相匹配
# buffer_size表示插入或更新数据缓存到多少才flush(即向数据库插入或更新),None表示在析构时flush,0表示不缓存
# update_on_duplicate当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
pass
# tablename为数据库表名
# pk为主键的元组,可以不是真正的表主键,但是可以用来判重决定insert还是update,例如('exchange_id', 'trade_id')
# data为单条数据,dict的形式,例如{'exchange_id': 'DCE', 'trade_id': ' 1', 'price': 1.2, 'volume': 1}
def upsert(self, tablename, pk, data):
pass
# 立即把缓冲器的数据推到数据库,会在buffer_size满了或者析构时自动调用,也可以手动调用
def flush(self):
pass
'''
class UpsertHandlerBase:
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
self.engine = engine
self.tablename2pk = {}
self.tablename2datas = {}
self.buffer_size = buffer_size
self.update_on_duplicate = update_on_duplicate
def __del__(self):
self.flush()
def flush(self):
for (tablename, pk) in self.tablename2pk.items():
datas = self.tablename2datas[tablename]
if len(datas) > 0:
with self.engine.connect() as conn:
self._flush(conn, tablename, pk, datas)
self.tablename2datas[tablename] = []
def _flush(self, conn, tablename, pk, datas):
columns = datas[0].keys()
sql = f"""INSERT INTO {tablename}({", ".join(columns)}) VALUES\n"""
for i, data in enumerate(datas):
if i != len(datas) - 1:
sql += f""" ({self._format_values(data.values())}),\n"""
else:
sql += f""" ({self._format_values(data.values())});\n"""
try:
conn.execute(sql)
except sqlalchemy.exc.IntegrityError as e:
if self.is_duplicate_key(e):
# 插入遇到重复KEY
if len(datas) <= 500:
for data in datas:
self.upsert_one(conn, tablename, pk, data)
else:
l = len(datas)
p = int(l // 2)
self._flush(conn, tablename, pk, datas[:p])
self._flush(conn, tablename, pk, datas[p:])
else:
raise e
def upsert_one(self, conn, tablename, pk, data):
r = None
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
r = conn.execute(f"UPDATE {tablename} SET {update_str} WHERE {self._format_update_conditions(pk, data)}")
if not r or r.rowcount == 0:
try:
r = conn.execute(f"INSERT INTO {tablename}({', '.join(data.keys())}) VALUES({self._format_values(data.values())})")
except sqlalchemy.exc.IntegrityError as e:
if self.is_duplicate_key(e):
pass
else:
raise e
@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
return (str(e.orig).lower().find("duplicate") != -1)
def _isinf(self, x):
return x>=9223372036854775807 or x<=-9223372036854775808
def _format_value(self, v):
if v is None:
return "null"
elif type(v) == float:
if math.isnan(v) or math.isinf(v) or self._isinf(v):
return "null"
else:
return f"{v}"
elif type(v) == int:
if self._isinf(v):
return "null"
else:
return f"{v}"
elif type(v) == datetime.datetime:
return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
elif type(v) == datetime.date:
return "'"+v.strftime("%Y-%m-%d")+"'"
elif type(v) == pd.Timestamp:
return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
elif type(v) == str:
return f"'{v}'"
else:
return f"'{v}'"
def _format_values(self, data):
s = ''
for i, e in enumerate(data):
s += self._format_value(e)
s += ', '
return s[:-2]
def _format_update_values(self, pk, data):
s = ''
for i, (k, v) in enumerate(data.items()):
if k not in pk:
s += f"{k}={self._format_value(v)}, "
return s[:-2]
def _format_update_conditions(self, pk, data):
s = ''
for i, (k, v) in enumerate(data.items()):
if k in pk:
s += f"{k}={self._format_value(v)} and "
return s[:-4]
def upsert(self, tablename, pk, data):
if self.buffer_size is not None and self.buffer_size == 0:
with self.engine.connect() as conn:
self.upsert_one(conn, tablename, pk, data)
else:
if pk:
self.tablename2pk[tablename] = pk
if tablename not in self.tablename2datas:
self.tablename2datas[tablename] = []
self.tablename2datas[tablename].append(data)
if self.buffer_size is not None and len(self.tablename2datas[tablename]) >= self.buffer_size:
with self.engine.connect() as conn:
self._flush(conn, tablename, self.tablename2pk[tablename], self.tablename2datas[tablename])
self.tablename2datas[tablename] = []
def upsert_dataframe(self, tablename, pk, df):
if len(df) <= 2000:
with self.engine.connect() as conn:
for index, row in df.iterrows():
self.upsert_one(conn, tablename, pk, row.to_dict())
else:
l = len(df)
p = int(l // 2)
self.upsert_dataframe(tablename, pk, df[:p])
self.upsert_dataframe(tablename, pk, df[p:])
class MySQLUpsertHandler(UpsertHandlerBase):
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
super().__init__(engine, buffer_size, update_on_duplicate)
def __del__(self):
super().__del__()
@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
if len(e.orig.args) > 1 and str(e.orig.args[1]).startswith("Duplicate entry"):
return True
return False
def upsert_one(self, conn, tablename, pk, data):
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
duplicate_do_str = f"UPDATE {update_str}"
else:
duplicate_do_str = f"UPDATE {pk[0]}=VALUES({pk[0]})" # 等价于do nothing
sql = f"""INSERT INTO {tablename}({", ".join(data.keys())}) VALUES
({self._format_values(data.values())})
ON DUPLICATE KEY
{duplicate_do_str}\n"""
conn.execute(sql)
class PSQLUpsertHandler(UpsertHandlerBase):
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
super().__init__(engine, buffer_size, update_on_duplicate)
def __del__(self):
super().__del__()
@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
if str(e.orig).startswith("duplicate key"):
return True
return False
def upsert_one(self, conn, tablename, pk, data):
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
duplicate_do_str = f"do update set {update_str}"
else:
duplicate_do_str = f"do nothing"
sql = f"""INSERT INTO {tablename}({", ".join(data.keys())}) VALUES
({self._format_values(data.values())})
on conflict ({", ".join(pk)})
{duplicate_do_str}\n"""
conn.execute(sql)