[Python-checkins] cpython: Issue #14814: In the spirit of TOOWTDI, ditch the redundant version parameter

nick.coghlan python-checkins at python.org
Sat May 26 16:26:21 CEST 2012


http://hg.python.org/cpython/rev/f70e12499d05
changeset:   77160:f70e12499d05
user:        Nick Coghlan <ncoghlan at gmail.com>
date:        Sun May 27 00:25:58 2012 +1000
summary:
  Issue #14814: In the spirit of TOOWTDI, ditch the redundant version parameter to the factory functions by using the appropriate direct class references instead

files:
  Lib/ipaddress.py           |  122 +++++++++---------------
  Lib/test/test_ipaddress.py |   20 +---
  2 files changed, 48 insertions(+), 94 deletions(-)


diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py
--- a/Lib/ipaddress.py
+++ b/Lib/ipaddress.py
@@ -36,34 +36,22 @@
     """A Value Error related to the netmask."""
 
 
-def ip_address(address, version=None):
+def ip_address(address):
     """Take an IP string/int and return an object of the correct type.
 
     Args:
         address: A string or integer, the IP address.  Either IPv4 or
           IPv6 addresses may be supplied; integers less than 2**32 will
           be considered to be IPv4 by default.
-        version: An integer, 4 or 6.  If set, don't try to automatically
-          determine what the IP address type is.  Important for things
-          like ip_address(1), which could be IPv4, '192.0.2.1',  or IPv6,
-          '2001:db8::1'.
 
     Returns:
         An IPv4Address or IPv6Address object.
 
     Raises:
         ValueError: if the *address* passed isn't either a v4 or a v6
-          address, or if the version is not None, 4, or 6.
+          address
 
     """
-    if version is not None:
-        if version == 4:
-            return IPv4Address(address)
-        elif version == 6:
-            return IPv6Address(address)
-        else:
-            raise ValueError()
-
     try:
         return IPv4Address(address)
     except (AddressValueError, NetmaskValueError):
@@ -78,35 +66,22 @@
                      address)
 
 
-def ip_network(address, version=None, strict=True):
+def ip_network(address, strict=True):
     """Take an IP string/int and return an object of the correct type.
 
     Args:
         address: A string or integer, the IP network.  Either IPv4 or
           IPv6 networks may be supplied; integers less than 2**32 will
           be considered to be IPv4 by default.
-        version: An integer, 4 or 6.  If set, don't try to automatically
-          determine what the IP address type is.  Important for things
-          like ip_network(1), which could be IPv4, '192.0.2.1/32', or IPv6,
-          '2001:db8::1/128'.
 
     Returns:
         An IPv4Network or IPv6Network object.
 
     Raises:
         ValueError: if the string passed isn't either a v4 or a v6
-          address. Or if the network has host bits set.  Or if the version
-          is not None, 4, or 6.
+          address. Or if the network has host bits set.
 
     """
-    if version is not None:
-        if version == 4:
-            return IPv4Network(address, strict)
-        elif version == 6:
-            return IPv6Network(address, strict)
-        else:
-            raise ValueError()
-
     try:
         return IPv4Network(address, strict)
     except (AddressValueError, NetmaskValueError):
@@ -121,24 +96,20 @@
                      address)
 
 
-def ip_interface(address, version=None):
+def ip_interface(address):
     """Take an IP string/int and return an object of the correct type.
 
     Args:
         address: A string or integer, the IP address.  Either IPv4 or
           IPv6 addresses may be supplied; integers less than 2**32 will
           be considered to be IPv4 by default.
-        version: An integer, 4 or 6.  If set, don't try to automatically
-          determine what the IP address type is.  Important for things
-          like ip_interface(1), which could be IPv4, '192.0.2.1/32', or IPv6,
-          '2001:db8::1/128'.
 
     Returns:
         An IPv4Interface or IPv6Interface object.
 
     Raises:
         ValueError: if the string passed isn't either a v4 or a v6
-          address.  Or if the version is not None, 4, or 6.
+          address.
 
     Notes:
         The IPv?Interface classes describe an Address on a particular
@@ -146,14 +117,6 @@
         and Network classes.
 
     """
-    if version is not None:
-        if version == 4:
-            return IPv4Interface(address)
-        elif version == 6:
-            return IPv6Interface(address)
-        else:
-            raise ValueError()
-
     try:
         return IPv4Interface(address)
     except (AddressValueError, NetmaskValueError):
