| 1 | """
|
| 2 | const_pass.py - AST pass that collects string constants.
|
| 3 |
|
| 4 | Instead of emitting a dynamic allocation StrFromC("foo"), we emit a
|
| 5 | GLOBAL_STR(str99, "foo"), and then a reference to str99.
|
| 6 | """
|
| 7 | import collections
|
| 8 | import json
|
| 9 | import hashlib
|
| 10 | import string
|
| 11 |
|
| 12 | from mypy.nodes import (Expression, StrExpr, CallExpr)
|
| 13 | from mypy.types import Type
|
| 14 |
|
| 15 | from mycpp import format_strings
|
| 16 | from mycpp import util
|
| 17 | from mycpp.util import log
|
| 18 | from mycpp import visitor
|
| 19 |
|
| 20 | from typing import Dict, List, Tuple, Counter, TextIO, Union
|
| 21 |
|
| 22 | _ = log
|
| 23 |
|
| 24 | _ALPHABET = string.ascii_lowercase + string.ascii_uppercase
|
| 25 | _ALPHABET = _ALPHABET[:32]
|
| 26 |
|
| 27 | AllStrings = Dict[Union[int, StrExpr], str] # Node -> raw string
|
| 28 | UniqueStrings = Dict[bytes, str] # SHA1 digest -> raw string
|
| 29 | HashedStrings = Dict[str, List[str]] # short hash -> raw string
|
| 30 | VarNames = Dict[str, str] # raw string -> variable name
|
| 31 |
|
| 32 |
|
| 33 | class GlobalStrings:
|
| 34 |
|
| 35 | def __init__(self) -> None:
|
| 36 | # SHA1 hash -> encoded bytes
|
| 37 | self.all_strings: AllStrings = {}
|
| 38 | self.var_names: VarNames = {}
|
| 39 |
|
| 40 | # OLD
|
| 41 | self.unique: Dict[bytes, bytes] = {}
|
| 42 | self.int_id_lookup: Dict[Expression, str] = {}
|
| 43 | self.pairs: List[Tuple[str, str]] = []
|
| 44 |
|
| 45 | def Add(self, key: Union[int, StrExpr], s: str) -> None:
|
| 46 | """
|
| 47 | key: int for tests
|
| 48 | StrExpr node for production
|
| 49 | """
|
| 50 | self.all_strings[key] = s
|
| 51 |
|
| 52 | def ComputeStableVarNames(self) -> None:
|
| 53 | unique = _MakeUniqueStrings(self.all_strings)
|
| 54 | hash15 = _HashAndCollect(unique)
|
| 55 | self.var_names = _HandleCollisions(hash15)
|
| 56 |
|
| 57 | def GetVarName(self, node: StrExpr) -> str:
|
| 58 | # StrExpr -> str -> variable names
|
| 59 | return self.var_names[self.all_strings[node]]
|
| 60 |
|
| 61 | def WriteConstants(self, out_f: TextIO) -> None:
|
| 62 | if util.SMALL_STR:
|
| 63 | macro_name = 'GLOBAL_STR2'
|
| 64 | else:
|
| 65 | macro_name = 'GLOBAL_STR'
|
| 66 |
|
| 67 | # sort by the string value itself
|
| 68 | for raw_string in sorted(self.var_names):
|
| 69 | var_name = self.var_names[raw_string]
|
| 70 | out_f.write('%s(%s, %s);\n' %
|
| 71 | (macro_name, var_name, json.dumps(raw_string)))
|
| 72 |
|
| 73 | out_f.write('\n')
|
| 74 |
|
| 75 |
|
| 76 | class Collect(visitor.TypedVisitor):
|
| 77 |
|
| 78 | def __init__(self, types: Dict[Expression, Type],
|
| 79 | global_strings: GlobalStrings) -> None:
|
| 80 | visitor.TypedVisitor.__init__(self, types)
|
| 81 | self.global_strings = global_strings
|
| 82 |
|
| 83 | # Only generate unique strings.
|
| 84 | # Before this optimization, _gen/bin/oils_for_unix.mycpp.cc went up to:
|
| 85 | # "str2824"
|
| 86 | # After:
|
| 87 | # "str1789"
|
| 88 | #
|
| 89 | # So it saved over 1000 strings.
|
| 90 | #
|
| 91 | # The C++ compiler should also optimize it, but it's easy for us to
|
| 92 | # generate less source code.
|
| 93 |
|
| 94 | # unique string value -> id
|
| 95 | self.unique: Dict[str, str] = {}
|
| 96 | self.unique_id = 0
|
| 97 |
|
| 98 | def oils_visit_format_expr(self, left: Expression,
|
| 99 | right: Expression) -> None:
|
| 100 | if isinstance(left, StrExpr):
|
| 101 | # Do NOT visit the left, because we write it literally
|
| 102 | pass
|
| 103 | else:
|
| 104 | self.accept(left)
|
| 105 | self.accept(right)
|
| 106 |
|
| 107 | def visit_str_expr(self, o: StrExpr) -> None:
|
| 108 | raw_string = format_strings.DecodeMyPyString(o.value)
|
| 109 | self.global_strings.Add(o, raw_string)
|
| 110 |
|
| 111 | def oils_visit_probe_call(self, o: CallExpr) -> None:
|
| 112 | # Don't generate constants for DTRACE_PROBE()
|
| 113 | pass
|
| 114 |
|
| 115 | def oils_visit_log_call(self, fmt: StrExpr,
|
| 116 | args: List[Expression]) -> None:
|
| 117 | if len(args) == 0:
|
| 118 | self.accept(fmt)
|
| 119 | return
|
| 120 |
|
| 121 | # Don't generate a string constant for the format string, which is an
|
| 122 | # inlined C string, not a mycpp GC string
|
| 123 | for i, arg in enumerate(args):
|
| 124 | self.accept(arg)
|
| 125 |
|
| 126 |
|
| 127 | def _MakeUniqueStrings(all_strings: AllStrings) -> UniqueStrings:
|
| 128 | """
|
| 129 | Given all the strings, make a smaller set of unique strings.
|
| 130 | """
|
| 131 | unique: UniqueStrings = {}
|
| 132 | for _, raw_string in all_strings.items():
|
| 133 | b = raw_string.encode('utf-8')
|
| 134 | h = hashlib.sha1(b).digest()
|
| 135 | #print(repr(h))
|
| 136 |
|
| 137 | if h in unique:
|
| 138 | # extremely unlikely
|
| 139 | assert unique[h] == raw_string, ("SHA1 hash collision! %r and %r" %
|
| 140 | (unique[h], b))
|
| 141 | unique[h] = raw_string
|
| 142 | return unique
|
| 143 |
|
| 144 |
|
| 145 | def _ShortHash15(h: bytes) -> str:
|
| 146 | """
|
| 147 | Given a SHA1, create a 15 bit hash value.
|
| 148 |
|
| 149 | We use three base-(2**5) aka base-32 digits, encoded as letters.
|
| 150 | """
|
| 151 | bits16 = h[0] | h[1] << 8
|
| 152 |
|
| 153 | assert 0 <= bits16 < 2**16, bits16
|
| 154 |
|
| 155 | # 5 least significant bits
|
| 156 | d1 = bits16 & 0b11111
|
| 157 | bits16 >>= 5
|
| 158 | d2 = bits16 & 0b11111
|
| 159 | bits16 >>= 5
|
| 160 | d3 = bits16 & 0b11111
|
| 161 | bits16 >>= 5
|
| 162 |
|
| 163 | return _ALPHABET[d1] + _ALPHABET[d2] + _ALPHABET[d3]
|
| 164 |
|
| 165 |
|
| 166 | def _HashAndCollect(unique: UniqueStrings) -> HashedStrings:
|
| 167 | """
|
| 168 | Use the short hash.
|
| 169 | """
|
| 170 | hash15 = collections.defaultdict(list)
|
| 171 | for sha1, b in unique.items():
|
| 172 | short_hash = _ShortHash15(sha1)
|
| 173 | hash15[short_hash].append(b)
|
| 174 | return hash15
|
| 175 |
|
| 176 |
|
| 177 | def _SummarizeCollisions(hash15: HashedStrings) -> None:
|
| 178 | collisions: Counter[int] = collections.Counter()
|
| 179 | for short_hash, strs in hash15.items():
|
| 180 | n = len(strs)
|
| 181 | #if n > 1:
|
| 182 | if 0:
|
| 183 | print(short_hash)
|
| 184 | print(strs)
|
| 185 | collisions[n] += 1
|
| 186 |
|
| 187 | log('%10s %s', 'COUNT', 'ITEM')
|
| 188 | for item, count in collisions.most_common():
|
| 189 | log('%10d %s', count, item)
|
| 190 |
|
| 191 |
|
| 192 | def _HandleCollisions(hash15: HashedStrings) -> VarNames:
|
| 193 | var_names: VarNames = {}
|
| 194 | for short_hash, bytes_list in hash15.items():
|
| 195 | bytes_list.sort() # stable order, will bump some of the strings
|
| 196 | for i, b in enumerate(bytes_list):
|
| 197 | if i == 0:
|
| 198 | var_names[b] = 'S_%s' % short_hash
|
| 199 | else:
|
| 200 | var_names[b] = 'S_%s_%d' % (short_hash, i)
|
| 201 | return var_names
|
| 202 |
|
| 203 |
|
| 204 | def HashDemo() -> None:
|
| 205 | import sys
|
| 206 |
|
| 207 | # 5 bits
|
| 208 | #_ALPHABET = _ALPHABET.replace('l', 'Z') # use a nicer one?
|
| 209 | log('alpha %r', _ALPHABET)
|
| 210 |
|
| 211 | global_strings = GlobalStrings()
|
| 212 |
|
| 213 | all_lines = sys.stdin.readlines()
|
| 214 | for i, line in enumerate(all_lines):
|
| 215 | global_strings.Add(i, line.strip())
|
| 216 |
|
| 217 | unique = _MakeUniqueStrings(global_strings.all_strings)
|
| 218 | hash15 = _HashAndCollect(unique)
|
| 219 | var_names = _HandleCollisions(hash15)
|
| 220 |
|
| 221 | if 0:
|
| 222 | for b, var_name in var_names.items():
|
| 223 | if var_name[-1].isdigit():
|
| 224 | log('%r %r', var_name, b)
|
| 225 | #log('%r %r', var_name, b)
|
| 226 |
|
| 227 | log('Unique %d' % len(unique))
|
| 228 | log('hash15 %d' % len(hash15))
|
| 229 |
|
| 230 | _SummarizeCollisions(hash15)
|
| 231 |
|
| 232 |
|
| 233 | if __name__ == '__main__':
|
| 234 | HashDemo()
|