Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .logging_utils import get_logger
from .metric_utils import EvaluationResults, _compute, _inference_post_process
from .operator import SourceOperator
from .schema import loads_batch
from .schema import SerializeInstancesBeforeDump, loads_batch
from .settings_utils import get_constants, get_settings
from .standard import DatasetRecipe
from .task import Task
Expand Down Expand Up @@ -204,7 +204,7 @@ def _source_to_dataset(
)
if split is not None:
stream = {split: stream[split]}
ds_builder._generators = stream
ds_builder._generators = SerializeInstancesBeforeDump()(stream)

try:
ds_builder.download_and_prepare(
Expand Down Expand Up @@ -354,10 +354,12 @@ def produce(
if not is_list:
instance_or_instances = [instance_or_instances]
dataset_recipe = _get_recipe_with_cache(dataset_query, **kwargs)
result = dataset_recipe.produce(instance_or_instances)
instances = dataset_recipe.produce(instance_or_instances)
serialize = SerializeInstancesBeforeDump()
instances = [serialize.process_instance(instance) for instance in instances]
if not is_list:
return result[0]
return Dataset.from_list(result).with_transform(loads_batch)
return instances[0]
return Dataset.from_list(instances).with_transform(loads_batch)


def infer(
Expand Down
7 changes: 5 additions & 2 deletions src/unitxt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@
from .random_utils import __file__ as _
from .recipe import __file__ as _
from .register import __file__ as _
from .schema import loads_batch, loads_instance
from .schema import SerializeInstancesBeforeDump, loads_batch, loads_instance
from .serializers import __file__ as _
from .settings_utils import get_constants
from .span_lableing_operators import __file__ as _
from .split_utils import __file__ as _
from .splitters import __file__ as _
from .standard import __file__ as _
from .stream import MultiStream
from .stream import __file__ as _
from .stream_operators import __file__ as _
from .string_operators import __file__ as _
Expand Down Expand Up @@ -92,7 +93,9 @@ def generators(self):
logger.info("Loading with huggingface unitxt copy...")
dataset = get_dataset_artifact(self.config.name)

self._generators = dataset()
multi_stream: MultiStream = dataset()

self._generators = SerializeInstancesBeforeDump()(multi_stream)

return self._generators

Expand Down
29 changes: 17 additions & 12 deletions src/unitxt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .artifact import Artifact
from .dict_utils import dict_get
from .image_operators import ImageDataString
from .operator import InstanceOperatorValidator
from .operator import InstanceOperator, InstanceOperatorValidator
from .settings_utils import get_constants, get_settings
from .type_utils import isoftype
from .types import Image
Expand Down Expand Up @@ -68,6 +68,18 @@ def load_chat_source(chat_str):
return chat


class SerializeInstancesBeforeDump(InstanceOperator):
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
if settings.task_data_as_text:
instance["task_data"] = json.dumps(instance["task_data"])

if not isinstance(instance["source"], str):
instance["source"] = json.dumps(instance["source"])
return instance


def loads_batch(batch):
if (
"source" in batch
Expand Down Expand Up @@ -148,14 +160,6 @@ def _get_instance_task_data(
task_data["__tools__"] = instance["__tools__"]
return task_data

def serialize_instance_fields(self, instance, task_data):
if settings.task_data_as_text:
instance["task_data"] = json.dumps(task_data)

if not isinstance(instance["source"], str):
instance["source"] = json.dumps(instance["source"])
return instance

def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
Expand All @@ -179,7 +183,7 @@ def process(
for instance in instance.pop(constants.demos_field)
]

instance = self.serialize_instance_fields(instance, task_data)
instance["task_data"] = task_data

if self.remove_unnecessary_fields:
keys_to_delete = []
Expand Down Expand Up @@ -224,7 +228,8 @@ def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
instance, dict
), f"Instance should be a dict, got {type(instance)}"
schema = get_schema(stream_name)

assert all(
key in instance for key in schema
), f"Instance should have the following keys: {schema}. Instance is: {instance}"
schema.encode_example(instance)
), f"Instance should have the following keys: {schema.keys()}. Instance is: {instance.keys()}"
# schema.encode_example(instance)
Loading