Merge pull request #1432 from YosysHQ/eddie/fix1427

Refactor peepopt_dffmux and be sensitive to \init when trimming
diff --git a/passes/pmgen/peepopt_dffmux.pmg b/passes/pmgen/peepopt_dffmux.pmg
index c88a522..bfd155c 100644
--- a/passes/pmgen/peepopt_dffmux.pmg
+++ b/passes/pmgen/peepopt_dffmux.pmg
@@ -8,21 +8,23 @@
 	select GetSize(port(dff, \D)) > 1
 endmatch
 
+code sigD
+	sigD = port(dff, \D);
+endcode
+
 match rstmux
 	select rstmux->type == $mux
 	select GetSize(port(rstmux, \Y)) > 1
-	index <SigSpec> port(rstmux, \Y) === port(dff, \D)
+	index <SigSpec> port(rstmux, \Y) === sigD
 	choice <IdString> BA {\B, \A}
 	select port(rstmux, BA).is_fully_const()
 	set rstmuxBA BA
-	optional
+	semioptional
 endmatch
 
 code sigD
 	if (rstmux)
 		sigD = port(rstmux, rstmuxBA == \B ? \A : \B);
-	else
-		sigD = port(dff, \D);
 endcode
 
 match cemux
@@ -32,66 +34,97 @@
 	choice <IdString> AB {\A, \B}
 	index <SigSpec> port(cemux, AB) === port(dff, \Q)
 	set cemuxAB AB
+	semioptional
 endmatch
 
 code
-	SigSpec D = port(cemux, cemuxAB == \A ? \B : \A);
-	SigSpec Q = port(dff, \Q);
+	if (!cemux && !rstmux)
+		reject;
+endcode
+
+code
 	Const rst;
