[go: up one dir, main page]

Skip to content

Commit

Permalink
Generates random data for constant shaped args (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
kernhanda committed Dec 1, 2022
1 parent 6d93d29 commit 1f59252
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
34 changes: 21 additions & 13 deletions hatlib/arg_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo])
return indices


def _gen_random_data(dtype, shape):
dtype = np.uint16 if dtype == "bfloat16" else dtype
if isinstance(dtype, np.dtype):
dtype = dtype.type
if isinstance(dtype, type) and issubclass(dtype, np.integer):
iinfo = np.iinfo(dtype)
min_num = iinfo.min
max_num = iinfo.max
data = np.random.randint(low=min_num, high=max_num, size=tuple(shape), dtype=dtype)
else:
data = np.random.random(tuple(shape)).astype(dtype)

return data


def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> List[ArgValue]:
"""Generate argument values from argument descriptions
Input and input/output affine_arrays: initialized with random inputs
Expand Down Expand Up @@ -176,18 +191,7 @@ def generate_dim_value():
shape.append(v if isinstance(v, np.integer) or type(v) == int else v[0])

# materialize an array input using the generated shape
numpy_dtype = np.uint16 if arg.numpy_dtype == "bfloat16" else arg.numpy_dtype
if (isinstance(numpy_dtype, np.dtype)
and issubclass(numpy_dtype.type, np.integer)) or (isinstance(numpy_dtype, type)
and issubclass(numpy_dtype, np.integer)):
iinfo = np.iinfo(numpy_dtype)
min_num = iinfo.min
max_num = iinfo.max
runtime_array_inputs = np.random.randint(
low=min_num, high=max_num, size=tuple(shape), dtype=numpy_dtype
)
else:
runtime_array_inputs = np.random.random(tuple(shape)).astype(numpy_dtype)
runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape)
values.append(ArgValue(arg, runtime_array_inputs))

elif arg.name in dim_names_to_values:
Expand All @@ -202,7 +206,11 @@ def generate_dim_value():
if not hasattr(arg, 'numpy_strides'):
arg.numpy_strides = list(map(lambda x: x * arg.element_num_bytes, arg.shape[1:] + [1]))

values.append(ArgValue(arg))
if arg.usage != hat_file.UsageType.Output:
arg_data = _gen_random_data(arg.numpy_dtype, arg.shape)
values.append(ArgValue(arg, arg_data))
else:
values.append(ArgValue(arg))

# collect the dimension ArgValues for each output runtime_array ArgValue
for value in values:
Expand Down
17 changes: 10 additions & 7 deletions test/test_verify_hat.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,25 @@ def test_runtime_array(self):
{
/* Range */
/* Ensure we don't crash with random inputs */
int32_t delta0;
if (limit[0] < start[0]) {
delta0 = delta[0] <= 0 ? delta[0] : -delta[0];
delta0 = delta0 == 0 ? -1 : delta[0];
} else {
delta0 = delta[0] >= 0 ? delta[0] : -delta[0];
delta0 = delta0 == 0 ? 1 : delta[0];
}
int32_t start0 = start[0];
int32_t delta0 = delta[0] == 0 ? 1 : delta[0];
int32_t limit0 = (limit[0] <= start0) ? (start0 + delta0 * 25) : limit[0];
int32_t limit0 = limit[0];
*output_dim = (limit0 - start0) / delta0;
*output = (int32_t*)ALLOC(*output_dim * sizeof(int32_t));
printf(\"Allocated %d output elements\\n\", *output_dim);
printf(\"Allocated %u output elements\\n\", *output_dim);
printf(\"start=%d, limit=%d, delta=%d\\n\", start0, limit0, delta0);
for (uint32_t i = 0; i < *output_dim; ++i) {
(*output)[i] = start0 + (i * delta0);
}
for (uint32_t i = 0; i < *output_dim; ++i) {
(*output)[i] = start0 + (i * delta0);
}
}
'''
decl_code = '''#endif // TOML
Expand Down

0 comments on commit 1f59252

Please sign in to comment.