K3RN3LCTF 2021 - Game of Secrets

Cryptography – 500 pts (2 solves) – Chall author: Polymero (me)

“John wants to play a game, a game of secrets. Recover his secret or be encrypted.”

Files: gameofsecrets.py, output.txt

This challenge was part of our very first CTF, K3RN3LCTF 2021.

Exploration

The title of the challenge, ‘Game of Secrets’, together with the reference to the name ‘John’ were direct hints at the origin of the below home-rolled (read: insecure) cryptosystem, John Conway’s Game of Life. We can see the shape of the inner state by taking a look at the slicer() method. It is a simple 8x8 bit plane, which is updated according to the standard rules of the Game of Life:

  1. Cells that contain a ‘1’ are considered alive, whereas ‘0’ cells are considered dead.
  2. An alive cell which has 1 or fewer, or 4 or more neighbours will die out in the next step.
  3. Any dead cell with 2 or 3 neighbours will come back alive in the next step.
  4. Any cells not affected by rules 2 or 3 retain their state.

Note that in the consideration of the neighbours, a 3x3 grid is used which wraps around the full 8x8 plane. The key stream is then created by running the above for 12 steps and then XORing the field with one of the four round keys, where the round keys are simply generated from the sha256 hash of the master key. Every 48 steps, all the round keys are individually hashed with sha256. The resulting key stream is then simply XORed with the input plaintext to produce the ciphertext.