-	if (rstmux)
-		rst = port(rstmux, rstmuxBA).as_const();
-	int width = GetSize(D);
-
-	SigSpec &ceA = cemux->connections_.at(\A);
-	SigSpec &ceB = cemux->connections_.at(\B);
-	SigSpec &ceY = cemux->connections_.at(\Y);
-	SigSpec &dffD = dff->connections_.at(\D);
-	SigSpec &dffQ = dff->connections_.at(\Q);
-
-	if (D[width-1] == D[width-2]) {
-		did_something = true;
-
-		SigBit sign = D[width-1];
-		bool is_signed = sign.wire;
-		int i;
-		for (i = width-1; i >= 2; i--) {
-			if (!is_signed) {
-				module->connect(Q[i], sign);
-				if (D[i-1] != sign || (rst.size() && rst[i-1] != rst[width-1]))
-					break;
-			}
-			else {
-				module->connect(Q[i], Q[i-1]);
-				if (D[i-2] != sign || (rst.size() && rst[i-1] != rst[width-1]))
-					break;
-			}
-		}
-
-		ceA.remove(i, width-i);
-		ceB.remove(i, width-i);
-		ceY.remove(i, width-i);
-		cemux->fixup_parameters();
-		dffD.remove(i, width-i);
-		dffQ.remove(i, width-i);
-		dff->fixup_parameters();
-
-		log("dffcemux pattern in %s: dff=%s, cemux=%s; removed top %d bits.\n", log_id(module), log_id(dff), log_id(cemux), width-i);
-		accept;
+	SigSpec D;
+	if (cemux) {
+		D = port(cemux, cemuxAB == \A ? \B : \A);
+		if (rstmux)
+			rst = port(rstmux, rstmuxBA).as_const();
+		else
+			rst = Const(State::Sx, GetSize(D));
 	}
 	else {
+		log_assert(rstmux);
+		D = port(rstmux, rstmuxBA  == \B ? \A : \B);
+		rst = port(rstmux, rstmuxBA).as_const();
+	}
+	SigSpec Q = port(dff, \Q);
+	int width = GetSize(D);
+
+	SigSpec &dffD = dff->connections_.at(\D);
+	SigSpec &dffQ = dff->connections_.at(\Q);
+	Const init;
+	for (const auto &b : Q) {
+		auto it = b.wire->attributes.find(\init);
+		init.bits.push_back(it == b.wire->attributes.end() ? State::Sx : it->second[b.offset]);
+	}
+
+	auto cmpx = [=](State lhs, State rhs) {
+		if (lhs == State::Sx || rhs == State::Sx)
+			return true;
+		return lhs == rhs;
+	};
+
+	int i = width-1;
+	while (i > 1) {
+		log_dump(i, D[i], D[i-1], rst[i], rst[i-1], init[i], init[i-1]);
+		if (D[i] != D[i-1])
+			break;
+		if (!cmpx(rst[i], rst[i-1]))
+			break;
+		if (!cmpx(init[i], init[i-1]))
+			break;
+		if (!cmpx(rst[i], init[i]))
+			break;
+		module->connect(Q[i], Q[i-1]);
+		i--;
+	}
+	if (i < width-1) {
+		did_something = true;
+		if (cemux) {
+			SigSpec &ceA = cemux->connections_.at(\A);
+			SigSpec &ceB = cemux->connections_.at(\B);
+			SigSpec &ceY = cemux->connections_.at(\Y);
+			ceA.remove(i, width-1-i);
+			ceB.remove(i, width-1-i);
+			ceY.remove(i, width-1-i);
+			cemux->fixup_parameters();
+		}
+		if (rstmux) {
+			SigSpec &rstA = rstmux->connections_.at(\A);
+			SigSpec &rstB = rstmux->connections_.at(\B);
+			SigSpec &rstY = rstmux->connections_.at(\Y);
+			rstA.remove(i, width-1-i);
+			rstB.remove(i, width-1-i);
+			rstY.remove(i, width-1-i);
+			rstmux->fixup_parameters();
+		}
+		dffD.remove(i, width-1-i);
+		dffQ.remove(i, width-1-i);
+		dff->fixup_parameters();
+
+		log("dffcemux pattern in %s: dff=%s, cemux=%s, rstmux=%s; removed top %d bits.\n", log_id(module), log_id(dff), log_id(cemux, "n/a"), log_id(rstmux, "n/a"), width-1-i);
+		width = i+1;
+	}
+	if (cemux) {
+		SigSpec &ceA = cemux->connections_.at(\A);
+		SigSpec &ceB = cemux->connections_.at(\B);
+		SigSpec &ceY = cemux->connections_.at(\Y);
+
 		int count = 0;
 		for (int i = width-1; i >= 0; i--) {
 			if (D[i].wire)
 				continue;
-			Wire *w = Q[i].wire;
-			auto it = w->attributes.find(\init);
-			State init;
-			if (it != w->attributes.end())
-				init = it->second[Q[i].offset];
-			else
-				init = State::Sx;
-
-			if (init == State::Sx || init == D[i].data) {
+			if (cmpx(rst[i], D[i].data) && cmpx(init[i], D[i].data)) {
 				count++;
 				module->connect(Q[i], D[i]);
 				ceA.remove(i);
@@ -105,9 +138,10 @@
 			did_something = true;
 			cemux->fixup_parameters();
 			dff->fixup_parameters();
-			log("dffcemux pattern in %s: dff=%s, cemux=%s; removed %d constant bits.\n", log_id(module), log_id(dff), log_id(cemux), count);
+			log("dffcemux pattern in %s: dff=%s, cemux=%s, rstmux=%s; removed %d constant bits.\n", log_id(module), log_id(dff), log_id(cemux), log_id(rstmux, "n/a"), count);
 		}
-
-		accept;
 	}
+
+	if (did_something)
+		accept;
 endcode
diff --git a/passes/sat/sat.cc b/passes/sat/sat.cc
index 430bba1..93a4f22 100644
--- a/passes/sat/sat.cc
+++ b/passes/sat/sat.cc
@@ -265,15 +265,18 @@
 				RTLIL::SigSpec rhs = it.second->attributes.at("\\init");
 				log_assert(lhs.size() == rhs.size());
 
+				dict<RTLIL::SigBit,SigBit> seen_init;
 				RTLIL::SigSpec removed_bits;
 				for (int i = 0; i < lhs.size(); i++) {
 					RTLIL::SigSpec bit = lhs.extract(i, 1);
-					if (rhs[i] == State::Sx || !satgen.initial_state.check_all(bit)) {
+					if (rhs[i] == State::Sx || !satgen.initial_state.check_all(bit) || seen_init.at(bit, rhs[i]) != rhs[i]) {
 						removed_bits.append(bit);
 						lhs.remove(i, 1);
 						rhs.remove(i, 1);
 						i--;
 					}
+					else
+						seen_init[bit] = rhs[i];
 				}
 
 				if (removed_bits.size())
diff --git a/tests/sat/initval.ys b/tests/sat/initval.ys
index 2079d2f..1627a37 100644
--- a/tests/sat/initval.ys
+++ b/tests/sat/initval.ys
@@ -2,3 +2,23 @@
 proc;;
 
 sat -seq 10 -prove-asserts
+
+read_verilog <<EOT
+module gold(input clk, input i, output reg [1:0] o);
+initial o = 2'b10;
+always @(posedge clk)
+   o[0] <= {i,i};
+endmodule
+
+module gate(input clk, input i, output reg [1:0] o);
+initial o = 2'b10;
+always @(posedge clk)
+   o[0] <= i;
+always @*
+   o[1] <= o[0];
+endmodule
+EOT
+
+proc
+miter -equiv -flatten -make_assert -make_outputs gold gate miter
+sat -seq 1 -falsify -prove-asserts -show-ports miter
diff --git a/tests/various/peepopt.ys b/tests/various/peepopt.ys
index 6bca62e..ee5ad8a 100644
--- a/tests/various/peepopt.ys
+++ b/tests/various/peepopt.ys
@@ -131,8 +131,8 @@
 proc
 equiv_opt -assert peepopt
 design -load postopt
-select -assert-count 1 t:$dff r:WIDTH=5 %i
-select -assert-count 1 t:$mux r:WIDTH=5 %i
+select -assert-count 1 t:$dff r:WIDTH=4 %i
+select -assert-count 1 t:$mux r:WIDTH=4 %i
 select -assert-count 0 t:$dff t:$mux %% t:* %D
 
 ####################
@@ -173,3 +173,41 @@
 select -assert-count 2 t:$mux
 select -assert-count 2 t:$mux r:WIDTH=2 %i
 select -assert-count 0 t:$logic_not t:$dff t:$mux %% t:* %D
+
+####################
+
+design -reset
+read_verilog <<EOT
+module peepopt_dffmuxext_signed_rst_init(input clk, ce, rstn, input signed [1:0] i, output reg signed [3:0] o);
+    initial o <= 4'b0010;
+    always @(posedge clk) begin
+        if (ce) o <= i;
+        if (!rstn) o <= 4'b1111;
+    end
+endmodule
+EOT
+
+proc
+# NB: equiv_opt uses equiv_induct which covers
+#     only the induction half of temporal induction
+#     --- missing the base-case half
+#     This makes it akin to `sat -tempinduct-inductonly`
+#     instead of `sat -tempinduct-baseonly` or
+#     `sat -tempinduct` which is necessary for this
+#     testcase
+#equiv_opt -assert peepopt
+
+design -save gold
+peepopt
+wreduce
+design -stash gate
+design -import gold -as gold
+design -import gate -as gate
+miter -equiv -flatten -make_assert -make_outputs gold gate miter
+sat -tempinduct -verify -prove-asserts -show-ports miter
+
+design -load gate
+select -assert-count 1 t:$dff r:WIDTH=4 %i
+select -assert-count 2 t:$mux
+select -assert-count 2 t:$mux r:WIDTH=4 %i
+select -assert-count 0 t:$logic_not t:$dff t:$mux %% t:* %D