[Python-checkins] GH-98831: Implement basic cache effects (#99313)

gvanrossum webhook-mailer at python.org
Tue Nov 15 22:59:26 EST 2022


https://github.com/python/cpython/commit/e37744f289af00c6f6eba83f7abfb932b63de9e0
commit: e37744f289af00c6f6eba83f7abfb932b63de9e0
branch: main
author: Guido van Rossum <guido at python.org>
committer: gvanrossum <gvanrossum at gmail.com>
date: 2022-11-15T19:59:19-08:00
summary:

GH-98831: Implement basic cache effects (#99313)

files:
M Python/bytecodes.c
M Python/generated_cases.c.h
M Tools/cases_generator/generate_cases.py
M Tools/cases_generator/parser.py

diff --git a/Python/bytecodes.c b/Python/bytecodes.c
index 69ee741d5df0..1575b5390fb7 100644
--- a/Python/bytecodes.c
+++ b/Python/bytecodes.c
@@ -76,13 +76,9 @@ do { \
 #define NAME_ERROR_MSG \
     "name '%.200s' is not defined"
 
-typedef struct {
-    PyObject *kwnames;
-} CallShape;
-
 // Dummy variables for stack effects.
 static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub;
-static PyObject *container, *start, *stop, *v;
+static PyObject *container, *start, *stop, *v, *lhs, *rhs;
 
 static PyObject *
 dummy_func(
@@ -101,6 +97,8 @@ dummy_func(
     binaryfunc binary_ops[]
 )
 {
+    _PyInterpreterFrame  entry_frame;
+
     switch (opcode) {
 
 // BEGIN BYTECODES //
@@ -193,7 +191,21 @@ dummy_func(
             ERROR_IF(res == NULL, error);
         }
 
-        inst(BINARY_OP_MULTIPLY_INT, (left, right -- prod)) {
+        family(binary_op, INLINE_CACHE_ENTRIES_BINARY_OP) = {
+            BINARY_OP,
+            BINARY_OP_ADD_FLOAT,
+            BINARY_OP_ADD_INT,
+            BINARY_OP_ADD_UNICODE,
+            BINARY_OP_GENERIC,
+            // BINARY_OP_INPLACE_ADD_UNICODE,  // This is an odd duck.
+            BINARY_OP_MULTIPLY_FLOAT,
+            BINARY_OP_MULTIPLY_INT,
+            BINARY_OP_SUBTRACT_FLOAT,
+            BINARY_OP_SUBTRACT_INT,
+        };
+
+
+        inst(BINARY_OP_MULTIPLY_INT, (left, right, unused/1 -- prod)) {
             assert(cframe.use_tracing == 0);
             DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
             DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP);
@@ -202,10 +214,9 @@ dummy_func(
             _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
             _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
             ERROR_IF(prod == NULL, error);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
         }
 
-        inst(BINARY_OP_MULTIPLY_FLOAT, (left, right -- prod)) {
+        inst(BINARY_OP_MULTIPLY_FLOAT, (left, right, unused/1 -- prod)) {
             assert(cframe.use_tracing == 0);
             DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
             DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP);
@@ -216,10 +227,9 @@ dummy_func(
             _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
             _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
             ERROR_IF(prod == NULL, error);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
         }
 
-        inst(BINARY_OP_SUBTRACT_INT, (left, right -- sub)) {
+        inst(BINARY_OP_SUBTRACT_INT, (left, right, unused/1 -- sub)) {
             assert(cframe.use_tracing == 0);
             DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
             DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP);
@@ -228,10 +238,9 @@ dummy_func(
             _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
             _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
             ERROR_IF(sub == NULL, error);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
         }
 
-        inst(BINARY_OP_SUBTRACT_FLOAT, (left, right -- sub)) {
+        inst(BINARY_OP_SUBTRACT_FLOAT, (left, right, unused/1 -- sub)) {
             assert(cframe.use_tracing == 0);
             DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
             DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP);
@@ -241,10 +250,9 @@ dummy_func(
             _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
             _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
             ERROR_IF(sub == NULL, error);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
         }
 
-        inst(BINARY_OP_ADD_UNICODE, (left, right -- res)) {
+        inst(BINARY_OP_ADD_UNICODE, (left, right, unused/1 -- res)) {
             assert(cframe.use_tracing == 0);
             DEOPT_IF(!PyUnicode_CheckExact(left), BINARY_OP);
             DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@@ -253,7 +261,6 @@ dummy_func(
             _Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc);
             _Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
             ERROR_IF(res == NULL, error);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
         }
 
         // This is a subtle one. It's a super-instruction for
@@ -292,7 +299,7 @@ dummy_func(
             JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP + 1);
         }
 
-        inst(BINARY_OP_ADD_FLOAT, (left, right -- sum)) {
+        inst(BINARY_OP_ADD_FLOAT, (left, right, unused/1 -- sum)) {
             assert(cframe.use_tracing == 0);
             DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
             DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@@ -303,10 +310,9 @@ dummy_func(
             _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
             _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
             ERROR_IF(sum == NULL, error);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
         }
 
-        inst(BINARY_OP_ADD_INT, (left, right -- sum)) {
+        inst(BINARY_OP_ADD_INT, (left, right, unused/1 -- sum)) {
             assert(cframe.use_tracing == 0);
             DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
             DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@@ -315,7 +321,6 @@ dummy_func(
             _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
             _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
             ERROR_IF(sum == NULL, error);
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
         }
 
         inst(BINARY_SUBSCR, (container, sub -- res)) {
@@ -3691,30 +3696,21 @@ dummy_func(
             PUSH(Py_NewRef(peek));
         }
 
-        // stack effect: (__0 -- )
-        inst(BINARY_OP_GENERIC) {
-            PyObject *rhs = POP();
-            PyObject *lhs = TOP();
+        inst(BINARY_OP_GENERIC, (lhs, rhs, unused/1 -- res)) {
             assert(0 <= oparg);
             assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops));
             assert(binary_ops[oparg]);
-            PyObject *res = binary_ops[oparg](lhs, rhs);
+            res = binary_ops[oparg](lhs, rhs);
             Py_DECREF(lhs);
             Py_DECREF(rhs);
-            SET_TOP(res);
-            if (res == NULL) {
-                goto error;
-            }
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
+            ERROR_IF(res == NULL, error);
         }
 
-        // stack effect: (__0 -- )
-        inst(BINARY_OP) {
+        // This always dispatches, so the result is unused.
+        inst(BINARY_OP, (lhs, rhs, unused/1 -- unused)) {
             _PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr;
             if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
                 assert(cframe.use_tracing == 0);
-                PyObject *lhs = SECOND();
-                PyObject *rhs = TOP();
                 next_instr--;
                 _Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0));
                 DISPATCH_SAME_OPARG();
@@ -3761,13 +3757,8 @@ dummy_func(
     ;
 }
 
-// Families go below this point //
+// Future families go below this point //
 
-family(binary_op) = {
-    BINARY_OP, BINARY_OP_ADD_FLOAT,
-    BINARY_OP_ADD_INT, BINARY_OP_ADD_UNICODE, BINARY_OP_GENERIC, BINARY_OP_INPLACE_ADD_UNICODE,
-    BINARY_OP_MULTIPLY_FLOAT, BINARY_OP_MULTIPLY_INT, BINARY_OP_SUBTRACT_FLOAT,
-    BINARY_OP_SUBTRACT_INT };
 family(binary_subscr) = {
     BINARY_SUBSCR, BINARY_SUBSCR_DICT,
     BINARY_SUBSCR_GETITEM, BINARY_SUBSCR_LIST_INT, BINARY_SUBSCR_TUPLE_INT };
diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h
index 552d0e6d0178..b8bc66b14889 100644
--- a/Python/generated_cases.c.h
+++ b/Python/generated_cases.c.h
@@ -145,9 +145,9 @@
             _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
             _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
             if (prod == NULL) goto pop_2_error;
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
             STACK_SHRINK(1);
             POKE(1, prod);
+            next_instr += 1;
             DISPATCH();
         }
 
@@ -165,9 +165,9 @@
             _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
             _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
             if (prod == NULL) goto pop_2_error;
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
             STACK_SHRINK(1);
             POKE(1, prod);
+            next_instr += 1;
             DISPATCH();
         }
 
@@ -183,9 +183,9 @@
             _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
             _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
             if (sub == NULL) goto pop_2_error;
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
             STACK_SHRINK(1);
             POKE(1, sub);
+            next_instr += 1;
             DISPATCH();
         }
 
@@ -202,9 +202,9 @@
             _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
             _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
             if (sub == NULL) goto pop_2_error;
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
             STACK_SHRINK(1);
             POKE(1, sub);
+            next_instr += 1;
             DISPATCH();
         }
 
@@ -220,9 +220,9 @@
             _Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc);
             _Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
             if (res == NULL) goto pop_2_error;
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
             STACK_SHRINK(1);
             POKE(1, res);
+            next_instr += 1;
             DISPATCH();
         }
 
@@ -274,9 +274,9 @@
             _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
             _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
             if (sum == NULL) goto pop_2_error;
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
             STACK_SHRINK(1);
             POKE(1, sum);
+            next_instr += 1;
             DISPATCH();
         }
 
@@ -292,9 +292,9 @@
             _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
             _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
             if (sum == NULL) goto pop_2_error;
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
             STACK_SHRINK(1);
             POKE(1, sum);
+            next_instr += 1;
             DISPATCH();
         }
 
@@ -3703,29 +3703,30 @@
 
         TARGET(BINARY_OP_GENERIC) {
             PREDICTED(BINARY_OP_GENERIC);
-            PyObject *rhs = POP();
-            PyObject *lhs = TOP();
+            PyObject *rhs = PEEK(1);
+            PyObject *lhs = PEEK(2);
+            PyObject *res;
             assert(0 <= oparg);
             assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops));
             assert(binary_ops[oparg]);
-            PyObject *res = binary_ops[oparg](lhs, rhs);
+            res = binary_ops[oparg](lhs, rhs);
             Py_DECREF(lhs);
             Py_DECREF(rhs);
-            SET_TOP(res);
-            if (res == NULL) {
-                goto error;
-            }
-            JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
+            if (res == NULL) goto pop_2_error;
+            STACK_SHRINK(1);
+            POKE(1, res);
+            next_instr += 1;
             DISPATCH();
         }
 
         TARGET(BINARY_OP) {
             PREDICTED(BINARY_OP);
+            assert(INLINE_CACHE_ENTRIES_BINARY_OP == 1);
+            PyObject *rhs = PEEK(1);
+            PyObject *lhs = PEEK(2);
             _PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr;
             if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
                 assert(cframe.use_tracing == 0);
-                PyObject *lhs = SECOND();
-                PyObject *rhs = TOP();
                 next_instr--;
                 _Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0));
                 DISPATCH_SAME_OPARG();
diff --git a/Tools/cases_generator/generate_cases.py b/Tools/cases_generator/generate_cases.py
index b4f5f8f01dc1..d01653175091 100644
--- a/Tools/cases_generator/generate_cases.py
+++ b/Tools/cases_generator/generate_cases.py
@@ -18,7 +18,6 @@
 arg_parser = argparse.ArgumentParser()
 arg_parser.add_argument("-i", "--input", type=str, default="Python/bytecodes.c")
 arg_parser.add_argument("-o", "--output", type=str, default="Python/generated_cases.c.h")
-arg_parser.add_argument("-c", "--compare", action="store_true")
 arg_parser.add_argument("-q", "--quiet", action="store_true")
 
 
@@ -40,7 +39,6 @@ def parse_cases(
     families: list[parser.Family] = []
     while not psr.eof():
         if inst := psr.inst_def():
-            assert inst.block
             instrs.append(inst)
         elif sup := psr.super_def():
             supers.append(sup)
@@ -69,17 +67,45 @@ def always_exits(block: parser.Block) -> bool:
     return line.startswith(("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()"))
 
 
-def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0):
-    assert instr.block
+def find_cache_size(instr: InstDef, families: list[parser.Family]) -> str | None:
+    for family in families:
+        if instr.name == family.members[0]:
+            return family.size
+
+
+def write_instr(
+    instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0, cache_size: str | None = None
+) -> int:
+    # Returns cache offset
     if dedent < 0:
         indent += " " * -dedent
+    # Separate stack inputs from cache inputs
+    input_names: set[str] = set()
+    stack: list[parser.StackEffect] = []
+    cache: list[parser.CacheEffect] = []
+    for input in instr.inputs:
+        if isinstance(input, parser.StackEffect):
+            stack.append(input)
+            input_names.add(input.name)
+        else:
+            assert isinstance(input, parser.CacheEffect), input
+            cache.append(input)
+    outputs = instr.outputs
+    cache_offset = 0
+    for ceffect in cache:
+        if ceffect.name != "unused":
+            bits = ceffect.size * 16
+            f.write(f"{indent}    PyObject *{ceffect.name} = read{bits}(next_instr + {cache_offset});\n")
+        cache_offset += ceffect.size
+    if cache_size:
+        f.write(f"{indent}    assert({cache_size} == {cache_offset});\n")
     # TODO: Is it better to count forward or backward?
-    for i, input in enumerate(reversed(instr.inputs), 1):
-        f.write(f"{indent}    PyObject *{input} = PEEK({i});\n")
+    for i, effect in enumerate(reversed(stack), 1):
+        if effect.name != "unused":
+            f.write(f"{indent}    PyObject *{effect.name} = PEEK({i});\n")
     for output in instr.outputs:
-        if output not in instr.inputs:
-            f.write(f"{indent}    PyObject *{output};\n")
-    assert instr.block is not None
+        if output.name not in input_names and output.name != "unused":
+            f.write(f"{indent}    PyObject *{output.name};\n")
     blocklines = instr.block.to_text(dedent=dedent).splitlines(True)
     # Remove blank lines from ends
     while blocklines and not blocklines[0].strip():
@@ -95,7 +121,7 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
     while blocklines and not blocklines[-1].strip():
         blocklines.pop()
     # Write the body
-    ninputs = len(instr.inputs or ())
+    ninputs = len(stack)
     for line in blocklines:
         if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line):
             space, cond, label = m.groups()
@@ -107,46 +133,56 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
                 f.write(f"{space}if ({cond}) goto {label};\n")
         else:
             f.write(line)
-    noutputs = len(instr.outputs or ())
+    if always_exits(instr.block):
+        # None of the rest matters
+        return cache_offset
+    # Stack effect
+    noutputs = len(outputs)
     diff = noutputs - ninputs
     if diff > 0:
         f.write(f"{indent}    STACK_GROW({diff});\n")
     elif diff < 0:
         f.write(f"{indent}    STACK_SHRINK({-diff});\n")
-    for i, output in enumerate(reversed(instr.outputs or ()), 1):
-        if output not in (instr.inputs or ()):
-            f.write(f"{indent}    POKE({i}, {output});\n")
-    assert instr.block
-
-def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):
+    for i, output in enumerate(reversed(outputs), 1):
+        if output.name not in input_names and output.name != "unused":
+            f.write(f"{indent}    POKE({i}, {output.name});\n")
+    # Cache effect
+    if cache_offset:
+        f.write(f"{indent}    next_instr += {cache_offset};\n")
+    return cache_offset
+
+
+def write_cases(
+    f: TextIO, instrs: list[InstDef], supers: list[parser.Super], families: list[parser.Family]
+) -> dict[str, tuple[int, int, int]]:
     predictions: set[str] = set()
     for instr in instrs:
-        assert isinstance(instr, InstDef)
-        assert instr.block is not None
         for target in re.findall(RE_PREDICTED, instr.block.text):
             predictions.add(target)
     indent = "        "
     f.write(f"// This file is generated by {os.path.relpath(__file__)}\n")
     f.write(f"// Do not edit!\n")
     instr_index: dict[str, InstDef] = {}
+    effects_table: dict[str, tuple[int, int, int]] = {}  # name -> (ninputs, noutputs, cache_offset)
     for instr in instrs:
         instr_index[instr.name] = instr
         f.write(f"\n{indent}TARGET({instr.name}) {{\n")
         if instr.name in predictions:
             f.write(f"{indent}    PREDICTED({instr.name});\n")
-        write_instr(instr, predictions, indent, f)
-        assert instr.block
+        cache_offset = write_instr(
+            instr, predictions, indent, f,
+            cache_size=find_cache_size(instr, families)
+        )
+        effects_table[instr.name] = len(instr.inputs), len(instr.outputs), cache_offset
         if not always_exits(instr.block):
             f.write(f"{indent}    DISPATCH();\n")
         # Write trailing '}'
         f.write(f"{indent}}}\n")
 
     for sup in supers:
-        assert isinstance(sup, parser.Super)
         components = [instr_index[name] for name in sup.ops]
         f.write(f"\n{indent}TARGET({sup.name}) {{\n")
         for i, instr in enumerate(components):
-            assert instr.block
             if i > 0:
                 f.write(f"{indent}    NEXTOPARG();\n")
                 f.write(f"{indent}    next_instr++;\n")
@@ -156,6 +192,8 @@ def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):
         f.write(f"{indent}    DISPATCH();\n")
         f.write(f"{indent}}}\n")
 
