Skip to content

Commit

Permalink
Divide signatures (Chia-Network#15)
Browse files Browse the repository at this point in the history
* Add bip32 keys and test vectors
* Signature division working, identical on python and cpp
  • Loading branch information
mariano54 authored Sep 7, 2018
1 parent 91e1b8e commit 7964d46
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 38 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,8 @@ keybase team request-access chia_network
* Remove unnecessary dependency files
* Constant time and side channel attacks
* Adaptor signatures / Blind signatures
* More tests vectors (failed verifications, etc)


### Test vectors
Test vectors can be found in src/test.cpp and python-impl/tests.py.
1 change: 0 additions & 1 deletion python-impl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ as BLS signatures and aggregation. Use for reference / educational purposes only
For a good introduction to pairings, read [Pairings for Beginners](http://www.craigcostello.com.au/pairings/PairingsForBeginners.pdf) by Craig Costello.

### TODO
* Signature division
* Fast algorithm for final exponentiation
20 changes: 18 additions & 2 deletions python-impl/aggregation_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from util import hash256, hash_pks
from copy import deepcopy


class AggregationInfo:
Expand Down Expand Up @@ -33,15 +34,30 @@ def __lt__(self, other):
other.public_keys[i])])
for i in range(len(other.public_keys))]

for i in range(len(combined)):
if i >= len(combined_other):
for i in range(max(len(combined), len(combined_other))):
if i == len(combined):
return True
if i == len(combined_other):
return False
if combined[i] < combined_other[i]:
return True
if combined_other[i] < combined[i]:
return False
return True

def __str__(self):
ret = ""
for key, value in self.tree.items():
ret += ("(" + key[0].hex() + "," + key[1].serialize().hex()
+ "):\n" + hex(value) + "\n")
return ret

def __deepcopy__(self, memo):
new_tree = deepcopy(self.tree, memo)
new_mh = deepcopy(self.message_hashes, memo)
new_pubkeys = deepcopy(self.public_keys, memo)
return AggregationInfo(new_tree, new_mh, new_pubkeys)

@staticmethod
def from_msg_hash(public_key, message_hash):
tree = {}
Expand Down
11 changes: 6 additions & 5 deletions python-impl/bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def aggregate_sigs_secure(signatures, public_keys, message_hashes):
"""
if (len(signatures) != len(public_keys) or
len(public_keys) != len(message_hashes)):
raise "Invalid number of keys"
raise Exception("Invalid number of keys")
mh_pub_sigs = [(message_hashes[i], public_keys[i], signatures[i])
for i in range(len(signatures))]

Expand Down Expand Up @@ -68,7 +68,8 @@ def aggregate_sigs(signatures):

for signature in signatures:
if signature.aggregation_info.empty():
raise "Each signature must have a valid aggregation info"
raise Exception("Each signature must have a valid aggregation "
+ "info")
public_keys.append(signature.aggregation_info.public_keys)
message_hashes.append(signature.aggregation_info.message_hashes)

Expand Down Expand Up @@ -129,7 +130,7 @@ def aggregate_sigs(signatures):
sort_keys_sorted.sort()
sorted_public_keys = [pk for (mh, pk) in sort_keys_sorted]

computed_Ts = hash_pks(len(sorted_public_keys), sorted_public_keys)
computed_Ts = hash_pks(len(colliding_sigs), sorted_public_keys)

# Raise each sig to a power of each t,
# and multiply all together into agg_sig
Expand Down Expand Up @@ -206,7 +207,7 @@ def aggregate_public_keys(public_keys, secure):
Aggregates public keys together
"""
if len(public_keys) < 1:
raise "Invalid number of keys"
raise Exception("Invalid number of keys")
public_keys.sort()

