Update: Start MCS layer

This commit is contained in:
citronneur
2020-04-21 22:02:35 +02:00
parent 018c59fe42
commit 73b97d6929
8 changed files with 382 additions and 346 deletions

View File

@@ -26,6 +26,7 @@ import asyncio
from rdpy.core import tpkt, x224 from rdpy.core import tpkt, x224
from rdpy.core.nla import ntlm from rdpy.core.nla import ntlm
from rdpy.core.t125 import mcs
from rdpy.model.message import UInt8 from rdpy.model.message import UInt8
if __name__ == '__main__': if __name__ == '__main__':
@@ -37,7 +38,9 @@ if __name__ == '__main__':
'127.0.0.1', 33389) '127.0.0.1', 33389)
x224_layer = await x224.connect(tpkt.Tpkt(reader, writer), ntlm.NTLMv2("", "sylvain", "sylvain")) x224_layer = await x224.connect(tpkt.Tpkt(reader, writer), ntlm.NTLMv2("", "sylvain", "sylvain"))
await x224_layer.write(UInt8(8)) mcs_layer = mcs.Client(x224_layer)
await mcs_layer.connect()
await asyncio.sleep(10) await asyncio.sleep(10)
print("foooooooooooooooooooo") print("foooooooooooooooooooo")

View File

@@ -99,8 +99,8 @@ class RDPClientController(PDUClientListener):
@param height: height in pixel of screen @param height: height in pixel of screen
""" """
#set screen definition in MCS layer #set screen definition in MCS layer
self._mcsLayer._clientSettings.getBlock(gcc.MessageType.CS_CORE).desktopHeight.value = height self._mcsLayer._clientSettings.get_block(gcc.MessageType.CS_CORE).desktopHeight.value = height
self._mcsLayer._clientSettings.getBlock(gcc.MessageType.CS_CORE).desktopWidth.value = width self._mcsLayer._clientSettings.get_block(gcc.MessageType.CS_CORE).desktopWidth.value = width
def setUsername(self, username): def setUsername(self, username):
""" """

View File

@@ -22,22 +22,25 @@ Basic Encoding Rules use in RDP.
ASN.1 standard ASN.1 standard
""" """
from rdpy.model.type import UInt8, UInt16Be, UInt32Be, Buffer from rdpy.model.message import UInt8, UInt16Be, UInt32Be, Buffer
from rdpy.model.error import InvalidExpectedDataException, InvalidSize from rdpy.model.error import InvalidExpectedDataException, InvalidSize
class BerPc(object):
class BerPc:
BER_PC_MASK = 0x20 BER_PC_MASK = 0x20
BER_PRIMITIVE = 0x00 BER_PRIMITIVE = 0x00
BER_CONSTRUCT = 0x20 BER_CONSTRUCT = 0x20
class Class(object):
class Class:
BER_CLASS_MASK = 0xC0 BER_CLASS_MASK = 0xC0
BER_CLASS_UNIV = 0x00 BER_CLASS_UNIV = 0x00
BER_CLASS_APPL = 0x40 BER_CLASS_APPL = 0x40
BER_CLASS_CTXT = 0x80 BER_CLASS_CTXT = 0x80
BER_CLASS_PRIV = 0xC0 BER_CLASS_PRIV = 0xC0
class Tag(object):
class Tag:
BER_TAG_MASK = 0x1F BER_TAG_MASK = 0x1F
BER_TAG_BOOLEAN = 0x01 BER_TAG_BOOLEAN = 0x01
BER_TAG_INTEGER = 0x02 BER_TAG_INTEGER = 0x02
@@ -48,6 +51,7 @@ class Tag(object):
BER_TAG_SEQUENCE = 0x10 BER_TAG_SEQUENCE = 0x10
BER_TAG_SEQUENCE_OF = 0x10 BER_TAG_SEQUENCE_OF = 0x10
def berPC(pc): def berPC(pc):
""" """
@summary: Return BER_CONSTRUCT if true @summary: Return BER_CONSTRUCT if true
@@ -59,7 +63,8 @@ def berPC(pc):
return BerPc.BER_CONSTRUCT return BerPc.BER_CONSTRUCT
else: else:
return BerPc.BER_PRIMITIVE return BerPc.BER_PRIMITIVE
def readLength(s): def readLength(s):
""" """
@summary: Read length of BER structure @summary: Read length of BER structure
@@ -84,6 +89,7 @@ def readLength(s):
size = length size = length
return size.value return size.value
def writeLength(size): def writeLength(size):
""" """
@summary: Return structure length as expected in BER specification @summary: Return structure length as expected in BER specification
@@ -94,7 +100,8 @@ def writeLength(size):
return (UInt8(0x82), UInt16Be(size)) return (UInt8(0x82), UInt16Be(size))
else: else:
return UInt8(size) return UInt8(size)
def readUniversalTag(s, tag, pc): def readUniversalTag(s, tag, pc):
""" """
@summary: Read tag of BER packet @summary: Read tag of BER packet
@@ -106,6 +113,7 @@ def readUniversalTag(s, tag, pc):
s.read_type(byte) s.read_type(byte)
return byte.value == ((Class.BER_CLASS_UNIV | berPC(pc)) | (Tag.BER_TAG_MASK & tag)) return byte.value == ((Class.BER_CLASS_UNIV | berPC(pc)) | (Tag.BER_TAG_MASK & tag))
def writeUniversalTag(tag, pc): def writeUniversalTag(tag, pc):
""" """
@summary: Return universal tag byte @summary: Return universal tag byte
@@ -115,6 +123,7 @@ def writeUniversalTag(tag, pc):
""" """
return UInt8((Class.BER_CLASS_UNIV | berPC(pc)) | (Tag.BER_TAG_MASK & tag)) return UInt8((Class.BER_CLASS_UNIV | berPC(pc)) | (Tag.BER_TAG_MASK & tag))
def readApplicationTag(s, tag): def readApplicationTag(s, tag):
""" """
@summary: Read application tag @summary: Read application tag
@@ -136,6 +145,7 @@ def readApplicationTag(s, tag):
return readLength(s) return readLength(s)
def writeApplicationTag(tag, size): def writeApplicationTag(tag, size):
""" """
@summary: Return structure that represent BER application tag @summary: Return structure that represent BER application tag

View File

