Filter functions based on files we scan.

Currently we extract all functions from the compilation unit - it doesn't really make sense as it will try to process functions in included files too.
As we require sapi_in flag, we have information which files are of interest.

This change stores the top path in _TranslationUnit and uses it when looking for function definitions to filter only functions from the path we provided.

PiperOrigin-RevId: 297342507
Change-Id: Ie411321d375168f413f9f153a606c1113f55e79a
This commit is contained in:
Maciej Szawłowski 2020-02-26 06:01:39 -08:00 committed by Copybara-Service
parent 6332df5ef6
commit edd6b437ae
3 changed files with 131 additions and 80 deletions

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Module related to code analysis and generation.""" """Module related to code analysis and generation."""
from __future__ import absolute_import from __future__ import absolute_import
@ -23,12 +22,13 @@ import os
from clang import cindex from clang import cindex
# pylint: disable=unused-import # pylint: disable=unused-import
from typing import (Text, List, Optional, Set, Dict, Callable, IO, from typing import (Text, List, Optional, Set, Dict, Callable, IO, Generator as
Generator as Gen, Tuple, Union, Sequence) Gen, Tuple, Union, Sequence)
# pylint: enable=unused-import # pylint: enable=unused-import
_PARSE_OPTIONS = (cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES | _PARSE_OPTIONS = (
cindex.TranslationUnit.PARSE_INCOMPLETE | cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES
| cindex.TranslationUnit.PARSE_INCOMPLETE |
# for include directives # for include directives
cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD) cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD)
@ -80,6 +80,7 @@ def _stringify_tokens(tokens, separator='\n'):
return separator.join(str(l) for l in lines) return separator.join(str(l) for l in lines)
TYPE_MAPPING = { TYPE_MAPPING = {
cindex.TypeKind.VOID: '::sapi::v::Void', cindex.TypeKind.VOID: '::sapi::v::Void',
cindex.TypeKind.CHAR_S: '::sapi::v::Char', cindex.TypeKind.CHAR_S: '::sapi::v::Char',
@ -331,8 +332,10 @@ class Type(object):
"""Returns string representation of the Type.""" """Returns string representation of the Type."""
# (szwl): as simple as possible, keeps macros in separate lines not to # (szwl): as simple as possible, keeps macros in separate lines not to
# break things; this will go through clang format nevertheless # break things; this will go through clang format nevertheless
tokens = [x for x in self._get_declaration().get_tokens() tokens = [
if x.kind is not cindex.TokenKind.COMMENT] x for x in self._get_declaration().get_tokens()
if x.kind is not cindex.TokenKind.COMMENT
]
return _stringify_tokens(tokens) return _stringify_tokens(tokens)
@ -368,7 +371,7 @@ class OutputLine(object):
def __str__(self): def __str__(self):
# type: () -> Text # type: () -> Text
tabs = ('\t'*self.tab) if not self.define else '' tabs = ('\t' * self.tab) if not self.define else ''
return tabs + ''.join(t for t in self.spellings) return tabs + ''.join(t for t in self.spellings)
@ -426,8 +429,9 @@ class ArgumentType(Type):
type_ = type_.get_canonical() type_ = type_.get_canonical()
if type_.kind == cindex.TypeKind.ENUM: if type_.kind == cindex.TypeKind.ENUM:
return '::sapi::v::IntBase<{}>'.format(self._clang_type.spelling) return '::sapi::v::IntBase<{}>'.format(self._clang_type.spelling)
if type_.kind in [cindex.TypeKind.CONSTANTARRAY, if type_.kind in [
cindex.TypeKind.INCOMPLETEARRAY]: cindex.TypeKind.CONSTANTARRAY, cindex.TypeKind.INCOMPLETEARRAY
]:
return '::sapi::v::Reg<{}>'.format(self._clang_type.spelling) return '::sapi::v::Reg<{}>'.format(self._clang_type.spelling)
if type_.kind == cindex.TypeKind.LVALUEREFERENCE: if type_.kind == cindex.TypeKind.LVALUEREFERENCE:
@ -489,12 +493,13 @@ class Function(object):
self.name = cursor.spelling # type: Text self.name = cursor.spelling # type: Text
self.mangled_name = cursor.mangled_name # type: Text self.mangled_name = cursor.mangled_name # type: Text
self.result = ReturnType(self, cursor.result_type) self.result = ReturnType(self, cursor.result_type)
self.original_definition = '{} {}'.format(cursor.result_type.spelling, self.original_definition = '{} {}'.format(
self.cursor.displayname) # type: Text cursor.result_type.spelling, self.cursor.displayname) # type: Text
types = self.cursor.get_arguments() types = self.cursor.get_arguments()
self.argument_types = [ArgumentType(self, i, t.type, t.spelling) for i, t self.argument_types = [
in enumerate(types)] ArgumentType(self, i, t.type, t.spelling) for i, t in enumerate(types)
]
def translation_unit(self): def translation_unit(self):
# type: () -> _TranslationUnit # type: () -> _TranslationUnit
@ -550,8 +555,10 @@ class Function(object):
class _TranslationUnit(object): class _TranslationUnit(object):
"""Class wrapping clang's _TranslationUnit. Provides extra utilities.""" """Class wrapping clang's _TranslationUnit. Provides extra utilities."""
def __init__(self, tu): def __init__(self, path, tu, limit_scan_depth=False):
# type: (cindex.TranslatioUnit) -> None # type: (Text, cindex.TranslationUnit, bool) -> None
self.path = path
self.limit_scan_depth = limit_scan_depth
self._tu = tu self._tu = tu
self._processed = False self._processed = False
self.forward_decls = dict() self.forward_decls = dict()
@ -584,9 +591,12 @@ class _TranslationUnit(object):
if (cursor.kind == cindex.CursorKind.STRUCT_DECL and if (cursor.kind == cindex.CursorKind.STRUCT_DECL and
not cursor.is_definition()): not cursor.is_definition()):
self.forward_decls[Type(self, cursor.type)] = cursor self.forward_decls[Type(self, cursor.type)] = cursor
if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and
cursor.linkage != cindex.LinkageKind.INTERNAL): cursor.linkage != cindex.LinkageKind.INTERNAL):
if self.limit_scan_depth:
if (cursor.location and cursor.location.file.name == self.path):
self.functions.add(Function(self, cursor))
else:
self.functions.add(Function(self, cursor)) self.functions.add(Function(self, cursor))
def get_functions(self): def get_functions(self):
@ -617,20 +627,26 @@ class Analyzer(object):
"""Class responsible for analysis.""" """Class responsible for analysis."""
@staticmethod @staticmethod
def process_files(input_paths, compile_flags): def process_files(input_paths, compile_flags, limit_scan_depth=False):
# type: (Text, List[Text]) -> List[_TranslationUnit] # type: (Text, List[Text], bool) -> List[_TranslationUnit]
"""Processes files with libclang and returns TranslationUnit objects."""
_init_libclang() _init_libclang()
return [Analyzer._analyze_file_for_tu(path, compile_flags=compile_flags)
for path in input_paths] tus = []
for path in input_paths:
tu = Analyzer._analyze_file_for_tu(
path, compile_flags=compile_flags, limit_scan_depth=limit_scan_depth)
tus.append(tu)
return tus
# pylint: disable=line-too-long # pylint: disable=line-too-long
@staticmethod @staticmethod
def _analyze_file_for_tu(path, def _analyze_file_for_tu(path,
compile_flags=None, compile_flags=None,
test_file_existence=True, test_file_existence=True,
unsaved_files=None unsaved_files=None,
): limit_scan_depth=False):
# type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]]) -> _TranslationUnit # type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool) -> _TranslationUnit
"""Returns Analysis object for given path.""" """Returns Analysis object for given path."""
compile_flags = compile_flags or [] compile_flags = compile_flags or []
if test_file_existence and not os.path.isfile(path): if test_file_existence and not os.path.isfile(path):
@ -645,9 +661,14 @@ class Analyzer(object):
args = [lang] args = [lang]
args += compile_flags args += compile_flags
args.append('-I.') args.append('-I.')
return _TranslationUnit(index.parse(path, args=args, return _TranslationUnit(
path,
index.parse(
path,
args=args,
unsaved_files=unsaved_files, unsaved_files=unsaved_files,
options=_PARSE_OPTIONS)) options=_PARSE_OPTIONS),
limit_scan_depth=limit_scan_depth)
class Generator(object): class Generator(object):
@ -656,8 +677,7 @@ class Generator(object):
AUTO_GENERATED = ('// AUTO-GENERATED by the Sandboxed API generator.\n' AUTO_GENERATED = ('// AUTO-GENERATED by the Sandboxed API generator.\n'
'// Edits will be discarded when regenerating this file.\n') '// Edits will be discarded when regenerating this file.\n')
GUARD_START = ('#ifndef {0}\n' GUARD_START = ('#ifndef {0}\n' '#define {0}')
'#define {0}')
GUARD_END = '#endif // {}' GUARD_END = '#endif // {}'
EMBED_INCLUDE = '#include \"{}/{}_embed.h"' EMBED_INCLUDE = '#include \"{}/{}_embed.h"'
EMBED_CLASS = ('class {0}Sandbox : public ::sapi::Sandbox {{\n' EMBED_CLASS = ('class {0}Sandbox : public ::sapi::Sandbox {{\n'
@ -678,9 +698,14 @@ class Generator(object):
self.functions = None self.functions = None
_init_libclang() _init_libclang()
def generate(self,
name,
function_names,
namespace=None,
output_file=None,
embed_dir=None,
embed_name=None):
# pylint: disable=line-too-long # pylint: disable=line-too-long
def generate(self, name, function_names, namespace=None, output_file=None,
embed_dir=None, embed_name=None):
# type: (Text, List[Text], Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text # type: (Text, List[Text], Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text
"""Generates structures, functions and typedefs. """Generates structures, functions and typedefs.
@ -689,11 +714,11 @@ class Generator(object):
function_names: list of function names to export to the interface function_names: list of function names to export to the interface
namespace: namespace of the interface namespace: namespace of the interface
output_file: path to the output file, used to generate header guards; output_file: path to the output file, used to generate header guards;
defaults to None that does not generate the guard defaults to None that does not generate the guard #include directives;
#include directives; defaults to None that causes to emit the whole file defaults to None that causes to emit the whole file path
path
embed_dir: path to directory with embed includes embed_dir: path to directory with embed includes
embed_name: name of the embed object embed_name: name of the embed object
Returns: Returns:
generated interface as a string generated interface as a string
""" """
@ -722,8 +747,10 @@ class Generator(object):
self.functions = [] self.functions = []
# TODO(szwl): for d in translation_unit.diagnostics:, handle that # TODO(szwl): for d in translation_unit.diagnostics:, handle that
for translation_unit in self.translation_units: for translation_unit in self.translation_units:
self.functions += [f for f in translation_unit.get_functions() self.functions += [
if not func_names or f.name in func_names] f for f in translation_unit.get_functions()
if not func_names or f.name in func_names
]
# allow only nonmangled functions - C++ overloads are not handled in # allow only nonmangled functions - C++ overloads are not handled in
# code generation # code generation
self.functions = [f for f in self.functions if not f.is_mangled()] self.functions = [f for f in self.functions if not f.is_mangled()]
@ -769,6 +796,7 @@ class Generator(object):
Returns: Returns:
list of #define string representations list of #define string representations
""" """
def make_sort_condition(translation_unit): def make_sort_condition(translation_unit):
return lambda cursor: translation_unit.order[cursor.hash] return lambda cursor: translation_unit.order[cursor.hash]
@ -781,8 +809,8 @@ class Generator(object):
define = tu.defines[name] define = tu.defines[name]
tmp_result.append(define) tmp_result.append(define)
for define in sorted(tmp_result, key=sort_condition): for define in sorted(tmp_result, key=sort_condition):
result.append('#define ' + _stringify_tokens(define.get_tokens(), result.append('#define ' +
separator=' \\\n')) _stringify_tokens(define.get_tokens(), separator=' \\\n'))
return result return result
def _get_forward_decls(self, types): def _get_forward_decls(self, types):

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for code.""" """Tests for code."""
from __future__ import absolute_import from __future__ import absolute_import
@ -23,7 +22,6 @@ from clang import cindex
import code import code
import code_test_util import code_test_util
CODE = """ CODE = """
typedef int(fun*)(int,int); typedef int(fun*)(int,int);
extern "C" int function_a(int x, int y) { return x + y; } extern "C" int function_a(int x, int y) { return x + y; }
@ -35,14 +33,15 @@ struct a {
""" """
def analyze_string(content, path='tmp.cc'): def analyze_string(content, path='tmp.cc', limit_scan_depth=False):
"""Returns Analysis object for in memory content.""" """Returns Analysis object for in memory content."""
return analyze_strings(path, [(path, content)]) return analyze_strings(path, [(path, content)], limit_scan_depth)
def analyze_strings(path, unsaved_files): def analyze_strings(path, unsaved_files, limit_scan_depth=False):
"""Returns Analysis object for in memory content.""" """Returns Analysis object for in memory content."""
return code.Analyzer._analyze_file_for_tu(path, None, False, unsaved_files) return code.Analyzer._analyze_file_for_tu(path, None, False, unsaved_files,
limit_scan_depth)
class CodeAnalysisTest(parameterized.TestCase): class CodeAnalysisTest(parameterized.TestCase):
@ -84,12 +83,16 @@ class CodeAnalysisTest(parameterized.TestCase):
def testExternC(self): def testExternC(self):
translation_unit = analyze_string('extern "C" int function(char* a);') translation_unit = analyze_string('extern "C" int function(char* a);')
cursor_kinds = [x.kind for x in translation_unit._walk_preorder() cursor_kinds = [
if x.kind != cindex.CursorKind.MACRO_DEFINITION] x.kind
self.assertListEqual(cursor_kinds, [cindex.CursorKind.TRANSLATION_UNIT, for x in translation_unit._walk_preorder()
cindex.CursorKind.UNEXPOSED_DECL, if x.kind != cindex.CursorKind.MACRO_DEFINITION
cindex.CursorKind.FUNCTION_DECL, ]
cindex.CursorKind.PARM_DECL]) self.assertListEqual(cursor_kinds, [
cindex.CursorKind.TRANSLATION_UNIT, cindex.CursorKind.UNEXPOSED_DECL,
cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.PARM_DECL
])
@parameterized.named_parameters( @parameterized.named_parameters(
('1:', '/tmp/test.h', 'tmp', 'tmp/test.h'), ('1:', '/tmp/test.h', 'tmp', 'tmp/test.h'),
('2:', '/a/b/c/d/tmp/test.h', 'c/d', 'c/d/tmp/test.h'), ('2:', '/a/b/c/d/tmp/test.h', 'c/d', 'c/d/tmp/test.h'),
@ -97,7 +100,6 @@ class CodeAnalysisTest(parameterized.TestCase):
('4:', '/tmp/test.h', '', '/tmp/test.h'), ('4:', '/tmp/test.h', '', '/tmp/test.h'),
('5:', '/tmp/test.h', 'xxx', 'xxx/test.h'), ('5:', '/tmp/test.h', 'xxx', 'xxx/test.h'),
) )
def testGetIncludes(self, path, prefix, expected): def testGetIncludes(self, path, prefix, expected):
function_body = 'extern "C" int function(bool a1) { return a1 ? 1 : 2; }' function_body = 'extern "C" int function(bool a1) { return a1 ? 1 : 2; }'
translation_unit = analyze_string(function_body) translation_unit = analyze_string(function_body)
@ -120,8 +122,10 @@ class CodeAnalysisTest(parameterized.TestCase):
void types_6(char* a0); void types_6(char* a0);
} }
""" """
functions = ['function_a', 'types_1', 'types_2', 'types_3', 'types_4', functions = [
'types_5', 'types_6'] 'function_a', 'types_1', 'types_2', 'types_3', 'types_4', 'types_5',
'types_6'
]
generator = code.Generator([analyze_string(body)]) generator = code.Generator([analyze_string(body)])
result = generator.generate('Test', functions, 'sapi::Tests', None, None) result = generator.generate('Test', functions, 'sapi::Tests', None, None)
self.assertMultiLineEqual(code_test_util.CODE_GOLD, result) self.assertMultiLineEqual(code_test_util.CODE_GOLD, result)
@ -132,7 +136,7 @@ class CodeAnalysisTest(parameterized.TestCase):
extern "C" int function(struct x a) { return a.a; } extern "C" int function(struct x a) { return a.a; }
""" """
generator = code.Generator([analyze_string(body)]) generator = code.Generator([analyze_string(body)])
with self.assertRaisesRegexp(ValueError, r'Elaborate.*mapped.*'): with self.assertRaisesRegex(ValueError, r'Elaborate.*mapped.*'):
generator.generate('Test', ['function'], 'sapi::Tests', None, None) generator.generate('Test', ['function'], 'sapi::Tests', None, None)
def testElaboratedArgument2(self): def testElaboratedArgument2(self):
@ -141,7 +145,7 @@ class CodeAnalysisTest(parameterized.TestCase):
extern "C" int function(x a) { return a.a; } extern "C" int function(x a) { return a.a; }
""" """
generator = code.Generator([analyze_string(body)]) generator = code.Generator([analyze_string(body)])
with self.assertRaisesRegexp(ValueError, r'Elaborate.*mapped.*'): with self.assertRaisesRegex(ValueError, r'Elaborate.*mapped.*'):
generator.generate('Test', ['function'], 'sapi::Tests', None, None) generator.generate('Test', ['function'], 'sapi::Tests', None, None)
def testGetMappedType(self): def testGetMappedType(self):
@ -170,8 +174,7 @@ class CodeAnalysisTest(parameterized.TestCase):
'extern "C" int function(bool arg_bool, char* arg_ptr);', 'extern "C" int function(bool arg_bool, char* arg_ptr);',
['arg_bool', 'arg_ptr']), ['arg_bool', 'arg_ptr']),
('function without return value and no arguments', ('function without return value and no arguments',
'extern "C" void function();', 'extern "C" void function();', []),
[]),
) )
def testArgumentNames(self, body, names): def testArgumentNames(self, body, names):
generator = code.Generator([analyze_string(body)]) generator = code.Generator([analyze_string(body)])
@ -436,19 +439,19 @@ class CodeAnalysisTest(parameterized.TestCase):
types = args[0].get_related_types() types = args[0].get_related_types()
names = [t._clang_type.spelling for t in types] names = [t._clang_type.spelling for t in types]
self.assertLen(types, 4) self.assertLen(types, 4)
self.assertSameElements(names, ['struct_6p', 'struct_6', self.assertSameElements(
'struct struct_6_def', 'function_p3']) names, ['struct_6p', 'struct_6', 'struct struct_6_def', 'function_p3'])
self.assertLen(generator.translation_units, 1) self.assertLen(generator.translation_units, 1)
self.assertLen(generator.translation_units[0].forward_decls, 1) self.assertLen(generator.translation_units[0].forward_decls, 1)
t = next(x for x in types t = next(
if x._clang_type.spelling == 'struct struct_6_def') x for x in types if x._clang_type.spelling == 'struct struct_6_def')
self.assertIn(t, generator.translation_units[0].forward_decls) self.assertIn(t, generator.translation_units[0].forward_decls)
names = [t._clang_type.spelling for t in generator._get_related_types()] names = [t._clang_type.spelling for t in generator._get_related_types()]
self.assertEqual(names, ['struct_6', 'struct_6p', self.assertEqual(
'function_p3', 'struct struct_6_def']) names, ['struct_6', 'struct_6p', 'function_p3', 'struct struct_6_def'])
# Extra check for generation, in case rendering throws error for this test. # Extra check for generation, in case rendering throws error for this test.
forward_decls = generator._get_forward_decls(generator._get_related_types()) forward_decls = generator._get_forward_decls(generator._get_related_types())
@ -565,8 +568,8 @@ class CodeAnalysisTest(parameterized.TestCase):
typedef unsigned char uchar;""" typedef unsigned char uchar;"""
file3_code = 'typedef unsigned long ulong;' file3_code = 'typedef unsigned long ulong;'
file4_code = 'typedef char chr;' file4_code = 'typedef char chr;'
files = [('f1.h', file1_code), ('/f2.h', file2_code), files = [('f1.h', file1_code), ('/f2.h', file2_code), ('/f3.h', file3_code),
('/f3.h', file3_code), ('/f4.h', file4_code)] ('/f4.h', file4_code)]
generator = code.Generator([analyze_strings('f1.h', files)]) generator = code.Generator([analyze_strings('f1.h', files)])
functions = generator._get_functions() functions = generator._get_functions()
self.assertLen(functions, 1) self.assertLen(functions, 1)
@ -576,6 +579,25 @@ class CodeAnalysisTest(parameterized.TestCase):
# Extra check for generation, in case rendering throws error for this test. # Extra check for generation, in case rendering throws error for this test.
generator.generate('Test', [], 'sapi::Tests', None, None) generator.generate('Test', [], 'sapi::Tests', None, None)
def testFilterFunctionsFromInputFilesOnly(self):
file1_code = """
#include "/f2.h"
extern "C" int function1();
"""
file2_code = """
extern "C" int function2();
"""
files = [('f1.h', file1_code), ('/f2.h', file2_code)]
generator = code.Generator([analyze_strings('f1.h', files)])
functions = generator._get_functions()
self.assertLen(functions, 2)
generator = code.Generator([analyze_strings('f1.h', files, True)])
functions = generator._get_functions()
self.assertLen(functions, 1)
def testTypeToString(self): def testTypeToString(self):
body = """ body = """
#define SIZE 1024 #define SIZE 1024

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""SAPI interface header generator. """SAPI interface header generator.
Parses headers to extract type information from functions and generate a SAPI Parses headers to extract type information from functions and generate a SAPI
@ -34,6 +33,9 @@ flags.DEFINE_list('sapi_functions', [], 'function list to analyze')
flags.DEFINE_list('sapi_in', None, 'input files to analyze') flags.DEFINE_list('sapi_in', None, 'input files to analyze')
flags.DEFINE_string('sapi_embed_dir', None, 'directory with embed includes') flags.DEFINE_string('sapi_embed_dir', None, 'directory with embed includes')
flags.DEFINE_string('sapi_embed_name', None, 'name of the embed object') flags.DEFINE_string('sapi_embed_name', None, 'name of the embed object')
flags.DEFINE_bool(
'sapi_limit_scan_depth', False,
'scan only functions from top level file in compilation unit')
def extract_includes(path, array): def extract_includes(path, array):
@ -52,14 +54,12 @@ def main(c_flags):
c_flags.pop(0) c_flags.pop(0)
logging.debug(FLAGS.sapi_functions) logging.debug(FLAGS.sapi_functions)
extract_includes(FLAGS.sapi_isystem, c_flags) extract_includes(FLAGS.sapi_isystem, c_flags)
tus = code.Analyzer.process_files(FLAGS.sapi_in, c_flags) tus = code.Analyzer.process_files(FLAGS.sapi_in, c_flags,
FLAGS.sapi_limit_scan_depth)
generator = code.Generator(tus) generator = code.Generator(tus)
result = generator.generate(FLAGS.sapi_name, result = generator.generate(FLAGS.sapi_name, FLAGS.sapi_functions,
FLAGS.sapi_functions, FLAGS.sapi_ns, FLAGS.sapi_out,
FLAGS.sapi_ns, FLAGS.sapi_embed_dir, FLAGS.sapi_embed_name)
FLAGS.sapi_out,
FLAGS.sapi_embed_dir,
FLAGS.sapi_embed_name)
if FLAGS.sapi_out: if FLAGS.sapi_out:
with open(FLAGS.sapi_out, 'w') as out_file: with open(FLAGS.sapi_out, 'w') as out_file:
@ -67,6 +67,7 @@ def main(c_flags):
else: else:
sys.stdout.write(result) sys.stdout.write(result)
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flags_as_required(['sapi_name', 'sapi_in']) flags.mark_flags_as_required(['sapi_name', 'sapi_in'])
app.run(main) app.run(main)