computed_Ts = hash_pks(len(public_keys), public_keys)
Expand All @@ -228,7 +229,7 @@ def aggregate_private_keys(private_keys, public_keys, secure):
Aggregates private keys together
"""
if secure and len(private_keys) != len(public_keys):
raise "Invalid number of keys"
raise Exception("Invalid number of keys")

priv_pub_keys = [(public_keys[i], private_keys[i])
for i in range(len(private_keys))]
Expand Down
33 changes: 17 additions & 16 deletions python-impl/ec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import namedtuple
from fields import Fq, Fq2, Fq6, Fq12, FieldExtBase
import bls12381
import random
from copy import deepcopy
from util import hash256

# Struct for elliptic curve parameters
Expand All @@ -21,7 +21,7 @@ def __init__(self, x, y, infinity, ec=default_ec):
if (not isinstance(x, Fq) and not isinstance(x, FieldExtBase) or
(not isinstance(y, Fq) and not isinstance(y, FieldExtBase)) or
type(x) != type(y)):
raise "x,y should be field elements"
raise Exception("x,y should be field elements")
self.FE = type(x)
self.x = x
self.y = y
Expand All @@ -43,7 +43,7 @@ def __add__(self, other):
if other == 0:
return self
if not isinstance(other, AffinePoint):
raise "Incorrect object"
raise Exception("Incorrect object")

return add_points(self, other, self.ec, self.FE)

Expand Down Expand Up @@ -107,6 +107,12 @@ def serialize(self):

return bytes(output)

def __deepcopy__(self, memo):
return AffinePoint(deepcopy(self.x, memo),
deepcopy(self.y, memo),
self.infinity,
self.ec)


class JacobianPoint:
"""
Expand All @@ -118,7 +124,7 @@ def __init__(self, x, y, z, infinity, ec=default_ec):
if (not isinstance(x, Fq) and not isinstance(x, FieldExtBase) or
(not isinstance(y, Fq) and not isinstance(y, FieldExtBase)) or
(not isinstance(z, Fq) and not isinstance(z, FieldExtBase))):
raise "x,y should be field elements"
raise Exception("x,y should be field elements")
self.FE = type(x)
self.x = x
self.y = y
Expand Down Expand Up @@ -176,6 +182,13 @@ def to_affine(self):
def serialize(self):
return self.to_affine().serialize()

def __deepcopy__(self, memo):
return JacobianPoint(deepcopy(self.x, memo),
deepcopy(self.y, memo),
deepcopy(self.z, memo),
self.infinity,
self.ec)


def y_for_x(x, ec=default_ec, FE=Fq):
"""
Expand Down Expand Up @@ -344,18 +357,6 @@ def scalar_mult_jacobian(c, p1, ec=default_ec, FE=Fq):
return result


def order(ec=default_ec):
return ec.n


def rand_scalar(ec=default_ec):
return random.randrange(1, ec.n)


def rand_field_element(ec=default_ec):
return Fq(ec.q, random.randrange(1, ec.q))


def generator_Fq(ec=default_ec):
return AffinePoint(ec.gx, ec.gy, False, ec)

Expand Down
17 changes: 15 additions & 2 deletions python-impl/fields.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from copy import deepcopy


class Fq(int):
"""
Represents an element of a finite field mod a prime q.
Expand Down Expand Up @@ -143,6 +146,9 @@ def modsqrt(self):
t = (t * c) % self.Q
R = (R * b) % self.Q

def __deepcopy__(self, memo):
return Fq(self.Q, int(self))

def getint():
return super()