@@ -23,7 +23,7 @@ http://msdn.microsoft.com/en-us/library/cc240508.aspx
""" """
from hashlib import md5 from hashlib import md5
from rdpy.model.type import UInt8, UInt16Le, UInt32Le, CompositeType, CallableValue, Buffer, Stream, sizeof, FactoryType, ArrayType from rdpy.model.message import UInt8, UInt16Le, UInt32Le, CompositeType, Buffer, Stream, sizeof, FactoryType, ArrayType
from rdpy.core.t125 import per, mcs from rdpy.core.t125 import per, mcs
from rdpy.model.error import InvalidExpectedDataException from rdpy.model.error import InvalidExpectedDataException
from rdpy.model import log from rdpy.model import log
@@ -32,8 +32,8 @@ import rdpy.security.rsa_wrapper as rsa
t124_02_98_oid = ( 0, 0, 20, 124, 0, 1 ) t124_02_98_oid = ( 0, 0, 20, 124, 0, 1 )
h221_cs_key = "Duca"; h221_cs_key = b"Duca";
h221_sc_key = "McDn"; h221_sc_key = b"McDn";
class MessageType(object): class MessageType(object):
""" """
@@ -208,43 +208,47 @@ class CertificateType(object):
""" """
CERT_CHAIN_VERSION_1 = 0x00000001 CERT_CHAIN_VERSION_1 = 0x00000001
CERT_CHAIN_VERSION_2 = 0x00000002 CERT_CHAIN_VERSION_2 = 0x00000002
class DataBlock(CompositeType): class DataBlock(CompositeType):
""" """
@summary: Block settings @summary: Block settings
""" """
def __init__(self, dataBlock = None): def __init__(self, data_block=None):
CompositeType.__init__(self) super().__init__()
self.type = UInt16Le(lambda:self.dataBlock.__class__._TYPE_) self.type = UInt16Le(lambda:data_block._TYPE_)
self.length = UInt16Le(lambda:sizeof(self)) self.length = UInt16Le(lambda: sizeof(self))
def DataBlockFactory(): def factory():
""" """
@summary: build settings in accordance of type self.type.value
""" """
for c in [ClientCoreData, ClientSecurityData, ClientNetworkData, ServerCoreData, ServerNetworkData, ServerSecurityData]: gcc_type = [
ClientCoreData, ClientSecurityData, ClientNetworkData,
ServerCoreData, ServerNetworkData, ServerSecurityData
]
for c in gcc_type:
if self.type.value == c._TYPE_: if self.type.value == c._TYPE_:
return c(readLen = self.length - 4) return c(read_len=lambda: (self.length.value - 4))
log.debug("unknown GCC block type : %s"%hex(self.type.value)) log.debug("unknown GCC block type : %s"%hex(self.type.value))
#read entire packet # read entire packet
return Buffer(readLen =self.length - 4) return Buffer(read_len=lambda: (self.length.value - 4))
if dataBlock is None: if data_block is None:
dataBlock = FactoryType(DataBlockFactory) data_block = FactoryType(factory)
elif not "_TYPE_" in dataBlock.__class__.__dict__: elif "_TYPE_" not in data_block.__class__.__dict__:
raise InvalidExpectedDataException("Try to send an invalid GCC blocks") raise InvalidExpectedDataException("Try to send an invalid GCC blocks")
self.dataBlock = dataBlock self.dataBlock = data_block
class ClientCoreData(CompositeType): class ClientCoreData(CompositeType):
""" """
@summary: Class that represent core setting of client
@see: http://msdn.microsoft.com/en-us/library/cc240510.aspx
""" """
_TYPE_ = MessageType.CS_CORE _TYPE_ = MessageType.CS_CORE
def __init__(self, readLen = None): def __init__(self, read_len=None):
CompositeType.__init__(self, readLen = readLen) super().__init__(read_len=read_len)
self.rdpVersion = UInt32Le(Version.RDP_VERSION_5_PLUS) self.rdpVersion = UInt32Le(Version.RDP_VERSION_5_PLUS)
self.desktopWidth = UInt16Le(1280) self.desktopWidth = UInt16Le(1280)
self.desktopHeight = UInt16Le(800) self.desktopHeight = UInt16Le(800)
@@ -252,62 +256,59 @@ class ClientCoreData(CompositeType):
self.sasSequence = UInt16Le(Sequence.RNS_UD_SAS_DEL) self.sasSequence = UInt16Le(Sequence.RNS_UD_SAS_DEL)
self.kbdLayout = UInt32Le(KeyboardLayout.US) self.kbdLayout = UInt32Le(KeyboardLayout.US)
self.clientBuild = UInt32Le(3790) self.clientBuild = UInt32Le(3790)
self.clientName = Buffer("rdpy" + "\x00" * 11, readLen = CallableValue(32), unicode = True) self.clientName = Buffer(b"rdpy" + b"\x00" * 11, read_len=lambda: 32)
self.keyboardType = UInt32Le(KeyboardType.IBM_101_102_KEYS) self.keyboardType = UInt32Le(KeyboardType.IBM_101_102_KEYS)
self.keyboardSubType = UInt32Le(0) self.keyboardSubType = UInt32Le(0)
self.keyboardFnKeys = UInt32Le(12) self.keyboardFnKeys = UInt32Le(12)
self.imeFileName = Buffer("\x00" * 64, readLen = CallableValue(64), optional = True) self.imeFileName = Buffer(b"\x00" * 64, read_len=lambda: 64, optional=True)
self.postBeta2ColorDepth = UInt16Le(ColorDepth.RNS_UD_COLOR_8BPP, optional = True) self.postBeta2ColorDepth = UInt16Le(ColorDepth.RNS_UD_COLOR_8BPP, optional=True)
self.clientProductId = UInt16Le(1, optional = True) self.clientProductId = UInt16Le(1, optional=True)
self.serialNumber = UInt32Le(0, optional = True) self.serialNumber = UInt32Le(0, optional=True)
self.highColorDepth = UInt16Le(HighColor.HIGH_COLOR_24BPP, optional = True) self.highColorDepth = UInt16Le(HighColor.HIGH_COLOR_24BPP, optional=True)
self.supportedColorDepths = UInt16Le(Support.RNS_UD_15BPP_SUPPORT | Support.RNS_UD_16BPP_SUPPORT | Support.RNS_UD_24BPP_SUPPORT | Support.RNS_UD_32BPP_SUPPORT, optional = True) self.supportedColorDepths = UInt16Le(Support.RNS_UD_15BPP_SUPPORT | Support.RNS_UD_16BPP_SUPPORT | Support.RNS_UD_24BPP_SUPPORT | Support.RNS_UD_32BPP_SUPPORT, optional=True)
self.earlyCapabilityFlags = UInt16Le(CapabilityFlags.RNS_UD_CS_SUPPORT_ERRINFO_PDU, optional = True) self.earlyCapabilityFlags = UInt16Le(CapabilityFlags.RNS_UD_CS_SUPPORT_ERRINFO_PDU, optional=True)
self.clientDigProductId = Buffer("\x00" * 64, readLen = CallableValue(64), optional = True) self.clientDigProductId = Buffer(b"\x00" * 64, read_len=lambda: 64, optional=True)
self.connectionType = UInt8(optional = True) self.connectionType = UInt8(optional=True)
self.pad1octet = UInt8(optional = True) self.pad1octet = UInt8(optional=True)
self.serverSelectedProtocol = UInt32Le(optional = True) self.serverSelectedProtocol = UInt32Le(optional=True)
class ServerCoreData(CompositeType): class ServerCoreData(CompositeType):
""" """
@summary: Server side core settings structure
@see: http://msdn.microsoft.com/en-us/library/cc240517.aspx
""" """
_TYPE_ = MessageType.SC_CORE _TYPE_ = MessageType.SC_CORE
def __init__(self, readLen = None): def __init__(self, read_len=None):
CompositeType.__init__(self, readLen = readLen) super().__init__(read_len=read_len)
self.rdpVersion = UInt32Le(Version.RDP_VERSION_5_PLUS) self.rdpVersion = UInt32Le(Version.RDP_VERSION_5_PLUS)
self.clientRequestedProtocol = UInt32Le(optional = True) self.clientRequestedProtocol = UInt32Le(optional=True)
self.earlyCapabilityFlags = UInt32Le(optional = True) self.earlyCapabilityFlags = UInt32Le(optional=True)
class ClientSecurityData(CompositeType): class ClientSecurityData(CompositeType):
""" """
@summary: Client security setting
@see: http://msdn.microsoft.com/en-us/library/cc240511.aspx
""" """
_TYPE_ = MessageType.CS_SECURITY _TYPE_ = MessageType.CS_SECURITY
def __init__(self, readLen = None): def __init__(self, read_len=None):
CompositeType.__init__(self, readLen = readLen) super().__init__(read_len=read_len)
self.encryptionMethods = UInt32Le(EncryptionMethod.ENCRYPTION_FLAG_40BIT | EncryptionMethod.ENCRYPTION_FLAG_56BIT | EncryptionMethod.ENCRYPTION_FLAG_128BIT) self.encryptionMethods = UInt32Le(EncryptionMethod.ENCRYPTION_FLAG_40BIT | EncryptionMethod.ENCRYPTION_FLAG_56BIT | EncryptionMethod.ENCRYPTION_FLAG_128BIT)
self.extEncryptionMethods = UInt32Le() self.extEncryptionMethods = UInt32Le()
class ServerSecurityData(CompositeType): class ServerSecurityData(CompositeType):
""" """
@summary: Server security settings
@see: http://msdn.microsoft.com/en-us/library/cc240518.aspx
""" """
_TYPE_ = MessageType.SC_SECURITY _TYPE_ = MessageType.SC_SECURITY
def __init__(self, readLen = None): def __init__(self, read_len=None):
CompositeType.__init__(self, readLen = readLen) super().__init__(read_len=read_len)
self.encryptionMethod = UInt32Le() self.encryptionMethod = UInt32Le()
self.encryptionLevel = UInt32Le() self.encryptionLevel = UInt32Le()
self.serverRandomLen = UInt32Le(0x00000020, constant = True, conditional = lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel == 0)) self.serverRandomLen = UInt32Le(0x00000020, constant=True, conditional=lambda: not(self.encryptionMethod.value == 0 and self.encryptionLevel.value == 0))
self.serverCertLen = UInt32Le(lambda:sizeof(self.serverCertificate), conditional = lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel == 0)) self.serverCertLen = UInt32Le(lambda: sizeof(self.serverCertificate), conditional=lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel.value == 0))
self.serverRandom = Buffer(readLen = self.serverRandomLen, conditional = lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel == 0)) self.serverRandom = Buffer(read_len=lambda: self.serverRandomLen.value, conditional=lambda: not(self.encryptionMethod.value == 0 and self.encryptionLevel.value == 0))
self.serverCertificate = ServerCertificate(readLen = self.serverCertLen, conditional = lambda:not(self.encryptionMethod.value == 0 and self.encryptionLevel == 0)) self.serverCertificate = ServerCertificate(readLen=lambda: self.serverCertLen.value, conditional=lambda: not(self.encryptionMethod.value == 0 and self.encryptionLevel.value == 0))
class ServerCertificate(CompositeType): class ServerCertificate(CompositeType):
""" """
@@ -448,79 +449,62 @@ class RSAPublicKey(CompositeType):
self.datalen = UInt32Le(lambda:((self.bitlen.value / 8) - 1)) self.datalen = UInt32Le(lambda:((self.bitlen.value / 8) - 1))
self.pubExp = UInt32Le() self.pubExp = UInt32Le()
self.modulus = Buffer(readLen = CallableValue(lambda:(self.keylen.value - 8))) self.modulus = Buffer(readLen = CallableValue(lambda:(self.keylen.value - 8)))
self.padding = Buffer("\x00" * 8, readLen = CallableValue(8)) self.padding = Buffer(b"\x00" * 8, readLen = CallableValue(8))
class ChannelDef(CompositeType): class ChannelDef(CompositeType):
""" """
Channels structure share between client and server
@see: http://msdn.microsoft.com/en-us/library/cc240513.aspx
""" """
def __init__(self, name = "", options = 0): def __init__(self, name=b""):
CompositeType.__init__(self) super().__init__()
#name of channel # name of channel
self.name = Buffer(name[0:8] + "\x00" * (8 - len(name)), readLen = CallableValue(8)) self.name = Buffer(name[0:8] + b"\x00" * (8 - len(name)), read_len=lambda: 8)
#unknown # unknown
self.options = UInt32Le() self.options = UInt32Le()
class ClientNetworkData(CompositeType): class ClientNetworkData(CompositeType):
""" """
@summary: GCC client network block
All channels asked by client are listed here
@see: http://msdn.microsoft.com/en-us/library/cc240512.aspx
""" """
_TYPE_ = MessageType.CS_NET _TYPE_ = MessageType.CS_NET
def __init__(self, readLen = None): def __init__(self, read_len=None):
CompositeType.__init__(self, readLen = readLen) CompositeType.__init__(self, read_len=read_len)
self.channelCount = UInt32Le(lambda:len(self.channelDefArray._array)) self.channelCount = UInt32Le(lambda: len(self.channelDefArray))
self.channelDefArray = ArrayType(ChannelDef, readLen = self.channelCount) self.channelDefArray = ArrayType(ChannelDef, read_len=lambda: self.channelCount.value)
class ServerNetworkData(CompositeType): class ServerNetworkData(CompositeType):
""" """
@summary: GCC server network block
All channels asked by client are listed here
@see: All channels asked by client are listed here
""" """
_TYPE_ = MessageType.SC_NET _TYPE_ = MessageType.SC_NET
def __init__(self, readLen = None): def __init__(self, read_len=None):
CompositeType.__init__(self, readLen = readLen) super().__init__(read_len=read_len)
self.MCSChannelId = UInt16Le(mcs.Channel.MCS_GLOBAL_CHANNEL) self.MCSChannelId = UInt16Le(mcs.Channel.MCS_GLOBAL_CHANNEL)
self.channelCount = UInt16Le(lambda:len(self.channelIdArray._array)) self.channelCount = UInt16Le(lambda: len(self.channelIdArray))
self.channelIdArray = ArrayType(UInt16Le, readLen = self.channelCount) self.channelIdArray = ArrayType(UInt16Le, read_len=lambda: self.channelCount.value)
self.pad = UInt16Le(conditional = lambda:((self.channelCount.value % 2) == 1)) self.pad = UInt16Le(conditional=lambda: ((self.channelCount.value % 2) == 1))
class Settings(CompositeType): class Settings(CompositeType):
""" """
@summary: Class which group all clients settings supported by RDPY
""" """
def __init__(self, init = [], readLen = None): def __init__(self, init=None, read_len=None):
CompositeType.__init__(self, readLen = readLen) super().__init__(read_len=read_len)
self.settings = ArrayType(DataBlock, [DataBlock(i) for i in init]) self.settings = ArrayType(DataBlock, [DataBlock(i) for i in init])
def getBlock(self, messageType): def get_block(self, message_type):
""" """
@param messageType: type of block
@return: specific block of type messageType
""" """
for i in self.settings._array: for i in self.settings._array:
if i.type.value == messageType: if i.type.value == message_type:
return i.dataBlock return i.dataBlock
return None return None
def __getattr__(self, name):
""" def client_settings():
@summary: Magic function for better access
@return: _value parameter
"""
if not name in MessageType.__dict__:
return None
return self.getBlock(MessageType.__dict__[name])
def clientSettings():
""" """
@summary: Build settings for client
@return: Settings
""" """
return Settings([ClientCoreData(), ClientNetworkData(), ClientSecurityData()]) return Settings([ClientCoreData(), ClientNetworkData(), ClientSecurityData()])
@@ -578,9 +562,9 @@ def readConferenceCreateResponse(s):
raise InvalidExpectedDataException("cannot read h221_sc_key") raise InvalidExpectedDataException("cannot read h221_sc_key")
length = per.readLength(s) length = per.readLength(s)
serverSettings = Settings(readLen = CallableValue(length)) server_settings = Settings(read_len=lambda: length)
s.read_type(serverSettings) s.read_type(server_settings)
return serverSettings return server_settings
def writeConferenceCreateRequest(userData): def writeConferenceCreateRequest(userData):
""" """
@@ -593,7 +577,7 @@ def writeConferenceCreateRequest(userData):
return (per.writeChoice(0), per.writeObjectIdentifier(t124_02_98_oid), return (per.writeChoice(0), per.writeObjectIdentifier(t124_02_98_oid),
per.writeLength(len(userDataStream.getvalue()) + 14), per.writeChoice(0), per.writeLength(len(userDataStream.getvalue()) + 14), per.writeChoice(0),
per.writeSelection(0x08), per.writeNumericString("1", 1), per.writePadding(1), per.writeSelection(0x08), per.writeNumericString(b"1", 1), per.writePadding(1),
per.writeNumberOfSet(1), per.writeChoice(0xc0), per.writeNumberOfSet(1), per.writeChoice(0xc0),
per.writeOctetStream(h221_cs_key, 4), per.writeOctetStream(userDataStream.getvalue())) per.writeOctetStream(h221_cs_key, 4), per.writeOctetStream(userDataStream.getvalue()))

