Skip to content

Commit 7a980ef

Browse files
committed
BugFix: Use dumping of task data and source only when dumping
Signed-off-by: elronbandel <[email protected]>
1 parent a3a9f78 commit 7a980ef

File tree

3 files changed

+29
-18
lines changed

3 files changed

+29
-18
lines changed

src/unitxt/api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .logging_utils import get_logger
2222
from .metric_utils import EvaluationResults, _compute, _inference_post_process
2323
from .operator import SourceOperator
24-
from .schema import loads_instance
24+
from .schema import SerializeInstancesBeforeDump, loads_instance
2525
from .settings_utils import get_constants, get_settings
2626
from .standard import DatasetRecipe
2727
from .task import Task
@@ -151,7 +151,7 @@ def _source_to_dataset(
151151
)
152152
if split is not None:
153153
stream = {split: stream[split]}
154-
ds_builder._generators = stream
154+
ds_builder._generators = SerializeInstancesBeforeDump()(stream)
155155

156156
ds_builder.download_and_prepare(
157157
verification_mode="no_checks",
@@ -280,10 +280,12 @@ def produce(
280280
is_list = isinstance(instance_or_instances, list)
281281
if not is_list:
282282
instance_or_instances = [instance_or_instances]
283-
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
283+
instances = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
284+
serialize = SerializeInstancesBeforeDump()
285+
instances = [serialize.process_instance(instance) for instance in instances]
284286
if not is_list:
285-
return result[0]
286-
return Dataset.from_list(result).with_transform(loads_instance)
287+
return instances[0]
288+
return Dataset.from_list(instances).with_transform(loads_instance)
287289

288290

289291
def infer(

src/unitxt/dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@
4646
from .random_utils import __file__ as _
4747
from .recipe import __file__ as _
4848
from .register import __file__ as _
49-
from .schema import loads_instance
49+
from .schema import SerializeInstancesBeforeDump, loads_instance
5050
from .serializers import __file__ as _
5151
from .settings_utils import get_constants
5252
from .span_lableing_operators import __file__ as _
5353
from .split_utils import __file__ as _
5454
from .splitters import __file__ as _
5555
from .sql_utils import __file__ as _
5656
from .standard import __file__ as _
57+
from .stream import MultiStream
5758
from .stream import __file__ as _
5859
from .stream_operators import __file__ as _
5960
from .string_operators import __file__ as _
@@ -91,7 +92,9 @@ def generators(self):
9192
logger.info("Loading with huggingface unitxt copy...")
9293
dataset = get_dataset_artifact(self.config.name)
9394

94-
self._generators = dataset()
95+
multi_stream: MultiStream = dataset()
96+
97+
self._generators = SerializeInstancesBeforeDump()(multi_stream)
9598

9699
return self._generators
97100

src/unitxt/schema.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .artifact import Artifact
88
from .dict_utils import dict_get
99
from .image_operators import ImageDataString
10-
from .operator import InstanceOperatorValidator
10+
from .operator import InstanceOperator, InstanceOperatorValidator
1111
from .settings_utils import get_constants, get_settings
1212
from .type_utils import isoftype
1313
from .types import Image
@@ -87,6 +87,18 @@ def loads_instance(batch):
8787
return batch
8888

8989

90+
class SerializeInstancesBeforeDump(InstanceOperator):
91+
92+
def process(
93+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
94+
) -> Dict[str, Any]:
95+
if settings.task_data_as_text:
96+
instance["task_data"] = json.dumps(instance["task_data"])
97+
98+
if not isinstance(instance["source"], str):
99+
instance["source"] = json.dumps(instance["source"])
100+
return instance
101+
90102
class FinalizeDataset(InstanceOperatorValidator):
91103
group_by: List[List[str]]
92104
remove_unnecessary_fields: bool = True
@@ -126,13 +138,6 @@ def _get_instance_task_data(
126138
task_data = {**task_data, **instance["reference_fields"]}
127139
return task_data
128140

129-
def serialize_instance_fields(self, instance, task_data):
130-
if settings.task_data_as_text:
131-
instance["task_data"] = json.dumps(task_data)
132-
133-
if not isinstance(instance["source"], str):
134-
instance["source"] = json.dumps(instance["source"])
135-
return instance
136141

137142
def process(
138143
self, instance: Dict[str, Any], stream_name: Optional[str] = None
@@ -157,7 +162,7 @@ def process(
157162
for instance in instance.pop(constants.demos_field)
158163
]
159164

160-
instance = self.serialize_instance_fields(instance, task_data)
165+
instance["task_data"] = task_data
161166

162167
if self.remove_unnecessary_fields:
163168
keys_to_delete = []
@@ -202,7 +207,8 @@ def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
202207
instance, dict
203208
), f"Instance should be a dict, got {type(instance)}"
204209
schema = get_schema(stream_name)
210+
205211
assert all(
206212
key in instance for key in schema
207-
), f"Instance should have the following keys: {schema}. Instance is: {instance}"
208-
schema.encode_example(instance)
213+
), f"Instance should have the following keys: {schema.keys()}. Instance is: {instance.keys()}"
214+
# schema.encode_example(instance)

0 commit comments

Comments
 (0)