# Classy! # IDA script for keeping track of RE info related to C++ classes (specifically, # CodeWarrior PPC) and keeping things consistent throughout the database. # That's what makes IDA better than objdump anyway, right? :p # Started by Treeki: 16th October 2011 # Developed using Python 2.6 and IDA 6.1. import idaapi, cPickle, re print('Loading Classy') # DATABASE #DB_PATH = '/home/me/Dev/Reversing/nsmb2.cy' DB_PATH = 'Z:/stuff/Dev/Reversing/nsmb2.cy' DB_VERSION = 1 class ClassyDatabase(dict): def __init__(self, path): try: handle = open(path, 'rb') db = cPickle.load(handle) self.update(db) if 'version' not in db: # old DB self.version = 1 except IOError: pass except EOFError: pass self.path = path if 'version' not in self: # totally new DB self.initialise() else: self.upgrade() def __getattr__(self, key): try: return self[key] except KeyError: default = self.default_for(key) self[key] = default return self[key] def __setattr__(self, key, value): self[key] = value HASH_DEFAULTS = [ 'classes_by_name', 'classes_by_struct', 'known_methods', 'virtual_calls', ] NONE_DEFAULTS = [ 'last_class_name', ] def default_for(self, key): if key in self.HASH_DEFAULTS: return {} elif key in self.NONE_DEFAULTS: return None else: raise AttributeError def clear(self): path = self.path dict.clear(self) self.path = path self.initialise() def initialise(self): self.version = DB_VERSION def upgrade(self): pass def save(self): cPickle.dump(self, open(self.path, 'wb')) def setup_db(): global db db = ClassyDatabase(DB_PATH) # MENUS cy_menu_defs = [] def setup_menus(): global cy_menus try: cy_menus for item in cy_menus.values(): idaapi.del_menu_item(item) except: pass cy_menus = {} cy_menus['Classy'] = idaapi.add_menu_item( 'Edit/', 'Classy', '', 0, about_classy, None) for name, sequence, callback in cy_menu_defs: cy_menus[name] = idaapi.add_menu_item( 'Edit/Classy', "CY: %s" % name, sequence, 0x04000000, # SETMENU_CTXIDA callback, None) def menu_item(name, sequence=''): def custom_decorator(func): cy_menu_defs.append((name, sequence, func)) def wrapper(): try: func() except StandardError, e: Warning('Crap.') print e return wrapper return custom_decorator def about_classy(): print('Classy!') # CLASSES class Class(object): def __init__(self, name, struct, base): if type(struct) == idaapi.struc_t: struct = struct.id self.name = name self.mangled_name = self.mangle(name) self.struct = struct self.methods = [] self.vmethods = [] self.vmethods_by_offset = {} self.vtable = None self.base = base self.ctor = NullMethod(self) self.dtor = NullMethod(self) db.classes_by_name[name] = self db.classes_by_struct[struct] = self @staticmethod def mangle(name): components = name.split('::') if len(components) > 1: begin = "Q%d" % len(components) m_comps = ["%d%s" % (len(n), n) for n in components] return begin + ''.join(m_comps) else: return "%d%s" % (len(name), name) def rename(self, name): old_name = self.name del db.classes_by_name[old_name] db.classes_by_name[name] = self self.name = name self.mangled_name = self.mangle(name) self.refresh() def refresh(self): self.ctor.refresh() self.dtor.refresh() for m in self.methods: m.refresh() for m in self.vmethods: m.refresh() def set_vtable(self, ea): self.vtable = ea self.find_vtable_end() self.search_for_dtor() def set_vtable_end(self, ea): self.vtable_end = ea self.search_for_dtor() def find_vtable_end(self): check = self.vtable + 4 while Dword(check + 4) != 0: check += 4 print('Inferred vtable as %08x .. %08x' % (self.vtable, check)) self.vtable_end = check def search_for_dtor(self): if self.dtor.is_null(): return if self.dtor.is_virtual(): return ea = self.dtor.ea offset = self.vtable_contains(ea) if offset: self.dtor.unlink() # TODO: make this bit more DRY base, base_method = (self.base and self.base.find_virtual(offset)) or (None, None) if base_method is None: self.dtor = VirtualMethod(self, '__dt', ea, offset) else: self.dtor = OverrideMethod(self, '__dt', ea, offset, base_method) self.dtor.mangled_args = 'Fv' self.dtor.arg_signature = None # cannot be changed! self.dtor.refresh() def iter_vtable(self): check = self.vtable end = self.vtable_end while check <= end: yield (check, Dword(check)) check += 4 def detect_overrides(self): offset = 0 for addr, value in self.iter_vtable(): addr offset += 4 def find_virtual(self, offset): try: return (self, self.vmethods_by_offset[offset]) except KeyError: if self.base is None: return (self, None) else: return self.base.find_virtual(offset) def vtable_contains(self, ea): for addr, value in self.iter_vtable(): if value == ea: return addr - self.vtable return None def generate_cpp_definition(self): contents = [] vtable_found = False sid = self.struct s = idaapi.get_struc(sid) comments = (idaapi.get_struc_cmt(sid, 0), idaapi.get_struc_cmt(sid, 1)) for cmt in comments: if len(cmt) > 0: contents.append("\t/* %s */" % cmt) gen_offset = 0 # note: structs should always include their base class while ida_offset <= end_at: m = idaapi.get_member(s, ida_offset) mid = m.id comments = (idaapi.get_member_cmt(mid, 0), idaapi.get_member_cmt(mid, 1)) for cmt in comments: if len(cmt) > 0: contents.append("\t/* %s */" % cmt) # TODO: convert to swig api, figure out how the fuck I'll do this ida_offset = idaapi.get_struc_next_offset(s, ida_offset) base_info = (self.base and (" : public %s" % self.base.name)) or '' return "class %s%s {\n%s\n};" % (self.name, base_info, "\n".join(contents)) class Method(object): def __init__(self, owner, name, ea, register=True): self.owner = owner self.name = name self.arg_signature = '?' self.mangled_args = '1?' self.ea = ea if ea != BADADDR and ea != 0: db.known_methods[ea] = self if register: owner.methods.append(self) def __str__(self): return repr(self) def __repr__(self): if self.arg_signature is None: return "<%s %s::%s>" % (type(self), self.owner.name, self.name) else: return "<%s %s::%s(%s)>" % (type(self), self.owner.name, self.name, self.arg_signature) def is_null(self): return False def is_virtual(self): return False def rename(self, name): self.name = name self.refresh() def refresh(self): MakeName(self.ea, "%s__%s%s" % (self.name, self.owner.mangled_name, self.mangled_args)) # figure out if any virtual xrefs include us # a bit slow, but who cares call_dict = db.virtual_calls func = idaapi.get_func(self.ea) check = func.startEA end = func.endEA done_already = {} while check <= end: if check in call_dict: method = call_dict[check] if method not in done_already: method.refresh_comment() done_already[method] = True check += 4 def unlink(self, orphan=False): if self.name == '__ct': self.owner.ctor = NullMethod(self.owner) elif self.name == '__dt': self.owner.dtor = NullMethod(self.owner) if self in self.owner.methods: self.owner.methods.remove(self) self.owner = None if self.ea != 0 and self.ea != BADADDR: del db.known_methods[self.ea] if orphan: MakeName(self.ea, '') def set_signature(self, sig): if self.arg_signature is None: raise TypeError, "this method has a fixed signature" mang = mangle_args(sig) self.arg_signature = sig self.mangled_args = mang self.refresh() class VirtualMethod(Method): def __init__(self, owner, name, ea, vt_offset): Method.__init__(self, owner, name, ea, False) self.vt_offset = vt_offset self.virtual_calls = [] self.overrides = [] owner.vmethods.append(self) owner.vmethods_by_offset[vt_offset] = self def is_virtual(self): return True def unlink(self, orphan=False): self.owner.vmethods.remove(self) del self.owner.vmethods_by_offset[self.vt_offset] for m in self.overrides: m.unlink(True) Method.unlink(self, orphan) def rename(self, name): Method.rename(self, name) for o in self.overrides: o.rename(name, False) o.refresh() def set_signature(self, sig): Method.set_signature(self, sig) for o in self.overrides: o.copy_signature(self) o.refresh() def add_xref(self, ea): self.virtual_calls.append(ea) self.virtual_calls.sort() self.refresh_xref(ea) self.refresh_comment() db.virtual_calls[ea] = self def remove_xref(self, ea): self.virtual_calls.remove(ea) self.refresh_comment() MakeComm(ea, '') del db.virtual_calls[ea] def refresh(self): Method.refresh(self) self.refresh_comment() for ea in self.virtual_calls: self.refresh_xref(ea) def refresh_xref(self, ea): MakeComm(ea, "%08X %s::%s(%s)" % (self.ea, self.owner.name, self.name, self.arg_signature)) def get_comment(self): lines = [] if len(self.overrides) > 0: lines.append('OVERRIDDEN BY:') for method in self.overrides: lines.append(' - %08X : %s' % (method.ea, method.owner.name)) lines.append('') if len(self.virtual_calls) > 0: lines.append('VIRTUAL METHOD CALLS:') for ea in self.virtual_calls: func = idaapi.get_func(ea) name = idaapi.demangle_name(idaapi.get_func_name(func.startEA), 0) offset = ea - func.startEA lines.append(' - %08X : %s + %X' % (ea, name, offset)) lines.append('') return "\n".join(lines) def refresh_comment(self): SetFunctionCmt(self.ea, self.get_comment(), 0) class OverrideMethod(VirtualMethod): def __init__(self, owner, name, ea, vt_offset, base): VirtualMethod.__init__(self, owner, name, ea, vt_offset) if ea == BADADDR || ea == 0: raise ValueError("An override method can't be pure!!") self.base = base self.original = (hasattr(base, 'original') and base.original) or base self.copy_signature(self.original) owner.vmethods.append(self) owner.vmethods_by_offset[vt_offset] = self base.overrides.append(self) def unlink(self, orphan=False): if self in self.base.overrides: self.base.overrides.remove(self) VirtualMethod.unlink(self, orphan) def copy_signature(self, other): self.arg_signature = other.arg_signature self.mangled_args = other.mangled_args for m in self.overrides: m.copy_signature(other) def rename(self, name, delegate_to_original=True): # a strange little hack if delegate_to_original: self.original.rename(name) else: VirtualMethod.rename(self, name) def set_signature(self, sig): self.original.set_signature(sig) def refresh(self): VirtualMethod.refresh(self) self.base.refresh_comment() def get_comment(self): return "ORIGINAL: %08X : %s\nBASE: %08X : %s\n\n%s" % \ (self.original.ea, self.original.owner.name, self.base.ea, self.base.owner.name, VirtualMethod.get_comment(self)) class NullMethod(Method): def __init__(self, owner): Method.__init__(self, owner, '', BADADDR, False) def is_null(self): return True def refresh(self): pass def unlink(self): pass # UTILITY FUNCTIONS class ClassChooser(Choose2): def __init__(self): Choose2.__init__(self, 'Choose a class', [['Class', 40]]) def OnClose(self): pass def OnGetLine(self, n): return [self.clslist[n]] def OnGetSize(self): return len(self.clslist) _class_chooser = ClassChooser() def choose_class(title=None, optional=False): names = db.classes_by_name.keys() base_index = 0 if optional: names.insert(0, '[[ None ]]') base_index = 1 _class_chooser.title = title or 'Choose a class' try: _class_chooser.deflt = names.index(db.last_class_name) + (2 if optional else 1) except StandardError: _class_chooser.deflt = 0 _class_chooser.clslist = names num = _class_chooser.Show(True) if num < base_index: return None else: choice = names[num] db.last_class_name = choice return db.classes_by_name[choice] def unlink_method_if_exists(ea): if ea in db.known_methods: db.known_methods[ea].unlink() def get_current_function(): ea = idaapi.get_screen_ea() func = idaapi.get_func(ea) if func is None: # try whatever's pointed to by this func = idaapi.get_func(Dword(ea)) if func is None: # and if it's still none ... Warning('Place the cursor on top of a function.') raise ValueError return func.startEA # DO STUFF @menu_item('Create Class') def create_class(): name = idaapi.askstr(idaapi.HIST_IDENT, '', 'Enter a class name') if name in db.classes_by_name: Warning('That name is already used.') return if name is None: return struct = idaapi.choose_struc('Associate this class with a struct') if struct is not None and struct.id in db.classes_by_struct: Warning('That struct is already used.') return base = choose_class('Choose the base class for this one', True) if struct is None: safe_name = name.replace(':', '_') struct = idaapi.get_struc_id(safe_name) if struct == BADADDR: struct = idaapi.add_struc(BADADDR, safe_name, 0) if struct == -1: Warning('Oops.') return print("Created struct: %s" % safe_name) if base != None: AddStrucMember(struct, '_', 0, FF_DATA | FF_STRU, base.struct, GetStrucSize(base.struct)) else: print("Used existing struct: %s" % safe_name) cls = Class(name, struct, base) print("Done!") Refresh() @menu_item('Delete Class') def delete_class(): cls = choose_class('Choose the class to delete', False) if cls != None: name = cls.name del db.classes_by_struct[cls.struct] del db.classes_by_name[cls.name] print("Deleted class %s" % name) else: Warning("No class chosen") Refresh() @menu_item('Set VTable') def set_vtable(): cls = choose_class() if cls is None: return ea = idaapi.get_screen_ea() cls.set_vtable(ea) print("Set vtable for %s to %08x" % (cls.name, ea)) Refresh() @menu_item('Set VTable End') def set_vtable_end(): cls = choose_class() if cls is None: return ea = idaapi.get_screen_ea() if ea < cls.vtable: Warning("The virtual table cannot go backwards. It starts at %08x." % cls.vtable) return cls.set_vtable_end(ea) print("Set vtable end for %s to %08x" % (cls.name, ea)) Refresh() @menu_item('Set Ctor') def set_ctor(): cls = choose_class() if cls is None: return ea = get_current_function() unlink_method_if_exists(ea) cls.ctor = Method(cls, '__ct', ea) cls.ctor.refresh() print("Set ctor for %s to %08x" % (cls.name, ea)) Refresh() @menu_item('Set Dtor') def set_dtor(): cls = choose_class() if cls is None: return ea = get_current_function() unlink_method_if_exists(ea) # is it virtual? v_offset = (cls.vtable and cls.vtable_contains(ea)) if v_offset: base, base_method = (cls.base and cls.base.find_virtual(v_offset)) or (None, None) if base_method is None: cls.dtor = VirtualMethod(cls, '__dt', ea, v_offset) else: cls.dtor = OverrideMethod(cls, '__dt', ea, v_offset, base_method) else: cls.dtor = Method(cls, '__dt', ea) cls.dtor.mangled_args = 'Fv' cls.dtor.arg_signature = None # cannot be changed! cls.dtor.refresh() print("Set dtor for %s to %08x" % (cls.name, ea)) Refresh() @menu_item('Register as Method', 'Shift+M') def register_method(): cls = choose_class() if cls is None: return ea = get_current_function() unlink_method_if_exists(ea) name = idaapi.askstr(idaapi.HIST_IDENT, 's_%08X' % ea, 'Enter a method name') if name is None: return method = Method(cls, name, ea) method.refresh() print("Method %s::%s() created at %08x" % (cls.name, name, ea)) Refresh() @menu_item('Set Method Arguments', 'Shift+A') def set_method_args(): ea = get_current_function() try: method = db.known_methods[ea] except KeyError: Warning("No method at this address (%08X)." % ea) return if method.arg_signature is None: Warning("This method's argument signature cannot be changed") return arg = idaapi.askstr(idaapi.HIST_TYPE, method.arg_signature, 'Enter the arguments') if arg is None: return method.set_signature(arg) print("Method %s::%s() set args to %s" % (method.owner.name, method.name, method.mangled_args)) Refresh() @menu_item('Register as Virtual Method', 'Shift+V') def register_virtual_method(): cls = choose_class() if cls is None: return if cls.vtable is None: Warning('This class does not have a virtual table defined.') return ea = get_current_function() unlink_method_if_exists(ea) # try to find it within the vtable offset = cls.vtable_contains(ea) if offset is None: Warning("Could not find this method within the vtable for %s." % cls.name) print("Virtual table for %s: %08x .. %08x" % (cls.name, cls.vtable, cls.vtable_end)) return # check to make sure it's not defined anywhere in a base class base, base_method = (cls.base and cls.base.find_virtual(offset)) or (None, None) if base_method is None: name = idaapi.askstr(idaapi.HIST_IDENT, 'vf%02X' % offset, 'Enter a method name') if name is None: return method = VirtualMethod(cls, name, ea, offset) else: method = OverrideMethod(cls, base_method.name, ea, offset, base_method) method.refresh() print("Virtual Method %s::%s() created at %08x with offset %x" % (cls.name, method.name, ea, offset)) Refresh() @menu_item('Rename Method', 'Shift+N') def rename_method(): ea = get_current_function() if ea in db.known_methods: method = db.known_methods[ea] name = idaapi.askstr(idaapi.HIST_IDENT, method.name, 'Enter a method name') if name is None: return method.rename(name) Refresh() else: print("Don't know about a method here") @menu_item('Unlink Method', 'Shift+U') def unlink_method(): ea = get_current_function() if ea in db.known_methods: db.known_methods[ea].unlink(True) Refresh() else: print("Don't know about a method here") @menu_item('Mark Virtual Call', 'Shift+C') def mark_virtual_call(): ea = idaapi.get_screen_ea() inslen = idaapi.decode_insn(ea) if inslen == 0: raise ValueError op = idaapi.cmd.Operands[1] if not op: raise ValueError if op.type != idaapi.o_displ: raise ValueError chosen_cls = choose_class('Choose the class for this method') offset = op.addr cls, method = chosen_cls.find_virtual(offset) if method is None: Warning("The class %s and its base classes do not have this method defined." % chosen_cls.name) return if ea in db.virtual_calls: db.virtual_calls[ea].remove_xref(ea) original = (hasattr(method, 'original') and method.original) or method original.add_xref(ea) Refresh() @menu_item('Unmark Virtual Call', 'Shift+Ctrl+C') def unmark_virtual_call(): ea = idaapi.get_screen_ea() if ea in db.virtual_calls: db.virtual_calls[ea].remove_xref(ea) Refresh() @menu_item('Save') def save_db(): db.save() print('Saved the database.') @menu_item('Reset DB') def reset_db(): if idaapi.askyn_c(0, 'Are you really sure you want to lose EVERYTHING?') == 1: db.clear() print('Done. You asked for it ...') MANGLED_TYPES = { 's32': 'i', 'int': 'i', 'u32': 'Ui', 'uint': 'Ui', 'unsigned int': 'Ui', 's16': 's', 'short': 's', 'u16': 'Us', 'ushort': 'Us', 'unsigned short': 'Us', 's8': 'c', 'char': 'c', 'u8': 'Uc', 'uchar': 'Uc', 'unsigned char': 'Uc', 'bool': 'b', 'wchar_t': 'w', 'f32': 'f', 'float': 'f', 'f64': 'd', 'double': 'd', 'void': 'v', } ARRAY_REGEX = re.compile(r'\[(\d+)\]$') # TYPE DECLARATIONS def mangle_args(string): pieces = [p.strip() for p in string.split(',')] # 'CF' for const methods output = ['F'] for p in pieces: type_name = p prefix, postfix = '', '' while True: if type_name.endswith('&'): type_name = type_name[:-1].strip() prefix = 'R' + prefix elif type_name.endswith('*'): type_name = type_name[:-1].strip() prefix = 'P' + prefix elif type_name.endswith('const'): type_name = type_name[:-5].strip() prefix = 'C' + prefix else: match = ARRAY_REGEX.search(type_name) if match: type_name = type_name[:type_name.rindex('[')].strip() prefix = 'A%s_%s' % (match.group(1), prefix) else: # nothing, we must be done break try: type_def = MANGLED_TYPES[type_name] except KeyError: type_def = Class.mangle(type_name) output.append(prefix + type_def + postfix) if len(pieces) == 0: output.append('v') return ''.join(output) # FINAL SETUP setup_menus() setup_db() print('Done!')