Code with Finding: |
class SharedKeyCryptoComm {
/**
* RETURNS NULL IF CHECKSUM CHECK FAILS!!!
* @throws ConnectionException
*/
public static byte[] receiveBytes(InputStream is, Cipher c, SecretKey sk,
BigInteger recvNonce) throws ConnectionException {
int blockSize = c.getBlockSize();
byte[] checksum = new byte[MD5CHECKSUMLEN]; //MD5
byte[] recvnonce = new byte [NONCE_LENGTH];
byte[] expctnonce = Arrays.copyOf(recvNonce.toByteArray(), NONCE_LENGTH);
byte[] iv = new byte[blockSize];
byte[] size = new byte[4]; //int
//first fetch the checksum
if (!readIntoBuffer(is, checksum)) {
System.out.println("Error/Timeout receiving the message. (checksum)");
System.out.println("Closing the connection...");
throw new ConnectionException();
}
if (!readIntoBuffer(is, recvnonce)) {
System.out.println("Error/Timeout receiving the message. (recvnonce)");
System.out.println("Closing the connection...");
throw new ConnectionException();
}
//fetch iv
if (!readIntoBuffer(is, iv)) {
System.out.println("Error/Timeout receiving the message. (iv)");
System.out.println("Closing the connection...");
throw new ConnectionException();
}
//fetch size of enc msg
if (!readIntoBuffer(is, size)) {
System.out.println("Error/Timeout receiving the message. (encmsglen)");
System.out.println("Closing the connection...");
throw new ConnectionException();
}
int encmsglen = ByteBuffer.wrap(size).getInt();
byte[] encmsg = new byte[encmsglen];
//read the actual message in
if (!readIntoBuffer(is, encmsg)) {
System.out.println("Error/Timeout receiving the message. (encmsg)");
System.out.println("Closing the connection...");
throw new ConnectionException();
}
IvParameterSpec ivp = new IvParameterSpec(iv);
try {
c.init(Cipher.DECRYPT_MODE, sk, ivp);
}
catch (Exception e) {/*cannot happen*/}
byte[] msgbytes = null;
try {
msgbytes = c.doFinal(encmsg);
//msg = new String(msgbytes, "UTF8");
} catch (Exception e) {
e.printStackTrace(); //this should not happen
}
//generate checksum of received msg.
byte[] wholeMessage = new byte[NONCE_LENGTH + iv.length + size.length + msgbytes.length];
System.arraycopy(recvnonce, 0, wholeMessage, 0, NONCE_LENGTH);
System.arraycopy(iv, 0, wholeMessage, NONCE_LENGTH, iv.length);
System.arraycopy(size, 0, wholeMessage, NONCE_LENGTH + iv.length, size.length);
System.arraycopy(msgbytes, 0, wholeMessage, NONCE_LENGTH + iv.length + size.length, msgbytes.length);
//compare the checksum received to the generated checksum.
if (Arrays.equals(checksum, Hash.generateChecksum(wholeMessage)) &&
Arrays.equals(recvnonce, expctnonce)) {
// zero out the wholeMessage array
Arrays.fill(wholeMessage, (byte)0x00);
return msgbytes;
}
if (!Arrays.equals(checksum, Hash.generateChecksum(wholeMessage))) {
System.out.println("Generated checksum for message does not equal the received checksum!");
}
else {
System.out.println("Received nonce for the message does not equal the expected nonce!");
}
System.out.println("Closing the connection...");
throw new ConnectionException();
}
}
|