diff --git a/vma.py b/vma.py index 820323c..7ceb470 100755 --- a/vma.py +++ b/vma.py @@ -84,6 +84,23 @@ class VmaHeader(): # make sure the file object points at the end of the vma header fo.seek(self.header_size, os.SEEK_SET) + # reread the header and generate a md5 checksum of the data + self.__gen_md5sum(fo) + + + def __gen_md5sum(self, fo): + p = fo.tell() + fo.seek(0, os.SEEK_SET) + h = hashlib.md5() + + data = fo.read(self.header_size) + data = data[:32] + b'\0' * 16 + data[48:] + h.update(data) + + self.generated_md5sum = h.digest() + + fo.seek(p, os.SEEK_SET) + class VmaDeviceInfoHeader(): def __init__(self, fo, vma_header): @@ -109,6 +126,8 @@ class VmaDeviceInfoHeader(): class VmaExtentHeader(): def __init__(self, fo, vma_header): + self.pos_start = fo.tell() + # 0 - 3: magic # VMA extent magic string ("VMAE") magic = fo.read(4) @@ -135,6 +154,24 @@ class VmaExtentHeader(): for i in range(59): self.blockinfo.append(Blockinfo(fo, vma_header)) + self.pos_end = fo.tell() + + self.__gen_md5sum(fo) + + + def __gen_md5sum(self, fo): + p = fo.tell() + fo.seek(self.pos_start, os.SEEK_SET) + h = hashlib.md5() + + data = fo.read(self.pos_end - self.pos_start) + data = data[:24] + b'\0' * 16 + data[40:] + h.update(data) + + self.generated_md5sum = h.digest() + + fo.seek(p, os.SEEK_SET) + class Blob(): def __init__(self, fo): @@ -201,6 +238,11 @@ def extract(fo, args): vma_header = VmaHeader(fo) + # check the md5 checksum given in the header with the value calculated from + # the file + if not args.skip_hash: + assert vma_header.md5sum == vma_header.generated_md5sum + extract_configs(fo, args, vma_header) # extract_configs may move the read head somewhere into the blob buffer @@ -228,6 +270,11 @@ def extract(fo, args): extent_header = VmaExtentHeader(fo, vma_header) assert vma_header.uuid == extent_header.uuid + # check the md5 checksum given in the header with the value calculated from + # the file + if not args.skip_hash: + assert extent_header.md5sum == extent_header.generated_md5sum + for blockinfo in extent_header.blockinfo: if blockinfo.dev_id == 0: continue @@ -283,7 +330,10 @@ def main(): parser.add_argument('filename', type=str) parser.add_argument('destination', type=str) parser.add_argument('-v', '--verbose', default=False, action='store_true') - parser.add_argument('-f', '--force', default=False, action='store_true') + parser.add_argument('-f', '--force', default=False, action='store_true', + help='overwrite target file if it exists') + parser.add_argument('--skip-hash', default=False, action='store_true', + help='do not perform md5 checksum test of data') args = parser.parse_args() if(not os.path.exists(args.filename)):