[go: up one dir, main page]

Skip to content

Commit

Permalink
Another fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Ritwik Das committed May 8, 2023
1 parent 2e2ad05 commit b02b922
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions hatlib/arg_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_dimension_arg_indices(array_arg: ArgInfo, all_arguments: List[ArgInfo])
return indices


def _gen_random_data(dtype, shape, strides):
def _gen_random_data(dtype, shape, strides=None):
dtype = np.uint16 if dtype == "bfloat16" else dtype
if isinstance(dtype, np.dtype):
dtype = dtype.type
Expand All @@ -157,7 +157,7 @@ def _gen_random_data(dtype, shape, strides):
else:
data = np.random.random(tuple(shape)).astype(dtype)

return np.lib.stride_tricks.as_strided(data, strides=strides)
return np.lib.stride_tricks.as_strided(data, strides=strides) if strides is not None else data


def generate_arg_values(arguments: List[ArgInfo], dim_names_to_values={}) -> List[ArgValue]:
Expand Down Expand Up @@ -198,7 +198,7 @@ def generate_dim_value():
shape.append(v if isinstance(v, np.integer) or type(v) == int else v[0])

# materialize an array input using the generated shape
runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape, arg.numpy_strides)
runtime_array_inputs = _gen_random_data(arg.numpy_dtype, shape)
values.append(ArgValue(arg, runtime_array_inputs))

elif arg.name in dim_names_to_values and arg.usage == hat_file.UsageType.Input:
Expand Down

0 comments on commit b02b922

Please sign in to comment.