Python模拟在类中使用的全局函数

11

我似乎无法理解Python中的模拟。我有一个全局函数:

a.py:

def has_permission(args):
    ret_val = ...get-true-or-false...
    return ret_val

b.py:

->

b.py:

class MySerializer(HyperlinkedModelSerializer):

     def get_fields():
         fields = super().get_fields()
         for f in :
             if has_permission(...):
                 ret_val[f.name] = fields[f]
         return ret_val

c.py:

class CountrySerializer(MySerializer):
    class Meta:
        model = Country

问题: 现在我想测试c.py文件,但是我想mock a.py文件中定义的has_permission函数,该函数在b.py文件定义的MySerializer类的get_fields方法中被调用...我该如何做呢?

我已经尝试了以下方法:

@patch('b.MySerializer.has_permission')

@patch('b.MySerializer.get_fields.has_permission')

@patch('a.has_permission')

但无论我尝试什么,要么根本不起作用并且“has_permission”仍然被执行,要么Python会抱怨找不到属性“has_permission”。

使用以下修补程序:

test.py

class TestSerializerFields(TestCase):
    @patch(... the above examples....)
    def test_my_country_serializer():
        s = CountrySerializer()
        self..assertTrue(issubclass(my_serializer_fields.MyCharField, type(s.get_fields()['field1'])))

你将在哪里应用这个“patch”? - vks
@vks:可能是在测试中。 - Martijn Pieters
2
@patch('b.MySerializer.has_permission') 是错误的,因为 has_permission 函数并不属于该类。你传递给 patch 的是一个导入路径。微妙之处在于,虽然它在 a.py 中定义,但一旦你在 b.py 中导入它,它也可以从 b.py 中导入……而且你想要修补的是从 b.py 导入的副本。 - Anentropic
2个回答

21
你需要在b模块中打补丁全局变量。
@patch('b.has_permission')

因为这就是你的代码寻找它的地方。

还请参阅Where to patch部分mock文档。


谢谢...就这样了!如果我理解正确的话,那么patch('b.has_permission')会修补整个b.py文件中定义的has_permission函数的所有调用? - Robin van Leeuwen
@RvL:是的,它替换了你在模块中导入的has_permission对象。所有使用has_permission的代码都会始终使用同一个全局对象。 - Martijn Pieters
@RvL:你不能仅为一个函数修补全局变量的使用。 - Martijn Pieters
啊,现在开始对我有点意义了。谢谢 :) - Robin van Leeuwen

8
你需要在测试运行时存在的方法中打补丁。如果你试图在测试代码已经导入它的情况下修补定义方法,那么修补将没有效果。在@patch(...)执行的点上,被测试代码已经将全局方法引入其自己的模块。
以下是一个例子:

app/util/config.py:

# This is the global method we want to mock
def is_search_enabled():
    return True

app/service/searcher.py:

# Here is where that global method will be imported 
#  when this file is first imported
from app.util.config import is_search_enabled

class Searcher:
    def __init__(self, api_service):
        self._api_service = api_service

    def search(self):
        if not is_search_enabled():
            return None
        return self._api_service.perform_request('/search')

test/service/test_searcher.py:

from unittest.mock import patch, Mock
# The next line will cause the imports of `searcher.py` to execute...
from app.service.searcher import Searcher
# At this point, searcher.py has imported is_search_enabled into its module.
# If you later try and patch the method at its definition 
#  (app.util.config.is_search_enabled), it will have no effect because 
#  searcher.py won't look there again.

class MockApiService:
    pass

class TestSearcher:

    # By the time this executes, `is_search_enabled` has already been
    #  imported into `app.service.searcher`.  So that is where we must
    #  patch it.
    @patch('app.service.searcher.is_search_enabled')
    def test_no_search_when_disabled(self, mock_is_search_enabled):
        mock_is_search_enabled.return_value = False
        mock_api_service = MockApiService()
        mock_api_service.perform_request = Mock()
        searcher = Searcher(mock_api_service)

        results = searcher.search()

        assert results is None
        mock_api_service.perform_request.assert_not_called()

    # (For completeness' sake, make sure the code actually works when search is enabled...)
    def test_search(self):
        mock_api_service = MockApiService()
        mock_api_service.perform_request = mock_perform_request = Mock()
        searcher = Searcher(mock_api_service)
        expected_results = [1, 2, 3]
        mock_perform_request.return_value = expected_results

        actual_results = searcher.search()

        assert actual_results == expected_results
        mock_api_service.perform_request.assert_called_once_with('/search')

老兄,上帝保佑你!你的回复解决了我两天才能解决的问题。 - Chukky Katz

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接