I have a simple memoizer which I'm using to save some time around expensive network calls. Roughly, my code looks like this:
# mem.py
import functools
import time
def memoize(fn):
"""
Decorate a function so that it results are cached in memory.
>>> import random
>>> random.seed(0)
>>> f = lambda x: random.randint(0, 10)
>>> [f(1) for _ in range(10)]
[9, 8, 4, 2, 5, 4, 8, 3, 5, 6]
>>> [f(2) for _ in range(10)]
[9, 5, 3, 8, 6, 2, 10, 10, 8, 9]
>>> g = memoize(f)
>>> [g(1) for _ in range(10)]
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
>>> [g(2) for _ in range(10)]
[8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
"""
cache = {}
@functools.wraps(fn)
def wrapped(*args, **kwargs):
key = args, tuple(sorted(kwargs))
try:
return cache[key]
except KeyError:
cache[key] = fn(*args, **kwargs)
return cache[key]
return wrapped
def network_call(user_id):
time.sleep(1)
return 1
@memoize
def search(user_id):
return network_call(user_id)
And I have tests for this code, where I mock out different return values of network_call()
to make sure some modifications I do in search()
work as expected.
import mock
import mem
@mock.patch('mem.network_call')
def test_search(mock_network_call):
mock_network_call.return_value = 2
assert mem.search(1) == 2
@mock.patch('mem.network_call')
def test_search_2(mock_network_call):
mock_network_call.return_value = 3
assert mem.search(1) == 3
However, when I run these tests, I get a failure because search()
returns a cached result.
CAESAR-BAUTISTA:~ caesarbautista$ py.test test_mem.py
============================= test session starts ==============================
platform darwin -- Python 2.7.8 -- py-1.4.26 -- pytest-2.6.4
collected 2 items
test_mem.py .F
=================================== FAILURES ===================================
________________________________ test_search_2 _________________________________
args = (<MagicMock name='network_call' id='4438999312'>,), keywargs = {}
extra_args = [<MagicMock name='network_call' id='4438999312'>]
entered_patchers = [<mock._patch object at 0x108913dd0>]
exc_info = (<class '_pytest.assertion.reinterpret.AssertionError'>, AssertionError(u'assert 2 == 3\n + where 2 = <function search at 0x10893f848>(1)\n + where <function search at 0x10893f848> = mem.search',), <traceback object at 0x1089502d8>)
patching = <mock._patch object at 0x108913dd0>
arg = <MagicMock name='network_call' id='4438999312'>
@wraps(func)
def patched(*args, **keywargs):
# don't use a with here (backwards compatability with Python 2.4)
extra_args = []
entered_patchers = []
# can't use try...except...finally because of Python 2.4
# compatibility
exc_info = tuple()
try:
try:
for patching in patched.patchings:
arg = patching.__enter__()
entered_patchers.append(patching)
if patching.attribute_name is not None:
keywargs.update(arg)
elif patching.new is DEFAULT:
extra_args.append(arg)
args += tuple(extra_args)
> return func(*args, **keywargs)
/opt/boxen/homebrew/lib/python2.7/site-packages/mock.py:1201:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mock_network_call = <MagicMock name='network_call' id='4438999312'>
@mock.patch('mem.network_call')
def test_search_2(mock_network_call):
mock_network_call.return_value = 3
> assert mem.search(1) == 3
E assert 2 == 3
E + where 2 = <function search at 0x10893f848>(1)
E + where <function search at 0x10893f848> = mem.search
test_mem.py:15: AssertionError
====================== 1 failed, 1 passed in 0.03 seconds ======================
Is there a way to test memoized functions? I've considered some alternatives but they each have drawbacks.
One solution is to mock memoize()
. I am reluctant to do this because it is an implementation detail. Theoretically, I should be able to memoize and unmemoize functions without the rest of the system, including tests, noticing from a functional standpoint.
Another solution is to rewrite the code to expose the decorated function. That is I could so something like this:
def _search():
return 1
search = memoize(_search)
However, this runs into the same problems as above, although it's arguably worse because it will not work for recursive functions.