Source code for switchyard.lib.packet.tcp

import struct
from enum import IntEnum
from abc import ABCMeta, abstractmethod

from .packet import PacketHeaderBase,Packet
from .common import checksum
from ..exceptions import *

'''
References:
    IETF RFCs 675, 793, 1122, 2581, 3540, 5681
'''

class TCPOption(metaclass=ABCMeta):
    @abstractmethod
    def to_bytes(self):
        pass

    @abstractmethod
    def from_bytes(self, raw):
        pass

    @abstractmethod
    def __eq__(self, other):
        pass

# EndOfOptions, Padding, MaxSegmentSize, WindowScaling, SACK, Timestamp, AltChecksum

class TCPOptions(PacketHeaderBase):
    __slots__ = ['_optlist']
    def __init__(self, **kwargs):
        self._optlist = []
        super().__init__(**kwargs)

    def size(self):
        return len(self.to_bytes())

    def next_header_class(self):
        return

    def pre_serialize(self, raw, pkt, i):
        return

    def __eq__(self, other):
        if self.size() != other.size():
            return False
        return True  # FIXME

    def to_bytes(self):
        return b''.join([opt.to_bytes() for opt in self._optlist])

    def from_bytes(self, raw):
        # FIXME
        return 0

class TCPFlags(IntEnum):
    FIN = 0
    SYN = 1
    RST = 2
    PSH = 3
    ACK = 4
    URG = 5
    ECE = 6 # ECN-echo RFC 3168
    CWR = 7 # Congestion-window reduced RFC 3168
    NS =  8 # ECN-nonce concealment protection RFC 3540

[docs]class TCP(PacketHeaderBase): __slots__ = ['_src','_dst','_seq','_ack', '_flags','_window','_urg','_options','_len', '_checksum'] _PACKFMT = '!HHIIHHHH' _MINLEN = struct.calcsize(_PACKFMT) _next_header_map = {} _next_header_class_key = '' def __init__(self, **kwargs): self.src = self.dst = 0 self.seq = self.ack = 0 self._flags = 0x000 self.window = 0 self.urgent_pointer = 0 self._options = TCPOptions() self._checksum = 0 self._len = 0 super().__init__(**kwargs) def size(self): return struct.calcsize(TCP._PACKFMT) def _compute_checksum_ipv4(self, ip4, xdata): if ip4 is None: return 0 phdr = struct.pack('!IIxBH', int(ip4.src), int(ip4.dst), ip4.protocol.value, self._len) tcphdr = self._make_header(0) return checksum(phdr + tcphdr + xdata) def pre_serialize(self, raw, pkt, i): self._len = self.size() + len(raw) # checksum calc currently assumes we're only dealing with ipv4. # will need to be modified for ipv6 support... self._checksum = self._compute_checksum_ipv4(pkt.get_header_by_name('IPv4'), raw) def _make_header(self, csum): offset_flags = self.offset << 12 | self._flags header = struct.pack(TCP._PACKFMT, self.src, self.dst, self.seq, self.ack, offset_flags, self.window, csum, self.urgent_pointer) return header def to_bytes(self): ''' Return packed byte representation of the TCP header. ''' header = self._make_header(self._checksum) return header + self._options.to_bytes() def from_bytes(self, raw): '''Return an Ethernet object reconstructed from raw bytes, or an Exception if we can't resurrect the packet.''' if len(raw) < TCP._MINLEN: raise NotEnoughDataError("Not enough bytes ({}) to reconstruct an TCP object".format(len(raw))) fields = struct.unpack(TCP._PACKFMT, raw[:TCP._MINLEN]) self._src = fields[0] self._dst = fields[1] self._seq = fields[2] self._ack = fields[3] offset = fields[4] >> 12 self._flags = fields[4] & 0x01ff self._window = fields[5] csum = fields[6] self._urg = fields[7] headerlen = offset * 4 optlen = headerlen - TCP._MINLEN self._options.from_bytes(raw[TCP._MINLEN:headerlen]) return raw[headerlen:] def __eq__(self, other): return self.src == other.src and \ self.dst == other.dst and \ self.seq == other.seq and \ self.ack == other.ack and \ self.offset == other.offset and \ self.flags == other.flags and \ self.window == other.window and \ self.urgent_pointer == other.urgent_pointer and \ self.options == other.options @property def offset(self): return TCP._MINLEN // 4 + len(self._options.to_bytes()) // 4 @property def src(self): return self._src @property def dst(self): return self._dst @src.setter def src(self,value): self._src = value @dst.setter def dst(self,value): self._dst = value def __str__(self): return '{} {}->{} ({} {}:{})'.format(self.__class__.__name__, self.src, self.dst, self.flagstr, self.seq, self.ack) @property def seq(self): return self._seq @seq.setter def seq(self, value): self._seq = value @property def ack(self): return self._ack @ack.setter def ack(self, value): self._ack = value @property def window(self): return self._window @window.setter def window(self, value): self._window = value @property def checksum(self): return self._checksum @property def flags(self): return self._flags @property def flagstr(self): flist = [] for f in range(9): if self._isset(f): flist.append(TCPFlags(f).name[0]) return "".join(flist) @property def urgent_pointer(self): return self._urg @urgent_pointer.setter def urgent_pointer(self, value): self._urg = value @property def options(self): return self._options def _isset(self, flag): if isinstance(flag, int): flag = TCPFlags(flag) mask = 0x01 << flag.value return (self._flags & mask) == mask def _setflag(self, flag, value): mask = 0x01 << flag.value if value: self._flags = self._flags | mask else: self._flags = self._flags & ~mask @property def NS(self): return self._isset(TCPFlags.NS) @NS.setter def NS(self, value): self._setflag(TCPFlags.NS, value) @property def CWR(self): return self._isset(TCPFlags.CWR) @CWR.setter def CWR(self, value): self._setflag(TCPFlags.CWR, value) @property def ECE(self): return self._isset(TCPFlags.ECE) @ECE.setter def ECE(self, value): self._setflag(TCPFlags.ECE, value) @property def URG(self): return self._isset(TCPFlags.URG) @URG.setter def URG(self, value): self._setflag(TCPFlags.URG, value) @property def ACK(self): return self._isset(TCPFlags.ACK) @ACK.setter def ACK(self, value): self._setflag(TCPFlags.ACK, value) @property def PSH(self): return self._isset(TCPFlags.PSH) @PSH.setter def PSH(self, value): self._setflag(TCPFlags.PSH, value) @property def RST(self): return self._isset(TCPFlags.RST) @RST.setter def RST(self, value): self._setflag(TCPFlags.RST, value) @property def SYN(self): return self._isset(TCPFlags.SYN) @SYN.setter def SYN(self, value): self._setflag(TCPFlags.SYN, value) @property def FIN(self): return self._isset(TCPFlags.FIN) @FIN.setter def FIN(self, value): self._setflag(TCPFlags.FIN, value)