Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ extend-select = [
"UP035",
# Missing function argument type-annotation
"ANN001",
"ANN002",
"ANN003",
"ANN201",
"ANN202",
"ANN204",
"ANN205",
"ANN206",
# Using except without specifying an exception type to catch
"BLE001",
]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
realisation_to_srf,
],
)
def test_invocation_of_script(script: Callable):
def test_invocation_of_script(script: Callable) -> None:
"""Basic check that the scripts can be invoked."""
runner = CliRunner()
result = runner.invoke(script.app, ["--help"])
Expand Down
30 changes: 15 additions & 15 deletions tests/test_log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@


@log_utils.log_call()
def foo(a: int, b: int):
def foo(a: int, b: int) -> int:
return a + b


def test_basic_log():
def test_basic_log() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -35,11 +35,11 @@ def test_basic_log():


@log_utils.log_call(exclude_args={"b"})
def foo_less_b(a: int, b: int):
def foo_less_b(a: int, b: int) -> int:
return a + b


def test_excluded_log():
def test_excluded_log() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -59,11 +59,11 @@ def test_excluded_log():


@log_utils.log_call(action_name="FOOBAR")
def bar(a: Any):
def bar(a: Any) -> None:
pass


def test_renamed_bar():
def test_renamed_bar() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -81,11 +81,11 @@ def test_renamed_bar():


@log_utils.log_call(include_result=False)
def baz(a: Any):
def baz(a: Any) -> int:
return 1


def test_no_result():
def test_no_result() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -104,11 +104,11 @@ def test_no_result():


@log_utils.log_call()
def failing_function():
def failing_function() -> None:
raise ValueError("This function should fail!")


def test_failing_function():
def test_failing_function() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -122,7 +122,7 @@ def test_failing_function():
assert "error" in return_log


def test_successful_check_call_log(tmp_path: Path):
def test_successful_check_call_log(tmp_path: Path) -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -140,7 +140,7 @@ def test_successful_check_call_log(tmp_path: Path):
assert "stdout" in completion_message and "test.txt" in completion_message["stdout"]


def test_failing_check_call_log():
def test_failing_check_call_log() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -156,7 +156,7 @@ def test_failing_check_call_log():
)


def test_repeated_logs():
def test_repeated_logs() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -177,12 +177,12 @@ def test_repeated_logs():
)


def _thread_worker(logger_name: str):
def _thread_worker(logger_name: str) -> None:
logger = log_utils.get_logger(logger_name)
logger.info("Threaded log message")


def test_thread_safety():
def test_thread_safety() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand Down
44 changes: 22 additions & 22 deletions tests/test_realisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from workflow import defaults, realisations


