diff options
-rw-r--r-- | README | 75 | ||||
-rw-r--r-- | classy.py | 886 |
2 files changed, 961 insertions, 0 deletions
@@ -0,0 +1,75 @@ +Classy! + +IDA script for keeping track of RE info related to C++ classes (specifically, + CodeWarrior PPC) and keeping things consistent throughout the database. + +That's what makes IDA better than objdump anyway, right? :p + +Started by Treeki: 16th October 2011 +Developed using Python 2.6 and IDA 6.1. + + + +Supported so far: +- Save and reset database +- Classes: + - Create a class using an IDA struct + - Base classes are handled using a struct field named _ + - Struct is automatically created if one is not chosen - if this is used, + Classy adds the _ field for you +- Virtual Tables: + - Set a VTable for a class. End is automatically detected + - If the auto detection fails (pure virtuals, for example) you can + manually override it +- CTors/DTors: + - Set a ctor and dtor for a class (TODO: Multiple ctors) + - Dtors automatically become virtual if a pointer to them is found in the + vtable, even if the dtor was registered before the vtable was set +- Methods: + - Register regular and virtual methods + - Mark virtual calls by highlighting the "lwz r12, ..." line, pressing + Shift+C and choosing a class + - Xrefs to each virtual method are shown at the top of the method body + - Set method arguments including pointers, arrays, refs and consts + - Arguments are automatically mangled so they are shown in IDA names + - Shift+V automatically detects whether a method is a new virtual one or + an override + - Overrides are shown at the top of the original method body + - Renames and argument list changes are propagated from original methods + to overrides (Though not vice versa yet) + +That's all for now, I think... + + + + +Current Todo: +- Only show classes that fit the criteria in the chooser for Shift+C +- Support PTMFs, too +- Automatically create virtual methods using the vtable. +- Named arguments to methods +- Return types for methods +- Const methods +- Improve vtable end detection heuristics +- Support pure virtuals. +- Handle IDA chooser history/defaults correctly everywhere +- Better keybindings +- Rename class menu item +- Netnodes +- Override dtors +- Sanity check to see if a virtual method might exist in a base class +- Confirm unlinking of methods +- Don't let you create more than one method with the same name +- choose_class differentiation between "no class" and Cancel +- Auto usage of struct if a named one exists +- Remove virtual xrefs + +Caveats: +- Multiple inheritance will not be supported in the near future. +- Probably won't work as well with non-CodeWarrior stuff. + +Notes: +- When autocreating virtual methods, start with the base class!! + Otherwise, you might end up defining an override as a virtual + method. And that would be bad. + diff --git a/classy.py b/classy.py new file mode 100644 index 0000000..b032e94 --- /dev/null +++ b/classy.py @@ -0,0 +1,886 @@ +# Classy! + +# IDA script for keeping track of RE info related to C++ classes (specifically, +# CodeWarrior PPC) and keeping things consistent throughout the database. + +# That's what makes IDA better than objdump anyway, right? :p + +# Started by Treeki: 16th October 2011 +# Developed using Python 2.6 and IDA 6.1. + +import idaapi, cPickle, re + +print('Loading Classy') + + +# DATABASE +#DB_PATH = '/home/me/Dev/Reversing/nsmb2.cy' +DB_PATH = 'Z:/stuff/Dev/Reversing/nsmb2.cy' +DB_VERSION = 1 + +class ClassyDatabase(dict): + def __init__(self, path): + try: + handle = open(path, 'rb') + db = cPickle.load(handle) + self.update(db) + + if 'version' not in db: + # old DB + self.version = 1 + except IOError: + pass + except EOFError: + pass + + self.path = path + + if 'version' not in self: + # totally new DB + self.initialise() + else: + self.upgrade() + + def __getattr__(self, key): + try: + return self[key] + except KeyError: + default = self.default_for(key) + self[key] = default + return self[key] + + def __setattr__(self, key, value): + self[key] = value + + HASH_DEFAULTS = [ + 'classes_by_name', 'classes_by_struct', + 'known_methods', 'virtual_calls', + ] + + NONE_DEFAULTS = [ + 'last_class_name', + ] + + def default_for(self, key): + if key in self.HASH_DEFAULTS: + return {} + elif key in self.NONE_DEFAULTS: + return None + else: + raise AttributeError + + def clear(self): + path = self.path + dict.clear(self) + self.path = path + self.initialise() + + def initialise(self): + self.version = DB_VERSION + + def upgrade(self): + pass + + def save(self): + cPickle.dump(self, open(self.path, 'wb')) + +def setup_db(): + global db + db = ClassyDatabase(DB_PATH) + +# MENUS +cy_menu_defs = [] + +def setup_menus(): + global cy_menus + + try: + cy_menus + for item in cy_menus.values(): + idaapi.del_menu_item(item) + except: + pass + + cy_menus = {} + cy_menus['Classy'] = idaapi.add_menu_item( + 'Edit/', 'Classy', '', 0, about_classy, None) + + for name, sequence, callback in cy_menu_defs: + cy_menus[name] = idaapi.add_menu_item( + 'Edit/Classy', "CY: %s" % name, sequence, + 0x04000000, # SETMENU_CTXIDA + callback, None) + + +def menu_item(name, sequence=''): + def custom_decorator(func): + cy_menu_defs.append((name, sequence, func)) + + def wrapper(): + try: + func() + except StandardError, e: + Warning('Crap.') + print e + + return wrapper + return custom_decorator + + +def about_classy(): + print('Classy!') + + +# CLASSES +class Class(object): + def __init__(self, name, struct, base): + if type(struct) == idaapi.struc_t: + struct = struct.id + + self.name = name + self.mangled_name = self.mangle(name) + self.struct = struct + self.methods = [] + self.vmethods = [] + self.vmethods_by_offset = {} + self.vtable = None + self.base = base + + self.ctor = NullMethod(self) + self.dtor = NullMethod(self) + + db.classes_by_name[name] = self + db.classes_by_struct[struct] = self + + @staticmethod + def mangle(name): + components = name.split('::') + if len(components) > 1: + begin = "Q%d" % len(components) + m_comps = ["%d%s" % (len(n), n) for n in components] + return begin + ''.join(m_comps) + else: + return "%d%s" % (len(name), name) + + def rename(self, name): + old_name = self.name + + del db.classes_by_name[old_name] + db.classes_by_name[name] = self + + self.name = name + self.mangled_name = self.mangle(name) + + self.refresh() + + def refresh(self): + self.ctor.refresh() + self.dtor.refresh() + + for m in self.methods: + m.refresh() + for m in self.vmethods: + m.refresh() + + def set_vtable(self, ea): + self.vtable = ea + self.find_vtable_end() + self.search_for_dtor() + + def set_vtable_end(self, ea): + self.vtable_end = ea + self.search_for_dtor() + + def find_vtable_end(self): + check = self.vtable + 4 + + while Dword(check + 4) != 0: + check += 4 + + print('Inferred vtable as %08x .. %08x' % (self.vtable, check)) + self.vtable_end = check + + def search_for_dtor(self): + if self.dtor.is_null(): return + if self.dtor.is_virtual(): return + + ea = self.dtor.ea + offset = self.vtable_contains(ea) + if offset: + self.dtor.unlink() + + # TODO: make this bit more DRY + base, base_method = (self.base and self.base.find_virtual(offset)) or (None, None) + if base_method is None: + self.dtor = VirtualMethod(self, '__dt', ea, offset) + else: + self.dtor = OverrideMethod(self, '__dt', ea, offset, base_method) + + self.dtor.mangled_args = 'Fv' + self.dtor.arg_signature = None # cannot be changed! + self.dtor.refresh() + + def iter_vtable(self): + check = self.vtable + end = self.vtable_end + + while check <= end: + yield (check, Dword(check)) + check += 4 + + def detect_overrides(self): + offset = 0 + + for addr, value in self.iter_vtable(): + addr + offset += 4 + + + def find_virtual(self, offset): + try: + return (self, self.vmethods_by_offset[offset]) + except KeyError: + if self.base is None: + return (self, None) + else: + return self.base.find_virtual(offset) + + def vtable_contains(self, ea): + for addr, value in self.iter_vtable(): + if value == ea: + return addr - self.vtable + + return None + + def generate_cpp_definition(self): + contents = [] + + vtable_found = False + + sid = self.struct + s = idaapi.get_struc(sid) + + comments = (idaapi.get_struc_cmt(sid, 0), idaapi.get_struc_cmt(sid, 1)) + for cmt in comments: + if len(cmt) > 0: contents.append("\t/* %s */" % cmt) + + gen_offset = 0 + + # note: structs should always include their base class + while ida_offset <= end_at: + m = idaapi.get_member(s, ida_offset) + mid = m.id + + comments = (idaapi.get_member_cmt(mid, 0), idaapi.get_member_cmt(mid, 1)) + for cmt in comments: + if len(cmt) > 0: contents.append("\t/* %s */" % cmt) + + # TODO: convert to swig api, figure out how the fuck I'll do this + + ida_offset = idaapi.get_struc_next_offset(s, ida_offset) + + base_info = (self.base and (" : public %s" % self.base.name)) or '' + return "class %s%s {\n%s\n};" % (self.name, base_info, "\n".join(contents)) + + +class Method(object): + def __init__(self, owner, name, ea, register=True): + self.owner = owner + self.name = name + self.arg_signature = '?' + self.mangled_args = '1?' + self.ea = ea + + if ea != BADADDR and ea != 0: + db.known_methods[ea] = self + + if register: + owner.methods.append(self) + + def __str__(self): + return repr(self) + def __repr__(self): + if self.arg_signature is None: + return "<%s %s::%s>" % (type(self), self.owner.name, self.name) + else: + return "<%s %s::%s(%s)>" % (type(self), self.owner.name, self.name, self.arg_signature) + + def is_null(self): return False + def is_virtual(self): return False + + def rename(self, name): + self.name = name + self.refresh() + + def refresh(self): + MakeName(self.ea, "%s__%s%s" % (self.name, self.owner.mangled_name, self.mangled_args)) + + # figure out if any virtual xrefs include us + # a bit slow, but who cares + call_dict = db.virtual_calls + func = idaapi.get_func(self.ea) + check = func.startEA + end = func.endEA + + done_already = {} + + while check <= end: + if check in call_dict: + method = call_dict[check] + if method not in done_already: + method.refresh_comment() + done_already[method] = True + check += 4 + + def unlink(self, orphan=False): + if self.name == '__ct': + self.owner.ctor = NullMethod(self.owner) + elif self.name == '__dt': + self.owner.dtor = NullMethod(self.owner) + + if self in self.owner.methods: + self.owner.methods.remove(self) + + self.owner = None + if self.ea != 0 and self.ea != BADADDR: + del db.known_methods[self.ea] + + if orphan: + MakeName(self.ea, '') + + def set_signature(self, sig): + if self.arg_signature is None: + raise TypeError, "this method has a fixed signature" + + mang = mangle_args(sig) + self.arg_signature = sig + self.mangled_args = mang + self.refresh() + + +class VirtualMethod(Method): + def __init__(self, owner, name, ea, vt_offset): + Method.__init__(self, owner, name, ea, False) + + self.vt_offset = vt_offset + self.virtual_calls = [] + self.overrides = [] + + owner.vmethods.append(self) + owner.vmethods_by_offset[vt_offset] = self + + def is_virtual(self): return True + + def unlink(self, orphan=False): + self.owner.vmethods.remove(self) + del self.owner.vmethods_by_offset[self.vt_offset] + + for m in self.overrides: + m.unlink(True) + + Method.unlink(self, orphan) + + def rename(self, name): + Method.rename(self, name) + for o in self.overrides: + o.rename(name, False) + o.refresh() + + def set_signature(self, sig): + Method.set_signature(self, sig) + for o in self.overrides: + o.copy_signature(self) + o.refresh() + + def add_xref(self, ea): + self.virtual_calls.append(ea) + self.virtual_calls.sort() + self.refresh_xref(ea) + self.refresh_comment() + db.virtual_calls[ea] = self + + def remove_xref(self, ea): + self.virtual_calls.remove(ea) + self.refresh_comment() + MakeComm(ea, '') + del db.virtual_calls[ea] + + def refresh(self): + Method.refresh(self) + self.refresh_comment() + for ea in self.virtual_calls: + self.refresh_xref(ea) + + def refresh_xref(self, ea): + MakeComm(ea, "%08X %s::%s(%s)" % (self.ea, self.owner.name, self.name, self.arg_signature)) + + def get_comment(self): + lines = [] + if len(self.overrides) > 0: + lines.append('OVERRIDDEN BY:') + for method in self.overrides: + lines.append(' - %08X : %s' % (method.ea, method.owner.name)) + lines.append('') + + if len(self.virtual_calls) > 0: + lines.append('VIRTUAL METHOD CALLS:') + for ea in self.virtual_calls: + func = idaapi.get_func(ea) + name = idaapi.demangle_name(idaapi.get_func_name(func.startEA), 0) + offset = ea - func.startEA + lines.append(' - %08X : %s + %X' % (ea, name, offset)) + lines.append('') + + return "\n".join(lines) + + def refresh_comment(self): + SetFunctionCmt(self.ea, self.get_comment(), 0) + + +class OverrideMethod(VirtualMethod): + def __init__(self, owner, name, ea, vt_offset, base): + VirtualMethod.__init__(self, owner, name, ea, vt_offset) + + if ea == BADADDR || ea == 0: + raise ValueError("An override method can't be pure!!") + + self.base = base + self.original = (hasattr(base, 'original') and base.original) or base + self.copy_signature(self.original) + + owner.vmethods.append(self) + owner.vmethods_by_offset[vt_offset] = self + + base.overrides.append(self) + + def unlink(self, orphan=False): + if self in self.base.overrides: + self.base.overrides.remove(self) + + VirtualMethod.unlink(self, orphan) + + def copy_signature(self, other): + self.arg_signature = other.arg_signature + self.mangled_args = other.mangled_args + + for m in self.overrides: + m.copy_signature(other) + + def rename(self, name, delegate_to_original=True): + # a strange little hack + if delegate_to_original: + self.original.rename(name) + else: + VirtualMethod.rename(self, name) + + def set_signature(self, sig): + self.original.set_signature(sig) + + def refresh(self): + VirtualMethod.refresh(self) + self.base.refresh_comment() + + def get_comment(self): + return "ORIGINAL: %08X : %s\nBASE: %08X : %s\n\n%s" % \ + (self.original.ea, self.original.owner.name, + self.base.ea, self.base.owner.name, + VirtualMethod.get_comment(self)) + + +class NullMethod(Method): + def __init__(self, owner): + Method.__init__(self, owner, '', BADADDR, False) + + def is_null(self): return True + + def refresh(self): pass + def unlink(self): pass + + +# UTILITY FUNCTIONS +class ClassChooser(Choose2): + def __init__(self): + Choose2.__init__(self, 'Choose a class', [['Class', 40]]) + + def OnClose(self): pass + def OnGetLine(self, n): return [self.clslist[n]] + def OnGetSize(self): return len(self.clslist) + + +_class_chooser = ClassChooser() + +def choose_class(title=None, optional=False): + names = db.classes_by_name.keys() + base_index = 0 + if optional: + names.insert(0, '[[ None ]]') + base_index = 1 + + _class_chooser.title = title or 'Choose a class' + try: + _class_chooser.deflt = names.index(db.last_class_name) + (2 if optional else 1) + except StandardError: + _class_chooser.deflt = 0 + _class_chooser.clslist = names + + num = _class_chooser.Show(True) + + if num < base_index: + return None + else: + choice = names[num] + db.last_class_name = choice + return db.classes_by_name[choice] + + +def unlink_method_if_exists(ea): + if ea in db.known_methods: + db.known_methods[ea].unlink() + + +def get_current_function(): + ea = idaapi.get_screen_ea() + func = idaapi.get_func(ea) + + if func is None: + # try whatever's pointed to by this + func = idaapi.get_func(Dword(ea)) + + if func is None: + # and if it's still none ... + Warning('Place the cursor on top of a function.') + raise ValueError + + return func.startEA + + +# DO STUFF +@menu_item('Create Class') +def create_class(): + name = idaapi.askstr(idaapi.HIST_IDENT, '', 'Enter a class name') + if name in db.classes_by_name: + Warning('That name is already used.') + return + + if name is None: + return + + struct = idaapi.choose_struc('Associate this class with a struct') + if struct is not None and struct.id in db.classes_by_struct: + Warning('That struct is already used.') + return + + base = choose_class('Choose the base class for this one', True) + + if struct is None: + safe_name = name.replace(':', '_') + + struct = idaapi.get_struc_id(safe_name) + if struct == BADADDR: + struct = idaapi.add_struc(BADADDR, safe_name, 0) + if struct == -1: + Warning('Oops.') + return + + print("Created struct: %s" % safe_name) + + if base != None: + AddStrucMember(struct, '_', 0, FF_DATA | FF_STRU, base.struct, GetStrucSize(base.struct)) + + else: + print("Used existing struct: %s" % safe_name) + + cls = Class(name, struct, base) + + print("Done!") + Refresh() + + +@menu_item('Set VTable') +def set_vtable(): + cls = choose_class() + if cls is None: return + + ea = idaapi.get_screen_ea() + cls.set_vtable(ea) + + print("Set vtable for %s to %08x" % (cls.name, ea)) + Refresh() + + +@menu_item('Set VTable End') +def set_vtable_end(): + cls = choose_class() + if cls is None: return + + ea = idaapi.get_screen_ea() + if ea < cls.vtable: + Warning("The virtual table cannot go backwards. It starts at %08x." % cls.vtable) + return + + cls.set_vtable_end(ea) + + print("Set vtable end for %s to %08x" % (cls.name, ea)) + Refresh() + + +@menu_item('Set Ctor') +def set_ctor(): + cls = choose_class() + if cls is None: return + + ea = get_current_function() + unlink_method_if_exists(ea) + cls.ctor = Method(cls, '__ct', ea) + cls.ctor.refresh() + + print("Set ctor for %s to %08x" % (cls.name, ea)) + Refresh() + + +@menu_item('Set Dtor') +def set_dtor(): + cls = choose_class() + if cls is None: return + + ea = get_current_function() + unlink_method_if_exists(ea) + + # is it virtual? + v_offset = (cls.vtable and cls.vtable_contains(ea)) + if v_offset: + base, base_method = (cls.base and cls.base.find_virtual(v_offset)) or (None, None) + if base_method is None: + cls.dtor = VirtualMethod(cls, '__dt', ea, v_offset) + else: + cls.dtor = OverrideMethod(cls, '__dt', ea, v_offset, base_method) + else: + cls.dtor = Method(cls, '__dt', ea) + + cls.dtor.mangled_args = 'Fv' + cls.dtor.arg_signature = None # cannot be changed! + cls.dtor.refresh() + + print("Set dtor for %s to %08x" % (cls.name, ea)) + Refresh() + + +@menu_item('Register as Method', 'Shift+M') +def register_method(): + cls = choose_class() + if cls is None: return + + ea = get_current_function() + unlink_method_if_exists(ea) + + name = idaapi.askstr(idaapi.HIST_IDENT, 's_%08X' % ea, 'Enter a method name') + if name is None: return + method = Method(cls, name, ea) + method.refresh() + + print("Method %s::%s() created at %08x" % (cls.name, name, ea)) + Refresh() + + +@menu_item('Set Method Arguments', 'Shift+A') +def set_method_args(): + ea = get_current_function() + + try: + method = db.known_methods[ea] + except KeyError: + Warning("No method at this address (%08X)." % ea) + return + + if method.arg_signature is None: + Warning("This method's argument signature cannot be changed") + return + + arg = idaapi.askstr(idaapi.HIST_TYPE, method.arg_signature, 'Enter the arguments') + if arg is None: + return + + method.set_signature(arg) + + print("Method %s::%s() set args to %s" % (method.owner.name, method.name, method.mangled_args)) + Refresh() + + +@menu_item('Register as Virtual Method', 'Shift+V') +def register_virtual_method(): + cls = choose_class() + if cls is None: return + + if cls.vtable is None: + Warning('This class does not have a virtual table defined.') + return + + ea = get_current_function() + unlink_method_if_exists(ea) + + # try to find it within the vtable + offset = cls.vtable_contains(ea) + + if offset is None: + Warning("Could not find this method within the vtable for %s." % cls.name) + print("Virtual table for %s: %08x .. %08x" % (cls.name, cls.vtable, cls.vtable_end)) + return + + # check to make sure it's not defined anywhere in a base class + base, base_method = (cls.base and cls.base.find_virtual(offset)) or (None, None) + if base_method is None: + name = idaapi.askstr(idaapi.HIST_IDENT, 'vf%02X' % offset, 'Enter a method name') + if name is None: + return + method = VirtualMethod(cls, name, ea, offset) + else: + method = OverrideMethod(cls, base_method.name, ea, offset, base_method) + + method.refresh() + + print("Virtual Method %s::%s() created at %08x with offset %x" % (cls.name, method.name, ea, offset)) + Refresh() + + +@menu_item('Rename Method', 'Shift+N') +def rename_method(): + ea = get_current_function() + + if ea in db.known_methods: + method = db.known_methods[ea] + + name = idaapi.askstr(idaapi.HIST_IDENT, method.name, 'Enter a method name') + if name is None: return + + method.rename(name) + Refresh() + else: + print("Don't know about a method here") + + +@menu_item('Unlink Method', 'Shift+U') +def unlink_method(): + ea = get_current_function() + + if ea in db.known_methods: + db.known_methods[ea].unlink(True) + Refresh() + else: + print("Don't know about a method here") + + +@menu_item('Mark Virtual Call', 'Shift+C') +def mark_virtual_call(): + ea = idaapi.get_screen_ea() + + inslen = idaapi.decode_insn(ea) + if inslen == 0: raise ValueError + + op = idaapi.cmd.Operands[1] + if not op: raise ValueError + if op.type != idaapi.o_displ: raise ValueError + + chosen_cls = choose_class('Choose the class for this method') + offset = op.addr + + cls, method = chosen_cls.find_virtual(offset) + if method is None: + Warning("The class %s and its base classes do not have this method defined." % chosen_cls.name) + return + + if ea in db.virtual_calls: + db.virtual_calls[ea].remove_xref(ea) + + original = (hasattr(method, 'original') and method.original) or method + original.add_xref(ea) + Refresh() + + +@menu_item('Unmark Virtual Call', 'Shift+Ctrl+C') +def unmark_virtual_call(): + ea = idaapi.get_screen_ea() + if ea in db.virtual_calls: + db.virtual_calls[ea].remove_xref(ea) + Refresh() + + +@menu_item('Save') +def save_db(): + db.save() + print('Saved the database.') + + +@menu_item('Reset DB') +def reset_db(): + if idaapi.askyn_c(0, 'Are you really sure you want to lose EVERYTHING?') == 1: + db.clear() + print('Done. You asked for it ...') + + +MANGLED_TYPES = { + 's32': 'i', 'int': 'i', + 'u32': 'Ui', 'uint': 'Ui', 'unsigned int': 'Ui', + 's16': 's', 'short': 's', + 'u16': 'Us', 'ushort': 'Us', 'unsigned short': 'Us', + 's8': 'c', 'char': 'c', + 'u8': 'Uc', 'uchar': 'Uc', 'unsigned char': 'Uc', + 'bool': 'b', + 'wchar_t': 'w', + 'f32': 'f', 'float': 'f', + 'f64': 'd', 'double': 'd', + 'void': 'v', + } + +ARRAY_REGEX = re.compile(r'\[(\d+)\]$') + +# TYPE DECLARATIONS +def mangle_args(string): + pieces = [p.strip() for p in string.split(',')] + + # 'CF' for const methods + output = ['F'] + + for p in pieces: + type_name = p + prefix, postfix = '', '' + + while True: + if type_name.endswith('&'): + type_name = type_name[:-1].strip() + prefix = 'R' + prefix + elif type_name.endswith('*'): + type_name = type_name[:-1].strip() + prefix = 'P' + prefix + elif type_name.endswith('const'): + type_name = type_name[:-5].strip() + prefix = 'C' + prefix + else: + match = ARRAY_REGEX.search(type_name) + if match: + type_name = type_name[:type_name.rindex('[')].strip() + prefix = 'A%s_%s' % (match.group(1), prefix) + else: + # nothing, we must be done + break + + try: + type_def = MANGLED_TYPES[type_name] + except KeyError: + type_def = Class.mangle(type_name) + + output.append(prefix + type_def + postfix) + + if len(pieces) == 0: + output.append('v') + + return ''.join(output) + + + +# FINAL SETUP + +setup_menus() +setup_db() + +print('Done!') + |