@@ -281,7 +244,7 @@
             If the first and last objects are not the same version.
         ValueError:
             If the last object is not greater than the first.
-            If the version is not 4 or 6.
+            If the version of the first address is not 4 or 6.
 
     """
     if (not (isinstance(first, _BaseAddress) and
@@ -318,7 +281,7 @@
         if current == ip._ALL_ONES:
             break
         first_int = current + 1
-        first = ip_address(first_int, version=first._version)
+        first = first.__class__(first_int)
 
 
 def _collapse_addresses_recursive(addresses):
@@ -586,12 +549,12 @@
     def __add__(self, other):
         if not isinstance(other, int):
             return NotImplemented
-        return ip_address(int(self) + other, version=self._version)
+        return self.__class__(int(self) + other)
 
     def __sub__(self, other):
         if not isinstance(other, int):
             return NotImplemented
-        return ip_address(int(self) - other, version=self._version)
+        return self.__class__(int(self) - other)
 
     def __repr__(self):
         return '%s(%r)' % (self.__class__.__name__, str(self))
@@ -612,13 +575,12 @@
 
 class _BaseNetwork(_IPAddressBase):
 
-    """A generic IP object.
+    """A generic IP network object.
 
     This IP class contains the version independent methods which are
     used by networks.
 
     """
-
     def __init__(self, address):
         self._cache = {}
 
@@ -642,14 +604,14 @@
         bcast = int(self.broadcast_address) - 1
         while cur <= bcast:
             cur += 1
-            yield ip_address(cur - 1, version=self._version)
+            yield self._address_class(cur - 1)
 
     def __iter__(self):
         cur = int(self.network_address)
         bcast = int(self.broadcast_address)
         while cur <= bcast:
             cur += 1
-            yield ip_address(cur - 1, version=self._version)
+            yield self._address_class(cur - 1)
 
     def __getitem__(self, n):
         network = int(self.network_address)
@@ -657,12 +619,12 @@
         if n >= 0:
             if network + n > broadcast:
                 raise IndexError
-            return ip_address(network + n, version=self._version)
+            return self._address_class(network + n)
         else:
             n += 1
             if broadcast + n < network:
                 raise IndexError
-            return ip_address(broadcast + n, version=self._version)
+            return self._address_class(broadcast + n)
 
     def __lt__(self, other):
         if self._version != other._version:
@@ -746,8 +708,8 @@
     def broadcast_address(self):
         x = self._cache.get('broadcast_address')
         if x is None:
-            x = ip_address(int(self.network_address) | int(self.hostmask),
-                           version=self._version)
+            x = self._address_class(int(self.network_address) |
+                                    int(self.hostmask))
             self._cache['broadcast_address'] = x
         return x
 
@@ -755,15 +717,15 @@
     def hostmask(self):
         x = self._cache.get('hostmask')
         if x is None:
-            x = ip_address(int(self.netmask) ^ self._ALL_ONES,
-                          version=self._version)
+            x = self._address_class(int(self.netmask) ^ self._ALL_ONES)
             self._cache['hostmask'] = x
         return x
 
     @property
     def network(self):
-        return ip_network('%s/%d' % (str(self.network_address),
-                                     self.prefixlen))
+        # XXX (ncoghlan): This is redundant now and will likely be removed
+        return self.__class__('%s/%d' % (str(self.network_address),
+                                         self.prefixlen))
 
     @property
     def with_prefixlen(self):
@@ -787,6 +749,10 @@
         raise NotImplementedError('BaseNet has no version')
 
     @property
+    def _address_class(self):
+        raise NotImplementedError('BaseNet has no associated address class')
+
+    @property
     def prefixlen(self):
         return self._prefixlen
 
@@ -840,9 +806,8 @@
             raise StopIteration
 
         # Make sure we're comparing the network of other.
-        other = ip_network('%s/%s' % (str(other.network_address),
-                                      str(other.prefixlen)),
-                           version=other._version)
+        other = other.__class__('%s/%s' % (str(other.network_address),
+                                           str(other.prefixlen)))
 
         s1, s2 = self.subnets()
         while s1 != other and s2 != other:
@@ -973,9 +938,9 @@
                 'prefix length diff %d is invalid for netblock %s' % (
                     new_prefixlen, str(self)))
 
