get rid of sql_alchemy
This commit is contained in:
@@ -11,4 +11,3 @@ aiohttp_socks
|
||||
certifi
|
||||
bitstring
|
||||
pycryptodomex>=3.7
|
||||
sqlalchemy>=1.3.0b3
|
||||
|
||||
@@ -36,10 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
|
||||
import binascii
|
||||
import base64
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import not_, or_
|
||||
|
||||
from .sql_db import SqlDB, sql
|
||||
from . import constants
|
||||
@@ -66,7 +62,6 @@ def validate_features(features : int):
|
||||
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
|
||||
raise UnknownEvenFeatureBits()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
FLAG_DISABLE = 1 << 1
|
||||
FLAG_DIRECTION = 1 << 0
|
||||
@@ -193,57 +188,45 @@ class Address(NamedTuple):
|
||||
port: int
|
||||
last_connected_date: int
|
||||
|
||||
create_channel_info = """
|
||||
CREATE TABLE IF NOT EXISTS channel_info (
|
||||
short_channel_id VARCHAR(64),
|
||||
node1_id VARCHAR(66),
|
||||
node2_id VARCHAR(66),
|
||||
capacity_sat INTEGER,
|
||||
PRIMARY KEY(short_channel_id)
|
||||
)"""
|
||||
|
||||
class ChannelInfoBase(Base):
|
||||
__tablename__ = 'channel_info'
|
||||
short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||
capacity_sat = Column(Integer)
|
||||
def to_nametuple(self):
|
||||
return ChannelInfo(
|
||||
short_channel_id=self.short_channel_id,
|
||||
node1_id=self.node1_id,
|
||||
node2_id=self.node2_id,
|
||||
capacity_sat=self.capacity_sat
|
||||
)
|
||||
create_policy = """
|
||||
CREATE TABLE IF NOT EXISTS policy (
|
||||
key VARCHAR(66),
|
||||
cltv_expiry_delta INTEGER NOT NULL,
|
||||
htlc_minimum_msat INTEGER NOT NULL,
|
||||
htlc_maximum_msat INTEGER,
|
||||
fee_base_msat INTEGER NOT NULL,
|
||||
fee_proportional_millionths INTEGER NOT NULL,
|
||||
channel_flags INTEGER NOT NULL,
|
||||
timestamp INTEGER NOT NULL,
|
||||
PRIMARY KEY(key)
|
||||
)"""
|
||||
|
||||
class PolicyBase(Base):
|
||||
__tablename__ = 'policy'
|
||||
key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
cltv_expiry_delta = Column(Integer, nullable=False)
|
||||
htlc_minimum_msat = Column(Integer, nullable=False)
|
||||
htlc_maximum_msat = Column(Integer)
|
||||
fee_base_msat = Column(Integer, nullable=False)
|
||||
fee_proportional_millionths = Column(Integer, nullable=False)
|
||||
channel_flags = Column(Integer, nullable=False)
|
||||
timestamp = Column(Integer, nullable=False)
|
||||
create_address = """
|
||||
CREATE TABLE IF NOT EXISTS address (
|
||||
node_id VARCHAR(66),
|
||||
host STRING(256),
|
||||
port INTEGER NOT NULL,
|
||||
timestamp INTEGER,
|
||||
PRIMARY KEY(node_id, host, port)
|
||||
)"""
|
||||
|
||||
def to_nametuple(self):
|
||||
return Policy(
|
||||
key=self.key,
|
||||
cltv_expiry_delta=self.cltv_expiry_delta,
|
||||
htlc_minimum_msat=self.htlc_minimum_msat,
|
||||
htlc_maximum_msat=self.htlc_maximum_msat,
|
||||
fee_base_msat= self.fee_base_msat,
|
||||
fee_proportional_millionths = self.fee_proportional_millionths,
|
||||
channel_flags=self.channel_flags,
|
||||
timestamp=self.timestamp
|
||||
)
|
||||
|
||||
class NodeInfoBase(Base):
|
||||
__tablename__ = 'node_info'
|
||||
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
features = Column(Integer, nullable=False)
|
||||
timestamp = Column(Integer, nullable=False)
|
||||
alias = Column(String(64), nullable=False)
|
||||
|
||||
class AddressBase(Base):
|
||||
__tablename__ = 'address'
|
||||
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
host = Column(String(256))
|
||||
port = Column(Integer)
|
||||
last_connected_date = Column(Integer(), nullable=True)
|
||||
create_node_info = """
|
||||
CREATE TABLE IF NOT EXISTS node_info (
|
||||
node_id VARCHAR(66),
|
||||
features INTEGER NOT NULL,
|
||||
timestamp INTEGER NOT NULL,
|
||||
alias STRING(64),
|
||||
PRIMARY KEY(node_id)
|
||||
)"""
|
||||
|
||||
|
||||
class ChannelDB(SqlDB):
|
||||
@@ -252,7 +235,7 @@ class ChannelDB(SqlDB):
|
||||
|
||||
def __init__(self, network: 'Network'):
|
||||
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
||||
super().__init__(network, path, Base, commit_interval=100)
|
||||
super().__init__(network, path, commit_interval=100)
|
||||
self.num_nodes = 0
|
||||
self.num_channels = 0
|
||||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
||||
@@ -276,16 +259,7 @@ class ChannelDB(SqlDB):
|
||||
now = int(time.time())
|
||||
node_id = peer.pubkey
|
||||
self._addresses[node_id].add((peer.host, peer.port, now))
|
||||
self.save_address(node_id, peer, now)
|
||||
|
||||
@sql
|
||||
def save_address(self, node_id, peer, now):
|
||||
addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
|
||||
if addr:
|
||||
addr.last_connected_date = now
|
||||
else:
|
||||
addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
|
||||
self.DBSession.add(addr)
|
||||
self.save_node_address(node_id, peer, now)
|
||||
|
||||
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
|
||||
unshuffled = set(self._nodes.keys()) - node_ids
|
||||
@@ -394,17 +368,47 @@ class ChannelDB(SqlDB):
|
||||
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
|
||||
assert len(good) == 1
|
||||
|
||||
def create_database(self):
|
||||
c = self.conn.cursor()
|
||||
c.execute(create_node_info)
|
||||
c.execute(create_address)
|
||||
c.execute(create_policy)
|
||||
c.execute(create_channel_info)
|
||||
self.conn.commit()
|
||||
|
||||
@sql
|
||||
def save_policy(self, policy):
|
||||
self.DBSession.execute(PolicyBase.__table__.insert().values(policy))
|
||||
c = self.conn.cursor()
|
||||
c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy))
|
||||
|
||||
@sql
|
||||
def delete_policy(self, short_channel_id, node_id):
|
||||
self.DBSession.execute(PolicyBase.__table__.delete().values(policy))
|
||||
c = self.conn.cursor()
|
||||
c.execute("""DELETE FROM policy WHERE key=?""", (key,))
|
||||
|
||||
@sql
|
||||
def save_channel(self, channel_info):
|
||||
self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info))
|
||||
c = self.conn.cursor()
|
||||
c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info))
|
||||
|
||||
@sql
|
||||
def save_node(self, node_info):
|
||||
c = self.conn.cursor()
|
||||
c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info))
|
||||
|
||||
@sql
|
||||
def save_node_address(self, node_id, peer, now):
|
||||
c = self.conn.cursor()
|
||||
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now))
|
||||
|
||||
@sql
|
||||
def save_node_addresses(self, node_id, node_addresses):
|
||||
c = self.conn.cursor()
|
||||
for addr in node_addresses:
|
||||
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port))
|
||||
r = c.fetchall()
|
||||
if r == []:
|
||||
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0))
|
||||
|
||||
def verify_channel_update(self, payload):
|
||||
short_channel_id = payload['short_channel_id']
|
||||
@@ -418,7 +422,6 @@ class ChannelDB(SqlDB):
|
||||
msg_payloads = [msg_payloads]
|
||||
old_addr = None
|
||||
new_nodes = {}
|
||||
new_addresses = {}
|
||||
for msg_payload in msg_payloads:
|
||||
try:
|
||||
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
|
||||
@@ -445,17 +448,6 @@ class ChannelDB(SqlDB):
|
||||
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||
self.update_counts()
|
||||
|
||||
@sql
|
||||
def save_node_addresses(self, node_if, node_addresses):
|
||||
for new_addr in node_addresses:
|
||||
old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
|
||||
if not old_addr:
|
||||
self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
|
||||
|
||||
@sql
|
||||
def save_node(self, node_info):
|
||||
self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
|
||||
|
||||
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
||||
short_channel_id: bytes) -> Optional[bytes]:
|
||||
if not start_node_id or not short_channel_id: return None
|
||||
@@ -506,12 +498,18 @@ class ChannelDB(SqlDB):
|
||||
@sql
|
||||
@profiler
|
||||
def load_data(self):
|
||||
for x in self.DBSession.query(AddressBase).all():
|
||||
self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0)))
|
||||
for x in self.DBSession.query(ChannelInfoBase).all():
|
||||
self._channels[x.short_channel_id] = x.to_nametuple()
|
||||
for x in self.DBSession.query(PolicyBase).filter_by().all():
|
||||
p = x.to_nametuple()
|
||||
c = self.conn.cursor()
|
||||
c.execute("""SELECT * FROM address""")
|
||||
for x in c:
|
||||
node_id, host, port, timestamp = x
|
||||
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
|
||||
c.execute("""SELECT * FROM channel_info""")
|
||||
for x in c:
|
||||
ci = ChannelInfo(*x)
|
||||
self._channels[ci.short_channel_id] = ci
|
||||
c.execute("""SELECT * FROM policy""")
|
||||
for x in c:
|
||||
p = Policy(*x)
|
||||
self._policies[(p.start_node, p.short_channel_id)] = p
|
||||
for channel_info in self._channels.values():
|
||||
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
||||
|
||||
@@ -13,12 +13,7 @@ from enum import IntEnum, auto
|
||||
from typing import NamedTuple, Dict
|
||||
import jsonrpclib
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import not_, or_
|
||||
from .sql_db import SqlDB, sql
|
||||
|
||||
from .util import bh2u, bfh, log_exceptions, ignore_exceptions
|
||||
from . import wallet
|
||||
from .storage import WalletStorage
|
||||
@@ -42,80 +37,105 @@ class TxMinedDepth(IntEnum):
|
||||
FREE = auto()
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class SweepTx(Base):
|
||||
__tablename__ = 'sweep_txs'
|
||||
funding_outpoint = Column(String(34), primary_key=True)
|
||||
index = Column(Integer(), primary_key=True)
|
||||
prevout = Column(String(34))
|
||||
tx = Column(String())
|
||||
|
||||
class ChannelInfo(Base):
|
||||
__tablename__ = 'channel_info'
|
||||
outpoint = Column(String(34), primary_key=True)
|
||||
address = Column(String(32))
|
||||
create_sweep_txs="""
|
||||
CREATE TABLE IF NOT EXISTS sweep_txs (
|
||||
funding_outpoint VARCHAR(34) NOT NULL,
|
||||
"index" INTEGER NOT NULL,
|
||||
prevout VARCHAR(34),
|
||||
tx VARCHAR,
|
||||
PRIMARY KEY(funding_outpoint, "index")
|
||||
)"""
|
||||
|
||||
create_channel_info="""
|
||||
CREATE TABLE IF NOT EXISTS channel_info (
|
||||
outpoint VARCHAR(34) NOT NULL,
|
||||
address VARCHAR(32),
|
||||
PRIMARY KEY(outpoint)
|
||||
)"""
|
||||
|
||||
|
||||
class SweepStore(SqlDB):
|
||||
|
||||
def __init__(self, path, network):
|
||||
super().__init__(network, path, Base)
|
||||
super().__init__(network, path)
|
||||
|
||||
def create_database(self):
|
||||
c = self.conn.cursor()
|
||||
c.execute(create_channel_info)
|
||||
c.execute(create_sweep_txs)
|
||||
self.conn.commit()
|
||||
|
||||
@sql
|
||||
def get_sweep_tx(self, funding_outpoint, prevout):
|
||||
return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prevout==prevout).all()]
|
||||
c = self.conn.cursor()
|
||||
c.execute("SELECT tx FROM sweep_txs WHERE funding_outpoint=? AND prevout=?", (funding_outpoint, prevout))
|
||||
return [Transaction(bh2u(r[0])) for r in c.fetchall()]
|
||||
|
||||
@sql
|
||||
def get_tx_by_index(self, funding_outpoint, index):
|
||||
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
|
||||
return str(r.prevout), bh2u(r.tx)
|
||||
c = self.conn.cursor()
|
||||
c.execute("""SELECT prevout, tx FROM sweep_txs WHERE funding_outpoint=? AND "index"=?""", (funding_outpoint, index))
|
||||
r = c.fetchone()[0]
|
||||
return str(r[0]), bh2u(r[1])
|
||||
|
||||
@sql
|
||||
def list_sweep_tx(self):
|
||||
return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all())
|
||||
c = self.conn.cursor()
|
||||
c.execute("SELECT funding_outpoint FROM sweep_txs")
|
||||
return set([r[0] for r in c.fetchall()])
|
||||
|
||||
@sql
|
||||
def add_sweep_tx(self, funding_outpoint, prevout, tx):
|
||||
n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
|
||||
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prevout=prevout, tx=bfh(tx)))
|
||||
self.DBSession.commit()
|
||||
c = self.conn.cursor()
|
||||
c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
|
||||
n = int(c.fetchone()[0])
|
||||
c.execute("""INSERT INTO sweep_txs (funding_outpoint, "index", prevout, tx) VALUES (?,?,?,?)""", (funding_outpoint, n, prevout, bfh(str(tx))))
|
||||
self.conn.commit()
|
||||
|
||||
@sql
|
||||
def get_num_tx(self, funding_outpoint):
|
||||
return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count())
|
||||
c = self.conn.cursor()
|
||||
c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
|
||||
return int(c.fetchone()[0])
|
||||
|
||||
@sql
|
||||
def remove_sweep_tx(self, funding_outpoint):
|
||||
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
|
||||
for x in r:
|
||||
self.DBSession.delete(x)
|
||||
self.DBSession.commit()
|
||||
c = self.conn.cursor()
|
||||
c.execute("DELETE FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
|
||||
self.conn.commit()
|
||||
|
||||
@sql
|
||||
def add_channel(self, outpoint, address):
|
||||
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
|
||||
self.DBSession.commit()
|
||||
c = self.conn.cursor()
|
||||
c.execute("INSERT INTO channel_info (address, outpoint) VALUES (?,?)", (address, outpoint))
|
||||
self.conn.commit()
|
||||
|
||||
@sql
|
||||
def remove_channel(self, outpoint):
|
||||
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
|
||||
self.DBSession.delete(v)
|
||||
self.DBSession.commit()
|
||||
c = self.conn.cursor()
|
||||
c.execute("DELETE FROM channel_info WHERE outpoint=?", (outpoint,))
|
||||
self.conn.commit()
|
||||
|
||||
@sql
|
||||
def has_channel(self, outpoint):
|
||||
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none())
|
||||
c = self.conn.cursor()
|
||||
c.execute("SELECT * FROM channel_info WHERE outpoint=?", (outpoint,))
|
||||
r = c.fetchone()
|
||||
return r is not None
|
||||
|
||||
@sql
|
||||
def get_address(self, outpoint):
|
||||
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
|
||||
return str(r.address) if r else None
|
||||
c = self.conn.cursor()
|
||||
c.execute("SELECT address FROM channel_info WHERE outpoint=?", (outpoint,))
|
||||
r = c.fetchone()
|
||||
return r[0] if r else None
|
||||
|
||||
@sql
|
||||
def list_channel_info(self):
|
||||
return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()]
|
||||
c = self.conn.cursor()
|
||||
c.execute("SELECT address, outpoint FROM channel_info")
|
||||
return [(r[0], r[1]) for r in c.fetchall()]
|
||||
|
||||
|
||||
|
||||
class LNWatcher(AddressSynchronizer):
|
||||
|
||||
@@ -3,18 +3,11 @@ import concurrent
|
||||
import queue
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import sqlite3
|
||||
|
||||
from .logging import Logger
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/26971050/sqlalchemy-sqlite-too-many-sql-variables
|
||||
SQLITE_LIMIT_VARIABLE_NUMBER = 999
|
||||
|
||||
|
||||
def sql(func):
|
||||
"""wrapper for sql methods"""
|
||||
def wrapper(self, *args, **kwargs):
|
||||
@@ -26,9 +19,8 @@ def sql(func):
|
||||
|
||||
class SqlDB(Logger):
|
||||
|
||||
def __init__(self, network, path, base, commit_interval=None):
|
||||
def __init__(self, network, path, commit_interval=None):
|
||||
Logger.__init__(self)
|
||||
self.base = base
|
||||
self.network = network
|
||||
self.path = path
|
||||
self.commit_interval = commit_interval
|
||||
@@ -37,13 +29,10 @@ class SqlDB(Logger):
|
||||
self.sql_thread.start()
|
||||
|
||||
def run_sql(self):
|
||||
#return
|
||||
self.logger.info("SQL thread started")
|
||||
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
|
||||
DBSession = sessionmaker(bind=engine, autoflush=False)
|
||||
if not os.path.exists(self.path):
|
||||
self.base.metadata.create_all(engine)
|
||||
self.DBSession = DBSession()
|
||||
self.conn = sqlite3.connect(self.path)
|
||||
self.logger.info("Creating database")
|
||||
self.create_database()
|
||||
i = 0
|
||||
while self.network.asyncio_loop.is_running():
|
||||
try:
|
||||
@@ -62,7 +51,7 @@ class SqlDB(Logger):
|
||||
if self.commit_interval:
|
||||
i = (i + 1) % self.commit_interval
|
||||
if i == 0:
|
||||
self.DBSession.commit()
|
||||
self.conn.commit()
|
||||
# write
|
||||
self.DBSession.commit()
|
||||
self.conn.commit()
|
||||
self.logger.info("SQL thread terminated")
|
||||
|
||||
Reference in New Issue
Block a user