Source code for switchyard.lib.packet.ipv4

import struct
from abc import ABCMeta, abstractmethod
from ipaddress import IPv4Address
from collections import namedtuple

from switchyard.lib.packet.packet import PacketHeaderBase,Packet
from switchyard.lib.address import EthAddr,IPAddr,SpecialIPv4Addr,SpecialEthAddr
from switchyard.lib.packet.common import IPProtocol,IPFragmentFlag,IPOptionNumber, checksum
from switchyard.lib.packet.icmp import ICMP
from switchyard.lib.packet.igmp import IGMP
from switchyard.lib.packet.udp import UDP
from switchyard.lib.packet.tcp import TCP

'''
References:
    RFC791, INTERNET PROTOCOL.  DARPA INTERNET PROGRAM PROTOCOL SPECIFICATION.
        September 1981.
    RFC 1063, MTU discovery options.
    RFC 2113, Router alert option.
'''

IPTypeClasses = {
    IPProtocol.ICMP: ICMP,
    IPProtocol.TCP: TCP,
    IPProtocol.UDP: UDP,
    IPProtocol.IGMP: IGMP,
}


class IPOption(object, metaclass=ABCMeta):
    _PACKFMT = 'B'
    __slots__ = ['_optnum']
    def __init__(self, optnum):
        self._optnum = IPOptionNumber(optnum)

    @property
    def optnum(self):
        return self._optnum

    def length(self):
        return struct.calcsize(IPOption._PACKFMT)

    def to_bytes(self):
        return struct.pack(IPOption._PACKFMT, self._optnum.value)

    def from_bytes(self, raw):
        return self.length()

    def __eq__(self, other):
        return self._optnum == other._optnum


class IPOptionNoOperation(IPOption):
    def __init__(self):
        super().__init__(IPOptionNumber.NoOperation)

 
class IPOptionEndOfOptionList(IPOption):
    def __init__(self):
        super().__init__(IPOptionNumber.EndOfOptionList)