+    return effects_table
+
 
 def main():
     args = arg_parser.parse_args()
@@ -176,12 +214,28 @@ def main():
             file=sys.stderr,
         )
     with eopen(args.output, "w") as f:
-        write_cases(f, instrs, supers)
+        effects_table = write_cases(f, instrs, supers, families)
     if not args.quiet:
         print(
             f"Wrote {ninstrs + nsupers} instructions to {args.output}",
             file=sys.stderr,
         )
+    # Check that families have consistent effects
+    errors = 0
+    for family in families:
+        head = effects_table[family.members[0]]
+        for member in family.members:
+            if effects_table[member] != head:
+                errors += 1
+                print(
+                    f"Family {family.name!r} has inconsistent effects (inputs, outputs, cache units):",
+                    file=sys.stderr,
+                )
+                print(
+                    f"  {family.members[0]} = {head}; {member} = {effects_table[member]}",
+                )
+    if errors:
+        sys.exit(1)
 
 
 if __name__ == "__main__":
diff --git a/Tools/cases_generator/parser.py b/Tools/cases_generator/parser.py
index 9e95cdb42d42..1f855312aeba 100644
--- a/Tools/cases_generator/parser.py
+++ b/Tools/cases_generator/parser.py
@@ -56,11 +56,28 @@ class Block(Node):
     tokens: list[lx.Token]
 
 
