Skip to content

Commit 30f5c68

Browse files
authored
change CohereHook.get_conn type to method (#50470)
1 parent 21c1e1c commit 30f5c68

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

providers/cohere/src/airflow/providers/cohere/hooks/cohere.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import logging
2121
import warnings
22-
from functools import cached_property
2322
from typing import TYPE_CHECKING, Any
2423

2524
import cohere
@@ -65,6 +64,7 @@ def __init__(
6564
self.timeout = timeout
6665
self.max_retries = max_retries
6766
self.request_options = request_options
67+
self._client: cohere.ClientV2 | None = None
6868

6969
if self.max_retries:
7070
warnings.warn(
@@ -77,20 +77,23 @@ def __init__(
7777
else:
7878
self.request_options.update({"max_retries": self.max_retries})
7979

80-
@cached_property
81-
def get_conn(self) -> cohere.ClientV2: # type: ignore[override]
82-
conn = self.get_connection(self.conn_id)
83-
return cohere.ClientV2(
84-
api_key=conn.password,
85-
timeout=self.timeout,
86-
base_url=conn.host or None,
87-
)
80+
def get_conn(self) -> cohere.ClientV2:
81+
"""Return a new or cached Cohere client instance."""
82+
if self._client is None:
83+
# create a new client instance if there is no existing client
84+
conn = self.get_connection(self.conn_id)
85+
self._client = cohere.ClientV2(
86+
api_key=conn.password,
87+
timeout=self.timeout,
88+
base_url=conn.host or None,
89+
)
90+
return self._client
8891

8992
def create_embeddings(
9093
self, texts: list[str], model: str = "embed-multilingual-v3.0"
9194
) -> EmbedByTypeResponseEmbeddings:
9295
logger.info("Creating embeddings with model: embed-multilingual-v3.0")
93-
response = self.get_conn.embed(
96+
response = self.get_conn().embed(
9497
texts=texts,
9598
model=model,
9699
input_type="search_document",
@@ -117,7 +120,7 @@ def test_connection(
117120
try:
118121
if messages is None:
119122
messages = [UserChatMessageV2(role="user", content="hello world!")]
120-
self.get_conn.chat(model=model, messages=messages)
123+
self.get_conn().chat(model=model, messages=messages)
121124
return True, "Connection successfully established."
122125
except Exception as e:
123126
return False, f"Unexpected error: {str(e)}"

providers/cohere/tests/unit/cohere/hooks/test_cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ def test__get_api_key(self):
4242
patch("cohere.ClientV2") as client,
4343
):
4444
hook = CohereHook(timeout=timeout)
45-
_ = hook.get_conn
45+
_ = hook.get_conn()
4646
client.assert_called_once_with(api_key=api_key, timeout=timeout, base_url=base_url)

0 commit comments

Comments
 (0)