Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions packages/models-library/src/models_library/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from common_library.basic_types import DEFAULT_FACTORY
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
Expand Down Expand Up @@ -85,13 +86,17 @@ class BaseProjectModel(BaseModel):
]
description: Annotated[
str,
BeforeValidator(none_to_empty_str_pre_validator),
Field(
description="longer one-line description about the project",
examples=["Dabbling in temporal transitions ..."],
),
]
thumbnail: Annotated[
HttpUrl | None,
BeforeValidator(
empty_str_to_none_pre_validator,
),
Field(
description="url of the project thumbnail",
examples=["https://placeimg.com/171/96/tech/grayscale/?0.jpg"],
Expand All @@ -104,15 +109,6 @@ class BaseProjectModel(BaseModel):
# Pipeline of nodes (SEE projects_nodes.py)
workbench: Annotated[NodesDict, Field(description="Project's pipeline")]

# validators
_empty_thumbnail_is_none = field_validator("thumbnail", mode="before")(
empty_str_to_none_pre_validator
)

_none_description_is_empty = field_validator("description", mode="before")(
none_to_empty_str_pre_validator
)


class ProjectAtDB(BaseProjectModel):
# Model used to READ from database
Expand Down
15 changes: 11 additions & 4 deletions packages/models-library/src/models_library/projects_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,22 +302,28 @@ class Node(BaseModel):
Field(default_factory=dict, description="values of output properties"),
] = DEFAULT_FACTORY

output_node: Annotated[bool | None, Field(deprecated=True, alias="outputNode")] = (
None # <-- (DEPRECATED) Can be removed
)
output_node: Annotated[
bool | None,
Field(
deprecated=True,
alias="outputNode",
),
] = None # <-- (DEPRECATED) Can be removed

output_nodes: Annotated[ # <-- (DEPRECATED) Can be removed
list[NodeID] | None,
Field(
description="Used in group-nodes. Node IDs of those connected to the output",
alias="outputNodes",
deprecated=True,
),
] = None

parent: Annotated[ # <-- (DEPRECATED) Can be removed
NodeID | None,
Field(
description="Parent's (group-nodes') node ID s. Used to group",
deprecated=True,
),
] = None

Expand Down Expand Up @@ -453,7 +459,8 @@ def _update_json_schema_extra(schema: JsonDict) -> None:

model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
validate_by_name=True,
validate_by_alias=True,
json_schema_extra=_update_json_schema_extra,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
UUID_RE,
)

NodeID = UUID

UUIDStr: TypeAlias = Annotated[str, StringConstraints(pattern=UUID_RE)]

NodeID: TypeAlias = UUID
NodeIDStr: TypeAlias = UUIDStr

