diff --git a/contrib/pyln-testing/pyln/testing/db.py b/contrib/pyln-testing/pyln/testing/db.py index 7110a8e52..96abc7a5c 100644 --- a/contrib/pyln-testing/pyln/testing/db.py +++ b/contrib/pyln-testing/pyln/testing/db.py @@ -14,7 +14,12 @@ import time from typing import Dict, List, Optional, Union -class Sqlite3Db(object): +class BaseDb(object): + def wipe_db(self): + raise NotImplementedError("wipe_db method must be implemented by the subclass") + + +class Sqlite3Db(BaseDb): def __init__(self, path: str) -> None: self.path = path self.provider = None @@ -32,6 +37,8 @@ class Sqlite3Db(object): db.row_factory = sqlite3.Row c = db.cursor() + # Don't get upset by concurrent writes; wait for up to 5 seconds! + c.execute("PRAGMA busy_timeout = 5000") c.execute(query) rows = c.fetchall() @@ -55,8 +62,12 @@ class Sqlite3Db(object): def stop(self): pass + def wipe_db(self): + if os.path.exists(self.path): + os.remove(self.path) -class PostgresDb(object): + +class PostgresDb(BaseDb): def __init__(self, dbname, port): self.dbname = dbname self.port = port @@ -102,6 +113,12 @@ class PostgresDb(object): cur.execute("DROP DATABASE {};".format(self.dbname)) cur.close() + def wipe_db(self): + cur = self.conn.cursor() + cur.execute(f"DROP DATABASE IF EXISTS {self.dbname};") + cur.execute(f"CREATE DATABASE {self.dbname};") + cur.close() + class SqliteDbProvider(object): def __init__(self, directory: str) -> None: diff --git a/tests/db.py b/tests/db.py index 364beef44..a18fa7427 100644 --- a/tests/db.py +++ b/tests/db.py @@ -12,8 +12,11 @@ import string import subprocess import time +class BaseDb(object): + def wipe_db(self): + raise NotImplementedError("wipe_db method must be implemented by the subclass") -class Sqlite3Db(object): +class Sqlite3Db(BaseDb): def __init__(self, path): self.path = path self.provider = None @@ -50,8 +53,11 @@ class Sqlite3Db(object): c.close() db.close() + def wipe_db(self): + if os.path.exists(self.path): + os.remove(self.path) -class PostgresDb(object): +class PostgresDb(BaseDb): def __init__(self, dbname, port): self.dbname = dbname self.port = port @@ -89,6 +95,12 @@ class PostgresDb(object): cur.execute(query) + def wipe_db(self): + cur = self.conn.cursor() + cur.execute(f"DROP DATABASE IF EXISTS {self.dbname};") + cur.execute(f"CREATE DATABASE {self.dbname};") + cur.close() + class SqliteDbProvider(object): def __init__(self, directory): self.directory = directory