+ at dataclass
+class Effect(Node):
+    pass
+
+
+ at dataclass
+class StackEffect(Effect):
+    name: str
+    # TODO: type, condition
+
+
+ at dataclass
+class CacheEffect(Effect):
+    name: str
+    size: int
+
+
 @dataclass
 class InstHeader(Node):
     name: str
-    inputs: list[str]
-    outputs: list[str]
+    inputs: list[Effect]
+    outputs: list[Effect]
 
 
 @dataclass
@@ -69,16 +86,17 @@ class InstDef(Node):
     block: Block
 
     @property
-    def name(self):
+    def name(self) -> str:
         return self.header.name
 
     @property
-    def inputs(self):
+    def inputs(self) -> list[Effect]:
         return self.header.inputs
 
     @property
-    def outputs(self):
-        return self.header.outputs
+    def outputs(self) -> list[StackEffect]:
+        # This is always true
+        return [x for x in self.header.outputs if isinstance(x, StackEffect)]
 
 
 @dataclass
@@ -90,6 +108,7 @@ class Super(Node):
 @dataclass
 class Family(Node):
     name: str
+    size: str  # Variable giving the cache size in code units
     members: list[str]
 
 
@@ -123,18 +142,16 @@ def inst_header(self) -> InstHeader | None:
                     return InstHeader(name, [], [])
         return None
 
