Filter functions based on files we scan.

Currently we extract all functions from the compilation unit - it doesn't really make sense as it will try to process functions in included files too.
As we require sapi_in flag, we have information which files are of interest.

This change stores the top path in _TranslationUnit and uses it when looking for function definitions to filter only functions from the path we provided.

PiperOrigin-RevId: 297342507
Change-Id: Ie411321d375168f413f9f153a606c1113f55e79a
This commit is contained in:
Maciej Szawłowski 2020-02-26 06:01:39 -08:00 committed by Copybara-Service
parent 6332df5ef6
commit edd6b437ae
3 changed files with 131 additions and 80 deletions

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module related to code analysis and generation."""
from __future__ import absolute_import
@ -23,12 +22,13 @@ import os
from clang import cindex
# pylint: disable=unused-import
from typing import (Text, List, Optional, Set, Dict, Callable, IO,
Generator as Gen, Tuple, Union, Sequence)
from typing import (Text, List, Optional, Set, Dict, Callable, IO, Generator as
Gen, Tuple, Union, Sequence)
# pylint: enable=unused-import
_PARSE_OPTIONS = (cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES |
cindex.TranslationUnit.PARSE_INCOMPLETE |
_PARSE_OPTIONS = (
cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES
| cindex.TranslationUnit.PARSE_INCOMPLETE |
# for include directives
cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD)
@ -80,6 +80,7 @@ def _stringify_tokens(tokens, separator='\n'):
return separator.join(str(l) for l in lines)
TYPE_MAPPING = {
cindex.TypeKind.VOID: '::sapi::v::Void',
cindex.TypeKind.CHAR_S: '::sapi::v::Char',
@ -331,8 +332,10 @@ class Type(object):
"""Returns string representation of the Type."""
# (szwl): as simple as possible, keeps macros in separate lines not to
# break things; this will go through clang format nevertheless
tokens = [x for x in self._get_declaration().get_tokens()
if x.kind is not cindex.TokenKind.COMMENT]
tokens = [
x for x in self._get_declaration().get_tokens()
if x.kind is not cindex.TokenKind.COMMENT
]
return _stringify_tokens(tokens)
@ -368,7 +371,7 @@ class OutputLine(object):
def __str__(self):
# type: () -> Text
tabs = ('\t'*self.tab) if not self.define else ''
tabs = ('\t' * self.tab) if not self.define else ''
return tabs + ''.join(t for t in self.spellings)
@ -426,8 +429,9 @@ class ArgumentType(Type):
type_ = type_.get_canonical()
if type_.kind == cindex.TypeKind.ENUM:
return '::sapi::v::IntBase<{}>'.format(self._clang_type.spelling)
if type_.kind in [cindex.TypeKind.CONSTANTARRAY,
cindex.TypeKind.INCOMPLETEARRAY]:
if type_.kind in [
cindex.TypeKind.CONSTANTARRAY, cindex.TypeKind.INCOMPLETEARRAY
]:
return '::sapi::v::Reg<{}>'.format(self._clang_type.spelling)
if type_.kind == cindex.TypeKind.LVALUEREFERENCE:
@ -489,12 +493,13 @@ class Function(object):
self.name = cursor.spelling # type: Text
self.mangled_name = cursor.mangled_name # type: Text
self.result = ReturnType(self, cursor.result_type)
self.original_definition = '{} {}'.format(cursor.result_type.spelling,
self.cursor.displayname) # type: Text
self.original_definition = '{} {}'.format(
cursor.result_type.spelling, self.cursor.displayname) # type: Text
types = self.cursor.get_arguments()
self.argument_types = [ArgumentType(self, i, t.type, t.spelling) for i, t
in enumerate(types)]
self.argument_types = [
ArgumentType(self, i, t.type, t.spelling) for i, t in enumerate(types)
]
def translation_unit(self):
# type: () -> _TranslationUnit
@ -550,8 +555,10 @@ class Function(object):
class _TranslationUnit(object):
"""Class wrapping clang's _TranslationUnit. Provides extra utilities."""
def __init__(self, tu):
# type: (cindex.TranslatioUnit) -> None
def __init__(self, path, tu, limit_scan_depth=False):
# type: (Text, cindex.TranslationUnit, bool) -> None
self.path = path
self.limit_scan_depth = limit_scan_depth
self._tu = tu
self._processed = False
self.forward_decls = dict()
@ -584,9 +591,12 @@ class _TranslationUnit(object):
if (cursor.kind == cindex.CursorKind.STRUCT_DECL and
not cursor.is_definition()):
self.forward_decls[Type(self, cursor.type)] = cursor
if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and
cursor.linkage != cindex.LinkageKind.INTERNAL):
if self.limit_scan_depth:
if (cursor.location and cursor.location.file.name == self.path):
self.functions.add(Function(self, cursor))
else:
self.functions.add(Function(self, cursor))
def get_functions(self):
@ -617,20 +627,26 @@ class Analyzer(object):
"""Class responsible for analysis."""
@staticmethod
def process_files(input_paths, compile_flags):
# type: (Text, List[Text]) -> List[_TranslationUnit]
def process_files(input_paths, compile_flags, limit_scan_depth=False):
# type: (Text, List[Text], bool) -> List[_TranslationUnit]
"""Processes files with libclang and returns TranslationUnit objects."""
_init_libclang()
return [Analyzer._analyze_file_for_tu(path, compile_flags=compile_flags)
for path in input_paths]
tus = []
for path in input_paths:
tu = Analyzer._analyze_file_for_tu(
path, compile_flags=compile_flags, limit_scan_depth=limit_scan_depth)
tus.append(tu)
return tus
# pylint: disable=line-too-long
@staticmethod
def _analyze_file_for_tu(path,
compile_flags=None,
test_file_existence=True,
unsaved_files=None
):
# type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]]) -> _TranslationUnit
unsaved_files=None,
limit_scan_depth=False):
# type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool) -> _TranslationUnit
"""Returns Analysis object for given path."""
compile_flags = compile_flags or []
if test_file_existence and not os.path.isfile(path):
@ -645,9 +661,14 @@ class Analyzer(object):
args = [lang]
args += compile_flags
args.append('-I.')
return _TranslationUnit(index.parse(path, args=args,
return _TranslationUnit(
path,
index.parse(
path,
args=args,
unsaved_files=unsaved_files,
options=_PARSE_OPTIONS))
options=_PARSE_OPTIONS),
limit_scan_depth=limit_scan_depth)
class Generator(object):
@ -656,8 +677,7 @@ class Generator(object):
AUTO_GENERATED = ('// AUTO-GENERATED by the Sandboxed API generator.\n'
'// Edits will be discarded when regenerating this file.\n')
GUARD_START = ('#ifndef {0}\n'
'#define {0}')
GUARD_START = ('#ifndef {0}\n' '#define {0}')
GUARD_END = '#endif // {}'
EMBED_INCLUDE = '#include \"{}/{}_embed.h"'
EMBED_CLASS = ('class {0}Sandbox : public ::sapi::Sandbox {{\n'
@ -678,9 +698,14 @@ class Generator(object):
self.functions = None
_init_libclang()
def generate(self,
name,
function_names,
namespace=None,
output_file=None,
embed_dir=None,
embed_name=None):
# pylint: disable=line-too-long
def generate(self, name, function_names, namespace=None, output_file=None,
embed_dir=None, embed_name=None):
# type: (Text, List[Text], Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text
"""Generates structures, functions and typedefs.
@ -689,11 +714,11 @@ class Generator(object):
function_names: list of function names to export to the interface
namespace: namespace of the interface
output_file: path to the output file, used to generate header guards;
defaults to None that does not generate the guard
#include directives; defaults to None that causes to emit the whole file
path
defaults to None that does not generate the guard #include directives;
defaults to None that causes to emit the whole file path
embed_dir: path to directory with embed includes
embed_name: name of the embed object
Returns:
generated interface as a string
"""
@ -722,8 +747,10 @@ class Generator(object):
self.functions = []
# TODO(szwl): for d in translation_unit.diagnostics:, handle that
for translation_unit in self.translation_units:
self.functions += [f for f in translation_unit.get_functions()
if not func_names or f.name in func_names]
self.functions += [
f for f in translation_unit.get_functions()
if not func_names or f.name in func_names
]
# allow only nonmangled functions - C++ overloads are not handled in
# code generation
self.functions = [f for f in self.functions if not f.is_mangled()]
@ -769,6 +796,7 @@ class Generator(object):
Returns:
list of #define string representations
"""
def make_sort_condition(translation_unit):
return lambda cursor: translation_unit.order[cursor.hash]
@ -781,8 +809,8 @@ class Generator(object):
define = tu.defines[name]
tmp_result.append(define)
for define in sorted(tmp_result, key=sort_condition):
result.append('#define ' + _stringify_tokens(define.get_tokens(),
separator=' \\\n'))
result.append('#define ' +
_stringify_tokens(define.get_tokens(), separator=' \\\n'))
return result
def _get_forward_decls(self, types):

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for code."""
from __future__ import absolute_import
@ -23,7 +22,6 @@ from clang import cindex
import code
import code_test_util
CODE = """
typedef int(fun*)(int,int);
extern "C" int function_a(int x, int y) { return x + y; }
@ -35,14 +33,15 @@ struct a {
"""
def analyze_string(content, path='tmp.cc'):
def analyze_string(content, path='tmp.cc', limit_scan_depth=False):
"""Returns Analysis object for in memory content."""
return analyze_strings(path, [(path, content)])
return analyze_strings(path, [(path, content)], limit_scan_depth)
def analyze_strings(path, unsaved_files):
def analyze_strings(path, unsaved_files, limit_scan_depth=False):
"""Returns Analysis object for in memory content."""
return code.Analyzer._analyze_file_for_tu(path, None, False, unsaved_files)
return code.Analyzer._analyze_file_for_tu(path, None, False, unsaved_files,
limit_scan_depth)
class CodeAnalysisTest(parameterized.TestCase):
@ -84,12 +83,16 @@ class CodeAnalysisTest(parameterized.TestCase):
def testExternC(self):
translation_unit = analyze_string('extern "C" int function(char* a);')
cursor_kinds = [x.kind for x in translation_unit._walk_preorder()
if x.kind != cindex.CursorKind.MACRO_DEFINITION]
self.assertListEqual(cursor_kinds, [cindex.CursorKind.TRANSLATION_UNIT,
cindex.CursorKind.UNEXPOSED_DECL,
cindex.CursorKind.FUNCTION_DECL,
cindex.CursorKind.PARM_DECL])
cursor_kinds = [
x.kind
for x in translation_unit._walk_preorder()
if x.kind != cindex.CursorKind.MACRO_DEFINITION
]
self.assertListEqual(cursor_kinds, [
cindex.CursorKind.TRANSLATION_UNIT, cindex.CursorKind.UNEXPOSED_DECL,
cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.PARM_DECL
])
@parameterized.named_parameters(
('1:', '/tmp/test.h', 'tmp', 'tmp/test.h'),
('2:', '/a/b/c/d/tmp/test.h', 'c/d', 'c/d/tmp/test.h'),
@ -97,7 +100,6 @@ class CodeAnalysisTest(parameterized.TestCase):
('4:', '/tmp/test.h', '', '/tmp/test.h'),
('5:', '/tmp/test.h', 'xxx', 'xxx/test.h'),
)
def testGetIncludes(self, path, prefix, expected):
function_body = 'extern "C" int function(bool a1) { return a1 ? 1 : 2; }'
translation_unit = analyze_string(function_body)
@ -120,8 +122,10 @@ class CodeAnalysisTest(parameterized.TestCase):
void types_6(char* a0);
}
"""
functions = ['function_a', 'types_1', 'types_2', 'types_3', 'types_4',
'types_5', 'types_6']
functions = [
'function_a', 'types_1', 'types_2', 'types_3', 'types_4', 'types_5',
'types_6'
]
generator = code.Generator([analyze_string(body)])
result = generator.generate('Test', functions, 'sapi::Tests', None, None)
self.assertMultiLineEqual(code_test_util.CODE_GOLD, result)
@ -132,7 +136,7 @@ class CodeAnalysisTest(parameterized.TestCase):
extern "C" int function(struct x a) { return a.a; }
"""
generator = code.Generator([analyze_string(body)])
with self.assertRaisesRegexp(ValueError, r'Elaborate.*mapped.*'):
with self.assertRaisesRegex(ValueError, r'Elaborate.*mapped.*'):
generator.generate('Test', ['function'], 'sapi::Tests', None, None)
def testElaboratedArgument2(self):
@ -141,7 +145,7 @@ class CodeAnalysisTest(parameterized.TestCase):
extern "C" int function(x a) { return a.a; }
"""
generator = code.Generator([analyze_string(body)])
with self.assertRaisesRegexp(ValueError, r'Elaborate.*mapped.*'):
with self.assertRaisesRegex(ValueError, r'Elaborate.*mapped.*'):
generator.generate('Test', ['function'], 'sapi::Tests', None, None)
def testGetMappedType(self):
@ -170,8 +174,7 @@ class CodeAnalysisTest(parameterized.TestCase):
'extern "C" int function(bool arg_bool, char* arg_ptr);',
['arg_bool', 'arg_ptr']),
('function without return value and no arguments',
'extern "C" void function();',
[]),
'extern "C" void function();', []),
)
def testArgumentNames(self, body, names):
generator = code.Generator([analyze_string(body)])
@ -436,19 +439,19 @@ class CodeAnalysisTest(parameterized.TestCase):
types = args[0].get_related_types()
names = [t._clang_type.spelling for t in types]
self.assertLen(types, 4)
self.assertSameElements(names, ['struct_6p', 'struct_6',
'struct struct_6_def', 'function_p3'])
self.assertSameElements(
names, ['struct_6p', 'struct_6', 'struct struct_6_def', 'function_p3'])
self.assertLen(generator.translation_units, 1)
self.assertLen(generator.translation_units[0].forward_decls, 1)
t = next(x for x in types
if x._clang_type.spelling == 'struct struct_6_def')
t = next(
x for x in types if x._clang_type.spelling == 'struct struct_6_def')
self.assertIn(t, generator.translation_units[0].forward_decls)
names = [t._clang_type.spelling for t in generator._get_related_types()]
self.assertEqual(names, ['struct_6', 'struct_6p',
'function_p3', 'struct struct_6_def'])
self.assertEqual(
names, ['struct_6', 'struct_6p', 'function_p3', 'struct struct_6_def'])
# Extra check for generation, in case rendering throws error for this test.
forward_decls = generator._get_forward_decls(generator._get_related_types())
@ -565,8 +568,8 @@ class CodeAnalysisTest(parameterized.TestCase):
typedef unsigned char uchar;"""
file3_code = 'typedef unsigned long ulong;'
file4_code = 'typedef char chr;'
files = [('f1.h', file1_code), ('/f2.h', file2_code),
('/f3.h', file3_code), ('/f4.h', file4_code)]
files = [('f1.h', file1_code), ('/f2.h', file2_code), ('/f3.h', file3_code),
('/f4.h', file4_code)]
generator = code.Generator([analyze_strings('f1.h', files)])
functions = generator._get_functions()
self.assertLen(functions, 1)
@ -576,6 +579,25 @@ class CodeAnalysisTest(parameterized.TestCase):
# Extra check for generation, in case rendering throws error for this test.
generator.generate('Test', [], 'sapi::Tests', None, None)
def testFilterFunctionsFromInputFilesOnly(self):
file1_code = """
#include "/f2.h"
extern "C" int function1();
"""
file2_code = """
extern "C" int function2();
"""
files = [('f1.h', file1_code), ('/f2.h', file2_code)]
generator = code.Generator([analyze_strings('f1.h', files)])
functions = generator._get_functions()
self.assertLen(functions, 2)
generator = code.Generator([analyze_strings('f1.h', files, True)])
functions = generator._get_functions()
self.assertLen(functions, 1)
def testTypeToString(self):
body = """
#define SIZE 1024

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SAPI interface header generator.
Parses headers to extract type information from functions and generate a SAPI
@ -34,6 +33,9 @@ flags.DEFINE_list('sapi_functions', [], 'function list to analyze')
flags.DEFINE_list('sapi_in', None, 'input files to analyze')
flags.DEFINE_string('sapi_embed_dir', None, 'directory with embed includes')
flags.DEFINE_string('sapi_embed_name', None, 'name of the embed object')
flags.DEFINE_bool(
'sapi_limit_scan_depth', False,
'scan only functions from top level file in compilation unit')
def extract_includes(path, array):
@ -52,14 +54,12 @@ def main(c_flags):
c_flags.pop(0)
logging.debug(FLAGS.sapi_functions)
extract_includes(FLAGS.sapi_isystem, c_flags)
tus = code.Analyzer.process_files(FLAGS.sapi_in, c_flags)
tus = code.Analyzer.process_files(FLAGS.sapi_in, c_flags,
FLAGS.sapi_limit_scan_depth)
generator = code.Generator(tus)
result = generator.generate(FLAGS.sapi_name,
FLAGS.sapi_functions,
FLAGS.sapi_ns,
FLAGS.sapi_out,
FLAGS.sapi_embed_dir,
FLAGS.sapi_embed_name)
result = generator.generate(FLAGS.sapi_name, FLAGS.sapi_functions,
FLAGS.sapi_ns, FLAGS.sapi_out,
FLAGS.sapi_embed_dir, FLAGS.sapi_embed_name)
if FLAGS.sapi_out:
with open(FLAGS.sapi_out, 'w') as out_file:
@ -67,6 +67,7 @@ def main(c_flags):
else:
sys.stdout.write(result)
if __name__ == '__main__':
flags.mark_flags_as_required(['sapi_name', 'sapi_in'])
app.run(main)