summaryrefslogtreecommitdiff
path: root/passes/pmgen/pmgen.py
diff options
context:
space:
mode:
Diffstat (limited to 'passes/pmgen/pmgen.py')
-rw-r--r--passes/pmgen/pmgen.py263
1 files changed, 176 insertions, 87 deletions
diff --git a/passes/pmgen/pmgen.py b/passes/pmgen/pmgen.py
index d9747b06..81052afc 100644
--- a/passes/pmgen/pmgen.py
+++ b/passes/pmgen/pmgen.py
@@ -3,15 +3,42 @@
import re
import sys
import pprint
+import getopt
pp = pprint.PrettyPrinter(indent=4)
-pmgfile = sys.argv[1]
-assert pmgfile.endswith(".pmg")
-prefix = pmgfile[0:-4]
-prefix = prefix.split('/')[-1]
-outfile = sys.argv[2]
-
+prefix = None
+pmgfiles = list()
+outfile = None
+debug = False
+genhdr = False
+
+opts, args = getopt.getopt(sys.argv[1:], "p:o:dg")
+
+for o, a in opts:
+ if o == "-p":
+ prefix = a
+ elif o == "-o":
+ outfile = a
+ elif o == "-d":
+ debug = True
+ elif o == "-g":
+ genhdr = True
+
+if outfile is None:
+ outfile = "/dev/stdout"
+
+for a in args:
+ assert a.endswith(".pmg")
+ if prefix is None and len(args) == 1:
+ prefix = a[0:-4]
+ prefix = prefix.split('/')[-1]
+ pmgfiles.append(a)
+
+assert prefix is not None
+
+current_pattern = None
+patterns = dict()
state_types = dict()
udata_types = dict()
blocks = list()
@@ -77,7 +104,8 @@ def rewrite_cpp(s):
return "".join(t)
-with open(pmgfile, "r") as f:
+def process_pmgfile(f):
+ global current_pattern
while True:
line = f.readline()
if line == "": break
@@ -87,14 +115,31 @@ with open(pmgfile, "r") as f:
if len(cmd) == 0 or cmd[0].startswith("//"): continue
cmd = cmd[0]
+ if cmd == "pattern":
+ if current_pattern is not None:
+ block = dict()
+ block["type"] = "final"
+ block["pattern"] = current_pattern
+ blocks.append(block)
+ line = line.split()
+ assert len(line) == 2
+ assert line[1] not in patterns
+ current_pattern = line[1]
+ patterns[current_pattern] = len(blocks)
+ state_types[current_pattern] = dict()
+ udata_types[current_pattern] = dict()
+ continue
+
+ assert current_pattern is not None
+
if cmd == "state":
m = re.match(r"^state\s+<(.*?)>\s+(([A-Za-z_][A-Za-z_0-9]*\s+)*[A-Za-z_][A-Za-z_0-9]*)\s*$", line)
assert m
type_str = m.group(1)
states_str = m.group(2)
for s in re.split(r"\s+", states_str):
- assert s not in state_types
- state_types[s] = type_str
+ assert s not in state_types[current_pattern]
+ state_types[current_pattern][s] = type_str
continue
if cmd == "udata":
@@ -103,19 +148,20 @@ with open(pmgfile, "r") as f:
type_str = m.group(1)
udatas_str = m.group(2)
for s in re.split(r"\s+", udatas_str):
- assert s not in udata_types
- udata_types[s] = type_str
+ assert s not in udata_types[current_pattern]
+ udata_types[current_pattern][s] = type_str
continue
if cmd == "match":
block = dict()
block["type"] = "match"
+ block["pattern"] = current_pattern
line = line.split()
assert len(line) == 2
- assert line[1] not in state_types
+ assert line[1] not in state_types[current_pattern]
block["cell"] = line[1]
- state_types[line[1]] = "Cell*";
+ state_types[current_pattern][line[1]] = "Cell*";
block["if"] = list()
block["select"] = list()
@@ -158,15 +204,18 @@ with open(pmgfile, "r") as f:
assert False
blocks.append(block)
+ continue
if cmd == "code":
block = dict()
block["type"] = "code"
+ block["pattern"] = current_pattern
+
block["code"] = list()
block["states"] = set()
for s in line.split()[1:]:
- assert s in state_types
+ assert s in state_types[current_pattern]
block["states"].add(s)
while True:
@@ -179,18 +228,37 @@ with open(pmgfile, "r") as f:
block["code"].append(rewrite_cpp(l.rstrip()))
blocks.append(block)
+ continue
-with open(outfile, "w") as f:
- print("// Generated by pmgen.py from {}.pgm".format(prefix), file=f)
- print("", file=f)
+ assert False
- print("#include \"kernel/yosys.h\"", file=f)
- print("#include \"kernel/sigtools.h\"", file=f)
- print("", file=f)
+for fn in pmgfiles:
+ with open(fn, "r") as f:
+ process_pmgfile(f)
+
+if current_pattern is not None:
+ block = dict()
+ block["type"] = "final"
+ block["pattern"] = current_pattern
+ blocks.append(block)
+
+current_pattern = None
+
+if debug:
+ pp.pprint(blocks)
- print("YOSYS_NAMESPACE_BEGIN", file=f)
+with open(outfile, "w") as f:
+ for fn in pmgfiles:
+ print("// Generated by pmgen.py from {}".format(fn), file=f)
print("", file=f)
+ if genhdr:
+ print("#include \"kernel/yosys.h\"", file=f)
+ print("#include \"kernel/sigtools.h\"", file=f)
+ print("", file=f)
+ print("YOSYS_NAMESPACE_BEGIN", file=f)
+ print("", file=f)
+
print("struct {}_pm {{".format(prefix), file=f)
print(" Module *module;", file=f)
print(" SigMap sigmap;", file=f)
@@ -212,17 +280,19 @@ with open(outfile, "w") as f:
print(" int rollback;", file=f)
print("", file=f)
- print(" struct state_t {", file=f)
- for s, t in sorted(state_types.items()):
- print(" {} {};".format(t, s), file=f)
- print(" } st;", file=f)
- print("", file=f)
+ for current_pattern in sorted(patterns.keys()):
+ print(" struct state_{}_t {{".format(current_pattern), file=f)
+ for s, t in sorted(state_types[current_pattern].items()):
+ print(" {} {};".format(t, s), file=f)
+ print(" }} st_{};".format(current_pattern), file=f)
+ print("", file=f)
- print(" struct udata_t {", file=f)
- for s, t in sorted(udata_types.items()):
- print(" {} {};".format(t, s), file=f)
- print(" } ud;", file=f)
- print("", file=f)
+ print(" struct udata_{}_t {{".format(current_pattern), file=f)
+ for s, t in sorted(udata_types[current_pattern].items()):
+ print(" {} {};".format(t, s), file=f)
+ print(" }} ud_{};".format(current_pattern), file=f)
+ print("", file=f)
+ current_pattern = None
for v, n in sorted(ids.items()):
if n[0] == "\\":
@@ -258,20 +328,24 @@ with open(outfile, "w") as f:
print(" }", file=f)
print("", file=f)
- print(" void check_blacklist() {", file=f)
- print(" if (!blacklist_dirty)", file=f)
- print(" return;", file=f)
- print(" blacklist_dirty = false;", file=f)
- for index in range(len(blocks)):
- block = blocks[index]
- if block["type"] == "match":
- print(" if (st.{} != nullptr && blacklist_cells.count(st.{})) {{".format(block["cell"], block["cell"]), file=f)
- print(" rollback = {};".format(index+1), file=f)
- print(" return;", file=f)
- print(" }", file=f)
- print(" rollback = 0;", file=f)
- print(" }", file=f)
- print("", file=f)
+ for current_pattern in sorted(patterns.keys()):
+ print(" void check_blacklist_{}() {{".format(current_pattern), file=f)
+ print(" if (!blacklist_dirty)", file=f)
+ print(" return;", file=f)
+ print(" blacklist_dirty = false;", file=f)
+ for index in range(len(blocks)):
+ block = blocks[index]
+ if block["pattern"] != current_pattern:
+ continue
+ if block["type"] == "match":
+ print(" if (st_{}.{} != nullptr && blacklist_cells.count(st_{}.{})) {{".format(current_pattern, block["cell"], current_pattern, block["cell"]), file=f)
+ print(" rollback = {};".format(index+1), file=f)
+ print(" return;", file=f)
+ print(" }", file=f)
+ print(" rollback = 0;", file=f)
+ print(" }", file=f)
+ print("", file=f)
+ current_pattern = None
print(" SigSpec port(Cell *cell, IdString portname) {", file=f)
print(" return sigmap(cell->getPort(portname));", file=f)
@@ -294,11 +368,13 @@ with open(outfile, "w") as f:
print(" {}_pm(Module *module, const vector<Cell*> &cells) :".format(prefix), file=f)
print(" module(module), sigmap(module) {", file=f)
- for s, t in sorted(udata_types.items()):
- if t.endswith("*"):
- print(" ud.{} = nullptr;".format(s), file=f)
- else:
- print(" ud.{} = {}();".format(s, t), file=f)
+ for current_pattern in sorted(patterns.keys()):
+ for s, t in sorted(udata_types[current_pattern].items()):
+ if t.endswith("*"):
+ print(" ud_{}.{} = nullptr;".format(current_pattern,s), file=f)
+ else:
+ print(" ud_{}.{} = {}();".format(current_pattern, s, t), file=f)
+ current_pattern = None
print(" for (auto cell : module->cells()) {", file=f)
print(" for (auto &conn : cell->connections())", file=f)
print(" add_siguser(conn.second, cell);", file=f)
@@ -328,34 +404,52 @@ with open(outfile, "w") as f:
print(" }", file=f)
print("", file=f)
- print(" void run(std::function<void()> on_accept_f) {", file=f)
- print(" on_accept = on_accept_f;", file=f)
- print(" rollback = 0;", file=f)
- print(" blacklist_dirty = false;", file=f)
- for s, t in sorted(state_types.items()):
- if t.endswith("*"):
- print(" st.{} = nullptr;".format(s), file=f)
- else:
- print(" st.{} = {}();".format(s, t), file=f)
- print(" block_0();", file=f)
- print(" }", file=f)
- print("", file=f)
-
- print(" void run(std::function<void({}_pm&)> on_accept_f) {{".format(prefix), file=f)
- print(" run([&](){on_accept_f(*this);});", file=f)
- print(" }", file=f)
- print("", file=f)
+ for current_pattern in sorted(patterns.keys()):
+ print(" void run_{}(std::function<void()> on_accept_f) {{".format(current_pattern), file=f)
+ print(" on_accept = on_accept_f;", file=f)
+ print(" rollback = 0;", file=f)
+ print(" blacklist_dirty = false;", file=f)
+ for s, t in sorted(state_types[current_pattern].items()):
+ if t.endswith("*"):
+ print(" st_{}.{} = nullptr;".format(current_pattern, s), file=f)
+ else:
+ print(" st_{}.{} = {}();".format(current_pattern, s, t), file=f)
+ print(" block_{}();".format(patterns[current_pattern]), file=f)
+ print(" }", file=f)
+ print("", file=f)
+ print(" void run_{}(std::function<void({}_pm&)> on_accept_f) {{".format(current_pattern, prefix), file=f)
+ print(" run_{}([&](){{on_accept_f(*this);}});".format(current_pattern), file=f)
+ print(" }", file=f)
+ print("", file=f)
+ print(" void run_{}(std::function<void(state_{}_t&)> on_accept_f) {{".format(current_pattern, current_pattern), file=f)
+ print(" run_{}([&](){{on_accept_f(st_{});}});".format(current_pattern, current_pattern), file=f)
+ print(" }", file=f)
+ print("", file=f)
+ print(" void run_{}() {{".format(current_pattern), file=f)
+ print(" run_{}([](){{}});".format(current_pattern, current_pattern), file=f)
+ print(" }", file=f)
+ print("", file=f)
+ current_pattern = None
for index in range(len(blocks)):
block = blocks[index]
print(" void block_{}() {{".format(index), file=f)
+ current_pattern = block["pattern"]
+
+ if block["type"] == "final":
+ print(" on_accept();", file=f)
+ print(" check_blacklist_{}();".format(current_pattern), file=f)
+ print(" }", file=f)
+ if index+1 != len(blocks):
+ print("", file=f)
+ continue
const_st = set()
nonconst_st = set()
restore_st = set()
- for i in range(index):
+ for i in range(patterns[current_pattern], index):
if blocks[i]["type"] == "code":
for s in blocks[i]["states"]:
const_st.add(s)
@@ -378,27 +472,27 @@ with open(outfile, "w") as f:
assert False
for s in sorted(const_st):
- t = state_types[s]
+ t = state_types[current_pattern][s]
if t.endswith("*"):
- print(" {} const &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f)
+ print(" {} const &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)
else:
- print(" const {} &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f)
+ print(" const {} &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)
for s in sorted(nonconst_st):
- t = state_types[s]
- print(" {} &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f)
+ t = state_types[current_pattern][s]
+ print(" {} &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)
if len(restore_st):
print("", file=f)
for s in sorted(restore_st):
- t = state_types[s]
+ t = state_types[current_pattern][s]
print(" {} backup_{} = {};".format(t, s, s), file=f)
if block["type"] == "code":
print("", file=f)
print(" do {", file=f)
- print("#define reject do { check_blacklist(); goto rollback_label; } while(0)", file=f)
- print("#define accept do { on_accept(); check_blacklist(); if (rollback) goto rollback_label; } while(0)", file=f)
+ print("#define reject do {{ check_blacklist_{}(); goto rollback_label; }} while(0)".format(current_pattern), file=f)
+ print("#define accept do {{ on_accept(); check_blacklist_{}(); if (rollback) goto rollback_label; }} while(0)".format(current_pattern), file=f)
print("#define branch do {{ block_{}(); if (rollback) goto rollback_label; }} while(0)".format(index+1), file=f)
for line in block["code"]:
@@ -417,11 +511,11 @@ with open(outfile, "w") as f:
if len(restore_st) or len(nonconst_st):
print("", file=f)
for s in sorted(restore_st):
- t = state_types[s]
+ t = state_types[current_pattern][s]
print(" {} = backup_{};".format(s, s), file=f)
for s in sorted(nonconst_st):
if s not in restore_st:
- t = state_types[s]
+ t = state_types[current_pattern][s]
if t.endswith("*"):
print(" {} = nullptr;".format(s), file=f)
else:
@@ -470,17 +564,12 @@ with open(outfile, "w") as f:
else:
assert False
-
+ current_pattern = None
print(" }", file=f)
print("", file=f)
- print(" void block_{}() {{".format(len(blocks)), file=f)
- print(" on_accept();", file=f)
- print(" check_blacklist();", file=f)
- print(" }", file=f)
print("};", file=f)
- print("", file=f)
- print("YOSYS_NAMESPACE_END", file=f)
-
-# pp.pprint(blocks)
+ if genhdr:
+ print("", file=f)
+ print("YOSYS_NAMESPACE_END", file=f)