Expand Down Expand Up @@ -174,12 +180,12 @@ def __new__(cls, Q, *args):
args[1].extension
except AttributeError:
if len(args) != 2:
raise "Invalid number of arguments"
raise Exception("Invalid number of arguments")
arg_extension = 1
new_args = [Fq(Q, a) for a in args]
if arg_extension != 1:
if len(args) != cls.embedding:
raise "Invalid number of arguments"
raise Exception("Invalid number of arguments")
for arg in new_args:
assert(arg.extension == arg_extension)
assert all(isinstance(arg, cls.basefield
Expand Down Expand Up @@ -345,6 +351,13 @@ def from_fq(cls, Q, fq):
ret.set_root(r)
return ret

def __deepcopy__(self, memo):
cls = type(self)
ret = super().__new__(cls, (deepcopy(a, memo) for a in self))
ret.Q = self.Q
ret.root = self.root
return ret


class Fq2(FieldExtBase):
# Fq2 is constructed as Fq(u) / (u2 - β) where β = -1
Expand Down
16 changes: 13 additions & 3 deletions python-impl/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
hash_to_point_prehashed_Fq2, y_for_x,
AffinePoint)
from fields import Fq
from copy import deepcopy
from aggregation_info import AggregationInfo
from signature import BLSSignature

Expand Down Expand Up @@ -43,6 +44,12 @@ def get_fingerprint(self):
def serialize(self):
return self.value.serialize()

def __eq__(self, other):
return self.value.serialize() == other.value.serialize()

def __hash__(self):
return int.from_bytes(self.value.serialize(), "big")

def __lt__(self, other):
return self.value.serialize() < other.value.serialize()

Expand All @@ -52,6 +59,9 @@ def __str__(self):
def __repr__(self):
return "BLSPublicKey(" + self.value.to_affine().__repr__() + ")"

def __deepcopy__(self, memo):
return BLSPublicKey.from_g1(deepcopy(self.value, memo))


class BLSPrivateKey:
"""
Expand Down Expand Up @@ -128,7 +138,7 @@ def from_seed(seed):

def private_child(self, i):
if (self.depth >= 255):
raise "Cannot go further than 255 levels"
raise Exception("Cannot go further than 255 levels")

# Hardened keys have i >= 2^31. Non-hardened have i < 2^31
hardened = (i >= (2 ** 31))
Expand Down Expand Up @@ -193,11 +203,11 @@ def from_bytes(serialized):

def public_child(self, i):
if (self.depth >= 255):
raise "Cannot go further than 255 levels"
raise Exception("Cannot go further than 255 levels")

# Hardened keys have i >= 2^31. Non-hardened have i < 2^31
if i >= (2 ** 31):
raise "Cannot derive hardened children from public key"
raise Exception("Cannot derive hardened children from public key")

hmac_input = self.public_key.serialize() + i.to_bytes(4, "big")
i_left = hmac256(hmac_input + bytes([0]), self.chain_code)
Expand Down
71 changes: 69 additions & 2 deletions python-impl/signature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ec import default_ec, y_for_x, AffinePoint
from ec import default_ec, y_for_x, AffinePoint, JacobianPoint
from fields import Fq, Fq2
from copy import deepcopy


class BLSSignature:
Expand Down Expand Up @@ -28,9 +29,76 @@ def from_bytes(buffer, aggregation_info=None):
def from_g2(g2_el, aggregation_info=None):
return BLSSignature(g2_el, aggregation_info)

def divide_by(self, divisor_signatures):
"""
Signature division (elliptic curve subtraction). This is useful if
you have already verified parts of the tree, since verification
of the resulting quotient signature will be faster (less pairings
have to be perfomed).
This function Divides an aggregate signature by other signatures
in the aggregate trees. A signature can only be divided if it is
part of the subset, and all message/public key pairs in the
aggregationInfo for the divisor signature are unique. i.e you cannot
divide s1 / s2, if s2 is an aggregate signature containing m1,pk1,
which is also present somewhere else in s1's tree. Note, s2 itself
does not have to be unique.
"""
message_hashes_to_remove = []
pubkeys_to_remove = []
prod = JacobianPoint(Fq2.one(default_ec.q), Fq2.one(default_ec.q),
Fq2.zero(default_ec.q), True, default_ec)
for divisor_sig in divisor_signatures:
pks = divisor_sig.aggregation_info.public_keys
message_hashes = divisor_sig.aggregation_info.message_hashes
if len(pks) != len(message_hashes):
raise Exception("Invalid aggregation info")

for i in range(len(pks)):
divisor = divisor_sig.aggregation_info.tree[
(message_hashes[i], pks[i])]
try:
dividend = self.aggregation_info.tree[
(message_hashes[i], pks[i])]
except KeyError:
raise Exception("Signature is not a subset")
if i == 0:
quotient = (Fq(default_ec.n, dividend)
/ Fq(default_ec.n, divisor))
else:
# Makes sure the quotient is identical for each public
# key, which means message/pk pair is unique.
new_quotient = (Fq(default_ec.n, dividend)
/ Fq(default_ec.n, divisor))
if quotient != new_quotient:
raise Exception("Cannot divide by aggregate signature,"
+ "msg/pk pairs are not unique")
message_hashes_to_remove.append(message_hashes[i])
pubkeys_to_remove.append(pks[i])
prod += (divisor_sig.value * -quotient)
copy = BLSSignature(deepcopy(self.value + prod),
deepcopy(self.aggregation_info))

for i in range(len(message_hashes_to_remove)):
a = message_hashes_to_remove[i]
b = pubkeys_to_remove[i]
if (a, b) in copy.aggregation_info.tree:
del copy.aggregation_info.tree[(a, b)]
sorted_keys = list(copy.aggregation_info.tree.keys())
sorted_keys.sort()
copy.aggregation_info.message_hashes = [t[0] for t in sorted_keys]
copy.aggregation_info.public_keys = [t[1] for t in sorted_keys]
return copy

def set_aggregation_info(self, aggregation_info):
self.aggregation_info = aggregation_info

def __eq__(self, other):
return self.value.serialize() == other.value.serialize()

def __hash__(self):
return int.from_bytes(self.value.serialize(), "big")

def __lt__(self, other):
return self.value.serialize() < other.value.serialize()

Expand Down Expand Up @@ -59,4 +127,3 @@ def __repr__(self):
See the License for the specific language governing permissions and
limitations under the License.
"""

Loading

0 comments on commit 7964d46

Please sign in to comment.