View File

@@ -24,6 +24,9 @@ Each channel have a particular role.
The main channel is the graphical channel. The main channel is the graphical channel.
It exist channel for file system order, audio channel, clipboard etc... It exist channel for file system order, audio channel, clipboard etc...
""" """
from typing import Tuple
from rdpy.core import x224
from rdpy.model.layer import LayerAutomata, IStreamSender, Layer from rdpy.model.layer import LayerAutomata, IStreamSender, Layer
from rdpy.model.message import sizeof, Stream, UInt8, UInt16Le, Buffer from rdpy.model.message import sizeof, Stream, UInt8, UInt16Le, Buffer
from rdpy.model.error import InvalidExpectedDataException, InvalidValue, InvalidSize, CallPureVirtualFuntion from rdpy.model.error import InvalidExpectedDataException, InvalidValue, InvalidSize, CallPureVirtualFuntion
@@ -33,16 +36,16 @@ import rdpy.model.log as log
from rdpy.core.t125 import ber, gcc, per from rdpy.core.t125 import ber, gcc, per
import rdpy.security.rsa_wrapper as rsa import rdpy.security.rsa_wrapper as rsa
class Message(object):
class Message:
""" """
@summary: Message type
""" """
MCS_TYPE_CONNECT_INITIAL = 0x65 MCS_TYPE_CONNECT_INITIAL = 0x65
MCS_TYPE_CONNECT_RESPONSE = 0x66 MCS_TYPE_CONNECT_RESPONSE = 0x66
class DomainMCSPDU: class DomainMCSPDU:
""" """
@summary: Domain MCS PDU header
""" """
ERECT_DOMAIN_REQUEST = 1 ERECT_DOMAIN_REQUEST = 1
DISCONNECT_PROVIDER_ULTIMATUM = 8 DISCONNECT_PROVIDER_ULTIMATUM = 8
@@ -53,44 +56,114 @@ class DomainMCSPDU:
SEND_DATA_REQUEST = 25 SEND_DATA_REQUEST = 25
SEND_DATA_INDICATION = 26 SEND_DATA_INDICATION = 26
class Channel: class Channel:
""" """
@summary: Channel id of main channels use in RDP
""" """
MCS_GLOBAL_CHANNEL = 1003 MCS_GLOBAL_CHANNEL = 1003
MCS_USERCHANNEL_BASE = 1001 MCS_USERCHANNEL_BASE = 1001
class IGCCConfig(object): class IGCCConfig(object):
""" """
@summary: Channel information @summary: Channel information
""" """
def getUserId(self): def getUserId(self):
""" """
@return: {integer} mcs user id @return: {integer} mcs user id
@see: mcs.IGCCConfig @see: mcs.IGCCConfig
""" """
raise CallPureVirtualFuntion("%s:%s defined by interface %s"%(self.__class__, "getUserId", "IGCCConfig")) raise CallPureVirtualFuntion("%s:%s defined by interface %s" % (self.__class__, "getUserId", "IGCCConfig"))
def getChannelId(self): def getChannelId(self):
""" """
@return: {integer} return channel id of proxy @return: {integer} return channel id of proxy
@see: mcs.IGCCConfig @see: mcs.IGCCConfig
""" """
raise CallPureVirtualFuntion("%s:%s defined by interface %s"%(self.__class__, "getChannelId", "IGCCConfig")) raise CallPureVirtualFuntion("%s:%s defined by interface %s" % (self.__class__, "getChannelId", "IGCCConfig"))
def getGCCClientSettings(self): def getGCCClientSettings(self):
""" """
@return: {gcc.Settings} mcs layer gcc client settings @return: {gcc.Settings} mcs layer gcc client settings
@see: mcs.IGCCConfig @see: mcs.IGCCConfig
""" """
raise CallPureVirtualFuntion("%s:%s defined by interface %s"%(self.__class__, "getGCCClientSettings", "IGCCConfig")) raise CallPureVirtualFuntion(
"%s:%s defined by interface %s" % (self.__class__, "getGCCClientSettings", "IGCCConfig"))
def getGCCServerSettings(self): def getGCCServerSettings(self):
""" """
@return: {gcc.Settings} mcs layer gcc server settings @return: {gcc.Settings} mcs layer gcc server settings
@see: mcs.IGCCConfig @see: mcs.IGCCConfig
""" """
raise CallPureVirtualFuntion("%s:%s defined by interface %s"%(self.__class__, "getGCCServerSettings", "IGCCConfig")) raise CallPureVirtualFuntion(
"%s:%s defined by interface %s" % (self.__class__, "getGCCServerSettings", "IGCCConfig"))
def write_domain_params(max_channels: int, max_users: int, max_tokens: int, max_pdu_size: int) -> tuple:
"""
"""
domain_param = (ber.writeInteger(max_channels), ber.writeInteger(max_users), ber.writeInteger(max_tokens),
ber.writeInteger(1), ber.writeInteger(0), ber.writeInteger(1),
ber.writeInteger(max_pdu_size), ber.writeInteger(2))
return ber.writeUniversalTag(ber.Tag.BER_TAG_SEQUENCE, True), writeLength(sizeof(domain_param)), domain_param
def read_domain_params(stream: Stream) -> Tuple[int, int, int, int]:
"""
"""
if not ber.readUniversalTag(stream, ber.Tag.BER_TAG_SEQUENCE, True):
raise InvalidValue("bad BER tags")
ber.readLength(stream) # length
max_channels = ber.readInteger(stream)
max_users = ber.readInteger(stream)
max_tokens = ber.readInteger(stream)
ber.readInteger(stream)
ber.readInteger(stream)
ber.readInteger(stream)
max_pdu_size = ber.readInteger(stream)
ber.readInteger(stream)
return max_channels, max_users, max_tokens, max_pdu_size
class Client:
def __init__(self, x224_layer: x224.X224):
self.x224 = x224_layer
async def write_connect_initial(self):
"""
"""
settings = gcc.client_settings()
settings.get_block(gcc.MessageType.CS_CORE).serverSelectedProtocol.value = self.x224.get_selected_protocol()
cc_req = gcc.writeConferenceCreateRequest(settings)
cc_req_stream = Stream()
cc_req_stream.write_type(cc_req)
tmp = (ber.writeOctetstring(b"\x01"), ber.writeOctetstring(b"\x01"), ber.writeBoolean(True),
write_domain_params(34, 2, 0, 0xffff),
write_domain_params(1, 1, 1, 0x420),
write_domain_params(0xffff, 0xfc17, 0xffff, 0xffff),
ber.writeOctetstring(cc_req_stream.getvalue()))
await self.x224.write((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_INITIAL, sizeof(tmp)), tmp))
async def read_connect_response(self):
payload = await self.x224.read()
ber.readApplicationTag(payload, UInt8(Message.MCS_TYPE_CONNECT_RESPONSE))
ber.readEnumerated(payload)
ber.readInteger(payload)
read_domain_params(payload)
if not ber.readUniversalTag(payload, ber.Tag.BER_TAG_OCTET_STRING, False):
raise InvalidExpectedDataException("invalid expected BER tag")
gccRequestLength = ber.readLength(payload)
if payload.data_len() != gccRequestLength:
raise InvalidSize("bad size of GCC request")
gcc.readConferenceCreateResponse(payload)
async def connect(self):
await self.write_connect_initial()
await self.read_connect_response()
class MCSLayer(LayerAutomata): class MCSLayer(LayerAutomata):
""" """
@@ -98,11 +171,13 @@ class MCSLayer(LayerAutomata):
the main layer of RDP protocol the main layer of RDP protocol
is why he can do everything and more! is why he can do everything and more!
""" """
class MCSProxySender(Layer, IStreamSender, IGCCConfig): class MCSProxySender(Layer, IStreamSender, IGCCConfig):
""" """
@summary: Proxy use to set as transport layer for upper channel @summary: Proxy use to set as transport layer for upper channel
use to abstract channel id for presentation layer use to abstract channel id for presentation layer
""" """
def __init__(self, presentation, mcs, channelId): def __init__(self, presentation, mcs, channelId):
""" """
@param presentation: {Layer} presentation layer @param presentation: {Layer} presentation layer
@@ -112,7 +187,7 @@ class MCSLayer(LayerAutomata):
Layer.__init__(self, presentation) Layer.__init__(self, presentation)
self._mcs = mcs self._mcs = mcs
self._channelId = channelId self._channelId = channelId
def send(self, data): def send(self, data):
""" """
@summary: A send proxy function, use channel id and specific @summary: A send proxy function, use channel id and specific
@@ -120,43 +195,42 @@ class MCSLayer(LayerAutomata):
@param data: {type.Type | Tuple} @param data: {type.Type | Tuple}
""" """
self._mcs.send(self._channelId, data) self._mcs.send(self._channelId, data)
def close(self): def close(self):
""" """
@summary: Close wrapped layer @summary: Close wrapped layer
""" """
self._mcs.close() self._mcs.close()
def getUserId(self): def getUserId(self):
""" """
@return: {integer} mcs user id @return: {integer} mcs user id
@see: mcs.IGCCConfig @see: mcs.IGCCConfig
""" """
return self._mcs._userId return self._mcs._userId
def getChannelId(self): def getChannelId(self):
""" """
@return: {integer} return channel id of proxy @return: {integer} return channel id of proxy
@see: mcs.IGCCConfig @see: mcs.IGCCConfig
""" """
return self._channelId return self._channelId
def getGCCClientSettings(self): def getGCCClientSettings(self):
""" """
@return: {gcc.Settings} mcs layer gcc client settings @return: {gcc.Settings} mcs layer gcc client settings
@see: mcs.IGCCConfig @see: mcs.IGCCConfig
""" """
return self._mcs._clientSettings return self._mcs._clientSettings
def getGCCServerSettings(self): def getGCCServerSettings(self):
""" """
@return: {gcc.Settings} mcs layer gcc server settings @return: {gcc.Settings} mcs layer gcc server settings
@see: mcs.IGCCConfig @see: mcs.IGCCConfig
""" """
return self._mcs._serverSettings return self._mcs._serverSettings
def __init__(self, presentation, receiveOpcode, sendOpcode, virtualChannels=[]):
def __init__(self, presentation, receiveOpcode, sendOpcode, virtualChannels = []):
""" """
@param presentation: {Layer} presentation layer @param presentation: {Layer} presentation layer
@param virtualChannels: {Array(Layer]} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)] @param virtualChannels: {Array(Layer]} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)]
@@ -166,50 +240,50 @@ class MCSLayer(LayerAutomata):
LayerAutomata.__init__(self, presentation) LayerAutomata.__init__(self, presentation)
self._clientSettings = gcc.clientSettings() self._clientSettings = gcc.clientSettings()
self._serverSettings = gcc.serverSettings() self._serverSettings = gcc.serverSettings()
#default user Id # default user Id
self._userId = 1 + Channel.MCS_USERCHANNEL_BASE self._userId = 1 + Channel.MCS_USERCHANNEL_BASE
#list of channel use in this layer and connection state # list of channel use in this layer and connection state
self._channels = {Channel.MCS_GLOBAL_CHANNEL: presentation} self._channels = {Channel.MCS_GLOBAL_CHANNEL: presentation}
#virtual channels # virtual channels
self._virtualChannels = virtualChannels self._virtualChannels = virtualChannels
#send opcode # send opcode
self._sendOpcode = sendOpcode self._sendOpcode = sendOpcode
#receive opcode # receive opcode
self._receiveOpcode = receiveOpcode self._receiveOpcode = receiveOpcode
def close(self): def close(self):
""" """
@summary: Send disconnect provider ultimatum @summary: Send disconnect provider ultimatum
""" """
self._transport.send((UInt8(self.writeMCSPDUHeader(DomainMCSPDU.DISCONNECT_PROVIDER_ULTIMATUM, 1)), self._transport.send((UInt8(self.writeMCSPDUHeader(DomainMCSPDU.DISCONNECT_PROVIDER_ULTIMATUM, 1)),
per.writeEnumerates(0x80), Buffer("\x00" * 6))) per.writeEnumerates(0x80), Buffer(b"\x00" * 6)))
self._transport.close() self._transport.close()
def allChannelConnected(self): def allChannelConnected(self):
""" """
@summary: All channels are connected to MCS layer @summary: All channels are connected to MCS layer
Send connect to upper channel Send connect to upper channel
And prepare MCS layer to receive data And prepare MCS layer to receive data
""" """
#connection is done # connection is done
self.setNextState(self.recvData) self.setNextState(self.recvData)
#try connection on all requested channel # try connection on all requested channel
for (channelId, layer) in self._channels.iteritems(): for (channelId, layer) in self._channels.iteritems():
#use proxy for each channel # use proxy for each channel
MCSLayer.MCSProxySender(layer, self, channelId).connect() MCSLayer.MCSProxySender(layer, self, channelId).connect()
def send(self, channelId, data): def send(self, channelId, data):
""" """
@summary: Specific send function for channelId @summary: Specific send function for channelId
@param channelId: {integer} Channel use to send @param channelId: {integer} Channel use to send
@param data: {type.type | tuple} message to send @param data: {type.type | tuple} message to send
""" """
self._transport.send((self.writeMCSPDUHeader(UInt8(self._sendOpcode)), self._transport.send((self.writeMCSPDUHeader(UInt8(self._sendOpcode)),
per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE),
per.writeInteger16(channelId), per.writeInteger16(channelId),
UInt8(0x70), UInt8(0x70),
per.writeLength(sizeof(data)), data)) per.writeLength(sizeof(data)), data))
def recvData(self, data): def recvData(self, data):
""" """
@summary: Main receive method @summary: Main receive method
@@ -217,31 +291,31 @@ class MCSLayer(LayerAutomata):
""" """
opcode = UInt8() opcode = UInt8()
data.read_type(opcode) data.read_type(opcode)
if self.readMCSPDUHeader(opcode.value, DomainMCSPDU.DISCONNECT_PROVIDER_ULTIMATUM): if self.readMCSPDUHeader(opcode.value, DomainMCSPDU.DISCONNECT_PROVIDER_ULTIMATUM):
log.info("MCS DISCONNECT_PROVIDER_ULTIMATUM") log.info("MCS DISCONNECT_PROVIDER_ULTIMATUM")
self._transport.close() self._transport.close()
return return
#client case # client case
elif not self.readMCSPDUHeader(opcode.value, self._receiveOpcode): elif not self.readMCSPDUHeader(opcode.value, self._receiveOpcode):
raise InvalidExpectedDataException("Invalid expected MCS opcode receive data") raise InvalidExpectedDataException("Invalid expected MCS opcode receive data")
#server user id # server user id
per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE) per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE)
channelId = per.readInteger16(data) channelId = per.readInteger16(data)
per.readEnumerates(data) per.readEnumerates(data)
per.readLength(data) per.readLength(data)
#channel id doesn't match a requested layer # channel id doesn't match a requested layer
if not self._channels.has_key(channelId): if not self._channels.has_key(channelId):
log.error("receive data for an unconnected layer") log.error("receive data for an unconnected layer")
return return
self._channels[channelId].recv(data) self._channels[channelId].recv(data)
def writeDomainParams(self, maxChannels, maxUsers, maxTokens, maxPduSize): def writeDomainParams(self, maxChannels, maxUsers, maxTokens, maxPduSize):
""" """
@summary: Write a special domain parameter structure @summary: Write a special domain parameter structure
@@ -256,8 +330,8 @@ class MCSLayer(LayerAutomata):
ber.writeInteger(1), ber.writeInteger(0), ber.writeInteger(1), ber.writeInteger(1), ber.writeInteger(0), ber.writeInteger(1),
ber.writeInteger(maxPduSize), ber.writeInteger(2)) ber.writeInteger(maxPduSize), ber.writeInteger(2))
return (ber.writeUniversalTag(ber.Tag.BER_TAG_SEQUENCE, True), writeLength(sizeof(domainParam)), domainParam) return (ber.writeUniversalTag(ber.Tag.BER_TAG_SEQUENCE, True), writeLength(sizeof(domainParam)), domainParam)
def writeMCSPDUHeader(self, mcsPdu, options = 0): def writeMCSPDUHeader(self, mcsPdu, options=0):
""" """
@summary: Write MCS PDU header @summary: Write MCS PDU header
@param mcsPdu: {integer} PDU code @param mcsPdu: {integer} PDU code
@@ -265,7 +339,7 @@ class MCSLayer(LayerAutomata):
@return: {integer} @return: {integer}
""" """
return (mcsPdu << 2) | options return (mcsPdu << 2) | options
def readMCSPDUHeader(self, opcode, mcsPdu): def readMCSPDUHeader(self, opcode, mcsPdu):
""" """
@summary: Read mcsPdu header and return options parameter @summary: Read mcsPdu header and return options parameter
@@ -274,7 +348,7 @@ class MCSLayer(LayerAutomata):
@return: {boolean} true if opcode is correct @return: {boolean} true if opcode is correct
""" """
return (opcode >> 2) == mcsPdu return (opcode >> 2) == mcsPdu
def readDomainParams(self, s): def readDomainParams(self, s):
""" """
@summary: Read domain parameters structure @summary: Read domain parameters structure
@@ -283,7 +357,7 @@ class MCSLayer(LayerAutomata):
""" """
if not ber.readUniversalTag(s, ber.Tag.BER_TAG_SEQUENCE, True): if not ber.readUniversalTag(s, ber.Tag.BER_TAG_SEQUENCE, True):
raise InvalidValue("bad BER tags") raise InvalidValue("bad BER tags")
ber.readLength(s)#length ber.readLength(s) # length
max_channels = ber.readInteger(s) max_channels = ber.readInteger(s)
max_users = ber.readInteger(s) max_users = ber.readInteger(s)
max_tokens = ber.readInteger(s) max_tokens = ber.readInteger(s)
@@ -293,23 +367,26 @@ class MCSLayer(LayerAutomata):
max_pdu_size = ber.readInteger(s) max_pdu_size = ber.readInteger(s)
ber.readInteger(s) ber.readInteger(s)
return (max_channels, max_users, max_tokens, max_pdu_size) return (max_channels, max_users, max_tokens, max_pdu_size)
class Client(MCSLayer):
class ClientOld(MCSLayer):
""" """
@summary: Client automata of multiple channel service layer @summary: Client automata of multiple channel service layer
""" """
def __init__(self, presentation, virtualChannels = []):
def __init__(self, presentation, virtualChannels=[]):
""" """
@param presentation: {Layer} presentation layer @param presentation: {Layer} presentation layer
@param virtualChannels: {Array(Layer)} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)] @param virtualChannels: {Array(Layer)} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)]
""" """
MCSLayer.__init__(self, presentation, DomainMCSPDU.SEND_DATA_INDICATION, DomainMCSPDU.SEND_DATA_REQUEST, virtualChannels) MCSLayer.__init__(self, presentation, DomainMCSPDU.SEND_DATA_INDICATION, DomainMCSPDU.SEND_DATA_REQUEST,
#use to know state of static channel virtualChannels)
# use to know state of static channel
self._isGlobalChannelRequested = False self._isGlobalChannelRequested = False
self._isUserChannelRequested = False self._isUserChannelRequested = False
#nb channel requested # nb channel requested
self._nbChannelRequested = 0 self._nbChannelRequested = 0
def connect(self): def connect(self):
""" """
@summary: Connect message in client automata case @summary: Connect message in client automata case
@@ -317,13 +394,13 @@ class Client(MCSLayer):
Wait ConnectResponse Wait ConnectResponse
""" """
self._clientSettings.CS_CORE.serverSelectedProtocol.value = self._transport._selectedProtocol self._clientSettings.CS_CORE.serverSelectedProtocol.value = self._transport._selectedProtocol
#ask for virtual channel # ask for virtual channel
self._clientSettings.CS_NET.channelDefArray._array = [x for (x, _) in self._virtualChannels] self._clientSettings.CS_NET.channelDefArray._array = [x for (x, _) in self._virtualChannels]
#send connect initial # send connect initial
self.sendConnectInitial() self.sendConnectInitial()
#next wait response # next wait response
self.setNextState(self.recvConnectResponse) self.setNextState(self.recvConnectResponse)
def connectNextChannel(self): def connectNextChannel(self):
""" """
@summary: Send sendChannelJoinRequest message on next disconnect channel @summary: Send sendChannelJoinRequest message on next disconnect channel
@@ -331,27 +408,27 @@ class Client(MCSLayer):
Wait channel confirm Wait channel confirm
""" """
self.setNextState(self.recvChannelJoinConfirm) self.setNextState(self.recvChannelJoinConfirm)
#global channel # global channel
if not self._isGlobalChannelRequested: if not self._isGlobalChannelRequested:
self.sendChannelJoinRequest(Channel.MCS_GLOBAL_CHANNEL) self.sendChannelJoinRequest(Channel.MCS_GLOBAL_CHANNEL)
self._isGlobalChannelRequested = True self._isGlobalChannelRequested = True
return return
#user channel # user channel
if not self._isUserChannelRequested: if not self._isUserChannelRequested:
self.sendChannelJoinRequest(self._userId) self.sendChannelJoinRequest(self._userId)
self._isUserChannelRequested = True self._isUserChannelRequested = True
return return
#static virtual channel # static virtual channel
if self._nbChannelRequested < self._serverSettings.getBlock(gcc.MessageType.SC_NET).channelCount.value: if self._nbChannelRequested < self._serverSettings.get_block(gcc.MessageType.SC_NET).channelCount.value:
channelId = self._serverSettings.getBlock(gcc.MessageType.SC_NET).channelIdArray[self._nbChannelRequested] channelId = self._serverSettings.get_block(gcc.MessageType.SC_NET).channelIdArray[self._nbChannelRequested]
self._nbChannelRequested += 1 self._nbChannelRequested += 1
self.sendChannelJoinRequest(channelId) self.sendChannelJoinRequest(channelId)
return return
self.allChannelConnected() self.allChannelConnected()
def recvConnectResponse(self, data): def recvConnectResponse(self, data):
""" """
@summary: Receive MCS connect response from server @summary: Receive MCS connect response from server
@@ -370,14 +447,14 @@ class Client(MCSLayer):
if data.data_len() != gccRequestLength: if data.data_len() != gccRequestLength:
raise InvalidSize("bad size of GCC request") raise InvalidSize("bad size of GCC request")
self._serverSettings = gcc.readConferenceCreateResponse(data) self._serverSettings = gcc.readConferenceCreateResponse(data)
#send domain request # send domain request
self.sendErectDomainRequest() self.sendErectDomainRequest()
#send attach user request # send attach user request
self.sendAttachUserRequest() self.sendAttachUserRequest()
#now wait user confirm from server # now wait user confirm from server
self.setNextState(self.recvAttachUserConfirm) self.setNextState(self.recvAttachUserConfirm)
def recvAttachUserConfirm(self, data): def recvAttachUserConfirm(self, data):
""" """
@summary: Receive an attach user confirm @summary: Receive an attach user confirm
@@ -386,17 +463,17 @@ class Client(MCSLayer):
""" """
opcode = UInt8() opcode = UInt8()
data.read_type(opcode) data.read_type(opcode)
if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ATTACH_USER_CONFIRM): if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ATTACH_USER_CONFIRM):
raise InvalidExpectedDataException("Invalid MCS PDU : ATTACH_USER_CONFIRM expected") raise InvalidExpectedDataException("Invalid MCS PDU : ATTACH_USER_CONFIRM expected")
if per.readEnumerates(data) != 0: if per.readEnumerates(data) != 0:
raise InvalidExpectedDataException("Server reject user") raise InvalidExpectedDataException("Server reject user")
self._userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE) self._userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE)
self.connectNextChannel() self.connectNextChannel()
def recvChannelJoinConfirm(self, data): def recvChannelJoinConfirm(self, data):
""" """
@summary: Receive a channel join confirm from server @summary: Receive a channel join confirm from server
@@ -405,29 +482,29 @@ class Client(MCSLayer):
""" """
opcode = UInt8() opcode = UInt8()
data.read_type(opcode) data.read_type(opcode)
if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.CHANNEL_JOIN_CONFIRM): if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.CHANNEL_JOIN_CONFIRM):
raise InvalidExpectedDataException("Invalid MCS PDU : CHANNEL_JOIN_CONFIRM expected") raise InvalidExpectedDataException("Invalid MCS PDU : CHANNEL_JOIN_CONFIRM expected")
confirm = per.readEnumerates(data) confirm = per.readEnumerates(data)
userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE) userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE)
if self._userId != userId: if self._userId != userId:
raise InvalidExpectedDataException("Invalid MCS User Id") raise InvalidExpectedDataException("Invalid MCS User Id")
channelId = per.readInteger16(data) channelId = per.readInteger16(data)
#must confirm global channel and user channel # must confirm global channel and user channel
if (confirm != 0) and (channelId == Channel.MCS_GLOBAL_CHANNEL or channelId == self._userId): if (confirm != 0) and (channelId == Channel.MCS_GLOBAL_CHANNEL or channelId == self._userId):
raise InvalidExpectedDataException("Server must confirm static channel") raise InvalidExpectedDataException("Server must confirm static channel")
if confirm == 0: if confirm == 0:
serverNet = self._serverSettings.getBlock(gcc.MessageType.SC_NET) serverNet = self._serverSettings.get_block(gcc.MessageType.SC_NET)
for i in range(0, serverNet.channelCount.value): for i in range(0, serverNet.channelCount.value):
if channelId == serverNet.channelIdArray[i].value: if channelId == serverNet.channelIdArray[i].value:
self._channels[channelId] = self._virtualChannels[i][1] self._channels[channelId] = self._virtualChannels[i][1]
self.connectNextChannel() self.connectNextChannel()
def sendConnectInitial(self): def sendConnectInitial(self):
""" """
@summary: Send connect initial packet @summary: Send connect initial packet
@@ -436,67 +513,69 @@ class Client(MCSLayer):
ccReq = gcc.writeConferenceCreateRequest(self._clientSettings) ccReq = gcc.writeConferenceCreateRequest(self._clientSettings)
ccReqStream = Stream() ccReqStream = Stream()
ccReqStream.write_type(ccReq) ccReqStream.write_type(ccReq)
tmp = (ber.writeOctetstring("\x01"), ber.writeOctetstring("\x01"), ber.writeBoolean(True), tmp = (ber.writeOctetstring("\x01"), ber.writeOctetstring("\x01"), ber.writeBoolean(True),
self.writeDomainParams(34, 2, 0, 0xffff), self.writeDomainParams(34, 2, 0, 0xffff),
self.writeDomainParams(1, 1, 1, 0x420), self.writeDomainParams(1, 1, 1, 0x420),
self.writeDomainParams(0xffff, 0xfc17, 0xffff, 0xffff), self.writeDomainParams(0xffff, 0xfc17, 0xffff, 0xffff),
ber.writeOctetstring(ccReqStream.getvalue())) ber.writeOctetstring(ccReqStream.getvalue()))
self._transport.send((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_INITIAL, sizeof(tmp)), tmp)) self._transport.send((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_INITIAL, sizeof(tmp)), tmp))
def sendErectDomainRequest(self): def sendErectDomainRequest(self):
""" """
@summary: Send a formated erect domain request for RDP connection @summary: Send a formated erect domain request for RDP connection
""" """
self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ERECT_DOMAIN_REQUEST)), self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ERECT_DOMAIN_REQUEST)),
per.writeInteger(0), per.writeInteger(0),
per.writeInteger(0))) per.writeInteger(0)))
def sendAttachUserRequest(self): def sendAttachUserRequest(self):
""" """
@summary: Send a formated attach user request for RDP connection @summary: Send a formated attach user request for RDP connection
""" """
self._transport.send(self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ATTACH_USER_REQUEST))) self._transport.send(self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ATTACH_USER_REQUEST)))
def sendChannelJoinRequest(self, channelId): def sendChannelJoinRequest(self, channelId):
""" """
@summary: Send a formated Channel join request from client to server @summary: Send a formated Channel join request from client to server
client automata function client automata function
@param channelId: {integer} id of channel requested @param channelId: {integer} id of channel requested
""" """
self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.CHANNEL_JOIN_REQUEST)), self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.CHANNEL_JOIN_REQUEST)),
per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE),
per.writeInteger16(channelId))) per.writeInteger16(channelId)))
class Server(MCSLayer): class Server(MCSLayer):
""" """
@summary: Server automata of multiple channel service layer @summary: Server automata of multiple channel service layer
""" """
def __init__(self, presentation, virtualChannels = []):
def __init__(self, presentation, virtualChannels=[]):
""" """
@param presentation: {Layer} presentation layer @param presentation: {Layer} presentation layer
@param virtualChannels: {List(Layer)} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)] @param virtualChannels: {List(Layer)} list additional channels like rdpsnd... [tuple(mcs.ChannelDef, layer)]
""" """
MCSLayer.__init__(self, presentation, DomainMCSPDU.SEND_DATA_REQUEST, DomainMCSPDU.SEND_DATA_INDICATION, virtualChannels) MCSLayer.__init__(self, presentation, DomainMCSPDU.SEND_DATA_REQUEST, DomainMCSPDU.SEND_DATA_INDICATION,
#nb channel requested virtualChannels)
# nb channel requested
self._nbChannelConfirmed = 0 self._nbChannelConfirmed = 0
def connect(self): def connect(self):
""" """
@summary: Connect message for server automata @summary: Connect message for server automata
Wait Connect Initial Wait Connect Initial
""" """
#basic rdp security layer # basic rdp security layer
if self._transport._selectedProtocol == 0: if self._transport._selectedProtocol == 0:
self._serverSettings.SC_SECURITY.encryptionMethod.value = gcc.EncryptionMethod.ENCRYPTION_FLAG_128BIT self._serverSettings.SC_SECURITY.encryptionMethod.value = gcc.EncryptionMethod.ENCRYPTION_FLAG_128BIT
self._serverSettings.SC_SECURITY.encryptionLevel.value = gcc.EncryptionLevel.ENCRYPTION_LEVEL_HIGH self._serverSettings.SC_SECURITY.encryptionLevel.value = gcc.EncryptionLevel.ENCRYPTION_LEVEL_HIGH
self._serverSettings.SC_SECURITY.serverRandom.value = rsa.random(256) self._serverSettings.SC_SECURITY.serverRandom.value = rsa.random(256)
self._serverSettings.SC_SECURITY.serverCertificate = self._presentation.getCertificate() self._serverSettings.SC_SECURITY.serverCertificate = self._presentation.getCertificate()
self._serverSettings.SC_CORE.clientRequestedProtocol.value = self._transport._requestedProtocol self._serverSettings.SC_CORE.clientRequestedProtocol.value = self._transport._requestedProtocol
self.setNextState(self.recvConnectInitial) self.setNextState(self.recvConnectInitial)
def recvConnectInitial(self, data): def recvConnectInitial(self, data):
""" """
@summary: Receive MCS connect initial from client @summary: Receive MCS connect initial from client
@@ -507,28 +586,28 @@ class Server(MCSLayer):
ber.readApplicationTag(data, UInt8(Message.MCS_TYPE_CONNECT_INITIAL)) ber.readApplicationTag(data, UInt8(Message.MCS_TYPE_CONNECT_INITIAL))
ber.readOctetString(data) ber.readOctetString(data)
ber.readOctetString(data) ber.readOctetString(data)
if not ber.readBoolean(data): if not ber.readBoolean(data):
raise InvalidExpectedDataException("invalid expected BER boolean tag") raise InvalidExpectedDataException("invalid expected BER boolean tag")
self.readDomainParams(data) self.readDomainParams(data)
self.readDomainParams(data) self.readDomainParams(data)
self.readDomainParams(data) self.readDomainParams(data)
self._clientSettings = gcc.readConferenceCreateRequest(Stream(ber.readOctetString(data))) self._clientSettings = gcc.readConferenceCreateRequest(Stream(ber.readOctetString(data)))
if not self._clientSettings.CS_NET is None: if not self._clientSettings.CS_NET is None:
i = 1 i = 1
for channelDef in self._clientSettings.CS_NET.channelDefArray._array: for channelDef in self._clientSettings.CS_NET.channelDefArray._array:
self._serverSettings.SC_NET.channelIdArray._array.append(UInt16Le(i + Channel.MCS_GLOBAL_CHANNEL)) self._serverSettings.SC_NET.channelIdArray._array.append(UInt16Le(i + Channel.MCS_GLOBAL_CHANNEL))
#if channel can be handle by serve add it # if channel can be handle by serve add it
for serverChannelDef, layer in self._virtualChannels: for serverChannelDef, layer in self._virtualChannels:
if channelDef.name == serverChannelDef.name: if channelDef.name == serverChannelDef.name:
self._channels[i + Channel.MCS_GLOBAL_CHANNEL] = layer self._channels[i + Channel.MCS_GLOBAL_CHANNEL] = layer
i += 1 i += 1
self.sendConnectResponse() self.sendConnectResponse()
self.setNextState(self.recvErectDomainRequest) self.setNextState(self.recvErectDomainRequest)
def recvErectDomainRequest(self, data): def recvErectDomainRequest(self, data):
""" """
@summary: Receive erect domain request @summary: Receive erect domain request
@@ -537,15 +616,15 @@ class Server(MCSLayer):
""" """
opcode = UInt8() opcode = UInt8()
data.read_type(opcode) data.read_type(opcode)
if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ERECT_DOMAIN_REQUEST): if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ERECT_DOMAIN_REQUEST):
raise InvalidExpectedDataException("Invalid MCS PDU : ERECT_DOMAIN_REQUEST expected") raise InvalidExpectedDataException("Invalid MCS PDU : ERECT_DOMAIN_REQUEST expected")
per.readInteger(data) per.readInteger(data)
per.readInteger(data) per.readInteger(data)
self.setNextState(self.recvAttachUserRequest) self.setNextState(self.recvAttachUserRequest)
def recvAttachUserRequest(self, data): def recvAttachUserRequest(self, data):
""" """
@summary: Receive Attach user request @summary: Receive Attach user request
@@ -555,13 +634,13 @@ class Server(MCSLayer):
""" """
opcode = UInt8() opcode = UInt8()
data.read_type(opcode) data.read_type(opcode)
if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ATTACH_USER_REQUEST): if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.ATTACH_USER_REQUEST):
raise InvalidExpectedDataException("Invalid MCS PDU : ATTACH_USER_REQUEST expected") raise InvalidExpectedDataException("Invalid MCS PDU : ATTACH_USER_REQUEST expected")
self.sendAttachUserConfirm() self.sendAttachUserConfirm()
self.setNextState(self.recvChannelJoinRequest) self.setNextState(self.recvChannelJoinRequest)
def recvChannelJoinRequest(self, data): def recvChannelJoinRequest(self, data):
""" """
@summary: Receive for each client channel a request @summary: Receive for each client channel a request
@@ -571,22 +650,22 @@ class Server(MCSLayer):
""" """
opcode = UInt8() opcode = UInt8()
data.read_type(opcode) data.read_type(opcode)
if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.CHANNEL_JOIN_REQUEST): if not self.readMCSPDUHeader(opcode.value, DomainMCSPDU.CHANNEL_JOIN_REQUEST):
raise InvalidExpectedDataException("Invalid MCS PDU : CHANNEL_JOIN_REQUEST expected") raise InvalidExpectedDataException("Invalid MCS PDU : CHANNEL_JOIN_REQUEST expected")
userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE) userId = per.readInteger16(data, Channel.MCS_USERCHANNEL_BASE)
if self._userId != userId: if self._userId != userId:
raise InvalidExpectedDataException("Invalid MCS User Id") raise InvalidExpectedDataException("Invalid MCS User Id")
channelId = per.readInteger16(data) channelId = per.readInteger16(data)
#actually algo support virtual channel but RDPY have no virtual channel # actually algo support virtual channel but RDPY have no virtual channel
confirm = 0 if channelId in self._channels.keys() or channelId == self._userId else 1 confirm = 0 if channelId in self._channels.keys() or channelId == self._userId else 1
self.sendChannelJoinConfirm(channelId, confirm) self.sendChannelJoinConfirm(channelId, confirm)
self._nbChannelConfirmed += 1 self._nbChannelConfirmed += 1
if self._nbChannelConfirmed == self._serverSettings.getBlock(gcc.MessageType.SC_NET).channelCount.value + 2: if self._nbChannelConfirmed == self._serverSettings.get_block(gcc.MessageType.SC_NET).channelCount.value + 2:
self.allChannelConnected() self.allChannelConnected()
def sendConnectResponse(self): def sendConnectResponse(self):
""" """
@summary: Send connect response @summary: Send connect response
@@ -594,27 +673,27 @@ class Server(MCSLayer):
ccReq = gcc.writeConferenceCreateResponse(self._serverSettings) ccReq = gcc.writeConferenceCreateResponse(self._serverSettings)
ccReqStream = Stream() ccReqStream = Stream()
ccReqStream.write_type(ccReq) ccReqStream.write_type(ccReq)
tmp = (ber.writeEnumerated(0), ber.writeInteger(0), self.writeDomainParams(22, 3, 0, 0xfff8), tmp = (ber.writeEnumerated(0), ber.writeInteger(0), self.writeDomainParams(22, 3, 0, 0xfff8),
ber.writeOctetstring(ccReqStream.getvalue())) ber.writeOctetstring(ccReqStream.getvalue()))
self._transport.send((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_RESPONSE, sizeof(tmp)), tmp)) self._transport.send((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_RESPONSE, sizeof(tmp)), tmp))
def sendAttachUserConfirm(self): def sendAttachUserConfirm(self):
""" """
@summary: Send attach user confirm @summary: Send attach user confirm
""" """
self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ATTACH_USER_CONFIRM), 2), self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.ATTACH_USER_CONFIRM), 2),
per.writeEnumerates(0), per.writeEnumerates(0),
per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE))) per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE)))
def sendChannelJoinConfirm(self, channelId, confirm): def sendChannelJoinConfirm(self, channelId, confirm):
""" """
@summary: Send a confirm channel (or not) to client @summary: Send a confirm channel (or not) to client
@param channelId: {integer} id of channel @param channelId: {integer} id of channel
@param confirm: {boolean} connection state @param confirm: {boolean} connection state
""" """
self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.CHANNEL_JOIN_CONFIRM), 2), self._transport.send((self.writeMCSPDUHeader(UInt8(DomainMCSPDU.CHANNEL_JOIN_CONFIRM), 2),
per.writeEnumerates(int(confirm)), per.writeEnumerates(int(confirm)),
per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE), per.writeInteger16(self._userId, Channel.MCS_USERCHANNEL_BASE),
per.writeInteger16(channelId), per.writeInteger16(channelId),
per.writeInteger16(channelId))) per.writeInteger16(channelId)))

