summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README75
-rw-r--r--classy.py886
2 files changed, 961 insertions, 0 deletions
diff --git a/README b/README
new file mode 100644
index 0000000..39b7b88
--- /dev/null
+++ b/README
@@ -0,0 +1,75 @@
+Classy!
+
+IDA script for keeping track of RE info related to C++ classes (specifically,
+ CodeWarrior PPC) and keeping things consistent throughout the database.
+
+That's what makes IDA better than objdump anyway, right? :p
+
+Started by Treeki: 16th October 2011
+Developed using Python 2.6 and IDA 6.1.
+
+
+
+Supported so far:
+- Save and reset database
+- Classes:
+ - Create a class using an IDA struct
+ - Base classes are handled using a struct field named _
+ - Struct is automatically created if one is not chosen - if this is used,
+ Classy adds the _ field for you
+- Virtual Tables:
+ - Set a VTable for a class. End is automatically detected
+ - If the auto detection fails (pure virtuals, for example) you can
+ manually override it
+- CTors/DTors:
+ - Set a ctor and dtor for a class (TODO: Multiple ctors)
+ - Dtors automatically become virtual if a pointer to them is found in the
+ vtable, even if the dtor was registered before the vtable was set
+- Methods:
+ - Register regular and virtual methods
+ - Mark virtual calls by highlighting the "lwz r12, ..." line, pressing
+ Shift+C and choosing a class
+ - Xrefs to each virtual method are shown at the top of the method body
+ - Set method arguments including pointers, arrays, refs and consts
+ - Arguments are automatically mangled so they are shown in IDA names
+ - Shift+V automatically detects whether a method is a new virtual one or
+ an override
+ - Overrides are shown at the top of the original method body
+ - Renames and argument list changes are propagated from original methods
+ to overrides (Though not vice versa yet)
+
+That's all for now, I think...
+
+
+
+
+Current Todo:
+- Only show classes that fit the criteria in the chooser for Shift+C
+- Support PTMFs, too
+- Automatically create virtual methods using the vtable.
+- Named arguments to methods
+- Return types for methods
+- Const methods
+- Improve vtable end detection heuristics
+- Support pure virtuals.
+- Handle IDA chooser history/defaults correctly everywhere
+- Better keybindings
+- Rename class menu item
+- Netnodes
+- Override dtors
+- Sanity check to see if a virtual method might exist in a base class
+- Confirm unlinking of methods
+- Don't let you create more than one method with the same name
+- choose_class differentiation between "no class" and Cancel
+- Auto usage of struct if a named one exists
+- Remove virtual xrefs
+
+Caveats:
+- Multiple inheritance will not be supported in the near future.
+- Probably won't work as well with non-CodeWarrior stuff.
+
+Notes:
+- When autocreating virtual methods, start with the base class!!
+ Otherwise, you might end up defining an override as a virtual
+ method. And that would be bad.
+
diff --git a/classy.py b/classy.py
new file mode 100644
index 0000000..b032e94
--- /dev/null
+++ b/classy.py
@@ -0,0 +1,886 @@
+# Classy!
+
+# IDA script for keeping track of RE info related to C++ classes (specifically,
+# CodeWarrior PPC) and keeping things consistent throughout the database.
+
+# That's what makes IDA better than objdump anyway, right? :p
+
+# Started by Treeki: 16th October 2011
+# Developed using Python 2.6 and IDA 6.1.
+
+import idaapi, cPickle, re
+
+print('Loading Classy')
+
+
+# DATABASE
+#DB_PATH = '/home/me/Dev/Reversing/nsmb2.cy'
+DB_PATH = 'Z:/stuff/Dev/Reversing/nsmb2.cy'
+DB_VERSION = 1
+
+class ClassyDatabase(dict):
+ def __init__(self, path):
+ try:
+ handle = open(path, 'rb')
+ db = cPickle.load(handle)
+ self.update(db)
+
+ if 'version' not in db:
+ # old DB
+ self.version = 1
+ except IOError:
+ pass
+ except EOFError:
+ pass
+
+ self.path = path
+
+ if 'version' not in self:
+ # totally new DB
+ self.initialise()
+ else:
+ self.upgrade()
+
+ def __getattr__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ default = self.default_for(key)
+ self[key] = default
+ return self[key]
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ HASH_DEFAULTS = [
+ 'classes_by_name', 'classes_by_struct',
+ 'known_methods', 'virtual_calls',
+ ]
+
+ NONE_DEFAULTS = [
+ 'last_class_name',
+ ]
+
+ def default_for(self, key):
+ if key in self.HASH_DEFAULTS:
+ return {}
+ elif key in self.NONE_DEFAULTS:
+ return None
+ else:
+ raise AttributeError
+
+ def clear(self):
+ path = self.path
+ dict.clear(self)
+ self.path = path
+ self.initialise()
+
+ def initialise(self):
+ self.version = DB_VERSION
+
+ def upgrade(self):
+ pass
+
+ def save(self):
+ cPickle.dump(self, open(self.path, 'wb'))
+
+def setup_db():
+ global db
+ db = ClassyDatabase(DB_PATH)
+
+# MENUS
+cy_menu_defs = []
+
+def setup_menus():
+ global cy_menus
+
+ try:
+ cy_menus
+ for item in cy_menus.values():
+ idaapi.del_menu_item(item)
+ except:
+ pass
+
+ cy_menus = {}
+ cy_menus['Classy'] = idaapi.add_menu_item(
+ 'Edit/', 'Classy', '', 0, about_classy, None)
+
+ for name, sequence, callback in cy_menu_defs:
+ cy_menus[name] = idaapi.add_menu_item(
+ 'Edit/Classy', "CY: %s" % name, sequence,
+ 0x04000000, # SETMENU_CTXIDA
+ callback, None)
+
+
+def menu_item(name, sequence=''):
+ def custom_decorator(func):
+ cy_menu_defs.append((name, sequence, func))
+
+ def wrapper():
+ try:
+ func()
+ except StandardError, e:
+ Warning('Crap.')
+ print e
+
+ return wrapper
+ return custom_decorator
+
+
+def about_classy():
+ print('Classy!')
+
+
+# CLASSES
+class Class(object):
+ def __init__(self, name, struct, base):
+ if type(struct) == idaapi.struc_t:
+ struct = struct.id
+
+ self.name = name
+ self.mangled_name = self.mangle(name)
+ self.struct = struct
+ self.methods = []
+ self.vmethods = []
+ self.vmethods_by_offset = {}
+ self.vtable = None
+ self.base = base
+
+ self.ctor = NullMethod(self)
+ self.dtor = NullMethod(self)
+
+ db.classes_by_name[name] = self
+ db.classes_by_struct[struct] = self
+
+ @staticmethod
+ def mangle(name):
+ components = name.split('::')
+ if len(components) > 1:
+ begin = "Q%d" % len(components)
+ m_comps = ["%d%s" % (len(n), n) for n in components]
+ return begin + ''.join(m_comps)
+ else:
+ return "%d%s" % (len(name), name)
+
+ def rename(self, name):
+ old_name = self.name
+
+ del db.classes_by_name[old_name]
+ db.classes_by_name[name] = self
+
+ self.name = name
+ self.mangled_name = self.mangle(name)
+
+ self.refresh()
+
+ def refresh(self):
+ self.ctor.refresh()
+ self.dtor.refresh()
+
+ for m in self.methods:
+ m.refresh()
+ for m in self.vmethods:
+ m.refresh()
+
+ def set_vtable(self, ea):
+ self.vtable = ea
+ self.find_vtable_end()
+ self.search_for_dtor()
+
+ def set_vtable_end(self, ea):
+ self.vtable_end = ea
+ self.search_for_dtor()
+
+ def find_vtable_end(self):
+ check = self.vtable + 4
+
+ while Dword(check + 4) != 0:
+ check += 4
+
+ print('Inferred vtable as %08x .. %08x' % (self.vtable, check))
+ self.vtable_end = check
+
+ def search_for_dtor(self):
+ if self.dtor.is_null(): return
+ if self.dtor.is_virtual(): return
+
+ ea = self.dtor.ea
+ offset = self.vtable_contains(ea)
+ if offset:
+ self.dtor.unlink()
+
+ # TODO: make this bit more DRY
+ base, base_method = (self.base and self.base.find_virtual(offset)) or (None, None)
+ if base_method is None:
+ self.dtor = VirtualMethod(self, '__dt', ea, offset)
+ else:
+ self.dtor = OverrideMethod(self, '__dt', ea, offset, base_method)
+
+ self.dtor.mangled_args = 'Fv'
+ self.dtor.arg_signature = None # cannot be changed!
+ self.dtor.refresh()
+
+ def iter_vtable(self):
+ check = self.vtable
+ end = self.vtable_end
+
+ while check <= end:
+ yield (check, Dword(check))
+ check += 4
+
+ def detect_overrides(self):
+ offset = 0
+
+ for addr, value in self.iter_vtable():
+ addr
+ offset += 4
+
+
+ def find_virtual(self, offset):
+ try:
+ return (self, self.vmethods_by_offset[offset])
+ except KeyError:
+ if self.base is None:
+ return (self, None)
+ else:
+ return self.base.find_virtual(offset)
+
+ def vtable_contains(self, ea):
+ for addr, value in self.iter_vtable():
+ if value == ea:
+ return addr - self.vtable
+
+ return None
+
+ def generate_cpp_definition(self):
+ contents = []
+
+ vtable_found = False
+
+ sid = self.struct
+ s = idaapi.get_struc(sid)
+
+ comments = (idaapi.get_struc_cmt(sid, 0), idaapi.get_struc_cmt(sid, 1))
+ for cmt in comments:
+ if len(cmt) > 0: contents.append("\t/* %s */" % cmt)
+
+ gen_offset = 0
+
+ # note: structs should always include their base class
+ while ida_offset <= end_at:
+ m = idaapi.get_member(s, ida_offset)
+ mid = m.id
+
+ comments = (idaapi.get_member_cmt(mid, 0), idaapi.get_member_cmt(mid, 1))
+ for cmt in comments:
+ if len(cmt) > 0: contents.append("\t/* %s */" % cmt)
+
+ # TODO: convert to swig api, figure out how the fuck I'll do this
+
+ ida_offset = idaapi.get_struc_next_offset(s, ida_offset)
+
+ base_info = (self.base and (" : public %s" % self.base.name)) or ''
+ return "class %s%s {\n%s\n};" % (self.name, base_info, "\n".join(contents))
+
+
+class Method(object):
+ def __init__(self, owner, name, ea, register=True):
+ self.owner = owner
+ self.name = name
+ self.arg_signature = '?'
+ self.mangled_args = '1?'
+ self.ea = ea
+
+ if ea != BADADDR and ea != 0:
+ db.known_methods[ea] = self
+
+ if register:
+ owner.methods.append(self)
+
+ def __str__(self):
+ return repr(self)
+ def __repr__(self):
+ if self.arg_signature is None:
+ return "<%s %s::%s>" % (type(self), self.owner.name, self.name)
+ else:
+ return "<%s %s::%s(%s)>" % (type(self), self.owner.name, self.name, self.arg_signature)
+
+ def is_null(self): return False
+ def is_virtual(self): return False
+
+ def rename(self, name):
+ self.name = name
+ self.refresh()
+
+ def refresh(self):
+ MakeName(self.ea, "%s__%s%s" % (self.name, self.owner.mangled_name, self.mangled_args))
+
+ # figure out if any virtual xrefs include us
+ # a bit slow, but who cares
+ call_dict = db.virtual_calls
+ func = idaapi.get_func(self.ea)
+ check = func.startEA
+ end = func.endEA
+
+ done_already = {}
+
+ while check <= end:
+ if check in call_dict:
+ method = call_dict[check]
+ if method not in done_already:
+ method.refresh_comment()
+ done_already[method] = True
+ check += 4
+
+ def unlink(self, orphan=False):
+ if self.name == '__ct':
+ self.owner.ctor = NullMethod(self.owner)
+ elif self.name == '__dt':
+ self.owner.dtor = NullMethod(self.owner)
+
+ if self in self.owner.methods:
+ self.owner.methods.remove(self)
+
+ self.owner = None
+ if self.ea != 0 and self.ea != BADADDR:
+ del db.known_methods[self.ea]
+
+ if orphan:
+ MakeName(self.ea, '')
+
+ def set_signature(self, sig):
+ if self.arg_signature is None:
+ raise TypeError, "this method has a fixed signature"
+
+ mang = mangle_args(sig)
+ self.arg_signature = sig
+ self.mangled_args = mang
+ self.refresh()
+
+
+class VirtualMethod(Method):
+ def __init__(self, owner, name, ea, vt_offset):
+ Method.__init__(self, owner, name, ea, False)
+
+ self.vt_offset = vt_offset
+ self.virtual_calls = []
+ self.overrides = []
+
+ owner.vmethods.append(self)
+ owner.vmethods_by_offset[vt_offset] = self
+
+ def is_virtual(self): return True
+
+ def unlink(self, orphan=False):
+ self.owner.vmethods.remove(self)
+ del self.owner.vmethods_by_offset[self.vt_offset]
+
+ for m in self.overrides:
+ m.unlink(True)
+
+ Method.unlink(self, orphan)
+
+ def rename(self, name):
+ Method.rename(self, name)
+ for o in self.overrides:
+ o.rename(name, False)
+ o.refresh()
+
+ def set_signature(self, sig):
+ Method.set_signature(self, sig)
+ for o in self.overrides:
+ o.copy_signature(self)
+ o.refresh()
+
+ def add_xref(self, ea):
+ self.virtual_calls.append(ea)
+ self.virtual_calls.sort()
+ self.refresh_xref(ea)
+ self.refresh_comment()
+ db.virtual_calls[ea] = self
+
+ def remove_xref(self, ea):
+ self.virtual_calls.remove(ea)
+ self.refresh_comment()
+ MakeComm(ea, '')
+ del db.virtual_calls[ea]
+
+ def refresh(self):
+ Method.refresh(self)
+ self.refresh_comment()
+ for ea in self.virtual_calls:
+ self.refresh_xref(ea)
+
+ def refresh_xref(self, ea):
+ MakeComm(ea, "%08X %s::%s(%s)" % (self.ea, self.owner.name, self.name, self.arg_signature))
+
+ def get_comment(self):
+ lines = []
+ if len(self.overrides) > 0:
+ lines.append('OVERRIDDEN BY:')
+ for method in self.overrides:
+ lines.append(' - %08X : %s' % (method.ea, method.owner.name))
+ lines.append('')
+
+ if len(self.virtual_calls) > 0:
+ lines.append('VIRTUAL METHOD CALLS:')
+ for ea in self.virtual_calls:
+ func = idaapi.get_func(ea)
+ name = idaapi.demangle_name(idaapi.get_func_name(func.startEA), 0)
+ offset = ea - func.startEA
+ lines.append(' - %08X : %s + %X' % (ea, name, offset))
+ lines.append('')
+
+ return "\n".join(lines)
+
+ def refresh_comment(self):
+ SetFunctionCmt(self.ea, self.get_comment(), 0)
+
+
+class OverrideMethod(VirtualMethod):
+ def __init__(self, owner, name, ea, vt_offset, base):
+ VirtualMethod.__init__(self, owner, name, ea, vt_offset)
+
+ if ea == BADADDR || ea == 0:
+ raise ValueError("An override method can't be pure!!")
+
+ self.base = base
+ self.original = (hasattr(base, 'original') and base.original) or base
+ self.copy_signature(self.original)
+
+ owner.vmethods.append(self)
+ owner.vmethods_by_offset[vt_offset] = self
+
+ base.overrides.append(self)
+
+ def unlink(self, orphan=False):
+ if self in self.base.overrides:
+ self.base.overrides.remove(self)
+
+ VirtualMethod.unlink(self, orphan)
+
+ def copy_signature(self, other):
+ self.arg_signature = other.arg_signature
+ self.mangled_args = other.mangled_args
+
+ for m in self.overrides:
+ m.copy_signature(other)
+
+ def rename(self, name, delegate_to_original=True):
+ # a strange little hack
+ if delegate_to_original:
+ self.original.rename(name)
+ else:
+ VirtualMethod.rename(self, name)
+
+ def set_signature(self, sig):
+ self.original.set_signature(sig)
+
+ def refresh(self):
+ VirtualMethod.refresh(self)
+ self.base.refresh_comment()
+
+ def get_comment(self):
+ return "ORIGINAL: %08X : %s\nBASE: %08X : %s\n\n%s" % \
+ (self.original.ea, self.original.owner.name,
+ self.base.ea, self.base.owner.name,
+ VirtualMethod.get_comment(self))
+
+
+class NullMethod(Method):
+ def __init__(self, owner):
+ Method.__init__(self, owner, '', BADADDR, False)
+
+ def is_null(self): return True
+
+ def refresh(self): pass
+ def unlink(self): pass
+
+
+# UTILITY FUNCTIONS
+class ClassChooser(Choose2):
+ def __init__(self):
+ Choose2.__init__(self, 'Choose a class', [['Class', 40]])
+
+ def OnClose(self): pass
+ def OnGetLine(self, n): return [self.clslist[n]]
+ def OnGetSize(self): return len(self.clslist)
+
+
+_class_chooser = ClassChooser()
+
+def choose_class(title=None, optional=False):
+ names = db.classes_by_name.keys()
+ base_index = 0
+ if optional:
+ names.insert(0, '[[ None ]]')
+ base_index = 1
+
+ _class_chooser.title = title or 'Choose a class'
+ try:
+ _class_chooser.deflt = names.index(db.last_class_name) + (2 if optional else 1)
+ except StandardError:
+ _class_chooser.deflt = 0
+ _class_chooser.clslist = names
+
+ num = _class_chooser.Show(True)
+
+ if num < base_index:
+ return None
+ else:
+ choice = names[num]
+ db.last_class_name = choice
+ return db.classes_by_name[choice]
+
+
+def unlink_method_if_exists(ea):
+ if ea in db.known_methods:
+ db.known_methods[ea].unlink()
+
+
+def get_current_function():
+ ea = idaapi.get_screen_ea()
+ func = idaapi.get_func(ea)
+
+ if func is None:
+ # try whatever's pointed to by this
+ func = idaapi.get_func(Dword(ea))
+
+ if func is None:
+ # and if it's still none ...
+ Warning('Place the cursor on top of a function.')
+ raise ValueError
+
+ return func.startEA
+
+
+# DO STUFF
+@menu_item('Create Class')
+def create_class():
+ name = idaapi.askstr(idaapi.HIST_IDENT, '', 'Enter a class name')
+ if name in db.classes_by_name:
+ Warning('That name is already used.')
+ return
+
+ if name is None:
+ return
+
+ struct = idaapi.choose_struc('Associate this class with a struct')
+ if struct is not None and struct.id in db.classes_by_struct:
+ Warning('That struct is already used.')
+ return
+
+ base = choose_class('Choose the base class for this one', True)
+
+ if struct is None:
+ safe_name = name.replace(':', '_')
+
+ struct = idaapi.get_struc_id(safe_name)
+ if struct == BADADDR:
+ struct = idaapi.add_struc(BADADDR, safe_name, 0)
+ if struct == -1:
+ Warning('Oops.')
+ return
+
+ print("Created struct: %s" % safe_name)
+
+ if base != None:
+ AddStrucMember(struct, '_', 0, FF_DATA | FF_STRU, base.struct, GetStrucSize(base.struct))
+
+ else:
+ print("Used existing struct: %s" % safe_name)
+
+ cls = Class(name, struct, base)
+
+ print("Done!")
+ Refresh()
+
+
+@menu_item('Set VTable')
+def set_vtable():
+ cls = choose_class()
+ if cls is None: return
+
+ ea = idaapi.get_screen_ea()
+ cls.set_vtable(ea)
+
+ print("Set vtable for %s to %08x" % (cls.name, ea))
+ Refresh()
+
+
+@menu_item('Set VTable End')
+def set_vtable_end():
+ cls = choose_class()
+ if cls is None: return
+
+ ea = idaapi.get_screen_ea()
+ if ea < cls.vtable:
+ Warning("The virtual table cannot go backwards. It starts at %08x." % cls.vtable)
+ return
+
+ cls.set_vtable_end(ea)
+
+ print("Set vtable end for %s to %08x" % (cls.name, ea))
+ Refresh()
+
+
+@menu_item('Set Ctor')
+def set_ctor():
+ cls = choose_class()
+ if cls is None: return
+
+ ea = get_current_function()
+ unlink_method_if_exists(ea)
+ cls.ctor = Method(cls, '__ct', ea)
+ cls.ctor.refresh()
+
+ print("Set ctor for %s to %08x" % (cls.name, ea))
+ Refresh()
+
+
+@menu_item('Set Dtor')
+def set_dtor():
+ cls = choose_class()
+ if cls is None: return
+
+ ea = get_current_function()
+ unlink_method_if_exists(ea)
+
+ # is it virtual?
+ v_offset = (cls.vtable and cls.vtable_contains(ea))
+ if v_offset:
+ base, base_method = (cls.base and cls.base.find_virtual(v_offset)) or (None, None)
+ if base_method is None:
+ cls.dtor = VirtualMethod(cls, '__dt', ea, v_offset)
+ else:
+ cls.dtor = OverrideMethod(cls, '__dt', ea, v_offset, base_method)
+ else:
+ cls.dtor = Method(cls, '__dt', ea)
+
+ cls.dtor.mangled_args = 'Fv'
+ cls.dtor.arg_signature = None # cannot be changed!
+ cls.dtor.refresh()
+
+ print("Set dtor for %s to %08x" % (cls.name, ea))
+ Refresh()
+
+
+@menu_item('Register as Method', 'Shift+M')
+def register_method():
+ cls = choose_class()
+ if cls is None: return
+
+ ea = get_current_function()
+ unlink_method_if_exists(ea)
+
+ name = idaapi.askstr(idaapi.HIST_IDENT, 's_%08X' % ea, 'Enter a method name')
+ if name is None: return
+ method = Method(cls, name, ea)
+ method.refresh()
+
+ print("Method %s::%s() created at %08x" % (cls.name, name, ea))
+ Refresh()
+
+
+@menu_item('Set Method Arguments', 'Shift+A')
+def set_method_args():
+ ea = get_current_function()
+
+ try:
+ method = db.known_methods[ea]
+ except KeyError:
+ Warning("No method at this address (%08X)." % ea)
+ return
+
+ if method.arg_signature is None:
+ Warning("This method's argument signature cannot be changed")
+ return
+
+ arg = idaapi.askstr(idaapi.HIST_TYPE, method.arg_signature, 'Enter the arguments')
+ if arg is None:
+ return
+
+ method.set_signature(arg)
+
+ print("Method %s::%s() set args to %s" % (method.owner.name, method.name, method.mangled_args))
+ Refresh()
+
+
+@menu_item('Register as Virtual Method', 'Shift+V')
+def register_virtual_method():
+ cls = choose_class()
+ if cls is None: return
+
+ if cls.vtable is None:
+ Warning('This class does not have a virtual table defined.')
+ return
+
+ ea = get_current_function()
+ unlink_method_if_exists(ea)
+
+ # try to find it within the vtable
+ offset = cls.vtable_contains(ea)
+
+ if offset is None:
+ Warning("Could not find this method within the vtable for %s." % cls.name)
+ print("Virtual table for %s: %08x .. %08x" % (cls.name, cls.vtable, cls.vtable_end))
+ return
+
+ # check to make sure it's not defined anywhere in a base class
+ base, base_method = (cls.base and cls.base.find_virtual(offset)) or (None, None)
+ if base_method is None:
+ name = idaapi.askstr(idaapi.HIST_IDENT, 'vf%02X' % offset, 'Enter a method name')
+ if name is None:
+ return
+ method = VirtualMethod(cls, name, ea, offset)
+ else:
+ method = OverrideMethod(cls, base_method.name, ea, offset, base_method)
+
+ method.refresh()
+
+ print("Virtual Method %s::%s() created at %08x with offset %x" % (cls.name, method.name, ea, offset))
+ Refresh()
+
+
+@menu_item('Rename Method', 'Shift+N')
+def rename_method():
+ ea = get_current_function()
+
+ if ea in db.known_methods:
+ method = db.known_methods[ea]
+
+ name = idaapi.askstr(idaapi.HIST_IDENT, method.name, 'Enter a method name')
+ if name is None: return
+
+ method.rename(name)
+ Refresh()
+ else:
+ print("Don't know about a method here")
+
+
+@menu_item('Unlink Method', 'Shift+U')
+def unlink_method():
+ ea = get_current_function()
+
+ if ea in db.known_methods:
+ db.known_methods[ea].unlink(True)
+ Refresh()
+ else:
+ print("Don't know about a method here")
+
+
+@menu_item('Mark Virtual Call', 'Shift+C')
+def mark_virtual_call():
+ ea = idaapi.get_screen_ea()
+
+ inslen = idaapi.decode_insn(ea)
+ if inslen == 0: raise ValueError
+
+ op = idaapi.cmd.Operands[1]
+ if not op: raise ValueError
+ if op.type != idaapi.o_displ: raise ValueError
+
+ chosen_cls = choose_class('Choose the class for this method')
+ offset = op.addr
+
+ cls, method = chosen_cls.find_virtual(offset)
+ if method is None:
+ Warning("The class %s and its base classes do not have this method defined." % chosen_cls.name)
+ return
+
+ if ea in db.virtual_calls:
+ db.virtual_calls[ea].remove_xref(ea)
+
+ original = (hasattr(method, 'original') and method.original) or method
+ original.add_xref(ea)
+ Refresh()
+
+
+@menu_item('Unmark Virtual Call', 'Shift+Ctrl+C')
+def unmark_virtual_call():
+ ea = idaapi.get_screen_ea()
+ if ea in db.virtual_calls:
+ db.virtual_calls[ea].remove_xref(ea)
+ Refresh()
+
+
+@menu_item('Save')
+def save_db():
+ db.save()
+ print('Saved the database.')
+
+
+@menu_item('Reset DB')
+def reset_db():
+ if idaapi.askyn_c(0, 'Are you really sure you want to lose EVERYTHING?') == 1:
+ db.clear()
+ print('Done. You asked for it ...')
+
+
+MANGLED_TYPES = {
+ 's32': 'i', 'int': 'i',
+ 'u32': 'Ui', 'uint': 'Ui', 'unsigned int': 'Ui',
+ 's16': 's', 'short': 's',
+ 'u16': 'Us', 'ushort': 'Us', 'unsigned short': 'Us',
+ 's8': 'c', 'char': 'c',
+ 'u8': 'Uc', 'uchar': 'Uc', 'unsigned char': 'Uc',
+ 'bool': 'b',
+ 'wchar_t': 'w',
+ 'f32': 'f', 'float': 'f',
+ 'f64': 'd', 'double': 'd',
+ 'void': 'v',
+ }
+
+ARRAY_REGEX = re.compile(r'\[(\d+)\]$')
+
+# TYPE DECLARATIONS
+def mangle_args(string):
+ pieces = [p.strip() for p in string.split(',')]
+
+ # 'CF' for const methods
+ output = ['F']
+
+ for p in pieces:
+ type_name = p
+ prefix, postfix = '', ''
+
+ while True:
+ if type_name.endswith('&'):
+ type_name = type_name[:-1].strip()
+ prefix = 'R' + prefix
+ elif type_name.endswith('*'):
+ type_name = type_name[:-1].strip()
+ prefix = 'P' + prefix
+ elif type_name.endswith('const'):
+ type_name = type_name[:-5].strip()
+ prefix = 'C' + prefix
+ else:
+ match = ARRAY_REGEX.search(type_name)
+ if match:
+ type_name = type_name[:type_name.rindex('[')].strip()
+ prefix = 'A%s_%s' % (match.group(1), prefix)
+ else:
+ # nothing, we must be done
+ break
+
+ try:
+ type_def = MANGLED_TYPES[type_name]
+ except KeyError:
+ type_def = Class.mangle(type_name)
+
+ output.append(prefix + type_def + postfix)
+
+ if len(pieces) == 0:
+ output.append('v')
+
+ return ''.join(output)
+
+
+
+# FINAL SETUP
+
+setup_menus()
+setup_db()
+
+print('Done!')
+