Skip to content

Commit 4eacba3

Browse files
committed
Implement util method that create tensor from metadata
1 parent 1a8258c commit 4eacba3

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

Packages/com.github.asus4.onnxruntime.unity/Runtime/ImageInference.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ private static (string[], OrtValue[]) AllocateTensors(IReadOnlyDictionary<string
116116
if (meta.IsTensor)
117117
{
118118
names.Add(kv.Key);
119-
values.Add(TensorFromMetadata(meta));
119+
values.Add(meta.CreateTensorOrtValue());
120120
}
121121
else
122122
{
@@ -126,14 +126,6 @@ private static (string[], OrtValue[]) AllocateTensors(IReadOnlyDictionary<string
126126
return (names.ToArray(), values.ToArray());
127127
}
128128

129-
private static OrtValue TensorFromMetadata(NodeMetadata metadata)
130-
{
131-
long[] shape = metadata.Dimensions.Select(x => (long)x).ToArray();
132-
var ortValue = OrtValue.CreateAllocatedTensorValue(
133-
OrtAllocator.DefaultInstance, metadata.ElementDataType, shape);
134-
return ortValue;
135-
}
136-
137129
private static bool IsSupportedImage(int[] shape)
138130
{
139131
int channels = shape.Length switch

Packages/com.github.asus4.onnxruntime.unity/Runtime/InferenceSessionExtension.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
using System;
12
using System.Diagnostics;
3+
using System.Linq;
24
using System.Text;
35

46
namespace Microsoft.ML.OnnxRuntime.Unity
@@ -37,5 +39,22 @@ public static void LogIOInfo(this InferenceSession session)
3739

3840
UnityEngine.Debug.Log(sb.ToString());
3941
}
42+
43+
/// <summary>
44+
/// Create OrtValue from NodeMetadata
45+
/// </summary>
46+
/// <param name="metadata">A metadata</param>
47+
/// <returns>Allocated OrtValue, should be disposed.</returns>
48+
public static OrtValue CreateTensorOrtValue(this NodeMetadata metadata)
49+
{
50+
if (!metadata.IsTensor)
51+
{
52+
throw new ArgumentException("metadata must be tensor");
53+
}
54+
long[] shape = metadata.Dimensions.Select(x => (long)x).ToArray();
55+
var ortValue = OrtValue.CreateAllocatedTensorValue(
56+
OrtAllocator.DefaultInstance, metadata.ElementDataType, shape);
57+
return ortValue;
58+
}
4059
}
4160
}

0 commit comments

Comments
 (0)