7
7
from .artifact import Artifact
8
8
from .dict_utils import dict_get
9
9
from .image_operators import ImageDataString
10
- from .operator import InstanceOperatorValidator
10
+ from .operator import InstanceOperator , InstanceOperatorValidator
11
11
from .settings_utils import get_constants , get_settings
12
12
from .type_utils import isoftype
13
13
from .types import Image
@@ -87,6 +87,18 @@ def loads_instance(batch):
87
87
return batch
88
88
89
89
90
+ class SerializeInstancesBeforeDump (InstanceOperator ):
91
+
92
+ def process (
93
+ self , instance : Dict [str , Any ], stream_name : Optional [str ] = None
94
+ ) -> Dict [str , Any ]:
95
+ if settings .task_data_as_text :
96
+ instance ["task_data" ] = json .dumps (instance ["task_data" ])
97
+
98
+ if not isinstance (instance ["source" ], str ):
99
+ instance ["source" ] = json .dumps (instance ["source" ])
100
+ return instance
101
+
90
102
class FinalizeDataset (InstanceOperatorValidator ):
91
103
group_by : List [List [str ]]
92
104
remove_unnecessary_fields : bool = True
@@ -126,13 +138,6 @@ def _get_instance_task_data(
126
138
task_data = {** task_data , ** instance ["reference_fields" ]}
127
139
return task_data
128
140
129
- def serialize_instance_fields (self , instance , task_data ):
130
- if settings .task_data_as_text :
131
- instance ["task_data" ] = json .dumps (task_data )
132
-
133
- if not isinstance (instance ["source" ], str ):
134
- instance ["source" ] = json .dumps (instance ["source" ])
135
- return instance
136
141
137
142
def process (
138
143
self , instance : Dict [str , Any ], stream_name : Optional [str ] = None
@@ -157,7 +162,7 @@ def process(
157
162
for instance in instance .pop (constants .demos_field )
158
163
]
159
164
160
- instance = self . serialize_instance_fields ( instance , task_data )
165
+ instance [ "task_data" ] = task_data
161
166
162
167
if self .remove_unnecessary_fields :
163
168
keys_to_delete = []
@@ -202,7 +207,8 @@ def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
202
207
instance , dict
203
208
), f"Instance should be a dict, got { type (instance )} "
204
209
schema = get_schema (stream_name )
210
+
205
211
assert all (
206
212
key in instance for key in schema
207
- ), f"Instance should have the following keys: { schema } . Instance is: { instance } "
208
- schema .encode_example (instance )
213
+ ), f"Instance should have the following keys: { schema . keys () } . Instance is: { instance . keys () } "
214
+ # schema.encode_example(instance)
0 commit comments