import struct
from abc import ABCMeta, abstractmethod
from ipaddress import IPv4Address
from collections import namedtuple
from switchyard.lib.packet.packet import PacketHeaderBase,Packet
from switchyard.lib.address import EthAddr,IPAddr,SpecialIPv4Addr,SpecialEthAddr
from switchyard.lib.packet.common import IPProtocol,IPFragmentFlag,IPOptionNumber, checksum
from switchyard.lib.packet.icmp import ICMP
from switchyard.lib.packet.igmp import IGMP
from switchyard.lib.packet.udp import UDP
from switchyard.lib.packet.tcp import TCP
'''
References:
RFC791, INTERNET PROTOCOL. DARPA INTERNET PROGRAM PROTOCOL SPECIFICATION.
September 1981.
RFC 1063, MTU discovery options.
RFC 2113, Router alert option.
'''
IPTypeClasses = {
IPProtocol.ICMP: ICMP,
IPProtocol.TCP: TCP,
IPProtocol.UDP: UDP,
IPProtocol.IGMP: IGMP,
}
class IPOption(object, metaclass=ABCMeta):
_PACKFMT = 'B'
__slots__ = ['_optnum']
def __init__(self, optnum):
self._optnum = IPOptionNumber(optnum)
@property
def optnum(self):
return self._optnum
def length(self):
return struct.calcsize(IPOption._PACKFMT)
def to_bytes(self):
return struct.pack(IPOption._PACKFMT, self._optnum.value)
def from_bytes(self, raw):
return self.length()
def __eq__(self, other):
return self._optnum == other._optnum
class IPOptionNoOperation(IPOption):
def __init__(self):
super().__init__(IPOptionNumber.NoOperation)
class IPOptionEndOfOptionList(IPOption):
def __init__(self):
super().__init__(IPOptionNumber.EndOfOptionList)
class IPOptionXRouting(IPOption):
_PACKFMT = 'BBB'
__slots__ = ['_routedata','_ptr']
def __init__(self, ipoptnum, numaddrs=9):
super().__init__(ipoptnum)
if numaddrs < 1 or numaddrs > 9:
raise Exception("Invalid number of addresses for IP routing-type option (must be 1-9)")
self._routedata = [IPv4Address("0.0.0.0")] * numaddrs
self._ptr = 4
def length(self):
return struct.calcsize(IPOptionXRouting._PACKFMT)+len(self._routedata)*4
def __len__(self):
return len(self._routedata)
def to_bytes(self):
raw = struct.pack(IPOptionXRouting._PACKFMT,self.optnum.value,self.length(), self._ptr)
for ipaddr in self._routedata:
raw += ipaddr.packed
return raw
def from_bytes(self, raw):
xtype = raw[0]
length = raw[1]
pointer = raw[2]
numaddrs = ((length - 3) // 4)
self._routedata = []
for i in range(numaddrs):
self._routedata.append(IPv4Address(raw[(3+(i*4)):(7+(i*4))]))
self.pointer = pointer
return length
@property
def pointer(self):
return self._ptr
@pointer.setter
def pointer(self, value):
xval = value // 4 - 1
if not 0 <= xval < len(self._routedata):
raise ValueError("Invalid pointer value")
self._ptr = value
def num_addrs(self):
return len(self._routedata)
def __getitem__(self, index):
if index < 0:
index = len(self._routedata) + index
if not 0 <= index < len(self._routedata):
raise IndexError("Index out of range")
return self._routedata[index]
def __setitem__(self, index, addr):
if not isinstance(addr, IPv4Address):
raise ValueError("Value must be IPv4Address")
if index < 0:
index = len(self._routedata) + index
if not 0 <= index < len(self._routedata):
raise IndexError("Index out of range")
self._routedata[index] = addr
def __delitem__(self, index):
if index < 0:
index = len(self._routedata) + index
if not 0 <= index < len(self._routedata):
raise IndexError("Index out of range")
del self._routedata[index]
def __eq__(self, other):
return self.optnum == other.optnum and \
self._ptr == other._ptr and \
self._routedata == other._routedata
class IPOptionLooseSourceRouting(IPOptionXRouting):
def __init__(self):
super().__init__(IPOptionNumber.LooseSourceRouting)
class IPOptionStrictSourceRouting(IPOptionXRouting):
def __init__(self):
super().__init__(IPOptionNumber.StrictSourceRouting)
class IPOptionRecordRoute(IPOptionXRouting):
def __init__(self):
super().__init__(IPOptionNumber.RecordRoute)
TimestampEntry = namedtuple('TimestampEntry', ['ipv4addr','timestamp'])
class IPOptionTimestamp(IPOption):
__slots__ = ['_entries','_ptr','_flag']
def __init__(self, tslist=[]):
super().__init__(IPOptionNumber.Timestamp)
self._entries = [TimestampEntry(IPv4Address("0.0.0.0"), 0)] * 4
self._ptr = 5
# flags: 0x0 only timestamps, 0x1 ipaddr and timestamp, 0x3 optlist initialized
# with up to 4 pairs of ipaddr and 0 timestamps
self._flag = 0x1
def length(self):
entrysize = 8
if self._flag == 0: entrysize = 4
return 4 + len(self._entries)*entrysize
def to_bytes(self):
raw = struct.pack('!BBBB', 0x40 | self.optnum.value, self.length(),
self._ptr, self._flag)
for i in range(len(self._entries)):
if self._flag > 0:
raw += self._entries[i].ipv4addr.packed
raw += struct.pack('!I', self._entries[i].timestamp)
return raw
def from_bytes(self, raw):
fields = struct.unpack('!BBBB', raw[:4])
self._ptr = fields[2]
self._flag = fields[3]&0x0f
self._entries = []
xlen = fields[1]
if xlen > len(raw):
raise Exception("Not enough data to unpack raw {}: need {} but only have {}".format(self.__class__.__name__, xlen, len(raw)))
raw = raw[4:xlen]
haveipaddr = self._flag != 0
unpackfmt = '!II'
if not haveipaddr:
unpackfmt = '!I'
for tstup in struct.iter_unpack(unpackfmt, raw):
if haveipaddr:
ts = TimestampEntry(*tstup)
else:
ts = TimestampEntry(None, tstup[0])
self._entries.append(ts)
return xlen
def num_timestamps(self):
return len(self._entries)
def timestamp_entry(self, index):
return self._entries[index]
class IPOption4Bytes(IPOption):
__slots__ = ['_value', '_copyflag']
_PACKFMT = '!BBH'
def __init__(self, optnum, value=0, copyflag=False):
super().__init__(optnum)
self._value = value
self._copyflag = 0
if copyflag:
self._copyflag = 0x80
def length(self):
return struct.calcsize(IPOption4Bytes._PACKFMT)
def from_bytes(self, raw):
fields = struct.unpack(IPOption4Bytes._PACKFMT, raw[:4])
self._value = fields[2]
return self.length()
def to_bytes(self):
return struct.pack(IPOption4Bytes._PACKFMT,
self._copyflag | self.optnum.value, self.length(), self._value)
def __eq__(self, other):
return self.optnum == other.optnum and \
self._value == other._value and \
self._copyflag == other._copyflag
class IPOptionRouterAlert(IPOption4Bytes):
def __init__(self):
super().__init__(IPOptionNumber.RouterAlert, copyflag=True)
class IPOptionMTUProbe(IPOption4Bytes):
def __init__(self):
super().__init__(IPOptionNumber.MTUProbe, value=1500, copyflag=False)
class IPOptionMTUReply(IPOption4Bytes):
def __init__(self):
super().__init__(IPOptionNumber.MTUReply, value=1500, copyflag=False)
IPOptionClasses = {
IPOptionNumber.EndOfOptionList: IPOptionEndOfOptionList,
IPOptionNumber.NoOperation: IPOptionNoOperation,
IPOptionNumber.LooseSourceRouting: IPOptionLooseSourceRouting,
IPOptionNumber.Timestamp: IPOptionTimestamp,
IPOptionNumber.RecordRoute: IPOptionRecordRoute,
IPOptionNumber.StrictSourceRouting: IPOptionStrictSourceRouting,
IPOptionNumber.MTUProbe: IPOptionMTUProbe,
IPOptionNumber.MTUReply: IPOptionMTUReply,
IPOptionNumber.RouterAlert: IPOptionRouterAlert,
}
class IPOptionList(object):
def __init__(self):
self._options = []
@staticmethod
def from_bytes(rawbytes):
'''
Takes a byte string as a parameter and returns a list of
IPOption objects.
'''
ipopts = IPOptionList()
i = 0
while i < len(rawbytes):
opttype = rawbytes[i]
optcopied = opttype >> 7 # high order 1 bit
optclass = (opttype >> 5) & 0x03 # next 2 bits
optnum = opttype & 0x1f # low-order 5 bits are optnum
optnum = IPOptionNumber(optnum)
obj = IPOptionClasses[optnum]()
eaten = obj.from_bytes(rawbytes[i:])
i += eaten
ipopts.append(obj)
return ipopts
def to_bytes(self):
'''
Takes a list of IPOption objects and returns a packed byte string
of options, appropriately padded if necessary.
'''
raw = b''
if not self._options:
return raw
for ipopt in self._options:
raw += ipopt.to_bytes()
padbytes = 4 - (len(raw) % 4)
raw += b'\x00'*padbytes
return raw
def append(self, opt):
if isinstance(opt, IPOption):
self._options.append(opt)
else:
raise Exception("Option to be added must be an IPOption object")
def __len__(self):
return len(self._options)
def __getitem__(self, i):
if i < 0:
i = len(self._options) + i
if 0 <= i < len(self._options):
return self._options[i]
raise IndexError("Invalid IP option index")
def __setitem__(self, i, val):
if i < 0:
i = len(self._options) + i
if not issubclass(val.__class__, IPOption):
raise ValueError("Assigned value must be of type IPOption, but {} is not.".format(val.__class__.__name__))
if 0 <= i < len(self._options):
self._options[i] = val
else:
raise IndexError("Invalid IP option index")
def __delitem__(self, i):
if i < 0:
i = len(self._options) + i
if 0 <= i < len(self._options):
del self._options[i]
else:
raise IndexError("Invalid IP option index")
def raw_length(self):
return len(self.to_bytes())
def size(self):
return len(self._options)
def __eq__(self, other):
if len(self._options) != len(other._options):
return False
return self._options == other._options
[docs]class IPv4(PacketHeaderBase):
__slots__ = ['_tos','_totallen','_ttl',
'_ipid','_flags','_fragoffset',
'_protocol','_csum',
'_srcip','_dstip','_options']
_PACKFMT = '!BBHHHBBH4s4s'
_MINLEN = struct.calcsize(_PACKFMT)
def __init__(self, **kwargs):
# fill in fields with (essentially) zero values
self.tos = 0x00
self._totallen = IPv4._MINLEN
self.ipid = 0x0000
self.ttl = 0
self._flags = IPFragmentFlag.NoFragments
self._fragoffset = 0
self.protocol = IPProtocol.ICMP
self._csum = 0x0000
self.src = SpecialIPv4Addr.IP_ANY.value
self.dst = SpecialIPv4Addr.IP_ANY.value
self._options = IPOptionList()
super().__init__(**kwargs)
def size(self):
return struct.calcsize(IPv4._PACKFMT) + self._options.raw_length()
def pre_serialize(self, raw, pkt, i):
self._totallen = self.size() + len(raw)
def to_bytes(self):
iphdr = struct.pack(IPv4._PACKFMT,
4 << 4 | self.hl, self.tos, self._totallen,
self.ipid, self._flags.value << 13 | self.fragment_offset,
self.ttl, self.protocol.value, self.checksum,
self.src.packed, self.dst.packed)
return iphdr + self._options.to_bytes()
def from_bytes(self, raw):
if len(raw) < 20:
raise Exception("Not enough data to unpack IPv4 header (only {} bytes)".format(len(raw)))
headerfields = struct.unpack(IPv4._PACKFMT, raw[:20])
v = headerfields[0] >> 4
if v != 4:
raise Exception("Version in raw bytes for IPv4 isn't 4!")
hl = (headerfields[0] & 0x0f) * 4
if len(raw) < hl:
raise Exception("Not enough data to unpack IPv4 header (only {} bytes, but header length field claims {})".format(len(raw), hl))
optionbytes = raw[20:hl]
self.tos = headerfields[1]
self._totallen = headerfields[2]
self.ipid = headerfields[3]
self.flags = IPFragmentFlag(headerfields[4] >> 13)
self.fragment_offset = headerfields[4] & 0x1fff
self.ttl = headerfields[5]
self.protocol = IPProtocol(headerfields[6])
self._csum = headerfields[7]
self.src = headerfields[8]
self.dst = headerfields[9]
self._options = IPOptionList.from_bytes(optionbytes)
return raw[hl:]
def __eq__(self, other):
return self.tos == other.tos and \
self.ipid == other.ipid and \
self.flags == other.flags and \
self.fragment_offset == other.fragment_offset and \
self.ttl == other.ttl and \
self.protocol == other.protocol and \
self.src == other.src and \
self.dst == other.dst
# self.checksum == other.checksum and \
def next_header_class(self):
cls = IPTypeClasses.get(self.protocol, None)
if cls is None:
print ("Warning: no class exists to parse next protocol type: {}".format(self.protocol))
return cls
# accessors and mutators
@property
def options(self):
return self._options
@property
def total_length(self):
return self._totallen
@property
def ttl(self):
return self._ttl
@ttl.setter
def ttl(self, value):
value = int(value)
if not (0 <= value <= 255):
raise ValueError("Invalid TTL value {}".format(value))
self._ttl = value
@property
def tos(self):
return self._tos
@tos.setter
def tos(self, value):
if not (0 <= value < 256):
raise Exception("Invalid type of service value; must be 0-255")
self._tos = value
@property
def dscp(self):
return self._tos >> 2
@property
def ecn(self):
return (self._tos & 0x03)
@dscp.setter
def dscp(self, value):
if not (0 <= value < 64):
raise Exception("Invalid DSCP value; must be 0-63")
self._tos = (self._tos & 0x03) | value << 2
@ecn.setter
def ecn(self, value):
if not (0 <= value < 4):
raise Exeption("Invalid ECN value; must be 0-3")
self._tos = (self._tos & 0xfa) | value
@property
def ipid(self):
return self._ipid
@ipid.setter
def ipid(self, value):
if not (0 <= value < 65536):
raise Exception("Invalid IP ID value; must be 0-65535")
self._ipid = value
@property
def protocol(self):
return self._protocol
@protocol.setter
def protocol(self, value):
self._protocol = IPProtocol(value)
@property
def src(self):
return self._srcip
@src.setter
def src(self, value):
self._srcip = IPAddr(value)
@property
def srcip(self):
'''Deprecated property. Use src instead.'''
return self._srcip
@srcip.setter
def srcip(self, value):
'''Deprecated property. Use src instead.'''
self._srcip = IPAddr(value)
@property
def dst(self):
return self._dstip
@dst.setter
def dst(self, value):
self._dstip = IPAddr(value)
@property
def dstip(self):
'''Deprecated property. Use dst instead.'''
return self._dstip
@dstip.setter
def dstip(self, value):
'''Deprecated property. Use dst instead.'''
self._dstip = IPAddr(value)
@property
def flags(self):
return self._flags
@flags.setter
def flags(self, value):
self._flags = IPFragmentFlag(value)
@property
def fragment_offset(self):
return self._fragoffset
@fragment_offset.setter
def fragment_offset(self, value):
if not (0 <= value < 2**13):
raise Exception("Invalid fragment offset value")
self._fragoffset = value
@property
def hl(self):
return self.size() // 4
@property
def checksum(self):
data = struct.pack(IPv4._PACKFMT,
(4 << 4) + self.hl, self.tos,
self._totallen, self.ipid,
(self.flags.value << 13) | self.fragment_offset,
self.ttl,
self.protocol.value, 0, self.src.packed, self.dst.packed)
data += self._options.to_bytes()
self._csum = checksum(data, 0)
return self._csum
def __str__(self):
return '{} {}->{} {}'.format(self.__class__.__name__, self.src, self.dst, self.protocol.name)