|  | 
|  | 1 | +# Copyright 2020-present MongoDB, Inc. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +# http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +"""Test that each file in mypy_fails/ actually fails mypy, and test some | 
|  | 16 | +sample client code that uses PyMongo typings.""" | 
|  | 17 | + | 
|  | 18 | +import os | 
|  | 19 | +import sys | 
|  | 20 | +import unittest | 
|  | 21 | +from typing import Any, Dict, Iterable, List | 
|  | 22 | + | 
|  | 23 | +try: | 
|  | 24 | +    from mypy import api | 
|  | 25 | +except ImportError: | 
|  | 26 | +    api = None | 
|  | 27 | + | 
|  | 28 | +from bson.son import SON | 
|  | 29 | +from pymongo.collection import Collection | 
|  | 30 | +from pymongo.errors import ServerSelectionTimeoutError | 
|  | 31 | +from pymongo.mongo_client import MongoClient | 
|  | 32 | +from pymongo.operations import InsertOne | 
|  | 33 | + | 
|  | 34 | +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails") | 
|  | 35 | + | 
|  | 36 | + | 
|  | 37 | +def get_tests() -> Iterable[str]: | 
|  | 38 | +    for dirpath, _, filenames in os.walk(TEST_PATH): | 
|  | 39 | +        for filename in filenames: | 
|  | 40 | +            yield os.path.join(dirpath, filename) | 
|  | 41 | + | 
|  | 42 | + | 
|  | 43 | +class TestMypyFails(unittest.TestCase): | 
|  | 44 | +    def ensure_mypy_fails(self, filename: str) -> None: | 
|  | 45 | +        if api is None: | 
|  | 46 | +            raise unittest.SkipTest("Mypy is not installed") | 
|  | 47 | +        stdout, stderr, exit_status = api.run([filename]) | 
|  | 48 | +        self.assertTrue(exit_status, msg=stdout) | 
|  | 49 | + | 
|  | 50 | +    def test_mypy_failures(self) -> None: | 
|  | 51 | +        for filename in get_tests(): | 
|  | 52 | +            with self.subTest(filename=filename): | 
|  | 53 | +                self.ensure_mypy_fails(filename) | 
|  | 54 | + | 
|  | 55 | + | 
|  | 56 | +class TestPymongo(unittest.TestCase): | 
|  | 57 | +    client: MongoClient | 
|  | 58 | +    coll: Collection | 
|  | 59 | + | 
|  | 60 | +    @classmethod | 
|  | 61 | +    def setUpClass(cls) -> None: | 
|  | 62 | +        cls.client = MongoClient(serverSelectionTimeoutMS=250, directConnection=False) | 
|  | 63 | +        cls.coll = cls.client.test.test | 
|  | 64 | +        try: | 
|  | 65 | +            cls.client.admin.command("ping") | 
|  | 66 | +        except ServerSelectionTimeoutError as exc: | 
|  | 67 | +            raise unittest.SkipTest(f"Could not connect to MongoDB: {exc}") | 
|  | 68 | + | 
|  | 69 | +    @classmethod | 
|  | 70 | +    def tearDownClass(cls) -> None: | 
|  | 71 | +        cls.client.close() | 
|  | 72 | + | 
|  | 73 | +    def test_insert_find(self) -> None: | 
|  | 74 | +        doc = {"my": "doc"} | 
|  | 75 | +        coll2 = self.client.test.test2 | 
|  | 76 | +        result = self.coll.insert_one(doc) | 
|  | 77 | +        self.assertEqual(result.inserted_id, doc["_id"]) | 
|  | 78 | +        retreived = self.coll.find_one({"_id": doc["_id"]}) | 
|  | 79 | +        if retreived: | 
|  | 80 | +            # Documents returned from find are mutable. | 
|  | 81 | +            retreived["new_field"] = 1 | 
|  | 82 | +            result2 = coll2.insert_one(retreived) | 
|  | 83 | +            self.assertEqual(result2.inserted_id, result.inserted_id) | 
|  | 84 | + | 
|  | 85 | +    def test_cursor_iterable(self) -> None: | 
|  | 86 | +        def to_list(iterable: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: | 
|  | 87 | +            return list(iterable) | 
|  | 88 | + | 
|  | 89 | +        self.coll.insert_one({}) | 
|  | 90 | +        cursor = self.coll.find() | 
|  | 91 | +        docs = to_list(cursor) | 
|  | 92 | +        self.assertTrue(docs) | 
|  | 93 | + | 
|  | 94 | +    def test_bulk_write(self) -> None: | 
|  | 95 | +        self.coll.insert_one({}) | 
|  | 96 | +        requests = [InsertOne({})] | 
|  | 97 | +        result = self.coll.bulk_write(requests) | 
|  | 98 | +        self.assertTrue(result.acknowledged) | 
|  | 99 | + | 
|  | 100 | +    def test_aggregate_pipeline(self) -> None: | 
|  | 101 | +        coll3 = self.client.test.test3 | 
|  | 102 | +        coll3.insert_many( | 
|  | 103 | +            [ | 
|  | 104 | +                {"x": 1, "tags": ["dog", "cat"]}, | 
|  | 105 | +                {"x": 2, "tags": ["cat"]}, | 
|  | 106 | +                {"x": 2, "tags": ["mouse", "cat", "dog"]}, | 
|  | 107 | +                {"x": 3, "tags": []}, | 
|  | 108 | +            ] | 
|  | 109 | +        ) | 
|  | 110 | + | 
|  | 111 | +        class mydict(Dict[str, Any]): | 
|  | 112 | +            pass | 
|  | 113 | + | 
|  | 114 | +        result = coll3.aggregate( | 
|  | 115 | +            [ | 
|  | 116 | +                mydict({"$unwind": "$tags"}), | 
|  | 117 | +                {"$group": {"_id": "$tags", "count": {"$sum": 1}}}, | 
|  | 118 | +                {"$sort": SON([("count", -1), ("_id", -1)])}, | 
|  | 119 | +            ] | 
|  | 120 | +        ) | 
|  | 121 | +        self.assertTrue(len(list(result))) | 
|  | 122 | + | 
|  | 123 | + | 
|  | 124 | +if __name__ == "__main__": | 
|  | 125 | +    unittest.main() | 
0 commit comments