diff --git a/rdpy/protocol/rdp/tpkt.py b/rdpy/protocol/rdp/tpkt.py index dba74dd..fb64096 100644 --- a/rdpy/protocol/rdp/tpkt.py +++ b/rdpy/protocol/rdp/tpkt.py @@ -70,7 +70,7 @@ class TPKT(RawLayer, IFastPathSender): represent the Raw Layer in stack (first layer) This layer only handle size of packet and determine if is a fast path packet """ - def __init__(self, presentation, fastPathListener): + def __init__(self, presentation, fastPathListener = None): """ @param presentation: presentation layer, in RDP case is x224 layer @param fastPathListener: IFastPathListener @@ -82,8 +82,10 @@ class TPKT(RawLayer, IFastPathSender): self._lastShortLength = UInt8() #fast path listener self._fastPathListener = fastPathListener - #set me as fast path sender - fastPathListener.setFastPathSender(self) + + if not fastPathListener is None: + #set me as fast path sender + fastPathListener.setFastPathSender(self) def connect(self): """ diff --git a/test/test_protocol_rdp_tpkt.py b/test/test_protocol_rdp_tpkt.py new file mode 100644 index 0000000..13b7053 --- /dev/null +++ b/test/test_protocol_rdp_tpkt.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2014 Sylvain Peyrefitte +# +# This file is part of rdpy. +# +# rdpy is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +""" +unit test for rdpy.protocol.rdp.tpkt module +""" + +import os, sys +# Change path so we find rdpy +sys.path.insert(1, os.path.join(sys.path[0], '..')) + +import unittest +import rdpy.protocol.rdp.tpkt as tpkt +import rdpy.network.type as type +import rdpy.base.error as error + +class TPKTCase(unittest.TestCase): + """ + @summary: test case for tpkt layer (RDP) + """ + + class TPKT_PASS(Exception): + pass + + class TPKT_FAIL(Exception): + pass + + def test_tpkt_layer_connect(self): + """ + @summary: test forward connect event to presentation layer + """ + class Presentation(object): + def connect(self): + raise TPKTCase.TPKT_PASS() + + layer = tpkt.TPKT(Presentation(), None) + self.assertRaises(TPKTCase.TPKT_PASS, layer.connect) + + def test_tpkt_layer_recv(self): + """ + @summary: test receive in classic case + """ + class Presentation(object): + def connect(self): + pass + def recv(self, data): + data.readType(type.String("test_tpkt_layer_recv", constant = True)) + raise TPKTCase.TPKT_PASS() + + message = type.String("test_tpkt_layer_recv") + + s = type.Stream() + s.writeType((type.UInt8(tpkt.Action.FASTPATH_ACTION_X224), type.UInt8(), type.UInt16Be(type.sizeof(message) + 4), message)) + + layer = tpkt.TPKT(Presentation(), None) + layer.connect() + self.assertRaises(TPKTCase.TPKT_PASS, layer.dataReceived, s.getvalue()) + + def test_tpkt_layer_recv_fastpath(self): + """ + @summary: test receive in fastpath case + """ + class FastPathLayer(tpkt.IFastPathListener): + def setFastPathSender(self, fastPathSender): + pass + def recvFastPath(self, fastPathS): + fastPathS.readType(type.String("test_tpkt_layer_recv_fastpath", constant = True)) + raise TPKTCase.TPKT_PASS() + + message = type.String("test_tpkt_layer_recv_fastpath") + + s = type.Stream() + s.writeType((type.UInt8(tpkt.Action.FASTPATH_ACTION_FASTPATH), type.UInt8(type.sizeof(message) + 2), message)) + + layer = tpkt.TPKT(None, FastPathLayer()) + layer.connect() + self.assertRaises(TPKTCase.TPKT_PASS, layer.dataReceived, s.getvalue()) + + def test_tpkt_layer_recv_fastpath_ext_length(self): + """ + @summary: test receive in fastpath case with extended length + """ + class FastPathLayer(tpkt.IFastPathListener): + def setFastPathSender(self, fastPathSender): + pass + def recvFastPath(self, fastPathS): + fastPathS.readType(type.String("test_tpkt_layer_recv_fastpath_ext_length", constant = True)) + raise TPKTCase.TPKT_PASS() + + message = type.String("test_tpkt_layer_recv_fastpath_ext_length") + + s = type.Stream() + s.writeType((type.UInt8(tpkt.Action.FASTPATH_ACTION_FASTPATH), type.UInt16Be((type.sizeof(message) + 3) | 0x8000), message)) + + layer = tpkt.TPKT(None, FastPathLayer()) + layer.connect() + self.assertRaises(TPKTCase.TPKT_PASS, layer.dataReceived, s.getvalue()) diff --git a/test/test_protocol_rdp_x224.py b/test/test_protocol_rdp_x224.py index 4c27a19..69f018e 100644 --- a/test/test_protocol_rdp_x224.py +++ b/test/test_protocol_rdp_x224.py @@ -53,12 +53,12 @@ class X224Case(unittest.TestCase): """ class Presentation(object): def recv(self, data): - data.readType(type.String('\x01\x02', constant = True)) + data.readType(type.String('test_x224_layer_recvData', constant = True)) raise X224Case.X224_PASS() layer = x224.X224Layer(Presentation()) s = type.Stream() - s.writeType((x224.X224DataHeader(), type.String('\x01\x02'))) + s.writeType((x224.X224DataHeader(), type.String('test_x224_layer_recvData'))) #reinit position s.pos = 0 @@ -74,13 +74,13 @@ class X224Case(unittest.TestCase): s.writeType(data) s.pos = 0 s.readType(x224.X224DataHeader()) - s.readType(type.String('\x01\x02', constant = True)) + s.readType(type.String('test_x224_layer_send', constant = True)) raise X224Case.X224_PASS() layer = x224.X224Layer(None) layer._transport = Transport() - self.assertRaises(X224Case.X224_PASS, layer.send, type.String('\x01\x02')) + self.assertRaises(X224Case.X224_PASS, layer.send, type.String('test_x224_layer_send')) def test_x224_client_connect(self): """