-    def check_overlaps(self, inp: list[str], outp: list[str]):
+    def check_overlaps(self, inp: list[Effect], outp: list[Effect]):
         for i, name in enumerate(inp):
-            try:
-                j = outp.index(name)
-            except ValueError:
-                continue
-            else:
-                if i != j:
-                    raise self.make_syntax_error(
-                        f"Input {name!r} at pos {i} repeated in output at different pos {j}")
+            for j, name2 in enumerate(outp):
+                if name == name2:
+                    if i != j:
+                        raise self.make_syntax_error(
+                            f"Input {name!r} at pos {i} repeated in output at different pos {j}")
+                    break
 
-    def stack_effect(self) -> tuple[list[str], list[str]]:
+    def stack_effect(self) -> tuple[list[Effect], list[Effect]]:
         # '(' [inputs] '--' [outputs] ')'
         if self.expect(lx.LPAREN):
             inp = self.inputs() or []
@@ -144,8 +161,8 @@ def stack_effect(self) -> tuple[list[str], list[str]]:
                     return inp, outp
         raise self.make_syntax_error("Expected stack effect")
 
-    def inputs(self) -> list[str] | None:
-        # input (, input)*
+    def inputs(self) -> list[Effect] | None:
+        # input (',' input)*
         here = self.getpos()
         if inp := self.input():
             near = self.getpos()
