@@ -239,18 +239,18 @@ def enhanced(*args, **kwargs):
239
239
240
240
241
241
def dataset_to_point_list (
242
- ds : xarray .Dataset , sample_dims : Sequence [str ]
242
+ ds : Union [ xarray .Dataset , dict [ str , xarray . DataArray ]] , sample_dims : Sequence [str ]
243
243
) -> Tuple [List [Dict [str , np .ndarray ]], Dict [str , Any ]]:
244
244
# All keys of the dataset must be a str
245
- var_names = list (ds .keys ())
245
+ var_names = cast ( List [ str ], list (ds .keys () ))
246
246
for vn in var_names :
247
247
if not isinstance (vn , str ):
248
248
raise ValueError (f"Variable names must be str, but dataset key { vn } is a { type (vn )} ." )
249
249
num_sample_dims = len (sample_dims )
250
- stacked_dims = {dim_name : ds [dim_name ] for dim_name in sample_dims }
251
- ds = ds .transpose (* sample_dims , ...)
250
+ stacked_dims = {dim_name : ds [var_names [0 ]][dim_name ] for dim_name in sample_dims }
252
251
stacked_dict = {
253
- vn : da .values .reshape ((- 1 , * da .shape [num_sample_dims :])) for vn , da in ds .items ()
252
+ vn : da .transpose (* sample_dims , ...).values .reshape ((- 1 , * da .shape [num_sample_dims :]))
253
+ for vn , da in ds .items ()
254
254
}
255
255
points = [
256
256
{vn : stacked_dict [vn ][i , ...] for vn in var_names }
0 commit comments