class IPOptionXRouting(IPOption):
    _PACKFMT = 'BBB'
    __slots__ = ['_routedata','_ptr']
    def __init__(self, ipoptnum, numaddrs=9):
        super().__init__(ipoptnum)
        if numaddrs < 1 or numaddrs > 9:
            raise Exception("Invalid number of addresses for IP routing-type option (must be 1-9)")
        self._routedata = [IPv4Address("0.0.0.0")] * numaddrs
        self._ptr = 4

    def length(self):
        return struct.calcsize(IPOptionXRouting._PACKFMT)+len(self._routedata)*4

    def __len__(self):
        return len(self._routedata)

    def to_bytes(self):
        raw = struct.pack(IPOptionXRouting._PACKFMT,self.optnum.value,self.length(), self._ptr)
        for ipaddr in self._routedata:
            raw += ipaddr.packed
        return raw

    def from_bytes(self, raw):
        xtype = raw[0]
        length = raw[1]
        pointer = raw[2]
        numaddrs = ((length - 3) // 4)
        self._routedata = []
        for i in range(numaddrs):
            self._routedata.append(IPv4Address(raw[(3+(i*4)):(7+(i*4))]))
        self.pointer = pointer
        return length

    @property
    def pointer(self):
        return self._ptr

    @pointer.setter
    def pointer(self, value):
        xval = value // 4 - 1
        if not 0 <= xval < len(self._routedata):
            raise ValueError("Invalid pointer value")
        self._ptr = value

    def num_addrs(self):
        return len(self._routedata)

    def __getitem__(self, index):
        if index < 0:
            index = len(self._routedata) + index
        if not 0 <= index < len(self._routedata):
            raise IndexError("Index out of range")
        return self._routedata[index]

    def __setitem__(self, index, addr):
        if not isinstance(addr, IPv4Address):
            raise ValueError("Value must be IPv4Address")
        if index < 0:
            index = len(self._routedata) + index
        if not 0 <= index < len(self._routedata):
            raise IndexError("Index out of range")
        self._routedata[index] = addr

    def __delitem__(self, index):
        if index < 0:
            index = len(self._routedata) + index
        if not 0 <= index < len(self._routedata):
            raise IndexError("Index out of range")
        del self._routedata[index] 

    def __eq__(self, other):
        return self.optnum == other.optnum and \
            self._ptr == other._ptr and \
            self._routedata == other._routedata

class IPOptionLooseSourceRouting(IPOptionXRouting):
    def __init__(self):
        super().__init__(IPOptionNumber.LooseSourceRouting)


class IPOptionStrictSourceRouting(IPOptionXRouting):
    def __init__(self):
        super().__init__(IPOptionNumber.StrictSourceRouting)


class IPOptionRecordRoute(IPOptionXRouting):
    def __init__(self):
        super().__init__(IPOptionNumber.RecordRoute)


TimestampEntry = namedtuple('TimestampEntry', ['ipv4addr','timestamp'])

class IPOptionTimestamp(IPOption):
    __slots__ = ['_entries','_ptr','_flag']

    def __init__(self, tslist=[]):
        super().__init__(IPOptionNumber.Timestamp)
        self._entries = [TimestampEntry(IPv4Address("0.0.0.0"), 0)] * 4
        self._ptr = 5
        # flags: 0x0 only timestamps, 0x1 ipaddr and timestamp, 0x3 optlist initialized
        # with up to 4 pairs of ipaddr and 0 timestamps
        self._flag = 0x1

    def length(self):
        entrysize = 8
        if self._flag == 0: entrysize = 4
        return 4 + len(self._entries)*entrysize

    def to_bytes(self):
        raw = struct.pack('!BBBB', 0x40 | self.optnum.value, self.length(),
            self._ptr, self._flag)
        for i in range(len(self._entries)):
            if self._flag > 0:
                raw += self._entries[i].ipv4addr.packed
            raw += struct.pack('!I', self._entries[i].timestamp)
        return raw

    def from_bytes(self, raw):
        fields = struct.unpack('!BBBB', raw[:4])
        self._ptr = fields[2]
        self._flag = fields[3]&0x0f
        self._entries = []
        xlen = fields[1]
        if xlen > len(raw):
            raise Exception("Not enough data to unpack raw {}: need {} but only have {}".format(self.__class__.__name__, xlen, len(raw)))
        raw = raw[4:xlen]
        haveipaddr = self._flag != 0
        unpackfmt = '!II'
        if not haveipaddr:
            unpackfmt = '!I' 
        for tstup in struct.iter_unpack(unpackfmt, raw):
            if haveipaddr:
                ts = TimestampEntry(*tstup)
            else:
                ts = TimestampEntry(None, tstup[0])
            self._entries.append(ts)
        return xlen

    def num_timestamps(self):
        return len(self._entries)
        
    def timestamp_entry(self, index):
        return self._entries[index]


class IPOption4Bytes(IPOption):
    __slots__ = ['_value', '_copyflag']
    _PACKFMT = '!BBH'

    def __init__(self, optnum, value=0, copyflag=False):
        super().__init__(optnum)
        self._value = value
        self._copyflag = 0
        if copyflag:
            self._copyflag = 0x80
    
    def length(self):
        return struct.calcsize(IPOption4Bytes._PACKFMT)

    def from_bytes(self, raw):
        fields = struct.unpack(IPOption4Bytes._PACKFMT, raw[:4])
        self._value = fields[2]
        return self.length()

    def to_bytes(self):
        return struct.pack(IPOption4Bytes._PACKFMT, 
            self._copyflag | self.optnum.value, self.length(), self._value)

    def __eq__(self, other):
        return self.optnum == other.optnum and \
            self._value == other._value and \
            self._copyflag == other._copyflag    


class IPOptionRouterAlert(IPOption4Bytes):
    def __init__(self):
        super().__init__(IPOptionNumber.RouterAlert, copyflag=True)


class IPOptionMTUProbe(IPOption4Bytes):
    def __init__(self):
        super().__init__(IPOptionNumber.MTUProbe, value=1500, copyflag=False)


class IPOptionMTUReply(IPOption4Bytes):
    def __init__(self):
        super().__init__(IPOptionNumber.MTUReply, value=1500, copyflag=False)


IPOptionClasses = {
    IPOptionNumber.EndOfOptionList: IPOptionEndOfOptionList,
    IPOptionNumber.NoOperation: IPOptionNoOperation,
    IPOptionNumber.LooseSourceRouting: IPOptionLooseSourceRouting,
    IPOptionNumber.Timestamp: IPOptionTimestamp,
    IPOptionNumber.RecordRoute: IPOptionRecordRoute,
    IPOptionNumber.StrictSourceRouting: IPOptionStrictSourceRouting,
    IPOptionNumber.MTUProbe: IPOptionMTUProbe,
    IPOptionNumber.MTUReply: IPOptionMTUReply,
    IPOptionNumber.RouterAlert: IPOptionRouterAlert,
}

class IPOptionList(object):
    def __init__(self):
        self._options = []

    @staticmethod
    def from_bytes(rawbytes):
        '''
        Takes a byte string as a parameter and returns a list of
        IPOption objects.
        '''
        ipopts = IPOptionList()

        i = 0
        while i < len(rawbytes):
            opttype = rawbytes[i]
            optcopied = opttype >> 7         # high order 1 bit
            optclass = (opttype >> 5) & 0x03 # next 2 bits
            optnum = opttype & 0x1f          # low-order 5 bits are optnum
            optnum = IPOptionNumber(optnum)
            obj = IPOptionClasses[optnum]()
            eaten = obj.from_bytes(rawbytes[i:])
            i += eaten
            ipopts.append(obj)
        return ipopts

    def to_bytes(self):
        '''
        Takes a list of IPOption objects and returns a packed byte string
        of options, appropriately padded if necessary.
        '''
        raw = b''
        if not self._options:
            return raw
        for ipopt in self._options:
            raw += ipopt.to_bytes()
        padbytes = 4 - (len(raw) % 4)
        raw += b'\x00'*padbytes
        return raw
    
    def append(self, opt):
        if isinstance(opt, IPOption):
            self._options.append(opt)
        else:
            raise Exception("Option to be added must be an IPOption object")

    def __len__(self):
        return len(self._options)

    def __getitem__(self, i):
        if i < 0:
            i = len(self._options) + i
        if 0 <= i < len(self._options):
            return self._options[i]
        raise IndexError("Invalid IP option index")

    def __setitem__(self, i, val):
        if i < 0:
            i = len(self._options) + i
        if not issubclass(val.__class__, IPOption):
            raise ValueError("Assigned value must be of type IPOption, but {} is not.".format(val.__class__.__name__))
        if 0 <= i < len(self._options):
            self._options[i] = val
        else:
            raise IndexError("Invalid IP option index")

    def __delitem__(self, i):
        if i < 0:
            i = len(self._options) + i
        if 0 <= i < len(self._options):
            del self._options[i]
        else:
            raise IndexError("Invalid IP option index")

    def raw_length(self):
        return len(self.to_bytes())

    def size(self):
        return len(self._options)

    def __eq__(self, other):
        if len(self._options) != len(other._options):
            return False
        return self._options == other._options


[docs]class IPv4(PacketHeaderBase): __slots__ = ['_tos','_totallen','_ttl', '_ipid','_flags','_fragoffset', '_protocol','_csum', '_srcip','_dstip','_options'] _PACKFMT = '!BBHHHBBH4s4s' _MINLEN = struct.calcsize(_PACKFMT) def __init__(self, **kwargs): # fill in fields with (essentially) zero values self.tos = 0x00 self._totallen = IPv4._MINLEN self.ipid = 0x0000 self.ttl = 0 self._flags = IPFragmentFlag.NoFragments self._fragoffset = 0 self.protocol = IPProtocol.ICMP self._csum = 0x0000 self.src = SpecialIPv4Addr.IP_ANY.value self.dst = SpecialIPv4Addr.IP_ANY.value self._options = IPOptionList() super().__init__(**kwargs) def size(self): return struct.calcsize(IPv4._PACKFMT) + self._options.raw_length() def pre_serialize(self, raw, pkt, i): self._totallen = self.size() + len(raw) def to_bytes(self): iphdr = struct.pack(IPv4._PACKFMT, 4 << 4 | self.hl, self.tos, self._totallen, self.ipid, self._flags.value << 13 | self.fragment_offset, self.ttl, self.protocol.value, self.checksum, self.src.packed, self.dst.packed) return iphdr + self._options.to_bytes() def from_bytes(self, raw): if len(raw) < 20: raise Exception("Not enough data to unpack IPv4 header (only {} bytes)".format(len(raw))) headerfields = struct.unpack(IPv4._PACKFMT, raw[:20]) v = headerfields[0] >> 4 if v != 4: raise Exception("Version in raw bytes for IPv4 isn't 4!") hl = (headerfields[0] & 0x0f) * 4 if len(raw) < hl: raise Exception("Not enough data to unpack IPv4 header (only {} bytes, but header length field claims {})".format(len(raw), hl)) optionbytes = raw[20:hl] self.tos = headerfields[1] self._totallen = headerfields[2] self.ipid = headerfields[3] self.flags = IPFragmentFlag(headerfields[4] >> 13) self.fragment_offset = headerfields[4] & 0x1fff self.ttl = headerfields[5] self.protocol = IPProtocol(headerfields[6]) self._csum = headerfields[7] self.src = headerfields[8] self.dst = headerfields[9] self._options = IPOptionList.from_bytes(optionbytes) return raw[hl:] def __eq__(self, other): return self.tos == other.tos and \ self.ipid == other.ipid and \ self.flags == other.flags and \ self.fragment_offset == other.fragment_offset and \ self.ttl == other.ttl and \ self.protocol == other.protocol and \ self.src == other.src and \ self.dst == other.dst # self.checksum == other.checksum and \ def next_header_class(self): cls = IPTypeClasses.get(self.protocol, None) if cls is None: print ("Warning: no class exists to parse next protocol type: {}".format(self.protocol)) return cls # accessors and mutators @property def options(self): return self._options @property def total_length(self): return self._totallen @property def ttl(self): return self._ttl @ttl.setter def ttl(self, value): value = int(value) if not (0 <= value <= 255): raise ValueError("Invalid TTL value {}".format(value)) self._ttl = value @property def tos(self): return self._tos @tos.setter def tos(self, value): if not (0 <= value < 256): raise Exception("Invalid type of service value; must be 0-255") self._tos = value @property def dscp(self): return self._tos >> 2 @property def ecn(self): return (self._tos & 0x03) @dscp.setter def dscp(self, value): if not (0 <= value < 64): raise Exception("Invalid DSCP value; must be 0-63") self._tos = (self._tos & 0x03) | value << 2 @ecn.setter def ecn(self, value): if not (0 <= value < 4): raise Exeption("Invalid ECN value; must be 0-3") self._tos = (self._tos & 0xfa) | value @property def ipid(self): return self._ipid @ipid.setter def ipid(self, value): if not (0 <= value < 65536): raise Exception("Invalid IP ID value; must be 0-65535") self._ipid = value @property def protocol(self): return self._protocol @protocol.setter def protocol(self, value): self._protocol = IPProtocol(value) @property def src(self): return self._srcip @src.setter def src(self, value): self._srcip = IPAddr(value) @property def srcip(self): '''Deprecated property. Use src instead.''' return self._srcip @srcip.setter def srcip(self, value): '''Deprecated property. Use src instead.''' self._srcip = IPAddr(value) @property def dst(self): return self._dstip @dst.setter def dst(self, value): self._dstip = IPAddr(value) @property def dstip(self): '''Deprecated property. Use dst instead.''' return self._dstip @dstip.setter def dstip(self, value): '''Deprecated property. Use dst instead.''' self._dstip = IPAddr(value) @property def flags(self): return self._flags @flags.setter def flags(self, value): self._flags = IPFragmentFlag(value) @property def fragment_offset(self): return self._fragoffset @fragment_offset.setter def fragment_offset(self, value): if not (0 <= value < 2**13): raise Exception("Invalid fragment offset value") self._fragoffset = value @property def hl(self): return self.size() // 4 @property def checksum(self): data = struct.pack(IPv4._PACKFMT, (4 << 4) + self.hl, self.tos, self._totallen, self.ipid, (self.flags.value << 13) | self.fragment_offset, self.ttl, self.protocol.value, 0, self.src.packed, self.dst.packed) data += self._options.to_bytes() self._csum = checksum(data, 0) return self._csum def __str__(self): return '{} {}->{} {}'.format(self.__class__.__name__, self.src, self.dst, self.protocol.name)