@@ -157,27 +174,25 @@ def inputs(self) -> list[str] | None:
         self.setpos(here)
         return None
 
-    def input(self) -> str | None:
-        # IDENTIFIER
+    @contextual
+    def input(self) -> Effect | None:
+        # IDENTIFIER '/' INTEGER (CacheEffect)
+        # IDENTIFIER (StackEffect)
         if (tkn := self.expect(lx.IDENTIFIER)):
-            if self.expect(lx.LBRACKET):
-                if arg := self.expect(lx.IDENTIFIER):
-                    if self.expect(lx.RBRACKET):
-                        return f"{tkn.text}[{arg.text}]"
-                    if self.expect(lx.TIMES):
-                        if num := self.expect(lx.NUMBER):
-                            if self.expect(lx.RBRACKET):
-                                return f"{tkn.text}[{arg.text}*{num.text}]"
-                raise self.make_syntax_error("Expected argument in brackets", tkn)
-
-            return tkn.text
-        if self.expect(lx.CONDOP):
-            while self.expect(lx.CONDOP):
-                pass
-            return "??"
-        return None
+            if self.expect(lx.DIVIDE):
+                if num := self.expect(lx.NUMBER):
+                    try:
+                        size = int(num.text)
+                    except ValueError:
+                        raise self.make_syntax_error(
+                            f"Expected integer, got {num.text!r}")
+                    else:
+                        return CacheEffect(tkn.text, size)
+                raise self.make_syntax_error("Expected integer")
+            else:
+                return StackEffect(tkn.text)
 
