Skip to content

Commit c3bcba0

Browse files
authored
Added get_dependency_min_version_spec function (#62)
1 parent 4de6e2b commit c3bcba0

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

src/lightning_utilities/core/imports.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import importlib
55
import operator
66
from functools import lru_cache
7+
from importlib import metadata
78
from importlib.util import find_spec
89
from typing import Callable
910

1011
import pkg_resources
12+
from packaging.requirements import Requirement
1113
from packaging.version import Version
1214

1315

@@ -107,3 +109,21 @@ def __str__(self) -> str:
107109

108110
def __repr__(self) -> str:
109111
return self.__str__()
112+
113+
114+
def get_dependency_min_version_spec(package_name: str, dependency_name: str) -> str:
115+
"""Returns the minimum version specifier of a dependency of a package.
116+
117+
>>> get_dependency_min_version_spec("pytorch-lightning", "jsonargparse")
118+
'>=4.12.0'
119+
"""
120+
dependencies = metadata.requires(package_name) or []
121+
for dep in dependencies:
122+
dependency = Requirement(dep)
123+
if dependency.name == dependency_name:
124+
spec = [str(s) for s in dependency.specifier if str(s)[0] == ">"]
125+
return spec[0] if spec else ""
126+
raise ValueError(
127+
"This is an internal error. Please file a GitHub issue with the error message. Dependency "
128+
f"{dependency_name!r} not found in package {package_name!r}."
129+
)

tests/unittests/core/test_imports.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
import operator
2+
import re
3+
from importlib.metadata import PackageNotFoundError
24

3-
from lightning_utilities.core.imports import compare_version, module_available, RequirementCache
5+
import pytest
6+
7+
from lightning_utilities.core.imports import (
8+
compare_version,
9+
get_dependency_min_version_spec,
10+
module_available,
11+
RequirementCache,
12+
)
413

514

615
def test_module_exists():
@@ -38,3 +47,14 @@ def test_requirement_cache():
3847
assert RequirementCache(f"pytest>={pytest.__version__}")
3948
assert not RequirementCache(f"pytest<{pytest.__version__}")
4049
assert "pip install -U '-'" in str(RequirementCache("-"))
50+
51+
52+
def test_get_dependency_min_version_spec():
53+
attrs_min_version_spec = get_dependency_min_version_spec("pytest", "attrs")
54+
assert re.match(r"^>=[\d.]+$", attrs_min_version_spec)
55+
56+
with pytest.raises(ValueError, match="'invalid' not found in package 'pytest'"):
57+
get_dependency_min_version_spec("pytest", "invalid")
58+
59+
with pytest.raises(PackageNotFoundError, match="invalid"):
60+
get_dependency_min_version_spec("invalid", "invalid")

0 commit comments

Comments
 (0)