diff --git a/rdpy/protocol/rdp/rdp.py b/rdpy/protocol/rdp/rdp.py index be9ca67..6aa3aae 100644 --- a/rdpy/protocol/rdp/rdp.py +++ b/rdpy/protocol/rdp/rdp.py @@ -324,10 +324,11 @@ class RDPServerController(pdu.layer.PDUServerListener): """ @summary: Controller use in server side mode """ - def __init__(self, privateKeyFileName, certificateFileName, colorDepth): + def __init__(self, colorDepth, privateKeyFileName = None, certificateFileName = None, rsaKeys = None): """ @param privateKeyFileName: file contain server private key @param certficiateFileName: file that contain public key + @param rsaKeys: {Tuple(rsa.PublicKey, rsa.PrivateKey)} rsa crypto @param colorDepth: 15, 16, 24 """ self._isReady = False @@ -336,11 +337,11 @@ class RDPServerController(pdu.layer.PDUServerListener): #build RDP protocol stack self._pduLayer = pdu.layer.Server(self) #secure layer - self._secLayer = sec.Server(self._pduLayer) + self._secLayer = sec.Server(self._pduLayer, rsaKeys) #multi channel service self._mcsLayer = mcs.Server(self._secLayer) #transport pdu layer - self._x224Layer = x224.Server(self._mcsLayer, privateKeyFileName, certificateFileName) + self._x224Layer = x224.Server(self._mcsLayer, privateKeyFileName, certificateFileName, False) #transport packet (protocol layer) self._tpktLayer = tpkt.TPKT(self._x224Layer) #fastpath stack @@ -553,7 +554,7 @@ class ServerFactory(layer.RawLayerServerFactory): @summary: Function call from twisted and build rdp protocol stack @param addr: destination address """ - controller = RDPServerController(self._privateKeyFileName, self._certificateFileName, self._colorDepth) + controller = RDPServerController(self._colorDepth, self._privateKeyFileName, self._certificateFileName) self.buildObserver(controller, addr) return controller.getProtocol() diff --git a/rdpy/protocol/rdp/sec.py b/rdpy/protocol/rdp/sec.py index 2fce0e8..85931d9 100644 --- a/rdpy/protocol/rdp/sec.py +++ b/rdpy/protocol/rdp/sec.py @@ -586,8 +586,14 @@ class Server(SecLayer): """ @summary: Client side of security layer """ - def __init__(self, presentation): + def __init__(self, presentation, rsaKeys = None): + """ + @param rsaKeys: {Tuple(rsa.PublicKey, rsa.PrivateKey)} rsa crypto + """ SecLayer.__init__(self, presentation) + self._rsaPublicKey, self._rsaPrivateKey = None, None + if not rsaKeys is None: + self._rsaPublicKey, self._rsaPrivateKey = rsaKeys def connect(self): """ diff --git a/rdpy/protocol/rdp/x224.py b/rdpy/protocol/rdp/x224.py index 5f04c04..24af912 100644 --- a/rdpy/protocol/rdp/x224.py +++ b/rdpy/protocol/rdp/x224.py @@ -220,16 +220,18 @@ class Server(X224Layer): """ @summary: Server automata of X224 layer """ - def __init__(self, presentation, privateKeyFileName, certificateFileName): + def __init__(self, presentation, privateKeyFileName = None, certificateFileName = None, forceSSL = False): """ - @param presentation: upper layer, MCS layer in RDP case - @param privateKeyFileName: file contain server private key - @param certficiateFileName: file that contain public key + @param presentation: {layer} upper layer, MCS layer in RDP case + @param privateKeyFileName: {str} file contain server private key + @param certficiateFileName: {str} file that contain public key + @param forceSSL: {boolean} reject old client that doerasn't support SSL """ X224Layer.__init__(self, presentation) #Server mode informations for TLS connection self._serverPrivateKeyFileName = privateKeyFileName self._serverCertificateFileName = certificateFileName + self._forceSSL = forceSSL and not self._serverPrivateKeyFileName is None and not self._serverCertificateFileName is None def connect(self): """ @@ -241,26 +243,33 @@ class Server(X224Layer): """ @summary: Read connection confirm packet Next state is send connection confirm - @param data: Stream + @param data: {Stream} @see : http://msdn.microsoft.com/en-us/library/cc240470.aspx """ message = ClientConnectionRequestPDU() data.readType(message) - if not message.protocolNeg._is_readed or message.protocolNeg.failureCode._is_readed: - raise InvalidExpectedDataException("Too older RDP client") + if not message.protocolNeg._is_readed: + self._requestedProtocol = Protocols.PROTOCOL_RDP + else: + self._requestedProtocol = message.protocolNeg.selectedProtocol.value - self._requestedProtocol = message.protocolNeg.selectedProtocol.value + #match best security layer available + if not self._serverPrivateKeyFileName is None and not self._serverCertificateFileName is None: + self._selectedProtocol = self._requestedProtocol & Protocols.PROTOCOL_SSL + else: + self._selectedProtocol = self._requestedProtocol & Protocols.PROTOCOL_RDP - if not self._requestedProtocol & Protocols.PROTOCOL_SSL: + #if force ssl is enable + if not self._selectedProtocol & Protocols.PROTOCOL_SSL and self._forceSSL: #send error message and quit message = ServerConnectionConfirm() message.protocolNeg.code.value = NegociationType.TYPE_RDP_NEG_FAILURE message.protocolNeg.failureCode.value = NegotiationFailureCode.SSL_REQUIRED_BY_SERVER self._transport.send(message) - raise InvalidExpectedDataException("rdpy needs ssl client compliant") + self.close() + return - self._selectedProtocol = Protocols.PROTOCOL_SSL self.sendConnectionConfirm() def sendConnectionConfirm(self): @@ -274,8 +283,10 @@ class Server(X224Layer): message.protocolNeg.code.value = NegociationType.TYPE_RDP_NEG_RSP message.protocolNeg.selectedProtocol.value = self._selectedProtocol self._transport.send(message) - #_transport is TPKT and transport is TCP layer of twisted - self._transport.transport.startTLS(ServerTLSContext(self._serverPrivateKeyFileName, self._serverCertificateFileName)) + if self._selectedProtocol == Protocols.PROTOCOL_SSL: + #_transport is TPKT and transport is TCP layer of twisted + self._transport.transport.startTLS(ServerTLSContext(self._serverPrivateKeyFileName, self._serverCertificateFileName)) + #connection is done send to presentation self.setNextState(self.recvData) self._presentation.connect() diff --git a/test/test_protocol_rdp_rc4.py b/test/test_protocol_rdp_rc4.py index 4f6cf7d..cfec719 100644 --- a/test/test_protocol_rdp_rc4.py +++ b/test/test_protocol_rdp_rc4.py @@ -26,7 +26,7 @@ import os, sys sys.path.insert(1, os.path.join(sys.path[0], '..')) import unittest -import rdpy.protocol.rdp.rc4 as rc4 +import rdpy.core.rc4 as rc4 class RC4Test(unittest.TestCase): diff --git a/test/test_protocol_rdp_x224.py b/test/test_protocol_rdp_x224.py index b56cc59..83c2dbe 100644 --- a/test/test_protocol_rdp_x224.py +++ b/test/test_protocol_rdp_x224.py @@ -172,37 +172,6 @@ class X224Test(unittest.TestCase): self.assertTrue(presentation_connect, "connect event is not forwarded") self.assertRaises(X224Test.X224_PASS, layer.recv, type.String('\x01\x02')) - def test_x224_server_recvConnectionRequest_invalid_old_client(self): - """ - @summary: unit test for X224Server.recvConnectionRequest function - old client with non protocol neg - """ - message = x224.ClientConnectionRequestPDU() - del message._typeName[message._typeName.index("protocolNeg")] - s = type.Stream() - s.writeType(message) - s.pos = 0 - - layer = x224.Server(None, "key", "cert") - layer.connect() - - self.assertRaises(error.InvalidExpectedDataException, layer.recv, s) - - def test_x224_server_recvConnectionRequest_invalid_protocol_neg_failure(self): - """ - @summary: unit test for X224Server.recvConnectionRequest function - """ - message = x224.ClientConnectionRequestPDU() - message.protocolNeg.code.value = x224.NegociationType.TYPE_RDP_NEG_FAILURE - s = type.Stream() - s.writeType(message) - s.pos = 0 - - layer = x224.Server(None, "key", "cert") - layer.connect() - - self.assertRaises(error.InvalidExpectedDataException, layer.recv, s) - def test_x224_server_recvConnectionRequest_client_accept_ssl(self): """ @summary: unit test for X224Server.recvConnectionRequest function @@ -215,6 +184,8 @@ class X224Test(unittest.TestCase): raise X224Test.X224_FAIL() if data.protocolNeg.code.value != x224.NegociationType.TYPE_RDP_NEG_FAILURE or data.protocolNeg.failureCode.value != x224.NegotiationFailureCode.SSL_REQUIRED_BY_SERVER: raise X224Test.X224_FAIL() + def close(self): + raise X224Test.X224_PASS() message = x224.ClientConnectionRequestPDU() message.protocolNeg.selectedProtocol.value = x224.Protocols.PROTOCOL_HYBRID @@ -222,11 +193,11 @@ class X224Test(unittest.TestCase): s.writeType(message) s.pos = 0 - layer = x224.Server(None, "key", "cert") + layer = x224.Server(None, "key", "cert", True) layer._transport = Transport() layer.connect() - self.assertRaises(error.InvalidExpectedDataException, layer.recv, s) + self.assertRaises(X224Test.X224_PASS, layer.recv, s) def test_x224_server_recvConnectionRequest_valid(self): """