Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions src/lightning_app/components/database/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import os
import sqlite3
import sys
import tempfile
import threading
from typing import List, Optional, Type, Union

import uvicorn
Expand Down Expand Up @@ -38,6 +41,7 @@ def __init__(
self,
models: Union[Type["SQLModel"], List[Type["SQLModel"]]],
db_filename: str = "database.db",
store_interval: int = 10,
debug: bool = False,
) -> None:
"""The Database Component enables to interact with an SQLite database to store some structured information
Expand All @@ -48,6 +52,8 @@ def __init__(
Arguments:
models: A SQLModel or a list of SQLModels table to be added to the database.
db_filename: The name of the SQLite database.
store_interval: Time interval (in seconds) at which the database is periodically synchronized to the Drive.
Note that the database is also always synchronized on exit.
debug: Whether to run the database in debug mode.

Example::
Expand Down Expand Up @@ -132,18 +138,44 @@ class CounterModel(SQLModel, table=True):
"""
super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))
self.db_filename = db_filename
self._root_folder = os.path.dirname(db_filename)
self.debug = debug
self.store_interval = store_interval
self._models = models if isinstance(models, list) else [models]
self.drive = None
self._store_thread = None
self._exit_event = None

def store_database(self):
with tempfile.TemporaryDirectory() as tmpdir:
tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))

source = sqlite3.connect(self.db_filename)
dest = sqlite3.connect(tmp_db_filename)

source.backup(dest)

source.close()
dest.close()

drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
drive.put(os.path.basename(tmp_db_filename))

print("Stored the database to the Drive.")

def periodic_store_database(self, store_interval):
while not self._exit_event.is_set():
self.store_database()
self._exit_event.wait(store_interval)

def run(self, token: Optional[str] = None) -> None:
"""
Arguments:
token: Token used to protect the database access. Ensure you don't expose it through the App State.
"""
self.drive = Drive("lit://database")
if self.drive.list(component_name=self.name):
self.drive.get(self.db_filename)
drive = Drive("lit://database", component_name=self.name, root_folder=self._root_folder)
filenames = drive.list(component_name=self.name)
if self.db_filename in filenames:
drive.get(self.db_filename)
print("Retrieved the database from Drive.")

app = FastAPI()
Expand All @@ -157,6 +189,10 @@ def run(self, token: Optional[str] = None) -> None:

sys.modules["uvicorn.main"].Server = _DatabaseUvicornServer

self._exit_event = threading.Event()
self._store_thread = threading.Thread(target=self.periodic_store_database, args=(self.store_interval,))
self._store_thread.start()

run(app, host=self.host, port=self.port, log_level="error")

def alive(self) -> bool:
Expand All @@ -173,5 +209,5 @@ def db_url(self) -> Optional[str]:
return self.internal_ip

def on_exit(self):
self.drive.put(self.db_filename)
print("Stored the database to the Drive.")
self._exit_event.set()
self.store_database()
64 changes: 51 additions & 13 deletions tests/tests_app/components/database/test_client_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import sys
import tempfile
import time
from pathlib import Path
from time import sleep
from typing import List, Optional
Expand Down Expand Up @@ -123,9 +125,10 @@ def test_work_database_restart():
id = str(uuid4()).split("-")[0]

class Flow(LightningFlow):
def __init__(self, restart=False):
def __init__(self, db_root=".", restart=False):
super().__init__()
self.db = Database(db_filename=id, models=[TestConfig])
self._db_filename = os.path.join(db_root, id)
self.db = Database(db_filename=self._db_filename, models=[TestConfig])
self._client = None
self.restart = restart

Expand All @@ -141,22 +144,57 @@ def run(self):
self._client.insert(TestConfig(name="echo", secrets=[Secret(name="example", value="secret")]))
self._exit()
else:
assert os.path.exists(id)
assert os.path.exists(self._db_filename)
assert len(self._client.select_all()) == 1
self._exit()

app = LightningApp(Flow())
MultiProcessRuntime(app).dispatch()
with tempfile.TemporaryDirectory() as tmpdir:
app = LightningApp(Flow(db_root=tmpdir))
MultiProcessRuntime(app).dispatch()

# Note: Waiting for SIGTERM signal to be handled
sleep(2)

app = LightningApp(Flow(db_root=tmpdir, restart=True))
MultiProcessRuntime(app).dispatch()

# Note: Waiting for SIGTERM signal to be handled
sleep(2)


@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.")
def test_work_database_periodic_store():

id = str(uuid4()).split("-")[0]

class Flow(LightningFlow):
def __init__(self, db_root="."):
super().__init__()
self._db_filename = os.path.join(db_root, id)
self.db = Database(db_filename=self._db_filename, models=[TestConfig], store_interval=1)
self._client = None
self._start_time = None

def run(self):
self.db.run()

# Note: Waiting for SIGTERM signal to be handled
sleep(2)
if not self.db.alive():
return
elif not self._client:
self._client = DatabaseClient(self.db.db_url, None, model=TestConfig)

os.remove(id)
if self._start_time is None:
self._client.insert(TestConfig(name="echo", secrets=[Secret(name="example", value="secret")]))
self._start_time = time.time()
elif time.time() - self._start_time > 2:
assert os.path.exists(self._db_filename)
assert len(self._client.select_all()) == 1
self._exit()

app = LightningApp(Flow(restart=True))
MultiProcessRuntime(app).dispatch()
with tempfile.TemporaryDirectory() as tmpdir:

# Note: Waiting for SIGTERM signal to be handled
sleep(2)
app = LightningApp(Flow(tmpdir))
MultiProcessRuntime(app).dispatch()

os.remove(id)
sleep(2)