View File

@@ -223,12 +223,9 @@ def readNumericString(s, minValue):
length = (length + minValue + 1) / 2 length = (length + minValue + 1) / 2
s.read(length) s.read(length)
def writeNumericString(nStr, minValue): def writeNumericString(nStr, minValue):
""" """
@summary: write string in per format
@param str: python string to write
@param min: min value
@return: String type that contain str encoded in per format
""" """
length = len(nStr) length = len(nStr)
mlength = minValue mlength = minValue
@@ -238,9 +235,9 @@ def writeNumericString(nStr, minValue):
result = [] result = []
for i in range(0, length, 2): for i in range(0, length, 2):
c1 = ord(nStr[i]) c1 = nStr[i]
if i + 1 < length: if i + 1 < length:
c2 = ord(nStr[i + 1]) c2 = nStr[i + 1]
else: else:
c2 = 0x30 c2 = 0x30
c1 = (c1 - 0x30) % 10 c1 = (c1 - 0x30) % 10
@@ -248,7 +245,7 @@ def writeNumericString(nStr, minValue):
result.append(UInt8((c1 << 4) | c2)) result.append(UInt8((c1 << 4) | c2))
return (writeLength(mlength), tuple(result)) return writeLength(mlength), tuple(result)
def readPadding(s, length): def readPadding(s, length):
""" """
@@ -264,7 +261,7 @@ def writePadding(length):
@param length: length of padding @param length: length of padding
@return: String with \x00 * length @return: String with \x00 * length
""" """
return Buffer("\x00" * length) return Buffer(b"\x00" * length)
def readOctetStream(s, octetStream, minValue = 0): def readOctetStream(s, octetStream, minValue = 0):
""" """
@@ -300,6 +297,6 @@ def writeOctetStream(oStr, minValue = 0):
result = [] result = []
for i in range(0, length): for i in range(0, length):
result.append(UInt8(ord(oStr[i]))) result.append(UInt8(oStr[i]))
return (writeLength(mlength), tuple(result)) return writeLength(mlength), tuple(result)

