macro fixes:

- made sure that define order is correct
- made sure to emit all defines related to target define
- fixed a bug where '(' was separated with macro name with space, this breaking the macro

PiperOrigin-RevId: 256129616
Change-Id: I636b13a72c6198fb59e8e387f42567c442b24352
This commit is contained in:
Maciej Szaw?owski 2019-07-02 02:57:57 -07:00 committed by Copybara-Service
parent 4e20e0702a
commit 9435f97538
2 changed files with 62 additions and 32 deletions

View File

@ -65,8 +65,8 @@ def get_header_guard(path):
return path + '_' return path + '_'
def _stringify_tokens(tokens, separator='\n', callbacks=None): def _stringify_tokens(tokens, separator='\n'):
# type: (Sequence[cindex.Token], Text, Dict[int, Callable]) -> Text # type: (Sequence[cindex.Token], Text) -> Text
"""Converts tokens to text respecting line position (disrespecting column).""" """Converts tokens to text respecting line position (disrespecting column)."""
previous = OutputLine(0, []) # not used in output previous = OutputLine(0, []) # not used in output
lines = [] # type: List[OutputLine] lines = [] # type: List[OutputLine]
@ -75,9 +75,6 @@ def _stringify_tokens(tokens, separator='\n', callbacks=None):
group_list = list(group) group_list = list(group)
line = OutputLine(previous.next_tab, group_list) line = OutputLine(previous.next_tab, group_list)
if callbacks and len(group_list) in callbacks:
callbacks[len(group_list)](group_list)
lines.append(line) lines.append(line)
previous = line previous = line
@ -298,6 +295,7 @@ class Type(object):
# skip unnamed structures eg. typedef struct {...} x; # skip unnamed structures eg. typedef struct {...} x;
# struct {...} will be rendered as part of typedef rendering # struct {...} will be rendered as part of typedef rendering
if self._get_declaration().spelling and not skip_self: if self._get_declaration().spelling and not skip_self:
self._tu.search_for_macro_name(self._get_declaration())
result.add(self) result.add(self)
for f in self._clang_type.get_fields(): for f in self._clang_type.get_fields():
@ -335,17 +333,8 @@ class Type(object):
# break things; this will go through clang format nevertheless # break things; this will go through clang format nevertheless
tokens = [x for x in self._get_declaration().get_tokens() tokens = [x for x in self._get_declaration().get_tokens()
if x.kind is not cindex.TokenKind.COMMENT] if x.kind is not cindex.TokenKind.COMMENT]
# look for lines with two tokens: a way of finding structures with
# body as a macro eg: return _stringify_tokens(tokens)
# #define BODY \
# int a; \
# int b;
# struct test {
# BODY;
# }
callbacks = {}
callbacks[2] = lambda x: self._tu.required_defines.add(x[0].spelling)
return _stringify_tokens(tokens, callbacks=callbacks)
class OutputLine(object): class OutputLine(object):
@ -354,19 +343,15 @@ class OutputLine(object):
def __init__(self, tab, tokens): def __init__(self, tab, tokens):
# type: (int, List[cindex.Token]) -> None # type: (int, List[cindex.Token]) -> None
self.tokens = tokens self.tokens = tokens
self.spellings = []
self.define = False self.define = False
self.tab = tab self.tab = tab
self.next_tab = tab self.next_tab = tab
map(self._process_token, self.tokens) list(map(self._process_token, self.tokens))
def append(self, t):
# type: (cindex.Token) -> None
"""Appends token to the line."""
self._process_token(t)
self.tokens.append(t)
def _process_token(self, t): def _process_token(self, t):
# type: (cindex.Token) -> None # type: (cindex.Token) -> None
"""Processes a token, setting up internal states rel. to intendation."""
if t.spelling == '#': if t.spelling == '#':
self.define = True self.define = True
elif t.spelling == '{': elif t.spelling == '{':
@ -375,11 +360,16 @@ class OutputLine(object):
self.tab -= 1 self.tab -= 1
self.next_tab -= 1 self.next_tab -= 1
is_bracket = t.spelling == '('
is_macro = len(self.spellings) == 1 and self.spellings[0] == '#'
if self.spellings and not is_bracket and not is_macro:
self.spellings.append(' ')
self.spellings.append(t.spelling)
def __str__(self): def __str__(self):
# type: () -> Text # type: () -> Text
tabs = ('\t'*self.tab) if not self.define else '' tabs = ('\t'*self.tab) if not self.define else ''
return tabs + ''.join(t for t in self.spellings)
return tabs + ' '.join(t.spelling for t in self.tokens)
class ArgumentType(Type): class ArgumentType(Type):
@ -587,6 +577,7 @@ class _TranslationUnit(object):
if (cursor.kind == cindex.CursorKind.MACRO_DEFINITION and if (cursor.kind == cindex.CursorKind.MACRO_DEFINITION and
cursor.location.file): cursor.location.file):
self.order[cursor.hash] = i
self.defines[cursor.spelling] = cursor self.defines[cursor.spelling] = cursor
# most likely a forward decl of struct # most likely a forward decl of struct
@ -609,15 +600,15 @@ class _TranslationUnit(object):
for c in self._tu.cursor.walk_preorder(): for c in self._tu.cursor.walk_preorder():
yield c yield c
# TODO(szwl): expand to look for macros in structs, unions etc.
def search_for_macro_name(self, cursor): def search_for_macro_name(self, cursor):
# type: (cindex.Cursor) -> None # type: (cindex.Cursor) -> None
"""Searches for possible macro usage in constant array types.""" """Searches for possible macro usage in constant array types."""
tokens = list(t.spelling for t in cursor.get_tokens()) tokens = list(t.spelling for t in cursor.get_tokens())
try: try:
for token in tokens: for token in tokens:
if token in self.defines: if token in self.defines and token not in self.required_defines:
self.required_defines.add(token) self.required_defines.add(token)
self.search_for_macro_name(self.defines[token])
except ValueError: except ValueError:
return return
@ -776,13 +767,20 @@ class Generator(object):
Returns: Returns:
list of #define string representations list of #define string representations
""" """
def make_sort_condition(translation_unit):
return lambda cursor: translation_unit.order[cursor.hash]
result = [] result = []
for tu in self.translation_units: for tu in self.translation_units:
tmp_result = []
sort_condition = make_sort_condition(tu)
for name in tu.required_defines: for name in tu.required_defines:
if name in tu.defines: if name in tu.defines:
define = tu.defines[name] define = tu.defines[name]
result.append('#define ' + _stringify_tokens(define.get_tokens(), tmp_result.append(define)
separator=' \\\n')) for define in sorted(tmp_result, key=sort_condition):
result.append('#define ' + _stringify_tokens(define.get_tokens(),
separator=' \\\n'))
return result return result
def _get_forward_decls(self, types): def _get_forward_decls(self, types):