-        first = ip_network('%s/%s' % (str(self.network_address),
-                                     str(self._prefixlen + prefixlen_diff)),
-                         version=self._version)
+        first = self.__class__('%s/%s' %
+                                 (str(self.network_address),
+                                  str(self._prefixlen + prefixlen_diff)))
 
         yield first
         current = first
@@ -983,16 +948,17 @@
             broadcast = current.broadcast_address
             if broadcast == self.broadcast_address:
                 return
-            new_addr = ip_address(int(broadcast) + 1, version=self._version)
-            current = ip_network('%s/%s' % (str(new_addr), str(new_prefixlen)),
-                                version=self._version)
+            new_addr = self._address_class(int(broadcast) + 1)
+            current = self.__class__('%s/%s' % (str(new_addr),
+                                                str(new_prefixlen)))
 
             yield current
 
     def masked(self):
         """Return the network object with the host bits masked out."""
-        return ip_network('%s/%d' % (self.network_address, self._prefixlen),
-                         version=self._version)
+        # XXX (ncoghlan): This is redundant now and will likely be removed
+        return self.__class__('%s/%d' % (self.network_address,
+                                         self._prefixlen))
 
     def supernet(self, prefixlen_diff=1, new_prefix=None):
         """The supernet containing the current network.
@@ -1030,11 +996,10 @@
                 'current prefixlen is %d, cannot have a prefixlen_diff of %d' %
                 (self.prefixlen, prefixlen_diff))
         # TODO (pmoody): optimize this.
-        t = ip_network('%s/%d' % (str(self.network_address),
-                                    self.prefixlen - prefixlen_diff),
-                         version=self._version, strict=False)
-        return ip_network('%s/%d' % (str(t.network_address), t.prefixlen),
-                          version=t._version)
+        t = self.__class__('%s/%d' % (str(self.network_address),
+                                      self.prefixlen - prefixlen_diff),
+                                     strict=False)
+        return t.__class__('%s/%d' % (str(t.network_address), t.prefixlen))
 
 
 class _BaseV4(object):
@@ -1391,6 +1356,9 @@
         .prefixlen: 27
 
     """
+    # Class to use when creating address objects
+    # TODO (ncoghlan): Investigate using IPv4Interface instead
+    _address_class = IPv4Address
 
     # the valid octets for host and netmasks. only useful for IPv4.
     _valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0))
@@ -2071,6 +2039,10 @@
 
     """
 
+    # Class to use when creating address objects
+    # TODO (ncoghlan): Investigate using IPv6Interface instead
+    _address_class = IPv6Address
+
     def __init__(self, address, strict=True):
         """Instantiate a new IPv6 Network object.
 
diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py
--- a/Lib/test/test_ipaddress.py
+++ b/Lib/test/test_ipaddress.py
@@ -780,12 +780,6 @@
         self.assertEqual(self.ipv4_address.version, 4)
         self.assertEqual(self.ipv6_address.version, 6)
 
-        with self.assertRaises(ValueError):
-            ipaddress.ip_address('1', version=[])
-
-        with self.assertRaises(ValueError):
-            ipaddress.ip_address('1', version=5)
-
     def testMaxPrefixLength(self):
         self.assertEqual(self.ipv4_interface.max_prefixlen, 32)
         self.assertEqual(self.ipv6_interface.max_prefixlen, 128)
@@ -1052,12 +1046,7 @@
 
     def testForceVersion(self):
         self.assertEqual(ipaddress.ip_network(1).version, 4)
-        self.assertEqual(ipaddress.ip_network(1, version=6).version, 6)
-
-        with self.assertRaises(ValueError):
-            ipaddress.ip_network(1, version='l')
-        with self.assertRaises(ValueError):
-            ipaddress.ip_network(1, version=3)
+        self.assertEqual(ipaddress.IPv6Network(1).version, 6)
 
     def testWithStar(self):
         self.assertEqual(str(self.ipv4_interface.with_prefixlen), "1.2.3.4/24")
@@ -1148,13 +1137,6 @@
                          sixtofouraddr.sixtofour)
         self.assertFalse(bad_addr.sixtofour)
 
-    def testIpInterfaceVersion(self):
-        with self.assertRaises(ValueError):
-            ipaddress.ip_interface(1, version=123)
-
-        with self.assertRaises(ValueError):
-            ipaddress.ip_interface(1, version='')
-
 
 if __name__ == '__main__':
     unittest.main()

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list