View File

@@ -134,20 +134,15 @@ class Negotiation(CompositeType):
class X224: class X224:
""" """
@summary: x224 layer management
there is an connection automata
""" """
def __init__(self, tpkt: tpkt.Tpkt): def __init__(self, tpkt: tpkt.Tpkt, selected_protocol: int):
""" """
@param tpkt: TPKT layer
""" """
self.tpkt = tpkt self.tpkt = tpkt
self.selected_protocol = selected_protocol
async def read(self): async def read(self):
""" """
@summary: Read data header from packet
And pass to presentation layer
@param data: Stream
""" """
header = X224DataHeader() header = X224DataHeader()
payload = await self.tpkt.read() payload = await self.tpkt.read()
@@ -156,12 +151,12 @@ class X224:
async def write(self, message): async def write(self, message):
""" """
@summary: Write message packet for TPDU layer
Add TPDU header
@param message:
""" """
await self.tpkt.write((X224DataHeader(), message)) await self.tpkt.write((X224DataHeader(), message))
def get_selected_protocol(self):
return self.selected_protocol
async def connect(tpkt: tpkt.Tpkt, authentication_protocol: sspi.IAuthenticationProtocol) -> X224: async def connect(tpkt: tpkt.Tpkt, authentication_protocol: sspi.IAuthenticationProtocol) -> X224:
""" """
@@ -190,11 +185,11 @@ async def connect(tpkt: tpkt.Tpkt, authentication_protocol: sspi.IAuthentication
raise InvalidExpectedDataException("RDPY doesn't support PROTOCOL_HYBRID_EX security Layer") raise InvalidExpectedDataException("RDPY doesn't support PROTOCOL_HYBRID_EX security Layer")
if selected_protocol == Protocols.PROTOCOL_RDP: if selected_protocol == Protocols.PROTOCOL_RDP:
return X224(tpkt) return X224(tpkt, selected_protocol)
elif selected_protocol == Protocols.PROTOCOL_SSL: elif selected_protocol == Protocols.PROTOCOL_SSL:
return X224(await tpkt.start_tls()) return X224(await tpkt.start_tls(), selected_protocol)
elif selected_protocol == Protocols.PROTOCOL_HYBRID: elif selected_protocol == Protocols.PROTOCOL_HYBRID:
return X224(await tpkt.start_nla(authentication_protocol)) return X224(await tpkt.start_nla(authentication_protocol), selected_protocol)
class Server(X224): class Server(X224):

View File

@@ -354,72 +354,51 @@ class SimpleType(DynMessage):
class CompositeType(Message): class CompositeType(Message):
""" """
@summary: Type node in Type tree
Track type field declared in __init__ function
Ex: self.lengthOfPacket = UInt16Le() -> record lengthOfPacket as sub type of node
""" """
def __init__(self, read_len = None, conditional = lambda:True, optional = False, constant = False, ): def __init__(self, read_len=None, conditional=lambda: True, optional=False, constant=False):
""" """
@param conditional : Callable object
Read and Write operation depend on return of this function
@param optional: If there is no enough byte in current stream
And optional is True, read type is ignored
@param constant: Check if object value doesn't change after read operation
@param readLen: Max length in bytes can be readed from stream
Use to check length information
""" """
super().__init__(conditional=conditional, optional=optional, constant=constant) super().__init__(conditional=conditional, optional=optional, constant=constant)
# list of ordoned type # list of ordorred type
self._typeName = [] self._type_name = []
self._read_len = read_len self._read_len = read_len
def __setattr__(self, name, value): def __setattr__(self, name, value):
""" """
@summary: Track Type field
For Type field record it in same order as declared
Keep other but bot handle in read or write function
@param name: name of new attribute
@param value: value of new attribute
""" """
if name[0] != '_' and (isinstance(value, Message) or isinstance(value, tuple)) and not name in self._typeName: if name[0] != '_' and (isinstance(value, Message) or isinstance(value, tuple)) and name not in self._type_name:
self._typeName.append(name) self._type_name.append(name)
self.__dict__[name] = value self.__dict__[name] = value
def __read__(self, s): def __read__(self, s):
""" """
@summary: Read composite type
Call read on each ordered sub-type
And check read length parameter
If an error occurred rollback type already read
@param s: Stream
@raise InvalidSize: if stream is greater than readLen parameter
""" """
readLen = 0 read_len = 0
for name in self._typeName: for name in self._type_name:
try: try:
s.read_type(self.__dict__[name]) s.read_type(self.__dict__[name])
readLen += sizeof(self.__dict__[name]) read_len += sizeof(self.__dict__[name])
# read is ok but read out of bound # read is ok but read out of bound
if not self._read_len is None and readLen > self._read_len.value: if self._read_len is not None and read_len > self._read_len.value:
# roll back # roll back
s.pos -= sizeof(self.__dict__[name]) s.seek(-sizeof(self.__dict__[name]), 1)
# and notify if not optional # and notify if not optional
if not self.__dict__[name]._optional: if not self.__dict__[name]._optional:
raise InvalidSize("Impossible to read type %s : read length is too small"%(self.__class__)) raise InvalidSize("Impossible to read type %s : read length is too small"%(self.__class__))
except Exception as e: except Exception as e:
log.error("Error during read %s::%s"%(self.__class__, name)) log.error("Error during read %s::%s"%(self.__class__, name))
#roll back already read # roll back already read
for tmpName in self._typeName: for tmp_name in self._type_name:
if tmpName == name: if tmp_name == name:
break break
s.seek(-sizeof(self.__dict__[tmpName]), 1) s.seek(-sizeof(self.__dict__[tmp_name]), 1)
raise e raise e
if not self._read_len is None and readLen < self._read_len.value: if self._read_len is not None and read_len < self._read_len.value:
log.debug("Still have correct data in packet %s, read %s bytes as padding"%(self.__class__, self._read_len.value - readLen)) log.debug("Still have correct data in packet %s, read %s bytes as padding"%(self.__class__, self._read_len.value - read_len))
s.read(self._read_len.value - readLen) s.read(self._read_len.value - read_len)
def __write__(self, s): def __write__(self, s):
""" """
@@ -427,7 +406,7 @@ class CompositeType(Message):
Call write on each ordered sub type Call write on each ordered sub type
@param s: Stream @param s: Stream
""" """
for name in self._typeName: for name in self._type_name:
try: try:
s.write_type(self.__dict__[name]) s.write_type(self.__dict__[name])
except Exception as e: except Exception as e:
@@ -443,7 +422,7 @@ class CompositeType(Message):
return self._read_len.value return self._read_len.value
size = 0 size = 0
for name in self._typeName: for name in self._type_name:
size += sizeof(self.__dict__[name]) size += sizeof(self.__dict__[name])
return size return size
@@ -454,9 +433,9 @@ class CompositeType(Message):
@param other: CompositeType @param other: CompositeType
@return: True if each sub-type are equals @return: True if each sub-type are equals
""" """
if self._typeName != other._typeName: if self._type_name != other._typeName:
return False return False
for name in self._typeName: for name in self._type_name:
if self.__dict__[name] != other.__dict__[name]: if self.__dict__[name] != other.__dict__[name]:
return False return False
return True return True
@@ -840,35 +819,23 @@ class ArrayType(Message):
""" """
@summary: Factory af n element @summary: Factory af n element
""" """
def __init__(self, typeFactory, init = None, readLen = None, conditional = lambda:True, optional = False, constant = False): def __init__(self, type_factory, init=None, read_len=None, conditional=lambda:True, optional=False, constant=False):
""" """
@param typeFactory: class use to init new element on read
@param init: init array
@param readLen: number of element in sequence
@param conditional : Callable object
Read and Write operation depend on return of this function
@param optional: If there is no enough byte in current stream
And optional is True, read type is ignored
@param constant: Check if object value doesn't change after read operation
""" """
Message.__init__(self, conditional, optional, constant) super().__init__(conditional, optional, constant)
self._typeFactory = typeFactory self._type_factory = type_factory
self._readLen = readLen self._read_len = read_len
self._array = [] self._array = init or []
if not init is None:
self._array = init
def __read__(self, s): def __read__(self, s):
""" """
@summary: Create readLen new object and read it
@param s: Stream
""" """
self._array = [] self._array = []
i = 0 i = 0
#self._readLen is None means that array will be read until end of stream # self._read_len is None means that array will be read until end of stream
while self._readLen is None or i < self._readLen.value: while self._read_len is None or i < self._read_len():
element = self._typeFactory() element = self._type_factory()
element._optional = self._readLen is None element._optional = self._read_len is None
s.read_type(element) s.read_type(element)
if not element._is_readed: if not element._is_readed:
break break
@@ -877,8 +844,6 @@ class ArrayType(Message):
def __write__(self, s): def __write__(self, s):
""" """
@summary: Just write array
@param s: Stream
""" """
s.write_type(self._array) s.write_type(self._array)
@@ -895,6 +860,9 @@ class ArrayType(Message):
""" """
return sizeof(self._array) return sizeof(self._array)
def __len__(self):
return len(self._array)
class FactoryType(Message): class FactoryType(Message):
""" """
@summary: Call a factory callback at read or write time @summary: Call a factory callback at read or write time