From 73b97d692979140309e9f594112404fa063f897f Mon Sep 17 00:00:00 2001 From: citronneur Date: Tue, 21 Apr 2020 22:02:35 +0200 Subject: [PATCH] Update: Start MCS layer --- bin/rdpy-rdpclient.py | 5 +- rdpy/core/rdp.py | 4 +- rdpy/core/t125/ber.py | 24 ++- rdpy/core/t125/gcc.py | 194 ++++++++++------------ rdpy/core/t125/mcs.py | 365 +++++++++++++++++++++++++----------------- rdpy/core/t125/per.py | 17 +- rdpy/core/x224.py | 21 +-- rdpy/model/message.py | 98 ++++-------- 8 files changed, 382 insertions(+), 346 deletions(-) diff --git a/bin/rdpy-rdpclient.py b/bin/rdpy-rdpclient.py index 69bc0da..7503a95 100755 --- a/bin/rdpy-rdpclient.py +++ b/bin/rdpy-rdpclient.py @@ -26,6 +26,7 @@ import asyncio from rdpy.core import tpkt, x224 from rdpy.core.nla import ntlm +from rdpy.core.t125 import mcs from rdpy.model.message import UInt8 if __name__ == '__main__': @@ -37,7 +38,9 @@ if __name__ == '__main__': '127.0.0.1', 33389) x224_layer = await x224.connect(tpkt.Tpkt(reader, writer), ntlm.NTLMv2("", "sylvain", "sylvain")) - await x224_layer.write(UInt8(8)) + mcs_layer = mcs.Client(x224_layer) + await mcs_layer.connect() + await asyncio.sleep(10) print("foooooooooooooooooooo") diff --git a/rdpy/core/rdp.py b/rdpy/core/rdp.py index eb4575b..395371b 100644 --- a/rdpy/core/rdp.py +++ b/rdpy/core/rdp.py @@ -99,8 +99,8 @@ class RDPClientController(PDUClientListener): @param height: height in pixel of screen """ #set screen definition in MCS layer - self._mcsLayer._clientSettings.getBlock(gcc.MessageType.CS_CORE).desktopHeight.value = height - self._mcsLayer._clientSettings.getBlock(gcc.MessageType.CS_CORE).desktopWidth.value = width + self._mcsLayer._clientSettings.get_block(gcc.MessageType.CS_CORE).desktopHeight.value = height + self._mcsLayer._clientSettings.get_block(gcc.MessageType.CS_CORE).desktopWidth.value = width def setUsername(self, username): """ diff --git a/rdpy/core/t125/ber.py b/rdpy/core/t125/ber.py index 709eb51..e27b7bd 100644 --- a/rdpy/core/t125/ber.py +++ b/rdpy/core/t125/ber.py @@ -22,22 +22,25 @@ Basic Encoding Rules use in RDP. ASN.1 standard """ -from rdpy.model.type import UInt8, UInt16Be, UInt32Be, Buffer +from rdpy.model.message import UInt8, UInt16Be, UInt32Be, Buffer from rdpy.model.error import InvalidExpectedDataException, InvalidSize -class BerPc(object): + +class BerPc: BER_PC_MASK = 0x20 BER_PRIMITIVE = 0x00 BER_CONSTRUCT = 0x20 -class Class(object): + +class Class: BER_CLASS_MASK = 0xC0 BER_CLASS_UNIV = 0x00 BER_CLASS_APPL = 0x40 BER_CLASS_CTXT = 0x80 BER_CLASS_PRIV = 0xC0 - -class Tag(object): + + +class Tag: BER_TAG_MASK = 0x1F BER_TAG_BOOLEAN = 0x01 BER_TAG_INTEGER = 0x02 @@ -48,6 +51,7 @@ class Tag(object): BER_TAG_SEQUENCE = 0x10 BER_TAG_SEQUENCE_OF = 0x10 + def berPC(pc): """ @summary: Return BER_CONSTRUCT if true @@ -59,7 +63,8 @@ def berPC(pc): return BerPc.BER_CONSTRUCT else: return BerPc.BER_PRIMITIVE - + + def readLength(s): """ @summary: Read length of BER structure @@ -84,6 +89,7 @@ def readLength(s): size = length return size.value + def writeLength(size): """ @summary: Return structure length as expected in BER specification @@ -94,7 +100,8 @@ def writeLength(size): return (UInt8(0x82), UInt16Be(size)) else: return UInt8(size) - + + def readUniversalTag(s, tag, pc): """ @summary: Read tag of BER packet @@ -106,6 +113,7 @@ def readUniversalTag(s, tag, pc): s.read_type(byte) return byte.value == ((Class.BER_CLASS_UNIV | berPC(pc)) | (Tag.BER_TAG_MASK & tag)) + def writeUniversalTag(tag, pc): """ @summary: Return universal tag byte @@ -115,6 +123,7 @@ def writeUniversalTag(tag, pc): """ return UInt8((Class.BER_CLASS_UNIV | berPC(pc)) | (Tag.BER_TAG_MASK & tag)) + def readApplicationTag(s, tag): """ @summary: Read application tag @@ -136,6 +145,7 @@ def readApplicationTag(s, tag): return readLength(s) + def writeApplicationTag(tag, size): """ @summary: Return structure that represent BER application tag diff --git a/rdpy/core/t125/gcc.py b/rdpy/core/t125/gcc.py index 82508ee..58f1fbb 100644 --- a/rdpy/core/t125/gcc.py +++ b/rdpy/core/t125/gcc.py @@ -23,7 +23,7 @@ http://msdn.microsoft.com/en-us/library/cc240508.aspx """ from hashlib import md5 -from rdpy.model.type import UInt8, UInt16Le, UInt32Le, CompositeType, CallableValue, Buffer, Stream, sizeof, FactoryType, ArrayType +from rdpy.model.message import UInt8, UInt16Le, UInt32Le, CompositeType, Buffer, Stream, sizeof, FactoryType, ArrayType from rdpy.core.t125 import per, mcs from rdpy.model.error import InvalidExpectedDataException from rdpy.model import log @@ -32,8 +32,8 @@ import rdpy.security.rsa_wrapper as rsa t124_02_98_oid = ( 0, 0, 20, 124, 0, 1 ) -h221_cs_key = "Duca"; -h221_sc_key = "McDn"; +h221_cs_key = b"Duca"; +h221_sc_key = b"McDn"; class MessageType(object): """ @@ -208,43 +208,47 @@ class CertificateType(object): """ CERT_CHAIN_VERSION_1 = 0x00000001 CERT_CHAIN_VERSION_2 = 0x00000002 - + + class DataBlock(CompositeType): """ @summary: Block settings """ - def __init__(self, dataBlock = None): - CompositeType.__init__(self) - self.type = UInt16Le(lambda:self.dataBlock.__class__._TYPE_) - self.length = UInt16Le(lambda:sizeof(self)) + def __init__(self, data_block=None): + super().__init__() + self.type = UInt16Le(lambda:data_block._TYPE_) + self.length = UInt16Le(lambda: sizeof(self)) - def DataBlockFactory(): + def factory(): """ - @summary: build settings in accordance of type self.type.value """ - for c in [ClientCoreData, ClientSecurityData, ClientNetworkData, ServerCoreData, ServerNetworkData, ServerSecurityData]: + gcc_type = [ + ClientCoreData, ClientSecurityData, ClientNetworkData, + ServerCoreData, ServerNetworkData, ServerSecurityData + ] + + for c in gcc_type: if self.type.value == c._TYPE_: - return c(readLen = self.length - 4) + return c(read_len=lambda: (self.length.value - 4)) log.debug("unknown GCC block type : %s"%hex(self.type.value)) - #read entire packet - return Buffer(readLen =self.length - 4) + # read entire packet + return Buffer(read_len=lambda: (self.length.value - 4)) - if dataBlock is None: - dataBlock = FactoryType(DataBlockFactory) - elif not "_TYPE_" in dataBlock.__class__.__dict__: + if data_block is None: + data_block = FactoryType(factory) + elif "_TYPE_" not in data_block.__class__.__dict__: raise InvalidExpectedDataException("Try to send an invalid GCC blocks") - self.dataBlock = dataBlock + self.dataBlock = data_block + class ClientCoreData(CompositeType): """ - @summary: Class that represent core setting of client - @see: http://msdn.microsoft.com/en-us/library/cc240510.aspx """ _TYPE_ = MessageType.CS_CORE - def __init__(self, readLen = None): - CompositeType.__init__(self, readLen = readLen) + def __init__(self, read_len=None): + super().__init__(read_len=read_len) self.rdpVersion = UInt32Le(Version.RDP_VERSION_5_PLUS) self.desktopWidth = UInt16Le(1280) self.desktopHeight = UInt16Le(800) @@ -252,62 +256,59 @@ class ClientCoreData(CompositeType): self.sasSequence = UInt16Le(Sequence.RNS_UD_SAS_DEL) self.kbdLayout = UInt32Le(KeyboardLayout.US) self.clientBuild = UInt32Le(3790) - self.clientName = Buffer("rdpy" + "\x00" * 11, readLen = CallableValue(32), unicode = True) + self.clientName = Buffer(b"rdpy" + b"\x00" * 11, read_len=lambda: 32) self.keyboardType = UInt32Le(KeyboardType.IBM_101_102_KEYS) self.keyboardSubType = UInt32Le(0) self.keyboardFnKeys = UInt32Le(12) - self.imeFileName = Buffer("\x00" * 64, readLen = CallableValue(64), optional = True) - self.postBeta2ColorDepth = UInt16Le(ColorDepth.RNS_UD_COLOR_8BPP, optional = True) - self.clientProductId = UInt16Le(1, optional = True) - self.serialNumber = UInt32Le(0, optional = True) - self.highColorDepth = UInt16Le(HighColor.HIGH_COLOR_24BPP, optional = True) - self.supportedColorDepths = UInt16Le(Support.RNS_UD_15BPP_SUPPORT | Support.RNS_UD_16BPP_SUPPORT | Support.RNS_UD_24BPP_SUPPORT | Support.RNS_UD_32BPP_SUPPORT, optional = True) - self.earlyCapabilityFlags = UInt16Le(CapabilityFlags.RNS_UD_CS_SUPPORT_ERRINFO_PDU, optional = True) - self.clientDigProductId = Buffer("\x00" * 64, readLen = CallableValue(64), optional = True) - self.connectionType = UInt8(optional = True) - self.pad1octet = UInt8(optional = True) - self.serverSelectedProtocol = UInt32Le(optional = True) - + self.imeFileName = Buffer(b"\x00" * 64, read_len=lambda: 64, optional=True) + self.postBeta2ColorDepth = UInt16Le(ColorDepth.RNS_UD_COLOR_8BPP, optional=True) + self.clientProductId = UInt16Le(1, optional=True) + self.serialNumber = UInt32Le(0, optional=True) + self.highColorDepth = UInt16Le(HighColor.HIGH_COLOR_24BPP, optional=True) + self.supportedColorDepths = UInt16Le(Support.RNS_UD_15BPP_SUPPORT | Support.RNS_UD_16BPP_SUPPORT | Support.RNS_UD_24BPP_SUPPORT | Support.RNS_UD_32BPP_SUPPORT, optional=True) + self.earlyCapabilityFlags = UInt16Le(CapabilityFlags.RNS_UD_CS_SUPPORT_ERRINFO_PDU, optional=True) + self.clientDigProductId = Buffer(b"\x00" * 64, read_len=lambda: 64, optional=True) + self.connectionType = UInt8(optional=True) + self.pad1octet = UInt8(optional=True) + self.serverSelectedProtocol = UInt32Le(optional=True) + + class ServerCoreData(CompositeType): """ - @summary: Server side core settings structure - @see: http://msdn.microsoft.com/en-us/library/cc240517.aspx """ _TYPE_ = MessageType.SC_CORE - def __init__(self, readLen = None): - CompositeType.__init__(self, readLen = readLen) + def __init__(self, read_len=None): + super().__init__(read_len=read_len) self.rdpVersion = UInt32Le(Version.RDP_VERSION_5_PLUS) - self.clientRequestedProtocol = UInt32Le(optional = True) - self.earlyCapabilityFlags = UInt32Le(optional = True) - + self.clientRequestedProtocol = UInt32Le(optional=True) + self.earlyCapabilityFlags = UInt32Le(optional=True) + + class ClientSecurityData(CompositeType): """ - @summary: Client security setting - @see: http://msdn.microsoft.com/en-us/library/cc240511.aspx """ _TYPE_ = MessageType.CS_SECURITY - def __init__(self, readLen = None): - CompositeType.__init__(self, readLen = readLen) + def __init__(self, read_len=None): + super().__init__(read_len=read_len) self.encryptionMethods = UInt32Le(EncryptionMethod.ENCRYPTION_FLAG_40BIT | EncryptionMethod.ENCRYPTION_FLAG_56BIT | EncryptionMethod.ENCRYPTION_FLAG_128BIT) self.extEncryptionMethods = UInt32Le() - + + class ServerSecurityData(CompositeType): """ - @summary: Server security settings - @see: http://msdn.microsoft.com/en-us/library/cc240518.aspx """ _TYPE_ = MessageType.SC_SECURITY - def __init__(self, readLen = None): - CompositeType.__init__(self, readLen = readLen) + def __init__(self, read_len=None): + super().__init__(read_len=read_len) self.encryptionMethod = UInt32Le() self.encryptionLevel = UInt32Le() - self.serverRandomLen = UInt32Le(0x00000020, constant = True, conditional = lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel == 0)) - self.serverCertLen = UInt32Le(lambda:sizeof(self.serverCertificate), conditional = lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel == 0)) - self.serverRandom = Buffer(readLen = self.serverRandomLen, conditional = lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel == 0)) - self.serverCertificate = ServerCertificate(readLen = self.serverCertLen, conditional = lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel == 0)) + self.serverRandomLen = UInt32Le(0x00000020, constant=True, conditional=lambda: not(self.encryptionMethod.value == 0 and self.encryptionLevel.value == 0)) + self.serverCertLen = UInt32Le(lambda: sizeof(self.serverCertificate), conditional=lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel.value == 0)) + self.serverRandom = Buffer(read_len=lambda: self.serverRandomLen.value, conditional=lambda: not(self.encryptionMethod.value == 0 and self.encryptionLevel.value == 0)) + self.serverCertificate = ServerCertificate(readLen=lambda: self.serverCertLen.value, conditional=lambda: not(self.encryptionMethod.value == 0 and self.encryptionLevel.value == 0)) class ServerCertificate(CompositeType): """ @@ -448,79 +449,62 @@ class RSAPublicKey(CompositeType): self.datalen = UInt32Le(lambda:((self.bitlen.value / 8) - 1)) self.pubExp = UInt32Le() self.modulus = Buffer(readLen = CallableValue(lambda:(self.keylen.value - 8))) - self.padding = Buffer("\x00" * 8, readLen = CallableValue(8)) + self.padding = Buffer(b"\x00" * 8, readLen = CallableValue(8)) + class ChannelDef(CompositeType): """ - Channels structure share between client and server - @see: http://msdn.microsoft.com/en-us/library/cc240513.aspx """ - def __init__(self, name = "", options = 0): - CompositeType.__init__(self) - #name of channel - self.name = Buffer(name[0:8] + "\x00" * (8 - len(name)), readLen = CallableValue(8)) - #unknown + def __init__(self, name=b""): + super().__init__() + # name of channel + self.name = Buffer(name[0:8] + b"\x00" * (8 - len(name)), read_len=lambda: 8) + # unknown self.options = UInt32Le() - + + class ClientNetworkData(CompositeType): """ - @summary: GCC client network block - All channels asked by client are listed here - @see: http://msdn.microsoft.com/en-us/library/cc240512.aspx """ _TYPE_ = MessageType.CS_NET - def __init__(self, readLen = None): - CompositeType.__init__(self, readLen = readLen) - self.channelCount = UInt32Le(lambda:len(self.channelDefArray._array)) - self.channelDefArray = ArrayType(ChannelDef, readLen = self.channelCount) - + def __init__(self, read_len=None): + CompositeType.__init__(self, read_len=read_len) + self.channelCount = UInt32Le(lambda: len(self.channelDefArray)) + self.channelDefArray = ArrayType(ChannelDef, read_len=lambda: self.channelCount.value) + + class ServerNetworkData(CompositeType): """ - @summary: GCC server network block - All channels asked by client are listed here - @see: All channels asked by client are listed here """ _TYPE_ = MessageType.SC_NET - def __init__(self, readLen = None): - CompositeType.__init__(self, readLen = readLen) + def __init__(self, read_len=None): + super().__init__(read_len=read_len) self.MCSChannelId = UInt16Le(mcs.Channel.MCS_GLOBAL_CHANNEL) - self.channelCount = UInt16Le(lambda:len(self.channelIdArray._array)) - self.channelIdArray = ArrayType(UInt16Le, readLen = self.channelCount) - self.pad = UInt16Le(conditional = lambda:((self.channelCount.value % 2) == 1)) - + self.channelCount = UInt16Le(lambda: len(self.channelIdArray)) + self.channelIdArray = ArrayType(UInt16Le, read_len=lambda: self.channelCount.value) + self.pad = UInt16Le(conditional=lambda: ((self.channelCount.value % 2) == 1)) + + class Settings(CompositeType): """ - @summary: Class which group all clients settings supported by RDPY """ - def __init__(self, init = [], readLen = None): - CompositeType.__init__(self, readLen = readLen) + def __init__(self, init=None, read_len=None): + super().__init__(read_len=read_len) self.settings = ArrayType(DataBlock, [DataBlock(i) for i in init]) - def getBlock(self, messageType): + def get_block(self, message_type): """ - @param messageType: type of block - @return: specific block of type messageType """ for i in self.settings._array: - if i.type.value == messageType: + if i.type.value == message_type: return i.dataBlock return None - - def __getattr__(self, name): - """ - @summary: Magic function for better access - @return: _value parameter - """ - if not name in MessageType.__dict__: - return None - return self.getBlock(MessageType.__dict__[name]) - -def clientSettings(): + + +def client_settings(): """ - @summary: Build settings for client - @return: Settings """ return Settings([ClientCoreData(), ClientNetworkData(), ClientSecurityData()]) @@ -578,9 +562,9 @@ def readConferenceCreateResponse(s): raise InvalidExpectedDataException("cannot read h221_sc_key") length = per.readLength(s) - serverSettings = Settings(readLen = CallableValue(length)) - s.read_type(serverSettings) - return serverSettings + server_settings = Settings(read_len=lambda: length) + s.read_type(server_settings) + return server_settings def writeConferenceCreateRequest(userData): """ @@ -593,7 +577,7 @@ def writeConferenceCreateRequest(userData): return (per.writeChoice(0), per.writeObjectIdentifier(t124_02_98_oid), per.writeLength(len(userDataStream.getvalue()) + 14), per.writeChoice(0), - per.writeSelection(0x08), per.writeNumericString("1", 1), per.writePadding(1), + per.writeSelection(0x08), per.writeNumericString(b"1", 1), per.writePadding(1), per.writeNumberOfSet(1), per.writeChoice(0xc0), per.writeOctetStream(h221_cs_key, 4), per.writeOctetStream(userDataStream.getvalue())) diff --git a/rdpy/core/t125/mcs.py b/rdpy/core/t125/mcs.py index 61e33dd..5879089 100644 --- a/rdpy/core/t125/mcs.py +++ b/rdpy/core/t125/mcs.py @@ -24,6 +24,9 @@ Each channel have a particular role. The main channel is the graphical channel. It exist channel for file system order, audio channel, clipboard etc... """ +from typing import Tuple + +from rdpy.core import x224 from rdpy.model.layer import LayerAutomata, IStreamSender, Layer from rdpy.model.message import sizeof, Stream, UInt8, UInt16Le, Buffer from rdpy.model.error import InvalidExpectedDataException, InvalidValue, InvalidSize, CallPureVirtualFuntion @@ -33,16 +36,16 @@ import rdpy.model.log as log from rdpy.core.t125 import ber, gcc, per import rdpy.security.rsa_wrapper as rsa -class Message(object): + +class Message: """ - @summary: Message type """ MCS_TYPE_CONNECT_INITIAL = 0x65 MCS_TYPE_CONNECT_RESPONSE = 0x66 + class DomainMCSPDU: """ - @summary: Domain MCS PDU header """ ERECT_DOMAIN_REQUEST = 1 DISCONNECT_PROVIDER_ULTIMATUM = 8 @@ -53,44 +56,114 @@ class DomainMCSPDU: SEND_DATA_REQUEST = 25 SEND_DATA_INDICATION = 26 + class Channel: """ - @summary: Channel id of main channels use in RDP """ MCS_GLOBAL_CHANNEL = 1003 MCS_USERCHANNEL_BASE = 1001 - + + class IGCCConfig(object): """ @summary: Channel information """ + def getUserId(self): """ @return: {integer} mcs user id @see: mcs.IGCCConfig """ - raise CallPureVirtualFuntion("%s:%s defined by interface %s"%(self.__class__, "getUserId", "IGCCConfig")) - + raise CallPureVirtualFuntion("%s:%s defined by interface %s" % (self.__class__, "getUserId", "IGCCConfig")) + def getChannelId(self): """ @return: {integer} return channel id of proxy @see: mcs.IGCCConfig """ - raise CallPureVirtualFuntion("%s:%s defined by interface %s"%(self.__class__, "getChannelId", "IGCCConfig")) - + raise CallPureVirtualFuntion("%s:%s defined by interface %s" % (self.__class__, "getChannelId", "IGCCConfig")) + def getGCCClientSettings(self): """ @return: {gcc.Settings} mcs layer gcc client settings @see: mcs.IGCCConfig """ - raise CallPureVirtualFuntion("%s:%s defined by interface %s"%(self.__class__, "getGCCClientSettings", "IGCCConfig")) - + raise CallPureVirtualFuntion( + "%s:%s defined by interface %s" % (self.__class__, "getGCCClientSettings", "IGCCConfig")) + def getGCCServerSettings(self): """ @return: {gcc.Settings} mcs layer gcc server settings @see: mcs.IGCCConfig """ - raise CallPureVirtualFuntion("%s:%s defined by interface %s"%(self.__class__, "getGCCServerSettings", "IGCCConfig")) + raise CallPureVirtualFuntion( + "%s:%s defined by interface %s" % (self.__class__, "getGCCServerSettings", "IGCCConfig")) + + +def write_domain_params(max_channels: int, max_users: int, max_tokens: int, max_pdu_size: int) -> tuple: + """ + """ + domain_param = (ber.writeInteger(max_channels), ber.writeInteger(max_users), ber.writeInteger(max_tokens), + ber.writeInteger(1), ber.writeInteger(0), ber.writeInteger(1), + ber.writeInteger(max_pdu_size), ber.writeInteger(2)) + return ber.writeUniversalTag(ber.Tag.BER_TAG_SEQUENCE, True), writeLength(sizeof(domain_param)), domain_param + + +def read_domain_params(stream: Stream) -> Tuple[int, int, int, int]: + """ + """ + if not ber.readUniversalTag(stream, ber.Tag.BER_TAG_SEQUENCE, True): + raise InvalidValue("bad BER tags") + ber.readLength(stream) # length + max_channels = ber.readInteger(stream) + max_users = ber.readInteger(stream) + max_tokens = ber.readInteger(stream) + ber.readInteger(stream) + ber.readInteger(stream) + ber.readInteger(stream) + max_pdu_size = ber.readInteger(stream) + ber.readInteger(stream) + return max_channels, max_users, max_tokens, max_pdu_size + + +class Client: + def __init__(self, x224_layer: x224.X224): + self.x224 = x224_layer + + async def write_connect_initial(self): + """ + """ + settings = gcc.client_settings() + settings.get_block(gcc.MessageType.CS_CORE).serverSelectedProtocol.value = self.x224.get_selected_protocol() + + cc_req = gcc.writeConferenceCreateRequest(settings) + cc_req_stream = Stream() + cc_req_stream.write_type(cc_req) + + tmp = (ber.writeOctetstring(b"\x01"), ber.writeOctetstring(b"\x01"), ber.writeBoolean(True), + write_domain_params(34, 2, 0, 0xffff), + write_domain_params(1, 1, 1, 0x420), + write_domain_params(0xffff, 0xfc17, 0xffff, 0xffff), + ber.writeOctetstring(cc_req_stream.getvalue())) + await self.x224.write((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_INITIAL, sizeof(tmp)), tmp)) + + async def read_connect_response(self): + payload = await self.x224.read() + ber.readApplicationTag(payload, UInt8(Message.MCS_TYPE_CONNECT_RESPONSE)) + ber.readEnumerated(payload) + ber.readInteger(payload) + read_domain_params(payload) + if not ber.readUniversalTag(payload, ber.Tag.BER_TAG_OCTET_STRING, False): + raise InvalidExpectedDataException("invalid expected BER tag") + gccRequestLength = ber.readLength(payload) + if payload.data_len() != gccRequestLength: + raise InvalidSize("bad size of GCC request") + gcc.readConferenceCreateResponse(payload) + + async def connect(self): + await self.write_connect_initial() + await self.read_connect_response() + class MCSLayer(LayerAutomata): """ @@ -98,11 +171,13 @@ class MCSLayer(LayerAutomata): the main layer of RDP protocol is why he can do everything and more! """ + class MCSProxySender(Layer, IStreamSender, IGCCConfig): """ @summary: Proxy use to set as transport layer for upper channel use to abstract channel id for presentation layer """ + def __init__(self, presentation, mcs, channelId): """ @param presentation: {Layer} presentation layer @@ -112,7 +187,7 @@ class MCSLayer(LayerAutomata): Layer.__init__(self, presentation) self._mcs = mcs self._channelId = channelId - + def send(self, data): """ @summary: A send proxy function, use channel id and specific @@ -120,43 +195,42 @@ class MCSLayer(LayerAutomata): @param data: {type.Type | Tuple} """ self._mcs.send(self._channelId, data) - + def close(self): """ @summary: Close wrapped layer """ self._mcs.close() - + def getUserId(self): """ @return: {integer} mcs user id @see: mcs.IGCCConfig """ return self._mcs._userId - + def getChannelId(self): """ @return: {integer} return channel id of proxy @see: mcs.IGCCConfig """ return self._channelId - + def getGCCClientSettings(self): """ @return: {gcc.Settings} mcs layer gcc client settings @see: mcs.IGCCConfig """ return self._mcs._clientSettings - + def getGCCServerSettings(self): """ @return: {gcc.Settings} mcs layer gcc server settings @see: mcs.IGCCConfig """ return self._mcs._serverSettings - - - def __init__(self, presentation, receiveOpcode, sendOpcode, virtualChannels = []): + + def __init__(self, presentation, receiveOpcode, sendOpcode, virtualChannels=[]): """ @param presentation: {Layer} presentation layer @param virtualChannels: {Array(Layer]} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)] @@ -166,50 +240,50 @@ class MCSLayer(LayerAutomata): LayerAutomata.__init__(self, presentation) self._clientSettings = gcc.clientSettings() self._serverSettings = gcc.serverSettings() - #default user Id + # default user Id self._userId = 1 + Channel.MCS_USERCHANNEL_BASE - #list of channel use in this layer and connection state + # list of channel use in this layer and connection state self._channels = {Channel.MCS_GLOBAL_CHANNEL: presentation} - #virtual channels + # virtual channels self._virtualChannels = virtualChannels - #send opcode + # send opcode self._sendOpcode = sendOpcode - #receive opcode + # receive opcode self._receiveOpcode = receiveOpcode - + def close(self): """ @summary: Send disconnect provider ultimatum """ self._transport.send((UInt8(self.writeMCSPDUHeader(DomainMCSPDU.DISCONNECT_PROVIDER_ULTIMATUM, 1)), - per.writeEnumerates(0x80), Buffer("\x00" * 6))) + per.writeEnumerates(0x80), Buffer(b"\x00" * 6))) self._transport.close() - + def allChannelConnected(self): """ @summary: All channels are connected to MCS layer Send connect to upper channel And prepare MCS layer to receive data """ - #connection is done + # connection is done self.setNextState(self.recvData) - #try connection on all requested channel + # try connection on all requested channel for (channelId, layer) in self._channels.iteritems(): - #use proxy for each channel + # use proxy for each channel MCSLayer.MCSProxySender(layer, self, channelId).connect() - + def send(self, channelId, data): """ @summary: Specific send function for channelId @param channelId: {integer} Channel use to send @param data: {type.type | tuple} message to send """ - self._transport.send((self.writeMCSPDUHeader(UInt8(self._sendOpcode)), - per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), - per.writeInteger16(channelId), - UInt8(0x70), + self._transport.send((self.writeMCSPDUHeader(UInt8(self._sendOpcode)), + per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), + per.writeInteger16(channelId), + UInt8(0x70), per.writeLength(sizeof(data)), data)) - + def recvData(self, data): """ @summary: Main receive method @@ -217,31 +291,31 @@ class MCSLayer(LayerAutomata): """ opcode = UInt8() data.read_type(opcode) - + if self.readMCSPDUHeader(opcode.value, DomainMCSPDU.DISCONNECT_PROVIDER_ULTIMATUM): log.info("MCS DISCONNECT_PROVIDER_ULTIMATUM") self._transport.close() return - - #client case + + # client case elif not self.readMCSPDUHeader(opcode.value, self._receiveOpcode): raise InvalidExpectedDataException("Invalid expected MCS opcode receive data") - - #server user id + + # server user id per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE) - + channelId = per.readInteger16(data) - - per.readEnumerates(data) + + per.readEnumerates(data) per.readLength(data) - - #channel id doesn't match a requested layer + + # channel id doesn't match a requested layer if not self._channels.has_key(channelId): log.error("receive data for an unconnected layer") return - self._channels[channelId].recv(data) - + self._channels[channelId].recv(data) + def writeDomainParams(self, maxChannels, maxUsers, maxTokens, maxPduSize): """ @summary: Write a special domain parameter structure @@ -256,8 +330,8 @@ class MCSLayer(LayerAutomata): ber.writeInteger(1), ber.writeInteger(0), ber.writeInteger(1), ber.writeInteger(maxPduSize), ber.writeInteger(2)) return (ber.writeUniversalTag(ber.Tag.BER_TAG_SEQUENCE, True), writeLength(sizeof(domainParam)), domainParam) - - def writeMCSPDUHeader(self, mcsPdu, options = 0): + + def writeMCSPDUHeader(self, mcsPdu, options=0): """ @summary: Write MCS PDU header @param mcsPdu: {integer} PDU code @@ -265,7 +339,7 @@ class MCSLayer(LayerAutomata): @return: {integer} """ return (mcsPdu << 2) | options - + def readMCSPDUHeader(self, opcode, mcsPdu): """ @summary: Read mcsPdu header and return options parameter @@ -274,7 +348,7 @@ class MCSLayer(LayerAutomata): @return: {boolean} true if opcode is correct """ return (opcode >> 2) == mcsPdu - + def readDomainParams(self, s): """ @summary: Read domain parameters structure @@ -283,7 +357,7 @@ class MCSLayer(LayerAutomata): """ if not ber.readUniversalTag(s, ber.Tag.BER_TAG_SEQUENCE, True): raise InvalidValue("bad BER tags") - ber.readLength(s)#length + ber.readLength(s) # length max_channels = ber.readInteger(s) max_users = ber.readInteger(s) max_tokens = ber.readInteger(s) @@ -293,23 +367,26 @@ class MCSLayer(LayerAutomata): max_pdu_size = ber.readInteger(s) ber.readInteger(s) return (max_channels, max_users, max_tokens, max_pdu_size) - -class Client(MCSLayer): + + +class ClientOld(MCSLayer): """ @summary: Client automata of multiple channel service layer """ - def __init__(self, presentation, virtualChannels = []): + + def __init__(self, presentation, virtualChannels=[]): """ @param presentation: {Layer} presentation layer @param virtualChannels: {Array(Layer)} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)] """ - MCSLayer.__init__(self, presentation, DomainMCSPDU.SEND_DATA_INDICATION, DomainMCSPDU.SEND_DATA_REQUEST, virtualChannels) - #use to know state of static channel + MCSLayer.__init__(self, presentation, DomainMCSPDU.SEND_DATA_INDICATION, DomainMCSPDU.SEND_DATA_REQUEST, + virtualChannels) + # use to know state of static channel self._isGlobalChannelRequested = False self._isUserChannelRequested = False - #nb channel requested + # nb channel requested self._nbChannelRequested = 0 - + def connect(self): """ @summary: Connect message in client automata case @@ -317,13 +394,13 @@ class Client(MCSLayer): Wait ConnectResponse """ self._clientSettings.CS_CORE.serverSelectedProtocol.value = self._transport._selectedProtocol - #ask for virtual channel + # ask for virtual channel self._clientSettings.CS_NET.channelDefArray._array = [x for (x, _) in self._virtualChannels] - #send connect initial + # send connect initial self.sendConnectInitial() - #next wait response + # next wait response self.setNextState(self.recvConnectResponse) - + def connectNextChannel(self): """ @summary: Send sendChannelJoinRequest message on next disconnect channel @@ -331,27 +408,27 @@ class Client(MCSLayer): Wait channel confirm """ self.setNextState(self.recvChannelJoinConfirm) - #global channel + # global channel if not self._isGlobalChannelRequested: self.sendChannelJoinRequest(Channel.MCS_GLOBAL_CHANNEL) self._isGlobalChannelRequested = True return - - #user channel + + # user channel if not self._isUserChannelRequested: self.sendChannelJoinRequest(self._userId) self._isUserChannelRequested = True return - - #static virtual channel - if self._nbChannelRequested < self._serverSettings.getBlock(gcc.MessageType.SC_NET).channelCount.value: - channelId = self._serverSettings.getBlock(gcc.MessageType.SC_NET).channelIdArray[self._nbChannelRequested] + + # static virtual channel + if self._nbChannelRequested < self._serverSettings.get_block(gcc.MessageType.SC_NET).channelCount.value: + channelId = self._serverSettings.get_block(gcc.MessageType.SC_NET).channelIdArray[self._nbChannelRequested] self._nbChannelRequested += 1 self.sendChannelJoinRequest(channelId) return - + self.allChannelConnected() - + def recvConnectResponse(self, data): """ @summary: Receive MCS connect response from server @@ -370,14 +447,14 @@ class Client(MCSLayer): if data.data_len() != gccRequestLength: raise InvalidSize("bad size of GCC request") self._serverSettings = gcc.readConferenceCreateResponse(data) - - #send domain request + + # send domain request self.sendErectDomainRequest() - #send attach user request + # send attach user request self.sendAttachUserRequest() - #now wait user confirm from server + # now wait user confirm from server self.setNextState(self.recvAttachUserConfirm) - + def recvAttachUserConfirm(self, data): """ @summary: Receive an attach user confirm @@ -386,17 +463,17 @@ class Client(MCSLayer): """ opcode = UInt8() data.read_type(opcode) - + if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ATTACH_USER_CONFIRM): raise InvalidExpectedDataException("Invalid MCS PDU : ATTACH_USER_CONFIRM expected") - + if per.readEnumerates(data) != 0: raise InvalidExpectedDataException("Server reject user") - + self._userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE) - + self.connectNextChannel() - + def recvChannelJoinConfirm(self, data): """ @summary: Receive a channel join confirm from server @@ -405,29 +482,29 @@ class Client(MCSLayer): """ opcode = UInt8() data.read_type(opcode) - + if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.CHANNEL_JOIN_CONFIRM): raise InvalidExpectedDataException("Invalid MCS PDU : CHANNEL_JOIN_CONFIRM expected") - + confirm = per.readEnumerates(data) - + userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE) if self._userId != userId: raise InvalidExpectedDataException("Invalid MCS User Id") - + channelId = per.readInteger16(data) - #must confirm global channel and user channel + # must confirm global channel and user channel if (confirm != 0) and (channelId == Channel.MCS_GLOBAL_CHANNEL or channelId == self._userId): raise InvalidExpectedDataException("Server must confirm static channel") - + if confirm == 0: - serverNet = self._serverSettings.getBlock(gcc.MessageType.SC_NET) + serverNet = self._serverSettings.get_block(gcc.MessageType.SC_NET) for i in range(0, serverNet.channelCount.value): if channelId == serverNet.channelIdArray[i].value: - self._channels[channelId] = self._virtualChannels[i][1] - + self._channels[channelId] = self._virtualChannels[i][1] + self.connectNextChannel() - + def sendConnectInitial(self): """ @summary: Send connect initial packet @@ -436,67 +513,69 @@ class Client(MCSLayer): ccReq = gcc.writeConferenceCreateRequest(self._clientSettings) ccReqStream = Stream() ccReqStream.write_type(ccReq) - + tmp = (ber.writeOctetstring("\x01"), ber.writeOctetstring("\x01"), ber.writeBoolean(True), self.writeDomainParams(34, 2, 0, 0xffff), self.writeDomainParams(1, 1, 1, 0x420), self.writeDomainParams(0xffff, 0xfc17, 0xffff, 0xffff), ber.writeOctetstring(ccReqStream.getvalue())) self._transport.send((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_INITIAL, sizeof(tmp)), tmp)) - + def sendErectDomainRequest(self): """ @summary: Send a formated erect domain request for RDP connection """ - self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ERECT_DOMAIN_REQUEST)), - per.writeInteger(0), + self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ERECT_DOMAIN_REQUEST)), + per.writeInteger(0), per.writeInteger(0))) - + def sendAttachUserRequest(self): """ @summary: Send a formated attach user request for RDP connection """ self._transport.send(self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ATTACH_USER_REQUEST))) - + def sendChannelJoinRequest(self, channelId): """ @summary: Send a formated Channel join request from client to server client automata function @param channelId: {integer} id of channel requested """ - self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.CHANNEL_JOIN_REQUEST)), - per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), + self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.CHANNEL_JOIN_REQUEST)), + per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), per.writeInteger16(channelId))) - + + class Server(MCSLayer): """ @summary: Server automata of multiple channel service layer """ - def __init__(self, presentation, virtualChannels = []): + + def __init__(self, presentation, virtualChannels=[]): """ @param presentation: {Layer} presentation layer @param virtualChannels: {List(Layer)} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)] """ - MCSLayer.__init__(self, presentation, DomainMCSPDU.SEND_DATA_REQUEST, DomainMCSPDU.SEND_DATA_INDICATION, virtualChannels) - #nb channel requested + MCSLayer.__init__(self, presentation, DomainMCSPDU.SEND_DATA_REQUEST, DomainMCSPDU.SEND_DATA_INDICATION, + virtualChannels) + # nb channel requested self._nbChannelConfirmed = 0 - + def connect(self): """ @summary: Connect message for server automata Wait Connect Initial """ - #basic rdp security layer + # basic rdp security layer if self._transport._selectedProtocol == 0: - self._serverSettings.SC_SECURITY.encryptionMethod.value = gcc.EncryptionMethod.ENCRYPTION_FLAG_128BIT self._serverSettings.SC_SECURITY.encryptionLevel.value = gcc.EncryptionLevel.ENCRYPTION_LEVEL_HIGH self._serverSettings.SC_SECURITY.serverRandom.value = rsa.random(256) self._serverSettings.SC_SECURITY.serverCertificate = self._presentation.getCertificate() - + self._serverSettings.SC_CORE.clientRequestedProtocol.value = self._transport._requestedProtocol self.setNextState(self.recvConnectInitial) - + def recvConnectInitial(self, data): """ @summary: Receive MCS connect initial from client @@ -507,28 +586,28 @@ class Server(MCSLayer): ber.readApplicationTag(data, UInt8(Message.MCS_TYPE_CONNECT_INITIAL)) ber.readOctetString(data) ber.readOctetString(data) - + if not ber.readBoolean(data): raise InvalidExpectedDataException("invalid expected BER boolean tag") - + self.readDomainParams(data) self.readDomainParams(data) self.readDomainParams(data) self._clientSettings = gcc.readConferenceCreateRequest(Stream(ber.readOctetString(data))) - + if not self._clientSettings.CS_NET is None: i = 1 for channelDef in self._clientSettings.CS_NET.channelDefArray._array: self._serverSettings.SC_NET.channelIdArray._array.append(UInt16Le(i + Channel.MCS_GLOBAL_CHANNEL)) - #if channel can be handle by serve add it + # if channel can be handle by serve add it for serverChannelDef, layer in self._virtualChannels: if channelDef.name == serverChannelDef.name: self._channels[i + Channel.MCS_GLOBAL_CHANNEL] = layer i += 1 - + self.sendConnectResponse() self.setNextState(self.recvErectDomainRequest) - + def recvErectDomainRequest(self, data): """ @summary: Receive erect domain request @@ -537,15 +616,15 @@ class Server(MCSLayer): """ opcode = UInt8() data.read_type(opcode) - + if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ERECT_DOMAIN_REQUEST): raise InvalidExpectedDataException("Invalid MCS PDU : ERECT_DOMAIN_REQUEST expected") - + per.readInteger(data) per.readInteger(data) - + self.setNextState(self.recvAttachUserRequest) - + def recvAttachUserRequest(self, data): """ @summary: Receive Attach user request @@ -555,13 +634,13 @@ class Server(MCSLayer): """ opcode = UInt8() data.read_type(opcode) - + if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ATTACH_USER_REQUEST): raise InvalidExpectedDataException("Invalid MCS PDU : ATTACH_USER_REQUEST expected") - + self.sendAttachUserConfirm() self.setNextState(self.recvChannelJoinRequest) - + def recvChannelJoinRequest(self, data): """ @summary: Receive for each client channel a request @@ -571,22 +650,22 @@ class Server(MCSLayer): """ opcode = UInt8() data.read_type(opcode) - + if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.CHANNEL_JOIN_REQUEST): raise InvalidExpectedDataException("Invalid MCS PDU : CHANNEL_JOIN_REQUEST expected") - + userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE) if self._userId != userId: raise InvalidExpectedDataException("Invalid MCS User Id") - + channelId = per.readInteger16(data) - #actually algo support virtual channel but RDPY have no virtual channel + # actually algo support virtual channel but RDPY have no virtual channel confirm = 0 if channelId in self._channels.keys() or channelId == self._userId else 1 self.sendChannelJoinConfirm(channelId, confirm) self._nbChannelConfirmed += 1 - if self._nbChannelConfirmed == self._serverSettings.getBlock(gcc.MessageType.SC_NET).channelCount.value + 2: + if self._nbChannelConfirmed == self._serverSettings.get_block(gcc.MessageType.SC_NET).channelCount.value + 2: self.allChannelConnected() - + def sendConnectResponse(self): """ @summary: Send connect response @@ -594,27 +673,27 @@ class Server(MCSLayer): ccReq = gcc.writeConferenceCreateResponse(self._serverSettings) ccReqStream = Stream() ccReqStream.write_type(ccReq) - - tmp = (ber.writeEnumerated(0), ber.writeInteger(0), self.writeDomainParams(22, 3, 0, 0xfff8), + + tmp = (ber.writeEnumerated(0), ber.writeInteger(0), self.writeDomainParams(22, 3, 0, 0xfff8), ber.writeOctetstring(ccReqStream.getvalue())) self._transport.send((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_RESPONSE, sizeof(tmp)), tmp)) - + def sendAttachUserConfirm(self): """ @summary: Send attach user confirm """ - self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ATTACH_USER_CONFIRM), 2), - per.writeEnumerates(0), - per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE))) - + self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ATTACH_USER_CONFIRM), 2), + per.writeEnumerates(0), + per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE))) + def sendChannelJoinConfirm(self, channelId, confirm): """ @summary: Send a confirm channel (or not) to client @param channelId: {integer} id of channel @param confirm: {boolean} connection state """ - self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.CHANNEL_JOIN_CONFIRM), 2), - per.writeEnumerates(int(confirm)), - per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), - per.writeInteger16(channelId), - per.writeInteger16(channelId))) \ No newline at end of file + self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.CHANNEL_JOIN_CONFIRM), 2), + per.writeEnumerates(int(confirm)), + per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), + per.writeInteger16(channelId), + per.writeInteger16(channelId))) diff --git a/rdpy/core/t125/per.py b/rdpy/core/t125/per.py index f629a4f..9500928 100644 --- a/rdpy/core/t125/per.py +++ b/rdpy/core/t125/per.py @@ -223,12 +223,9 @@ def readNumericString(s, minValue): length = (length + minValue + 1) / 2 s.read(length) + def writeNumericString(nStr, minValue): """ - @summary: write string in per format - @param str: python string to write - @param min: min value - @return: String type that contain str encoded in per format """ length = len(nStr) mlength = minValue @@ -238,9 +235,9 @@ def writeNumericString(nStr, minValue): result = [] for i in range(0, length, 2): - c1 = ord(nStr[i]) + c1 = nStr[i] if i + 1 < length: - c2 = ord(nStr[i + 1]) + c2 = nStr[i + 1] else: c2 = 0x30 c1 = (c1 - 0x30) % 10 @@ -248,7 +245,7 @@ def writeNumericString(nStr, minValue): result.append(UInt8((c1 << 4) | c2)) - return (writeLength(mlength), tuple(result)) + return writeLength(mlength), tuple(result) def readPadding(s, length): """ @@ -264,7 +261,7 @@ def writePadding(length): @param length: length of padding @return: String with \x00 * length """ - return Buffer("\x00" * length) + return Buffer(b"\x00" * length) def readOctetStream(s, octetStream, minValue = 0): """ @@ -300,6 +297,6 @@ def writeOctetStream(oStr, minValue = 0): result = [] for i in range(0, length): - result.append(UInt8(ord(oStr[i]))) + result.append(UInt8(oStr[i])) - return (writeLength(mlength), tuple(result)) \ No newline at end of file + return writeLength(mlength), tuple(result) \ No newline at end of file diff --git a/rdpy/core/x224.py b/rdpy/core/x224.py index 54f5849..18f1ae4 100644 --- a/rdpy/core/x224.py +++ b/rdpy/core/x224.py @@ -134,20 +134,15 @@ class Negotiation(CompositeType): class X224: """ - @summary: x224 layer management - there is an connection automata """ - def __init__(self, tpkt: tpkt.Tpkt): + def __init__(self, tpkt: tpkt.Tpkt, selected_protocol: int): """ - @param tpkt: TPKT layer """ self.tpkt = tpkt + self.selected_protocol = selected_protocol async def read(self): """ - @summary: Read data header from packet - And pass to presentation layer - @param data: Stream """ header = X224DataHeader() payload = await self.tpkt.read() @@ -156,12 +151,12 @@ class X224: async def write(self, message): """ - @summary: Write message packet for TPDU layer - Add TPDU header - @param message: """ await self.tpkt.write((X224DataHeader(), message)) + def get_selected_protocol(self): + return self.selected_protocol + async def connect(tpkt: tpkt.Tpkt, authentication_protocol: sspi.IAuthenticationProtocol) -> X224: """ @@ -190,11 +185,11 @@ async def connect(tpkt: tpkt.Tpkt, authentication_protocol: sspi.IAuthentication raise InvalidExpectedDataException("RDPY doesn't support PROTOCOL_HYBRID_EX security Layer") if selected_protocol == Protocols.PROTOCOL_RDP: - return X224(tpkt) + return X224(tpkt, selected_protocol) elif selected_protocol == Protocols.PROTOCOL_SSL: - return X224(await tpkt.start_tls()) + return X224(await tpkt.start_tls(), selected_protocol) elif selected_protocol == Protocols.PROTOCOL_HYBRID: - return X224(await tpkt.start_nla(authentication_protocol)) + return X224(await tpkt.start_nla(authentication_protocol), selected_protocol) class Server(X224): diff --git a/rdpy/model/message.py b/rdpy/model/message.py index f94cd71..a551488 100644 --- a/rdpy/model/message.py +++ b/rdpy/model/message.py @@ -354,72 +354,51 @@ class SimpleType(DynMessage): class CompositeType(Message): """ - @summary: Type node in Type tree - Track type field declared in __init__ function - Ex: self.lengthOfPacket = UInt16Le() -> record lengthOfPacket as sub type of node """ - def __init__(self, read_len = None, conditional = lambda:True, optional = False, constant = False, ): + def __init__(self, read_len=None, conditional=lambda: True, optional=False, constant=False): """ - @param conditional : Callable object - Read and Write operation depend on return of this function - @param optional: If there is no enough byte in current stream - And optional is True, read type is ignored - @param constant: Check if object value doesn't change after read operation - @param readLen: Max length in bytes can be readed from stream - Use to check length information """ super().__init__(conditional=conditional, optional=optional, constant=constant) - # list of ordoned type - self._typeName = [] + # list of ordorred type + self._type_name = [] self._read_len = read_len def __setattr__(self, name, value): """ - @summary: Track Type field - For Type field record it in same order as declared - Keep other but bot handle in read or write function - @param name: name of new attribute - @param value: value of new attribute """ - if name[0] != '_' and (isinstance(value, Message) or isinstance(value, tuple)) and not name in self._typeName: - self._typeName.append(name) + if name[0] != '_' and (isinstance(value, Message) or isinstance(value, tuple)) and name not in self._type_name: + self._type_name.append(name) self.__dict__[name] = value def __read__(self, s): """ - @summary: Read composite type - Call read on each ordered sub-type - And check read length parameter - If an error occurred rollback type already read - @param s: Stream - @raise InvalidSize: if stream is greater than readLen parameter """ - readLen = 0 - for name in self._typeName: + read_len = 0 + for name in self._type_name: try: s.read_type(self.__dict__[name]) - readLen += sizeof(self.__dict__[name]) + read_len += sizeof(self.__dict__[name]) # read is ok but read out of bound - if not self._read_len is None and readLen > self._read_len.value: + if self._read_len is not None and read_len > self._read_len.value: # roll back - s.pos -= sizeof(self.__dict__[name]) + s.seek(-sizeof(self.__dict__[name]), 1) # and notify if not optional if not self.__dict__[name]._optional: raise InvalidSize("Impossible to read type %s : read length is too small"%(self.__class__)) except Exception as e: log.error("Error during read %s::%s"%(self.__class__, name)) - #roll back already read - for tmpName in self._typeName: - if tmpName == name: + # roll back already read + for tmp_name in self._type_name: + if tmp_name == name: break - s.seek(-sizeof(self.__dict__[tmpName]), 1) + s.seek(-sizeof(self.__dict__[tmp_name]), 1) raise e - if not self._read_len is None and readLen < self._read_len.value: - log.debug("Still have correct data in packet %s, read %s bytes as padding"%(self.__class__, self._read_len.value - readLen)) - s.read(self._read_len.value - readLen) + if self._read_len is not None and read_len < self._read_len.value: + log.debug("Still have correct data in packet %s, read %s bytes as padding"%(self.__class__, self._read_len.value - read_len)) + s.read(self._read_len.value - read_len) def __write__(self, s): """ @@ -427,7 +406,7 @@ class CompositeType(Message): Call write on each ordered sub type @param s: Stream """ - for name in self._typeName: + for name in self._type_name: try: s.write_type(self.__dict__[name]) except Exception as e: @@ -443,7 +422,7 @@ class CompositeType(Message): return self._read_len.value size = 0 - for name in self._typeName: + for name in self._type_name: size += sizeof(self.__dict__[name]) return size @@ -454,9 +433,9 @@ class CompositeType(Message): @param other: CompositeType @return: True if each sub-type are equals """ - if self._typeName != other._typeName: + if self._type_name != other._typeName: return False - for name in self._typeName: + for name in self._type_name: if self.__dict__[name] != other.__dict__[name]: return False return True @@ -840,35 +819,23 @@ class ArrayType(Message): """ @summary: Factory af n element """ - def __init__(self, typeFactory, init = None, readLen = None, conditional = lambda:True, optional = False, constant = False): + def __init__(self, type_factory, init=None, read_len=None, conditional=lambda:True, optional=False, constant=False): """ - @param typeFactory: class use to init new element on read - @param init: init array - @param readLen: number of element in sequence - @param conditional : Callable object - Read and Write operation depend on return of this function - @param optional: If there is no enough byte in current stream - And optional is True, read type is ignored - @param constant: Check if object value doesn't change after read operation """ - Message.__init__(self, conditional, optional, constant) - self._typeFactory = typeFactory - self._readLen = readLen - self._array = [] - if not init is None: - self._array = init + super().__init__(conditional, optional, constant) + self._type_factory = type_factory + self._read_len = read_len + self._array = init or [] def __read__(self, s): """ - @summary: Create readLen new object and read it - @param s: Stream """ self._array = [] i = 0 - #self._readLen is None means that array will be read until end of stream - while self._readLen is None or i < self._readLen.value: - element = self._typeFactory() - element._optional = self._readLen is None + # self._read_len is None means that array will be read until end of stream + while self._read_len is None or i < self._read_len(): + element = self._type_factory() + element._optional = self._read_len is None s.read_type(element) if not element._is_readed: break @@ -877,8 +844,6 @@ class ArrayType(Message): def __write__(self, s): """ - @summary: Just write array - @param s: Stream """ s.write_type(self._array) @@ -895,6 +860,9 @@ class ArrayType(Message): """ return sizeof(self._array) + def __len__(self): + return len(self._array) + class FactoryType(Message): """ @summary: Call a factory callback at read or write time