class GoS:
    def __init__(self, key=None, spacing=12):
        self.spacing = spacing
        if type(key) == str:
            try:    key = bytes.fromhex(key)
            except: key = None
        if (key is None) or (len(key) != 8) or (type(key) != bytes):
            key = os.urandom(8)
        self.RKs = [ sha256(key).hexdigest()[i:i+16] for i in range(0,256//4,16) ]
        self.ratchet()
        self.state = self.slicer(self.RKs[0])
        self.i = 0
    
    def slicer(self, inp):
        if type(inp) == str:
            inp = bytes.fromhex(inp)
        return [ [ int(i) for i in list('{:08b}'.format(j)) ] for j in inp ]
    
    def ratchet(self):
        self.RKs = [ sha256(rk.encode()).hexdigest()[:16] for rk in self.RKs ]
    
    def update(self):
        rk_plane = self.slicer( self.RKs[ (self.i // self.spacing) % 4 ] )
        for yi in range(8):
            for xi in range(8):
                self.state[yi][xi] = self.state[yi][xi] ^ rk_plane[yi][xi]
                
    def get_sum(self, x, y):
        ret =  [ self.state[(y-1) % 8][i % 8] for i in [x-1, x, x+1] ]
        ret += [ self.state[ y    % 8][i % 8] for i in [x-1,    x+1] ]
        ret += [ self.state[(y+1) % 8][i % 8] for i in [x-1, x, x+1] ]
        return sum(ret)
    
    def rule(self, ownval, neighsum):
        if ( neighsum < 2 ) or ( neighsum > 3 ):
            return 0
        return 1
    
    def tick(self):
        new_state = [ [ 0 for _ in range(8) ] for _ in range(8) ]
        for yi in range(8):
            for xi in range(8):
                new_state[yi][xi] = self.rule( self.state[yi][xi], self.get_sum(xi, yi) )
        self.state = new_state
        self.i += 1
        if (self.i % (4 * self.spacing)) == 0:
            self.ratchet()
        if (self.i % self.spacing) == 0:
            self.update()
        
    def output(self):
        return bytes([int(''.join([str(j) for j in i]),2) for i in self.state]).hex()
                
    def stream(self, nbyt):
        lst = ''
        for _1 in range(-(-nbyt//8)):
            for _2 in range(3):
                for _3 in range(self.spacing):
                    self.tick()
            lst += self.output()
        return ''.join(lst[:2*nbyt])
    
    def xorstream(self, msgcip):
        if type(msgcip) == str:
            msgcip = bytes.fromhex(msgcip)
        keystream = list(bytes.fromhex(self.stream(len(msgcip))))
        bytstream = list(msgcip)
        return bytes([ bytstream[i] ^ keystream[i] for i in range(len(msgcip)) ]).hex()

Finally, we are given 9600 key stream bytes for free. Although this challenges will provide a solvable ciphertext, with this size of known keystream almost any randomly generated key stream will be vulnerable to the exploit described in the next section.

def __main__():

    gos = GoS()

    print("\n -- Here's your free stream to prepare you for your upcoming game! --\n")
    print(base64.urlsafe_b64encode(bytes.fromhex(gos.stream(1200*8))).decode())
    print('\n -- They are indistinguishable from noise, they are of variable length, and they are the key to your victory.')
    print('    Ladies and Gentlemen, give it up for THE ENCRYPTED PADDED FLAG!!! --\n')
    print(base64.urlsafe_b64encode(bytes.fromhex(
                         gos.xorstream( 
                         os.urandom(int(os.urandom(1).hex(),16)+8) +
                         FLAG +
                         os.urandom(int(os.urandom(1).hex(),16)+8) 
                         ))).decode().rstrip('='))
    print('\nGood Luck! ~^w^~\n')

__main__()
 -- Here's your free stream to prepare you for your upcoming game! --



 -- They are indistinguishable from noise, they are of variable length, and they are the key to your victory.
    Ladies and Gentlemen, give it up for THE ENCRYPTED PADDED FLAG!!! --

Z7slPiJ_t8Aq3Wm4_Yd-uqHXdXrxVfgOjq5tthxXyinrtZfWBhfmSXgLGB14jfZv4_EPcKtEkPB2ITlcBzHrocOuW7QZXAjgwMhzpcOISA4DoteJhw1w-O6AyunM0mdrdXxqTwOwr0jcTArKW06lEqVjDvu8HIxcFORRhDjQxzqyIVPwgXZ_Nsm44ih4knzu1INL1xAnilv1ZnMoAlu-iqwZe5tLhAeZeDf0ZDLWT_6NkzmsvoOcQFwSAjLtGk3u7Zzkaf1xSYAM8SrBejuSGK90q6t-I4Uqrv08mN4ABCWAgLx1uT2jetBMH_Sz7Gb1iytsC9wOzg46Be4d07lKmcyUpd63eDyRranLCzTH1lmFSm5Q0agCwWUcwLZNuiNJK_Y1ZHLnWtXEaoKCW9zHrHHNF_zXVakZlWl66eAZvTkPMhF_SVUZxglXg3FWoRVg-YMmMoAIfiXDYLgDT3cFEryudwacv6jLSMeUMPrt3IL_ktoxOX9ICoiEe3Hk

Good Luck! ~^w^~

Exploitation

Cellular automata are fun and all, but there is a slight inherent problem with the Game of Life. The population can die out and when it does the resulting field is all zeroes, see the figure below.

Whenever this occurs, the round key is fully exposed! So from here on out, our exploit strategy will be quite straightforward. For every set of 16 bytes per 64 byte block, we check whether or not any corresponding hash derivates are present in the given key stream. If it is, this implies we have recovered a round key derivate some distance into the key stream. If we manage to recover round keys for all four round keys, and derive them to what they will be at the point of flag encryption, we can set up our own local version and simply find the key stream used to encrypt the flag. Easy peasy.

Vulnerability Ghost-Checker

To further visualise this procedure, I wrote a little ghost-checker that checks where which round key is exposed. Note that this is a ghost-checker as it uses information not accessible to an attacker. Note that in the scripts below I still called the class JSG instead of GoS, but they operate identically.

jsg = JSG()
print(jsg.RKs)
n = 1000
out = []
iii = []
fks = []
kis = [3 - (i % 4) for i in range(n)]
xks = []
for i in range(n):
    for j in range(3):
        for k in range(jsg.spacing):
            jsg.tick()
    xks += [jsg.RKs[(jsg.i // jsg.spacing) % 4]]
    out += [jsg.output()]
    iii += [jsg.i]
    fks += [jsg.RKs]
    print(i+1, end='\r', flush=True)
vln = [out[i] == xks[i] for i in range(n)]
print(sum(vln),'        ')
print(jsg.RKs)
plt.figure(figsize=(32,8))
xkslst = [int(i,16) for i in xks]
outlst = [int(i,16) for i in out]
#plt.plot(range(len(outlst)),outlst,c='black',alpha=0.6,zorder=0)
plt.scatter(range(len(xkslst)),xkslst,c='tab:green')
plt.scatter([i for i in range(n) if vln[i]],[xkslst[i] for i in range(n) if vln[i]],c='tab:red',s=500)
for i in range(n):
    if vln[i]:
        plt.text(i,xkslst[i],'{}'.format(kis[i]),fontsize=18,ha='center',va='center',c='white',fontweight='bold')
plt.ylim(0,2**64-1)
plt.xlim(-1,n+1)
plt.show()

Turns out round keys are exposed quite frequently. Let’s see if we can recover them.

Recovering Round Keys

The following script is an exploit example on a local instantiation of the GoS class, see above.

k4s = out[0::4]
k3s = out[1::4]
k2s = out[2::4]
k1s = out[3::4]

found_RKs = []

# For every of the four round keys
for RKi in range(4):

    res = []
    olst = out[3-RKi::4]

    # For every 16-byte outputs
    for i,kk in enumerate(olst):

        ki = kk

        for ok in olst[i+1:]:

            # Triple hash as 12//4 = 3 times a hash is applied between state XORs
            ki = sha256(sha256(sha256(ki.encode()).hexdigest()[:16].encode()).hexdigest()[:16].encode()).hexdigest()[:16]

            # Check for matching 16-byte blocks
            if ok == ki:

                if [olst.index(kk)*4 + (3-RKi), kk] not in res:

                    res += [[olst.index(kk)*4 + (3-RKi), kk]]

                if [olst.index(ok)*4 + (3-RKi), ok] not in res:

                    res += [[olst.index(ok)*4 + (3-RKi), ok]]

    found_RKs += [res]
    print('RK', RKi)

    for ri in res:
        print(ri)

    print()

RK 0
[79, '282885f5bcf4c433']
[271, '2e75d25cb01d0936']
[587, '759a7cf2fd3e2335']
[679, '25db17884d944100']
[747, '9a26e457a293a11d']

RK 1
[362, '7068f3322c6c8d77']
[382, 'a0d671e3e62923c3']
[646, '446e6b6954879b36']
[746, 'e091e79f2feff66a']
[870, 'b01febfb716b7694']

RK 2
[93, '73f0ba728086aedb']
[113, '302e9178c210d0e4']
[201, '3004a268c27f2cfe']
[289, 'f154d50db8eff779']
[293, 'acf51e7577aa4dc9']
[349, '332cbc6503cbf19d']
[353, 'dcdf9b1cf6004a32']
[641, 'f936915832dd3cbf']
[777, '53bbe6a5a2c71d45']
[789, '06255dc46f70ba54']
[861, '3b10eb5b714d4caf']
[993, 'be4638cf78c6c118']

RK 3
[156, '200311f7c4745a2a']
[436, '7c162a02f10ccc7a']
[588, 'b9ed7d99370d319b']
[800, '8da7a86d6f457676']
[968, '96663c3cb19857ed']


Looks like we found quite a few collisions, let’s line them up to the same index. The synced index will be the maximum of the mininum indices of every found round key collision.

iii = list(range(36,36000+1,36))
mininds = [i[0][0] for i in found_RKs]

sync_ind = max(mininds)
sync_rks = []

for i in range(4):

    ki = found_RKs[i][0][1]

    for _ in range( sum( [ 1 for i in range(iii[mininds[i]]+1, iii[max(mininds)]+1) if i % (4*12) == 0 ] ) ):

        ki = sha256(ki.encode()).hexdigest()[:16]

    sync_rks += [ki]

print(sync_ind, sync_rks)
362 ['a6a3185e68def646', '7068f3322c6c8d77', 'd2647070b21eedc4', '30bff585f280bc2a']

Now that we have a synced set of round keys, we can just iterate them to the beginning of the flag encryption and XOR its output with the encrypted flag!

# Setting up a local copy with the recover synced round keys
rec_jsg = JSG()

# Set round keys
rec_jsg.RKs = sync_rks

# Set index
rec_jsg.i = (sync_ind+1)*3*12

# Set state
rec_jsg.state = rec_jsg.slicer(sync_rks[ (rec_jsg.i // rec_jsg.spacing) % 4 ])

# Line it up to the flag encryption
rec_stream = rec_jsg.stream((1200-sync_ind-1)*8)
assert jsg.i == rec_jsg.i

# Little check if we have done it correctly
print(jsg.RKs)
print(rec_jsg.RKs)
['eb2f659d698b2da2', 'a82f725c8ef7297f', '8873205e2e629afd', '3060628f9dab8001']
['eb2f659d698b2da2', 'a82f725c8ef7297f', '8873205e2e629afd', '3060628f9dab8001']
assert jsg.xorstream(b'insert_flag_here') == rec_jsg.xorstream(b'insert_flag_here')

Ta-da!

flag{C0ngr4tul4t10ns_y0u_h4v3_w0n_th3_G4m3_0f_L1f3!}

Thanks for reading! <3

~ Polymero