LocationID: TypeAlias = int
Expand Down
21 changes: 19 additions & 2 deletions packages/models-library/tests/test_services_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pytest
from models_library.projects import ProjectID
from models_library.projects_nodes import NodeID
from models_library.services_types import ServiceRunID
from models_library.services_types import ServiceKey, ServiceRunID, ServiceVersion
from models_library.users import UserID
from pydantic import PositiveInt
from pydantic import PositiveInt, TypeAdapter
from pytest_simcore.helpers.faker_factories import (
random_service_key,
random_service_version,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -38,3 +42,16 @@ def test_get_resource_tracking_run_id_for_dynamic():
assert isinstance(
ServiceRunID.get_resource_tracking_run_id_for_dynamic(), ServiceRunID
)


@pytest.mark.parametrize(
"service_key, service_version",
[(random_service_key(), random_service_version()) for _ in range(10)],
)
def test_faker_factory_service_key_and_version_are_in_sync(
service_key: ServiceKey, service_version: ServiceVersion
):
TypeAdapter(ServiceKey).validate_python(service_key)
TypeAdapter(ServiceVersion).validate_python(service_version)

assert service_key.startswith("simcore/services/")
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.ext.asyncio import AsyncConnection

from .models.projects import projects
from .utils_repos import transaction_context
from .utils_repos import pass_or_acquire_connection, transaction_context


class DBBaseProjectError(OsparcErrorMixin, Exception):
Expand All @@ -22,6 +22,23 @@ class ProjectsRepo:
def __init__(self, engine):
self.engine = engine

async def exists(
self,
project_uuid: uuid.UUID,
*,
connection: AsyncConnection | None = None,
) -> bool:
async with pass_or_acquire_connection(self.engine, connection) as conn:
return (
await conn.scalar(
sa.select(1)
.select_from(projects)
.where(projects.c.uuid == f"{project_uuid}")
.limit(1)
)
is not None
)

async def get_project_last_change_date(
self,
project_uuid: uuid.UUID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
from typing import Annotated, Any

import asyncpg.exceptions # type: ignore[import-untyped]
import sqlalchemy
import sqlalchemy.exc
from common_library.async_tools import maybe_await
from common_library.basic_types import DEFAULT_FACTORY
from common_library.errors_classes import OsparcErrorMixin
from pydantic import BaseModel, ConfigDict, Field
from simcore_postgres_database.utils_aiosqlalchemy import map_db_exception
from sqlalchemy.dialects.postgresql import insert as pg_insert

from ._protocols import DBConnection
from .aiopg_errors import ForeignKeyViolation, UniqueViolation
from .models.projects_node_to_pricing_unit import projects_node_to_pricing_unit
from .models.projects_nodes import projects_nodes
from .utils_aiosqlalchemy import map_db_exception


#
Expand Down Expand Up @@ -67,11 +66,32 @@ class ProjectNodeCreate(BaseModel):
parent: str | None = None
boot_options: dict[str, Any] | None = None

model_config = ConfigDict(frozen=True)

@classmethod
def get_field_names(cls, *, exclude: set[str]) -> set[str]:
return cls.model_fields.keys() - exclude

model_config = ConfigDict(frozen=True)
def _get_node_exclude_fields(self) -> set[str]: # pylint: disable=no-self-use
"""Get the base fields to exclude when converting to Node model."""
return {"node_id", "required_resources"}

def model_dump_as_node(self) -> dict[str, Any]:
"""Converts a ProjectNode from the database to a Node model for the API.

Usage:
Node.model_validate(project_node_create.model_dump_as_node(), by_name=True)

Handles field mapping and excludes database-specific fields that are not
part of the Node model.
"""
return self.model_dump(
# NOTE: this setup ensures using the defaults provided in Node model when the db does not
# provide them, e.g. `state`
exclude=self._get_node_exclude_fields(),
exclude_none=True,
exclude_unset=True,
)


class ProjectNode(ProjectNodeCreate):
Expand All @@ -80,6 +100,13 @@ class ProjectNode(ProjectNodeCreate):

model_config = ConfigDict(from_attributes=True)

def _get_node_exclude_fields(self) -> set[str]: # pylint: disable=no-self-use
"""Get the fields to exclude when converting to Node model, including DB-specific fields."""
base_excludes = super()._get_node_exclude_fields()
return base_excludes | {"created", "modified"}

# NOTE: model_dump_as_node is inherited from ProjectNodeCreate and uses the overridden _get_node_exclude_fields


@dataclass(frozen=True, kw_only=True)
class ProjectNodesRepo:
Expand All @@ -103,17 +130,26 @@ async def add(
"""
if not nodes:
return []

values = [
{
"project_uuid": f"{self.project_uuid}",
**node.model_dump(mode="json"),
}
for node in nodes
]

# Check
field_names = set(values[0].keys())
for v in values:
if set(v.keys()) != field_names:
msg = f"All rows in batch-insert MUST have same keys. Inconsistent keys in node values: {set(v.keys())} != {field_names}"
raise ValueError(msg)

# statement
insert_stmt = (
projects_nodes.insert()
.values(
[
{
"project_uuid": f"{self.project_uuid}",
**node.model_dump(exclude_unset=True, mode="json"),
}
for node in nodes
]
)
.values(values)
.returning(
*[
c
Expand All @@ -129,14 +165,17 @@ async def add(
rows = await maybe_await(result.fetchall())
assert isinstance(rows, list) # nosec
return [ProjectNode.model_validate(r) for r in rows]

except ForeignKeyViolation as exc:
# this happens when the project does not exist, as we first check the node exists
raise ProjectNodesProjectNotFoundError(
project_uuid=self.project_uuid
) from exc

except UniqueViolation as exc:
# this happens if the node already exists on creation
raise ProjectNodesDuplicateNodeError from exc

except sqlalchemy.exc.IntegrityError as exc:
raise map_db_exception(
exc,
Expand Down
36 changes: 36 additions & 0 deletions packages/postgres-database/tests/test_utils_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,42 @@ async def test_get_project_trashed_column_can_be_converted_to_datetime(
assert trashed == expected


@pytest.mark.parametrize("with_explicit_connection", [True, False])
async def test_projects_repo_exists_with_existing_project(
asyncpg_engine: AsyncEngine,
registered_project: dict,
with_explicit_connection: bool,
):
projects_repo = ProjectsRepo(asyncpg_engine)
project_uuid = registered_project["uuid"]

if with_explicit_connection:
async with transaction_context(asyncpg_engine) as conn:
exists = await projects_repo.exists(project_uuid, connection=conn)
else:
exists = await projects_repo.exists(project_uuid)

assert exists is True


@pytest.mark.parametrize("with_explicit_connection", [True, False])
async def test_projects_repo_exists_with_non_existing_project(
asyncpg_engine: AsyncEngine,
faker: Faker,
with_explicit_connection: bool,
):
projects_repo = ProjectsRepo(asyncpg_engine)
non_existing_uuid = faker.uuid4()

if with_explicit_connection:
async with transaction_context(asyncpg_engine) as conn:
exists = await projects_repo.exists(non_existing_uuid, connection=conn)
else:
exists = await projects_repo.exists(non_existing_uuid)

assert exists is False


async def test_get_project_last_change_date(
asyncpg_engine: AsyncEngine, registered_project: dict, faker: Faker
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,19 @@ async def test_not_implemented_use_cases(
parent_project_uuid=missing_parent_project["uuid"],
parent_node_id=missing_parent_node.node_id,
)


async def test_model_dump_as_node(
connection: SAConnection,
create_fake_user: Callable[..., Awaitable[RowProxy]],
create_fake_project: Callable[..., Awaitable[RowProxy]],
create_fake_projects_node: Callable[[uuid.UUID], Awaitable[ProjectNode]],
):
user: RowProxy = await create_fake_user(connection)
project: RowProxy = await create_fake_project(connection, user, hidden=True)
project_node = await create_fake_projects_node(project["uuid"])

node_data = project_node.model_dump_as_node()
assert isinstance(node_data, dict)
assert node_data["key"] == project_node.key
assert "node_id" not in node_data, "this is only in ProjectNode but not in Node!"
Loading
Loading