7
7
from datetime import datetime , timezone
8
8
from types import TracebackType
9
9
from typing import Any , AsyncIterator , Iterable , Optional , Sequence , cast
10
- from ulid import ULID
11
10
12
11
from langgraph .store .base import (
13
12
BaseStore ,
29
28
from redisvl .query import FilterQuery , VectorQuery
30
29
from redisvl .redis .connection import RedisConnectionFactory
31
30
from redisvl .utils .token_escaper import TokenEscaper
31
+ from ulid import ULID
32
32
33
33
from langgraph .store .redis .base import (
34
34
REDIS_KEY_SEPARATOR ,
@@ -57,14 +57,15 @@ class AsyncRedisStore(
57
57
58
58
store_index : AsyncSearchIndex
59
59
vector_index : AsyncSearchIndex
60
- _owns_client : bool
60
+ _owns_its_client : bool
61
61
62
62
def __init__ (
63
63
self ,
64
64
redis_url : Optional [str ] = None ,
65
65
* ,
66
66
redis_client : Optional [AsyncRedis ] = None ,
67
67
index : Optional [IndexConfig ] = None ,
68
+ connection_args : Optional [dict [str , Any ]] = None ,
68
69
) -> None :
69
70
"""Initialize store with Redis connection and optional index config."""
70
71
if redis_url is None and redis_client is None :
@@ -94,10 +95,16 @@ def __init__(
94
95
]
95
96
96
97
# Configure client
97
- self .configure_client (redis_url = redis_url , redis_client = redis_client )
98
+ self .configure_client (
99
+ redis_url = redis_url ,
100
+ redis_client = redis_client ,
101
+ connection_args = connection_args or {},
102
+ )
98
103
99
104
# Create store index
100
- self .store_index = AsyncSearchIndex .from_dict (self .SCHEMAS [0 ])
105
+ self .store_index = AsyncSearchIndex .from_dict (
106
+ self .SCHEMAS [0 ], redis_client = self ._redis
107
+ )
101
108
102
109
# Configure vector index if needed
103
110
if self .index_config :
@@ -131,7 +138,9 @@ def __init__(
131
138
vector_field ["attrs" ].update (self .index_config ["ann_index_config" ])
132
139
133
140
try :
134
- self .vector_index = AsyncSearchIndex .from_dict (vector_schema )
141
+ self .vector_index = AsyncSearchIndex .from_dict (
142
+ vector_schema , redis_client = self ._redis
143
+ )
135
144
except Exception as e :
136
145
raise ValueError (
137
146
f"Failed to create vector index with schema: { vector_schema } . Error: { str (e )} "
@@ -145,11 +154,12 @@ def configure_client(
145
154
self ,
146
155
redis_url : Optional [str ] = None ,
147
156
redis_client : Optional [AsyncRedis ] = None ,
157
+ connection_args : Optional [dict [str , Any ]] = None ,
148
158
) -> None :
149
159
"""Configure the Redis client."""
150
- self ._owns_client = redis_client is None
160
+ self ._owns_its_client = redis_client is None
151
161
self ._redis = redis_client or RedisConnectionFactory .get_async_redis_connection (
152
- redis_url
162
+ redis_url , ** connection_args
153
163
)
154
164
155
165
async def setup (self ) -> None :
@@ -160,11 +170,6 @@ async def setup(self) -> None:
160
170
self .index_config .get ("embed" ),
161
171
)
162
172
163
- # Now connect Redis client to indices
164
- await self .store_index .set_client (self ._redis )
165
- if self .index_config :
166
- await self .vector_index .set_client (self ._redis )
167
-
168
173
# Create indices in Redis
169
174
await self .store_index .create (overwrite = False )
170
175
if self .index_config :
@@ -188,9 +193,13 @@ async def from_conn_string(
188
193
189
194
def create_indexes (self ) -> None :
190
195
"""Create async indices."""
191
- self .store_index = AsyncSearchIndex .from_dict (self .SCHEMAS [0 ])
196
+ self .store_index = AsyncSearchIndex .from_dict (
197
+ self .SCHEMAS [0 ], redis_client = self ._redis
198
+ )
192
199
if self .index_config :
193
- self .vector_index = AsyncSearchIndex .from_dict (self .SCHEMAS [1 ])
200
+ self .vector_index = AsyncSearchIndex .from_dict (
201
+ self .SCHEMAS [1 ], redis_client = self ._redis
202
+ )
194
203
195
204
async def __aenter__ (self ) -> AsyncRedisStore :
196
205
"""Async context manager enter."""
@@ -210,7 +219,7 @@ async def __aexit__(
210
219
except asyncio .CancelledError :
211
220
pass
212
221
213
- if self ._owns_client :
222
+ if self ._owns_its_client :
214
223
await self ._redis .aclose () # type: ignore[attr-defined]
215
224
await self ._redis .connection_pool .disconnect ()
216
225
0 commit comments