View File

@ -600,11 +600,11 @@ class CodeAnalysisTest(parameterized.TestCase):
# pylint: disable=trailing-whitespace # pylint: disable=trailing-whitespace
expected = """typedef struct { expected = """typedef struct {
# if SOME_DEFINE >= 12 && SOME_OTHER == 13 #if SOME_DEFINE >= 12 && SOME_OTHER == 13
\tuint a ; \tuint a ;
# else #else
\tuint aa ; \tuint aa ;
# endif #endif
\tstruct { \tstruct {
\t\tuint a ; \t\tuint a ;
\t\tint b ; \t\tint b ;
@ -656,6 +656,38 @@ class CodeAnalysisTest(parameterized.TestCase):
# Extra check for generation, in case rendering throws error for this test. # Extra check for generation, in case rendering throws error for this test.
generator.generate('Test', [], 'sapi::Tests', None, None) generator.generate('Test', [], 'sapi::Tests', None, None)
def testYaraCase(self):
body = """
#define YR_ALIGN(n) __attribute__((aligned(n)))
#define DECLARE_REFERENCE(type, name) union { \
type name; \
int64_t name##_; \
} YR_ALIGN(8)
struct YR_NAMESPACE {
int32_t t_flags[1337];
DECLARE_REFERENCE(char*, name);
};
extern "C" int function_1(struct YR_NAMESPACE* a1);
"""
generator = code.Generator([analyze_string(body)])
self.assertLen(generator.translation_units, 1)
generator._get_related_types()
tu = generator.translation_units[0]
tu._process()
self.assertLen(tu.required_defines, 2)
defines = generator._get_defines()
# _get_defines will add dependant defines to tu.required_defines
self.assertLen(defines, 2)
gold = '#define DECLARE_REFERENCE('
# DECLARE_REFERENCE must be second to pass this test
self.assertTrue(defines[1].startswith(gold))
# Extra check for generation, in case rendering throws error for this test.
generator.generate('Test', [], 'sapi::Tests', None, None)
def testDoubleFunction(self): def testDoubleFunction(self):
body = """ body = """
extern "C" int function_1(int a); extern "C" int function_1(int a);