blob: e68f56669943caf992727784e494465392090be8 [file] [log] [blame]
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) 2017-2020 The Project X-Ray Authors.
#
# Use of this source code is governed by a ISC-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/ISC
#
# SPDX-License-Identifier: ISC
'''
This script solves the fuzzing problem through least-mean-square solution of
an overdetermined linear equation system.
The advantages of this method are:
- Ability to detect negative correlations (tags which require clearing bits)
- Can detect partial correlation tag <-> bit. This happens if for a small
number of specimens a tag is said to be "1" but in fact it is not due to
the way Vivado interprets requested features and encodes them into bitstream.
- Ease to detect tags with no corresponding bits by evaluating solution error.
The solution is computed using the Tikhonov regularization scheme to ensure
numerical stability. The parameter -a can be used to vary the regularization
factor.
By default each tag is solved separately (best results) while they can be
solved all at once (not recommended).
For each tag a vector of weights is calculated. Each weight corresponds to one
bit. Positive values indicate positive correlation and negative values negative
correlation.
Each weight vector is normalized so that maximum absolute weight is equal to
one.
The parameter -t is used to set threshold for those weights. Weights with
values above the threshold and below the "minus" threshold are output as
candidate bits.
For each weight vector a solution error is computed. If the error exceeds
threshold specified using the -e parameter then the tag is considered to
have no bits.
The option -m can be used to filter bits found for the specified tag in all
other tags. This allows to remove bits from a "IS_BLOCK_IN_USE" type tag from
other tags responsible for enabling other features of that block.
'''
import sys
import os
import argparse
import itertools
import json
import numpy as np
import numpy.linalg as linalg
from prjxray.util import OpenSafeFile
# =============================================================================
def load_data(file_name, tagfilter=lambda tag: True, address_map=None):
"""
Loads data generated by the segmaker.
Parameters
----------
file_name:
Name of the text file with data.
tagfilter:
A function for filtering tags. Should reqturn True or False.
address_map:
A dict indexed by tuples (address, offset) containing a list
of tile names.
Returns
-------
A list of dicts. Each contains:
- "seg": Segment name
- "bit": A list of bit names
- "tag": A list of tuples (tag name, tag value)
"""
segdata = None
all_segdata = []
with OpenSafeFile(file_name, "r") as fp:
for line in fp.readlines():
line = line.strip()
# Segment tag
if line.startswith("seg"):
fields = line.split()
if segdata is not None:
if len(segdata["tag"]):
all_segdata.append(segdata)
segdata = None
segname = fields[1]
# Map segment address to tile name
if address_map is not None:
address = segname.split("_")
address = (
int(address[0], base=16),
int(address[1]),
)
if address in address_map:
segname = "_or_".join(address_map[address])
# Append file name
segname = file_name + ":" + segname
# Append segdata
segdata = {"seg": segname, "bit": [], "tag": []}
if segdata is None:
continue
# Bit tag
if line.startswith("bit"):
fields = line.split()
segdata["bit"].append(fields[1])
# Tag tag
if line.startswith("tag"):
fields = line.split()
if not tagfilter(fields[1]):
continue
segdata["tag"].append((
fields[1],
int(fields[2]),
))
# Store the last segment if any
if segdata is not None:
if len(segdata["tag"]):
all_segdata.append(segdata)
return all_segdata
def write_segbits(file_name, all_tags, all_bits, W):
"""
Writes solution to a raw database file.
Parameters
----------
file_name:
Name of the .rdb file.
all_tags:
List of considered tags.
all_bits:
List of considered bits.
W:
Matrix with binary solution.
"""
lines = []
for r in range(W.shape[0]):
bits = []
for c in range(W.shape[1]):
w = W[r, c]
if w < 0:
bits.append("!" + all_bits[c])
if w > 0:
bits.append(all_bits[c])
if len(bits) == 0:
bits = ["<0 candidates>"]
lines.append(all_tags[r] + " " + " ".join(bits) + "\n")
with OpenSafeFile(file_name, "w") as fp:
for line in lines:
fp.write(line)
def dump_results(fp, all_tags, all_bits, W, X, E, tag_stats=None):
"""
Dumps solution results to an open file in a nice readable format.
Parameters
----------
fp:
An open file or stream
all_tags:
List of considered tags.
all_bits:
List of considered bits.
W:
Matrix with binary solution.
X:
Matrix with raw solution (floats).
E:
Vector with solution errors.
tag_stats:
Tag statistics.
"""
lines = []
pad_len = max([len(tag) for tag in all_tags])
skip_bit = []
for i in range(len(all_bits)):
skip_bit.append((W[:, i] == 0).all())
# Bit names
bit_len = 6
for i in range(bit_len):
line = " " * (pad_len + 2 + 3)
for j in range(len(all_bits)):
if skip_bit[j]:
continue
bname = all_bits[j].ljust(bit_len).replace("_", "|")
line += bname[i]
if i == (bit_len - 1):
if tag_stats is not None:
line += " #0 #1 "
lines.append(line)
# Tags and bit values
pad = max([len(tag) for tag in all_tags])
for r in range(W.shape[0]):
line = all_tags[r].ljust(pad + 1)
if (W[r, :] == 0).all():
line += "(!) "
else:
line += " "
for c in range(W.shape[1]):
if skip_bit[c]:
continue
b = W[r, c]
if b < 0:
line += "0"
elif b > 0:
line += "1"
else:
line += "-"
if tag_stats is not None:
stat = tag_stats[all_tags[r]]
line += " %4d|%4d" % stat
x_min = np.min(X[r, :])
x_max = np.max(X[r, :])
line += " lo=%+.3f hi=%+.3f e=%.3f" % (x_min, x_max, E[r])
lines.append(line)
lines.append("")
# Write
for line in lines:
fp.write(line + "\n")
def dump_solution_to_csv(fp, all_tags, all_bits, X):
"""
Dumps solution data to CSV.
Parameters
----------
fp:
An open file or stream
all_tags:
List of considered tags.
all_bits:
List of considered bits.
X:
Matrix with raw solution (floats).
"""
# Bits
line = ","
for bit in all_bits:
line += bit + ","
fp.write(line[:-1] + "\n")
# Tags + numbers
for r, tag in enumerate(all_tags):
line = tag + ","
for c in range(X.shape[1]):
line += "%+e," % X[r, c]
fp.write(line[:-1] + "\n")
def dump_correlation_report(
fp, all_tags, all_bits, W, C, correlation_exceptions):
for i, tag in enumerate(all_tags):
# No exceptions (100% correlation)
if len(correlation_exceptions[tag]) == 0:
continue
fp.write(tag + "\n")
for j, bit in enumerate(all_bits):
if bit not in correlation_exceptions[tag]:
continue
c = C[i, j]
w = W[i, j]
# Dump bit correlation factor
sgn = "+" if w > 0 else "-"
fp.write(" bit %s: (%s) %.1f%%\n" % (bit.ljust(6), sgn, c * 100.0))
# Dump counter-factual cases
e = correlation_exceptions[tag][bit]
for x, y, ex in e:
fp.write(" is %d, should be %d - %s\n" % (x, y, ex))
fp.write("\n")
# =============================================================================
def build_matrices(all_tags, all_bits, segdata, bias=0.0):
"""
Builds matrices for the linear equation system to be solved.
Parameters
----------
all_tags:
List of considered tags.
all_bits:
List of considered bits.
segdata:
List of segdata used.
bias:
T.B.D.
"""
M = len(segdata)
N = len(all_bits)
K = len(all_tags)
A = np.zeros((M, N), dtype=np.float64)
B = np.zeros((M, K), dtype=np.float64)
# A matrix
for r, c in itertools.product(range(M), range(N)):
if all_bits[c] in segdata[r]["bit"]:
A[r, c] = +1.0
else:
A[r, c] = -1.0
# B matrix
for r, c in itertools.product(range(M), range(K)):
for t, x in segdata[r]["tag"]:
if t == all_tags[c]:
v = +1.0 if x > 0 else -1.0
B[r, c] = v + bias
return A, B
def compute_error(A, B, X):
"""
Computes solution error.
Parameters
----------
A:
Matrix A
B:
Matrix B
X:
Matrix with computed solution.
Returns
-------
A vector with errors
"""
K = B.shape[1]
# Compute error
Bx = np.matmul(A, X)
E = np.empty((K))
for k in range(K):
E[k] = np.sqrt(np.sum(np.square(Bx[:, k] - B[:, k])))
return E
# =============================================================================
def solve_lms(all_tags, all_bits, segdata, bias=0.0):
"""
Solves using direct least square solution (NumPy)
Parameters
----------
all_tags:
List of considered tags.
all_bits:
List of considered bits.
segdata:
List of segdata used.
bias:
T.B.D.
"""
# Build matrices
A, B = build_matrices(all_tags, all_bits, segdata, bias)
# Solve
X, res, r, s = linalg.lstsq(A, B, rcond=None)
return X, compute_error(A, B, X)
def solve_tichonov(all_tags, all_bits, segdata, bias=0.0, a=0.0):
"""
Solves using Tichonov regularization method.
Parameters
----------
all_tags:
List of considered tags.
all_bits:
List of considered bits.
segdata:
List of segdata used.
bias:
T.B.D.
a:
Regularization coefficient.
Returns
-------
Tuple with:
- Solution matrix X
- Error vector.
"""
M = len(segdata)
N = len(all_bits)
K = len(all_tags)
# Build matrices
A, B = build_matrices(all_tags, all_bits, segdata, bias)
# Tikhonov regularization
# https://en.wikipedia.org/wiki/Tikhonov_regularization
AtA = np.matmul(A.T, A)
AtB = np.matmul(A.T, B)
X = np.matmul(np.linalg.inv(AtA + a * np.eye(N)), AtB)
return X, compute_error(A, B, X)
# =============================================================================
def solve_onebyone(all_tags, all_bits, segdata, solver=solve_lms, **kw):
"""
Solves each tag separately in one-by-one fashion.
Parameters
----------
all_tags:
List of considered tags.
all_bits:
List of considered bits.
segdata:
List of segdata used.
solver:
Solver function.
**kw:
Parameters to solver function.
Returns
-------
Tuple with:
- Solution matrix X
- Error vector.
"""
X = np.empty((len(all_bits), len(all_tags)))
E = np.empty((len(all_tags)))
for i, tag in enumerate(all_tags):
tag_segdata = [
data for data in segdata if tag in [t[0] for t in data["tag"]]
]
print("%s #%d" % (tag, len(tag_segdata)))
X1, E1 = solver([tag], all_bits, tag_segdata, **kw)
X[:, i] = X1[:, 0]
E[i] = E1[0]
return X, E
# =============================================================================
def detect_candidates(X, th, norm=None):
"""
Detects candidate bits.
Parameters
----------
X:
Matrix with solution
th:
Threshold
norm:
Normalization scheme. See code.
Returns
-------
A tuple with:
- Binary solution matrix W
- Transposed matrix X
"""
Xt = np.array(X.T)
W = np.zeros_like(Xt, dtype=int)
if norm == "max_abs":
Nv = np.max(np.abs(Xt), axis=1)
Xt /= np.tile(Nv[:, None], (1, Xt.shape[1]))
W[Xt < -th] = -1
W[Xt > +th] = +1
return W, X.T
# =============================================================================
def compute_bit_correlations(tags_to_solve, bits_to_solve, segdata, W):
"""
Basing on solution given in the matrix W returns a matrix C with
correlation coefficients of each bit.
Also returns a dict of dicts indexed by tag names and bit names with
correlation exceptions - concrete specimen names where the correlation
does not occur.
"""
C = np.zeros_like(W, dtype=float)
exceptions = {}
for i, tag in enumerate(tags_to_solve):
# Filter data for this tag
tag_segdata = [
data for data in segdata if tag in [t[0] for t in data["tag"]]
]
exceptions[tag] = {}
# Compute bit correlation
for j, bit in enumerate(bits_to_solve):
w = W[i, j]
# No correlation with that bit
if w == 0:
continue
corr_sum = 0
corr_count = 0
# Compute for one bit
for k, data in enumerate(tag_segdata):
bits = data["bit"]
vt = [v for t, v in data["tag"] if t == tag][0]
vb = 1 if bit in bits else 0
# Negative correlation
if w < 0:
vt = int(1 - vt)
else:
vt = int(vt)
# Correlates
if vt == vb:
corr_sum += 1
# Does not correlate
else:
if bit not in exceptions[tag]:
exceptions[tag][bit] = []
exceptions[tag][bit].append((
vb,
vt,
data["seg"],
))
corr_count += 1
# Store correlation
C[i, j] = corr_sum / corr_count
return C, exceptions
def compute_tag_stats(all_tags, segdata):
"""
Counts occurrence of all considered tags
Parameters
----------
all_tags:
Considered tags
segdata:
List of segdata used
Returns
-------
A dict indexed by tag name with tuples containing 0 and 1 occurrence count.
"""
stats = {}
for i, tag in enumerate(all_tags):
count0 = 0
count1 = 0
for data in segdata:
for t, v in data["tag"]:
if t == tag:
if v > 0:
count1 += 1
else:
count0 += 1
stats[tag] = (
count0,
count1,
)
return stats
def sort_bits(bit_name):
"""
Utility function for sorting bits.
"""
frm, ofs = bit_name.split("_")
return (
int(frm),
int(ofs),
)
def build_address_map(tilegrid_file):
"""
Loads the tilegrid and generates a map (baseaddr, offset) -> tile name(s).
Parameters
----------
tilegrid_file:
The tilegrid.json file/
Returns
-------
A dict with lists of tile names.
"""
address_map = {}
# Load tilegrid
with OpenSafeFile(tilegrid_file, "r") as fp:
tilegrid = json.load(fp)
# Loop over tiles
for tile_name, tile_data in tilegrid.items():
# No bits or bits empty
if "bits" not in tile_data:
continue
if not len(tile_data["bits"]):
continue
bits = tile_data["bits"]
# No bus
if "CLB_IO_CLK" not in bits:
continue
bus = bits["CLB_IO_CLK"]
# Make the address as integers
baseaddr = int(bus["baseaddr"], 16)
offset = int(bus["offset"])
address = (
baseaddr,
offset,
)
# Add tile to the map
if address not in address_map:
address_map[address] = []
address_map[address].append(tile_name)
return address_map
# =============================================================================
class FileOrStream(object):
def __init__(self, file_name, stream=sys.stdout):
self.file_name = file_name
self.stream = stream
self.fp = None
def __enter__(self):
if self.file_name is None:
return self.stream
if self.file_name == "-":
return self.stream
self.fp = open(self.file_name, "w")
return self.fp
def __exit__(self, exc_typ, exc_val, exc_tb):
if self.fp is not None:
self.fp.close()
# =============================================================================
def main():
"""
The main.
"""
# Parse arguments
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
"files",
nargs="*",
type=str,
help="Input file(s) generated by segmaker")
parser.add_argument(
"-o",
type=str,
default="segbits.rdb",
help="Output database file (def. segbits.rdb)")
parser.add_argument(
"-f",
type=str,
default=None,
help="Tag filter. Processes only tags containing the specified text")
parser.add_argument(
"-t", type=float, default=0.95, help="Candidate threshold (def. 0.95)")
parser.add_argument(
"-e",
type=float,
default=0.1,
help="RMS error threshold below which a tag is rejected (def. 0.1)")
parser.add_argument(
"-a",
type=float,
default=0.01,
help="Regularization coefficient (def. 0.01)")
parser.add_argument(
"--all",
action="store_true",
help="Solve all tags at once (may give worse results)")
parser.add_argument(
"-x",
type=str,
default=None,
help="A CSV file name to Write the numerical solution to")
parser.add_argument(
"-r",
type=str,
default=None,
help=
"A text file name to write bit correlation report to. Specify '-' for stdout"
)
parser.add_argument(
"-m",
type=str,
default=None,
help="Mask bits found for this feature in all other features")
parser.add_argument("-b", type=float, default=0.0, help="Bias")
parser.add_argument("-no_0", action="store_true", help="Do not output 0s")
parser.add_argument("-no_1", action="store_true", help="Do not output 1s")
args = parser.parse_args()
# Build (baseaddr, offset) -> tile name map
database_dir = os.path.join(
os.getenv("XRAY_DATABASE_DIR"), os.getenv("XRAY_DATABASE"),
os.getenv("XRAY_FABRIC"))
tilegrid_file = os.path.join(database_dir, "tilegrid.json")
address_map = build_address_map(tilegrid_file)
# Compute threshold
th = args.t
# Load and filter segdata
segdata = []
def tagfilter(tag):
if args.f is None:
return True
return args.f in tag
for name in args.files:
print(name)
segdata.extend(load_data(name, tagfilter, address_map))
# Make list of all bits
all_bits = set()
for seg in segdata:
all_bits |= set(seg["bit"])
all_bits = sorted(list(all_bits), key=sort_bits)
# Detect bits that are always set
const1_bits = set(all_bits)
for seg in segdata:
const1_bits &= set(seg["bit"])
# Make list of all tags
all_tags = set()
for seg in segdata:
all_tags |= set([tag[0] for tag in seg["tag"]])
all_tags = sorted(list(all_tags))
# Count 0s and 1s for each tag
tag_count = {}
for seg in segdata:
for tag, val in seg["tag"]:
if tag not in tag_count:
tag_count[tag] = [0, 0]
if val > 0:
tag_count[tag][1] += 1
else:
tag_count[tag][0] += 1
# Identify const0 and const1 tags
const_tags = {}
for tag in all_tags:
if tag_count[tag][0] == 0:
const_tags[tag] = 1
if tag_count[tag][1] == 0:
const_tags[tag] = 0
const0_tags = [t for t, v in const_tags.items() if v == 0]
const1_tags = [t for t, v in const_tags.items() if v == 1]
# Print config
print("# segs:", len(segdata))
print("# tags:", len(all_tags))
print("# bits:", len(all_bits))
print("threshold: %.2f" % th)
if len(segdata) == 0:
print("No data!")
exit(-1)
if len(all_tags) == 0:
print("No tags!")
exit(-1)
if len(all_bits) == 0:
print("No bits!")
exit(-1)
if len(const1_bits):
print("const 1 bits: " + ", ".join(const1_bits))
if len(const0_tags):
print("const 0 tags: " + ", ".join(const0_tags))
if len(const1_tags):
print("const 1 tags: " + ", ".join(const1_tags))
# Data to solve
tags_to_solve = list(all_tags)
bits_to_solve = list(all_bits)
for tag in const_tags.keys():
tags_to_solve.remove(tag)
for bit in const1_bits:
bits_to_solve.remove(bit)
# Statistics
tag_stats = compute_tag_stats(tags_to_solve, segdata)
# Solve
print("Solving...")
if args.all:
X, E = solve_tichonov(
tags_to_solve, bits_to_solve, segdata, bias=args.b, a=args.a)
else:
X, E = solve_onebyone(
tags_to_solve,
bits_to_solve,
segdata,
solver=solve_tichonov,
bias=args.b,
a=args.a)
# Detect candidate bits
W, X = detect_candidates(X, th, norm="max_abs")
# Mask
if args.m is not None:
print("Masking out %s" % args.m)
tags = [t for t in tags_to_solve if args.m in t]
for tag in tags:
i = tags_to_solve.index(tag)
for r in range(len(tags_to_solve)):
if r == i:
continue
for c in range(len(bits_to_solve)):
if W[r, c] == W[i, c]:
W[r, c] = 0
# Reject 0s and/or 1s
if args.no_0:
W[W < 0] = 0
if args.no_1:
W[W > 0] = 0
# Reject tags with error greater than threshold
for r in range(X.shape[0]):
if E[r] > args.e:
W[r, :] = 0
# Compute correlation
C, correlation_exceptions = compute_bit_correlations(
tags_to_solve, bits_to_solve, segdata, W)
# Write segbits
write_segbits(args.o, tags_to_solve, bits_to_solve, W)
# Dump to CSV
if args.x is not None:
with OpenSafeFile(args.x, "w") as fp:
dump_solution_to_csv(fp, tags_to_solve, bits_to_solve, X)
# Dump results
dump_results(sys.stdout, tags_to_solve, bits_to_solve, W, X, E, tag_stats)
# Dump correlation report
if args.r is not None:
if args.r != "-":
print("Dumping bit correlation report to '{}'".format(args.r))
with FileOrStream(args.r, sys.stdout) as fp:
dump_correlation_report(
fp, tags_to_solve, bits_to_solve, W, C, correlation_exceptions)
# =============================================================================
if __name__ == "__main__":
main()