Viewing file: audit-tests.py (13.19 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
"""This script contains the actual auditing tests.
It should not be imported directly, but should be run by the test_audit module with arguments identifying each test.
"""
import contextlib import os import sys
class TestHook: """Used in standard hook tests to collect any logged events.
Should be used in a with block to ensure that it has no impact after the test completes. """
def __init__(self, raise_on_events=None, exc_type=RuntimeError): self.raise_on_events = raise_on_events or () self.exc_type = exc_type self.seen = [] self.closed = False
def __enter__(self, *a): sys.addaudithook(self) return self
def __exit__(self, *a): self.close()
def close(self): self.closed = True
@property def seen_events(self): return [i[0] for i in self.seen]
def __call__(self, event, args): if self.closed: return self.seen.append((event, args)) if event in self.raise_on_events: raise self.exc_type("saw event " + event)
# Simple helpers, since we are not in unittest here def assertEqual(x, y): if x != y: raise AssertionError(f"{x!r} should equal {y!r}")
def assertIn(el, series): if el not in series: raise AssertionError(f"{el!r} should be in {series!r}")
def assertNotIn(el, series): if el in series: raise AssertionError(f"{el!r} should not be in {series!r}")
def assertSequenceEqual(x, y): if len(x) != len(y): raise AssertionError(f"{x!r} should equal {y!r}") if any(ix != iy for ix, iy in zip(x, y)): raise AssertionError(f"{x!r} should equal {y!r}")
@contextlib.contextmanager def assertRaises(ex_type): try: yield assert False, f"expected {ex_type}" except BaseException as ex: if isinstance(ex, AssertionError): raise assert type(ex) is ex_type, f"{ex} should be {ex_type}"
def test_basic(): with TestHook() as hook: sys.audit("test_event", 1, 2, 3) assertEqual(hook.seen[0][0], "test_event") assertEqual(hook.seen[0][1], (1, 2, 3))
def test_block_add_hook(): # Raising an exception should prevent a new hook from being added, # but will not propagate out. with TestHook(raise_on_events="sys.addaudithook") as hook1: with TestHook() as hook2: sys.audit("test_event") assertIn("test_event", hook1.seen_events) assertNotIn("test_event", hook2.seen_events)
def test_block_add_hook_baseexception(): # Raising BaseException will propagate out when adding a hook with assertRaises(BaseException): with TestHook( raise_on_events="sys.addaudithook", exc_type=BaseException ) as hook1: # Adding this next hook should raise BaseException with TestHook() as hook2: pass
def test_marshal(): import marshal o = ("a", "b", "c", 1, 2, 3) payload = marshal.dumps(o)
with TestHook() as hook: assertEqual(o, marshal.loads(marshal.dumps(o)))
try: with open("test-marshal.bin", "wb") as f: marshal.dump(o, f) with open("test-marshal.bin", "rb") as f: assertEqual(o, marshal.load(f)) finally: os.unlink("test-marshal.bin")
actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"] assertSequenceEqual(actual, [(o, marshal.version)] * 2)
actual = [a[0] for e, a in hook.seen if e == "marshal.loads"] assertSequenceEqual(actual, [payload])
actual = [e for e, a in hook.seen if e == "marshal.load"] assertSequenceEqual(actual, ["marshal.load"])
def test_pickle(): import pickle
class PicklePrint: def __reduce_ex__(self, p): return str, ("Pwned!",)
payload_1 = pickle.dumps(PicklePrint()) payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
# Before we add the hook, ensure our malicious pickle loads assertEqual("Pwned!", pickle.loads(payload_1))
with TestHook(raise_on_events="pickle.find_class") as hook: with assertRaises(RuntimeError): # With the hook enabled, loading globals is not allowed pickle.loads(payload_1) # pickles with no globals are okay pickle.loads(payload_2)
def test_monkeypatch(): class A: pass
class B: pass
class C(A): pass
a = A()
with TestHook() as hook: # Catch name changes C.__name__ = "X" # Catch type changes C.__bases__ = (B,) # Ensure bypassing __setattr__ is still caught type.__dict__["__bases__"].__set__(C, (B,)) # Catch attribute replacement C.__init__ = B.__init__ # Catch attribute addition C.new_attr = 123 # Catch class changes a.__class__ = B
actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"] assertSequenceEqual( [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual )
def test_open(): # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open() try: import ssl
load_dh_params = ssl.create_default_context().load_dh_params except ImportError: load_dh_params = None
# Try a range of "open" functions. # All of them should fail with TestHook(raise_on_events={"open"}) as hook: for fn, *args in [ (open, sys.argv[2], "r"), (open, sys.executable, "rb"), (open, 3, "wb"), (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1), (load_dh_params, sys.argv[2]), ]: if not fn: continue with assertRaises(RuntimeError): fn(*args)
actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]] actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]] assertSequenceEqual( [ i for i in [ (sys.argv[2], "r"), (sys.executable, "r"), (3, "w"), (sys.argv[2], "w"), (sys.argv[2], "rb") if load_dh_params else None, ] if i is not None ], actual_mode, ) assertSequenceEqual([], actual_flag)
def test_cantrace(): traced = []
def trace(frame, event, *args): if frame.f_code == TestHook.__call__.__code__: traced.append(event)
old = sys.settrace(trace) try: with TestHook() as hook: # No traced call eval("1")
# No traced call hook.__cantrace__ = False eval("2")
# One traced call hook.__cantrace__ = True eval("3")
# Two traced calls (writing to private member, eval) hook.__cantrace__ = 1 eval("4")
# One traced call (writing to private member) hook.__cantrace__ = 0 finally: sys.settrace(old)
assertSequenceEqual(["call"] * 4, traced)
def test_mmap(): import mmap
with TestHook() as hook: mmap.mmap(-1, 8) assertEqual(hook.seen[0][1][:2], (-1, 8))
def test_excepthook(): def excepthook(exc_type, exc_value, exc_tb): if exc_type is not RuntimeError: sys.__excepthook__(exc_type, exc_value, exc_tb)
def hook(event, args): if event == "sys.excepthook": if not isinstance(args[2], args[1]): raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})") if args[0] != excepthook: raise ValueError(f"Expected {args[0]} == {excepthook}") print(event, repr(args[2]))
sys.addaudithook(hook) sys.excepthook = excepthook raise RuntimeError("fatal-error")
def test_unraisablehook(): from _testcapi import write_unraisable_exc
def unraisablehook(hookargs): pass
def hook(event, args): if event == "sys.unraisablehook": if args[0] != unraisablehook: raise ValueError(f"Expected {args[0]} == {unraisablehook}") print(event, repr(args[1].exc_value), args[1].err_msg)
sys.addaudithook(hook) sys.unraisablehook = unraisablehook write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
def test_winreg(): from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
def hook(event, args): if not event.startswith("winreg."): return print(event, *args)
sys.addaudithook(hook)
k = OpenKey(HKEY_LOCAL_MACHINE, "Software") EnumKey(k, 0) try: EnumKey(k, 10000) except OSError: pass else: raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
kv = k.Detach() CloseKey(kv)
def test_socket(): import socket
def hook(event, args): if event.startswith("socket."): print(event, *args)
sys.addaudithook(hook)
socket.gethostname()
# Don't care if this fails, we just want the audit message sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: # Don't care if this fails, we just want the audit message sock.bind(('127.0.0.1', 8080)) except Exception: pass finally: sock.close()
def test_gc(): import gc
def hook(event, args): if event.startswith("gc."): print(event, *args)
sys.addaudithook(hook)
gc.get_objects(generation=1)
x = object() y = [x]
gc.get_referrers(x) gc.get_referents(y)
def test_http_client(): import http.client
def hook(event, args): if event.startswith("http.client."): print(event, *args[1:])
sys.addaudithook(hook)
conn = http.client.HTTPConnection('www.python.org') try: conn.request('GET', '/') except OSError: print('http.client.send', '[cannot send]') finally: conn.close()
def test_sqlite3(): import sqlite3
def hook(event, *args): if event.startswith("sqlite3."): print(event, *args)
sys.addaudithook(hook) cx1 = sqlite3.connect(":memory:") cx2 = sqlite3.Connection(":memory:")
# Configured without --enable-loadable-sqlite-extensions if hasattr(sqlite3.Connection, "enable_load_extension"): cx1.enable_load_extension(False) try: cx1.load_extension("test") except sqlite3.OperationalError: pass else: raise RuntimeError("Expected sqlite3.load_extension to fail")
def test_sys_getframe(): import sys
def hook(event, args): if event.startswith("sys."): print(event, args[0].f_code.co_name)
sys.addaudithook(hook) sys._getframe()
def test_sys_getframemodulename(): import sys
def hook(event, args): if event.startswith("sys."): print(event, *args)
sys.addaudithook(hook) sys._getframemodulename()
def test_threading(): import _thread
def hook(event, args): if event.startswith(("_thread.", "cpython.PyThreadState", "test.")): print(event, args)
sys.addaudithook(hook)
lock = _thread.allocate_lock() lock.acquire()
class test_func: def __repr__(self): return "<test_func>" def __call__(self): sys.audit("test.test_func") lock.release()
i = _thread.start_new_thread(test_func(), ()) lock.acquire()
def test_threading_abort(): # Ensures that aborting PyThreadState_New raises the correct exception import _thread
class ThreadNewAbortError(Exception): pass
def hook(event, args): if event == "cpython.PyThreadState_New": raise ThreadNewAbortError()
sys.addaudithook(hook)
try: _thread.start_new_thread(lambda: None, ()) except ThreadNewAbortError: # Other exceptions are raised and the test will fail pass
def test_wmi_exec_query(): import _wmi
def hook(event, args): if event.startswith("_wmi."): print(event, args[0])
sys.addaudithook(hook) _wmi.exec_query("SELECT * FROM Win32_OperatingSystem")
def test_syslog(): import syslog
def hook(event, args): if event.startswith("syslog."): print(event, *args)
sys.addaudithook(hook) syslog.openlog('python') syslog.syslog('test') syslog.setlogmask(syslog.LOG_DEBUG) syslog.closelog() # implicit open syslog.syslog('test2') # open with default ident syslog.openlog(logoption=syslog.LOG_NDELAY, facility=syslog.LOG_LOCAL0) sys.argv = None syslog.openlog() syslog.closelog()
def test_not_in_gc(): import gc
hook = lambda *a: None sys.addaudithook(hook)
for o in gc.get_objects(): if isinstance(o, list): assert hook not in o
def test_sys_monitoring_register_callback(): import sys
def hook(event, args): if event.startswith("sys.monitoring"): print(event, args)
sys.addaudithook(hook) sys.monitoring.register_callback(1, 1, None)
if __name__ == "__main__": from test.support import suppress_msvcrt_asserts
suppress_msvcrt_asserts()
test = sys.argv[1] globals()[test]()
|