diff --git a/tests/test_util.py b/tests/test_util.py index c08426d..c00d5f6 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -39,3 +39,7 @@ class TestUtil(): def test_is_iterable(self): assert_equal(Util.is_iterable('foo'), True) assert_equal(Util.is_iterable(7), False) + + def test_convert_to_list(obj): + assert_equal(isinstance(Util.convert_to_list('foo'), list), True) + assert_equal(isinstance(Util.convert_to_list(7), list), False) \ No newline at end of file diff --git a/util.py b/util.py index 5d0de5b..bef6d01 100644 --- a/util.py +++ b/util.py @@ -7,3 +7,13 @@ class Util: return True except TypeError: return False + + @classmethod + def convert_to_list(self, obj): + """Useful when writing functions that can accept multiple types of + input (list, tuple, ndarray, iterator). Checks if the object is a list. + If it is not a list, converts it to a list. + """ + if not isinstance(obj, list) and self.is_iterable(obj): + obj = list(obj) + return obj