-    def outputs(self) -> list[str] | None:
+    def outputs(self) -> list[Effect] | None:
         # output (, output)*
         here = self.getpos()
         if outp := self.output():
@@ -190,8 +205,10 @@ def outputs(self) -> list[str] | None:
         self.setpos(here)
         return None
 
-    def output(self) -> str | None:
-        return self.input()  # TODO: They're not quite the same.
+    @contextual
+    def output(self) -> Effect | None:
+        if (tkn := self.expect(lx.IDENTIFIER)):
+            return StackEffect(tkn.text)
 
     @contextual
     def super_def(self) -> Super | None:
@@ -216,24 +233,35 @@ def ops(self) -> list[str] | None:
     @contextual
     def family_def(self) -> Family | None:
         if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family":
+            size = None
             if self.expect(lx.LPAREN):
                 if (tkn := self.expect(lx.IDENTIFIER)):
+                    if self.expect(lx.COMMA):
+                        if not (size := self.expect(lx.IDENTIFIER)):
+                            raise self.make_syntax_error(
+                                "Expected identifier")
                     if self.expect(lx.RPAREN):
                         if self.expect(lx.EQUALS):
+                            if not self.expect(lx.LBRACE):
+                                raise self.make_syntax_error("Expected {")
                             if members := self.members():
-                                if self.expect(lx.SEMI):
-                                    return Family(tkn.text, members)
+                                if self.expect(lx.RBRACE) and self.expect(lx.SEMI):
+                                    return Family(tkn.text, size.text if size else "", members)
         return None
 
     def members(self) -> list[str] | None:
         here = self.getpos()
         if tkn := self.expect(lx.IDENTIFIER):
-            near = self.getpos()
-            if self.expect(lx.COMMA):
-                if rest := self.members():
-                    return [tkn.text] + rest
-            self.setpos(near)
-            return [tkn.text]
+            members = [tkn.text]
+            while self.expect(lx.COMMA):
+                if tkn := self.expect(lx.IDENTIFIER):
+                    members.append(tkn.text)
+                else:
+                    break
+            peek = self.peek()
+            if not peek or peek.kind != lx.RBRACE:
+                raise self.make_syntax_error("Expected comma or right paren")
+            return members
         self.setpos(here)
         return None
 
@@ -274,5 +302,5 @@ def c_blob(self) -> list[lx.Token]:
         filename = None
         src = "if (x) { x.foo; // comment\n}"
     parser = Parser(src, filename)
-    x = parser.inst_def()
+    x = parser.inst_def() or parser.super_def() or parser.family_def()
     print(x)



More information about the Python-checkins mailing list