def test_bounding_box_example(tmp_path: Path):
def test_bounding_box_example(tmp_path: Path) -> None:
domain_parameters = realisations.DomainParameters(
resolution=0.1, # a 0.1km resolution
domain=bounding_box.BoundingBox.from_centroid_bearing_extents(
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_bounding_box_example(tmp_path: Path):
).all()


def test_domain_parameters_properties():
def test_domain_parameters_properties() -> None:
domain_parameters = realisations.DomainParameters(
resolution=0.1, # a 0.1km resolution
domain=bounding_box.BoundingBox.from_centroid_bearing_extents(
Expand All @@ -78,7 +78,7 @@ def test_domain_parameters_properties():
assert domain_parameters.nz == 400


def test_srf_config_example(tmp_path: Path):
def test_srf_config_example(tmp_path: Path) -> None:
domain_parameters = realisations.DomainParameters(
resolution=0.1, # a 0.1km resolution
domain=bounding_box.BoundingBox.from_centroid_bearing_extents(
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_srf_config_example(tmp_path: Path):
assert realisations.SRFConfig.read_from_realisation(realisation_ffp) == srf_config


def test_bad_domain_parameters(tmp_path: Path):
def test_bad_domain_parameters(tmp_path: Path) -> None:
bad_json = tmp_path / "bad_domain_parameters.json"
bad_json.write_text(
json.dumps(
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_bad_domain_parameters(tmp_path: Path):
realisations.DomainParameters.read_from_realisation(bad_json)


def test_bad_config_key(tmp_path: Path):
def test_bad_config_key(tmp_path: Path) -> None:
bad_json = tmp_path / "bad_domain_parameters.json"
bad_json.write_text(
json.dumps(
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_bad_config_key(tmp_path: Path):
realisations.DomainParameters.read_from_realisation(bad_json)


def test_metadata(tmp_path: Path):
def test_metadata(tmp_path: Path) -> None:
metadata = realisations.RealisationMetadata(
name="consecutive write test",
version="1",
Expand All @@ -220,7 +220,7 @@ def test_metadata(tmp_path: Path):
)


def test_velocity_model(tmp_path: Path):
def test_velocity_model(tmp_path: Path) -> None:
velocity_model = realisations.VelocityModelParameters(
min_vs=1.0,
version="2.06",
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_velocity_model(tmp_path: Path):
)


def test_rupture_prop_config(tmp_path: Path):
def test_rupture_prop_config(tmp_path: Path) -> None:
rup_prop = realisations.RupturePropagationConfig(
rupture_causality_tree={"A": None, "B": "A", "C": "B"},
jump_points={
Expand Down Expand Up @@ -307,7 +307,7 @@ def test_rupture_prop_config(tmp_path: Path):
assert rupture_prop_config.hypocentre.tolist() == [0.0, 0.6]


def test_rupture_prop_properties():
def test_rupture_prop_properties() -> None:
rup_prop = realisations.RupturePropagationConfig(
rupture_causality_tree={"A": None, "B": "A", "C": "B"},
jump_points={
Expand All @@ -325,7 +325,7 @@ def test_rupture_prop_properties():
assert rup_prop.initial_fault == "A"


def test_hf_config(tmp_path: Path):
def test_hf_config(tmp_path: Path) -> None:
test_realisation = tmp_path / "realisation.json"
test_realisation.write_text("{}")
hf_config = realisations.HFConfig.read_from_realisation_or_defaults(
Expand All @@ -344,7 +344,7 @@ def test_hf_config(tmp_path: Path):
)


def test_emod3d(tmp_path: Path):
def test_emod3d(tmp_path: Path) -> None:
test_realisation = tmp_path / "realisation.json"
test_realisation.write_text("{}")
emod3d = realisations.EMOD3DParameters.read_from_realisation_or_defaults(
Expand All @@ -363,7 +363,7 @@ def test_emod3d(tmp_path: Path):
)


def test_broadband_parameters(tmp_path: Path):
def test_broadband_parameters(tmp_path: Path) -> None:
test_realisation = tmp_path / "realisation.json"
broadband_parameters = realisations.BroadbandParameters(
flo=0.5, dt=0.005, fmidbot=0.5, fmin=0.25, site_amp_version="2014"
Expand All @@ -385,14 +385,14 @@ def test_broadband_parameters(tmp_path: Path):
)


def test_logtrail_init_empty():
def test_logtrail_init_empty() -> None:
"""Test LogTrail initialization with no log provided."""
trail = realisations.LogTrail([])
assert trail.log == []
assert trail._config_key == "log_trail"


def test_logtrail_init_with_log_entries():
def test_logtrail_init_with_log_entries() -> None:
"""Test LogTrail initialization with a list of LogEntry objects."""
entry1 = realisations.LogEntry(
utility="util1", args=["a"], version="1", timestamp=datetime.now()
Expand All @@ -404,7 +404,7 @@ def test_logtrail_init_with_log_entries():
assert trail.log == [entry1, entry2]


def test_logtrail_init_with_dicts_post_init():
def test_logtrail_init_with_dicts_post_init() -> None:
"""Test LogTrail post_init conversion of dicts to LogEntry objects."""
log_data = [
{
Expand All @@ -431,7 +431,7 @@ def test_logtrail_init_with_dicts_post_init():
assert trail.log[1].args == ["b"]


def test_logtrail_log_entry_method():
def test_logtrail_log_entry_method() -> None:
"""Test adding an entry using the log_entry method."""
trail = realisations.LogTrail([])
trail.log_entry("my_util", ["--flag", "value"])
Expand All @@ -442,7 +442,7 @@ def test_logtrail_log_entry_method():
assert isinstance(trail.log[0].timestamp, datetime)


def test_logtrail_to_dict():
def test_logtrail_to_dict() -> None:
"""Test converting LogTrail to a dictionary."""
ts = datetime.now()
entry1 = realisations.LogEntry(
Expand Down Expand Up @@ -480,7 +480,7 @@ def test_logtrail_to_dict():

def test_append_log_entry_file_exists_no_key(
tmp_path: Path,
):
) -> None:
"""Test append_log_entry when file exists but lacks the 'log_trail' key."""
realisation_file = tmp_path / "test_realisation.json"
# Create a file with unrelated content
Expand All @@ -504,15 +504,15 @@ def test_append_log_entry_file_exists_no_key(
assert data["log_trail"]["log"][0]["utility"] == "script_name.py"


def test_seeds():
def test_seeds() -> None:
seeds = realisations.Seeds.random_seeds()
assert all(
0 <= seed <= 2 ** (struct.Struct("i").size * 8 - 1) - 1
for seed in seeds.to_dict().values()
)


def test_velocity_model_1d(tmp_path: Path):
def test_velocity_model_1d(tmp_path: Path) -> None:
velocity_model_1d = realisations.VelocityModel1D(
model=pd.DataFrame(
{
Expand Down Expand Up @@ -564,7 +564,7 @@ def test_velocity_model_1d(tmp_path: Path):
)


def test_intensity_measure_calculation_parameters(tmp_path: Path):
def test_intensity_measure_calculation_parameters(tmp_path: Path) -> None:
im_calc_params = realisations.IntensityMeasureCalculationParameters(
ims=[im_calculation.IM("PGA"), im_calculation.IM("PGV")],
valid_periods=np.array([0.1, 0.2, 0.3]),
Expand Down Expand Up @@ -605,5 +605,5 @@ def test_defaults_are_loadable(
tmp_path: Path,
realisation_config: realisations.RealisationConfiguration,
defaults_version: defaults.DefaultsVersion,
):
) -> None:
realisation_config.read_from_defaults(defaults_version)
8 changes: 5 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from workflow import utils


def test_get_available_cores_slurm_cpus_on_node():
def test_get_available_cores_slurm_cpus_on_node() -> None:
with patch.dict(os.environ, {"SLURM_CPUS_ON_NODE": "4"}):
assert utils.get_available_cores() == 4

def get_available_cores_slurm_nprocs():

def get_available_cores_slurm_nprocs() -> None:
with patch.dict(os.environ, {"SLURM_NPROCS": "8"}):
assert utils.get_available_cores() == 8

def get_available_cores_no_slurm():

def get_available_cores_no_slurm() -> None:
with patch("multiprocessing.cpu_count", return_value=16):
assert utils.get_available_cores() == 16
4 changes: 3 additions & 1 deletion workflow/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def log_call(

def decorator(f: Callable) -> Callable: # numpydoc ignore=GL08
@functools.wraps(f)
def wrapper(*args, **kwargs): # numpydoc ignore=GL08
def wrapper(
*args: list[Any], **kwargs: dict[str, Any]
) -> Callable: # numpydoc ignore=GL08
nonlocal exclude_args
signature = inspect.signature(f)
function_id = str(uuid.uuid4())
Expand Down
8 changes: 4 additions & 4 deletions workflow/realisations.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class Seeds(RealisationConfiguration):

@classmethod
def read_from_realisation_or_defaults(
cls, realisation_ffp: Path, *args
cls, realisation_ffp: Path, *args: list[Any]
) -> Self: # *args is to maintain compat with superclass (remove this and see the error in mypy).
"""Read seeds configuration from a realisation file or generate random seeds if not present.

Expand All @@ -256,7 +256,7 @@ def read_from_realisation_or_defaults(
----------
realisation_ffp : Path
The realisation filepath to read from.
*args : Any
*args : list
Ignored arguments.

Returns
Expand Down Expand Up @@ -315,7 +315,7 @@ class SourceConfig(RealisationConfiguration):
source_geometries: dict[str, IsSource]
"""Dictionary mapping source names to their definitions."""

def to_dict(self):
def to_dict(self) -> dict[str, Any]:
"""
Convert the object to a dictionary representation.

Expand Down Expand Up @@ -522,7 +522,7 @@ class VelocityModel1D(RealisationConfiguration):

model: pd.DataFrame

def write_velocity_model(self, velocity_model_path: Path):
def write_velocity_model(self, velocity_model_path: Path) -> None:
"""Write a 1D velocity model to the specified path.

Parameters
Expand Down
Loading
Loading