Files
rdpy/rdpy/protocol/rdp/mcs.py
2014-01-24 11:21:58 +01:00

291 lines
10 KiB
Python

'''
@author: sylvain
'''
from rdpy.network.const import ConstAttributes, TypeAttributes
from rdpy.network.layer import LayerAutomata
from rdpy.network.type import sizeof, Stream, UInt8, UInt16Be
from rdpy.network.error import InvalidExpectedDataException, InvalidValue, InvalidSize
from rdpy.protocol.rdp.ber import writeLength
import ber, gcc, per
@ConstAttributes
@TypeAttributes(UInt8)
class Message(object):
'''
message type
'''
MCS_TYPE_CONNECT_INITIAL = 0x65
MCS_TYPE_CONNECT_RESPONSE = 0x66
@ConstAttributes
@TypeAttributes(UInt8)
class DomainMCSPDU:
'''
domain mcs pdu header
'''
ERECT_DOMAIN_REQUEST = 1
DISCONNECT_PROVIDER_ULTIMATUM = 8
ATTACH_USER_REQUEST = 10
ATTACH_USER_CONFIRM = 11
CHANNEL_JOIN_REQUEST = 14
CHANNEL_JOIN_CONFIRM = 15
SEND_DATA_REQUEST = 25
SEND_DATA_INDICATION = 26
@ConstAttributes
@TypeAttributes(UInt16Be)
class Channel:
MCS_GLOBAL_CHANNEL = 1003
MCS_USERCHANNEL_BASE = 1001
class MCS(LayerAutomata):
'''
Multi Channel Service layer
the main layer of RDP protocol
is why he can do everything and more!
'''
def __init__(self, presentation):
'''
ctor call base class ctor
@param presentation: presentation layer
'''
LayerAutomata.__init__(self, presentation)
self._clientSettings = gcc.ClientSettings()
self._serverSettings = gcc.ServerSettings()
#default user Id
self._userId = UInt16Be(1)
#list of channel use in this layer and connection state
self._channelIds = {Channel.MCS_GLOBAL_CHANNEL: presentation}
#use to record already requested channel
self._channelIdsRequest = {}
def connect(self):
'''
connection send for client mode
a write connect initial packet
'''
self._clientSettings.core.serverSelectedProtocol = self._transport._selectedProtocol
self.sendConnectInitial()
def connectNextChannel(self):
'''
send sendChannelJoinRequest message on next unconnect channel
'''
for (channelId, layer) in self._channelIds.iteritems():
#for each unconnect channel send a request
if not self._channelIdsRequest.has_key(channelId):
self.sendChannelJoinRequest(channelId)
self.setNextState(self.recvChannelJoinConfirm)
return
#connection is done reinit class
self.setNextState(self.recvData)
#try connection on all requested channel
for (channelId, layer) in self._channelIds.iteritems():
if self._channelIdsRequest[channelId] and not layer is None:
layer._transport = self
layer._channelId = channelId
layer.connect()
def sendConnectInitial(self):
'''
send connect initial packet
'''
ccReq = gcc.writeConferenceCreateRequest(self._clientSettings)
ccReqStream = Stream()
ccReqStream.writeType(ccReq)
tmp = (ber.writeOctetstring("\x01"), ber.writeOctetstring("\x01"), ber.writeBoolean(True),
self.writeDomainParams(34, 2, 0, 0xffff),
self.writeDomainParams(1, 1, 1, 0x420),
self.writeDomainParams(0xffff, 0xfc17, 0xffff, 0xffff),
ber.writeOctetstring(ccReqStream.getvalue()))
self._transport.send((ber.writeApplicationTag(Message.MCS_TYPE_CONNECT_INITIAL, sizeof(tmp)), tmp))
#we must receive a connect response
self.setNextState(self.recvConnectResponse)
def sendErectDomainRequest(self):
'''
send a formated erect domain request for RDP connection
'''
self._transport.send((self.writeMCSPDUHeader(DomainMCSPDU.ERECT_DOMAIN_REQUEST), per.writeInteger(0), per.writeInteger(0)))
def sendAttachUserRequest(self):
'''
send a formated attach user request for RDP connection
'''
self._transport.send(self.writeMCSPDUHeader(DomainMCSPDU.ATTACH_USER_REQUEST))
def sendChannelJoinRequest(self, channelId):
'''
send a formated Channel join request from client to server
'''
self._transport.send((self.writeMCSPDUHeader(DomainMCSPDU.CHANNEL_JOIN_REQUEST), self._userId, channelId))
def recvConnectResponse(self, data):
'''
receive mcs connect response from server
@param data: Stream
'''
ber.readApplicationTag(data, Message.MCS_TYPE_CONNECT_RESPONSE)
ber.readEnumerated(data)
ber.readInteger(data)
self.readDomainParams(data)
if not ber.readUniversalTag(data, ber.Tag.BER_TAG_OCTET_STRING, False):
raise InvalidExpectedDataException("invalid expected ber tag")
gccRequestLength = ber.readLength(data)
if data.dataLen() != gccRequestLength:
raise InvalidSize("bad size of gcc request")
self._serverSettings = gcc.readConferenceCreateResponse(data)
#send domain request
self.sendErectDomainRequest()
#send attach user request
self.sendAttachUserRequest()
#now wait user confirm from server
self.setNextState(self.recvAttachUserConfirm)
def recvAttachUserConfirm(self, data):
'''
receive an attach user confirm
@param data: Stream
'''
opcode = UInt8()
confirm = UInt8()
data.readType((opcode, confirm))
if not self.readMCSPDUHeader(opcode, DomainMCSPDU.ATTACH_USER_CONFIRM):
raise InvalidExpectedDataException("invalid MCS PDU")
if confirm != 0:
raise Exception("server reject user")
if opcode & UInt8(2) == UInt8(2):
data.readType(self._userId)
#build channel list because we have user id
#add default channel + channels accepted by gcc connection sequence
self._channelIds[self._userId + Channel.MCS_USERCHANNEL_BASE] = None#TODO + [(x, False) for x in self._serverSettings.channelsId])
self.connectNextChannel()
def recvChannelJoinConfirm(self, data):
'''
receive a channel join confirm from server
@param data: Stream
'''
opcode = UInt8()
confirm = UInt8()
data.readType((opcode, confirm))
if not self.readMCSPDUHeader(opcode, DomainMCSPDU.CHANNEL_JOIN_CONFIRM):
raise InvalidExpectedDataException("invalid MCS PDU")
userId = UInt16Be()
channelId = UInt16Be()
data.readType((userId, channelId))
#save state of channel
self._channelIdsRequest[channelId] = confirm == 0
if confirm == 0:
print "server accept channel %d"%channelId.value
else:
print "server refused channel %d"%channelId.value
self.connectNextChannel()
def recvData(self, data):
'''
main receive method
@param data: Stream
'''
opcode = UInt8()
data.readType(opcode)
if self.readMCSPDUHeader(opcode, DomainMCSPDU.DISCONNECT_PROVIDER_ULTIMATUM):
print "receive DISCONNECT_PROVIDER_ULTIMATUM"
self.close()
elif not self.readMCSPDUHeader(opcode, DomainMCSPDU.SEND_DATA_INDICATION):
raise InvalidExpectedDataException("invalid expected mcs opcode")
userId = UInt16Be()
channelId = UInt16Be()
flags = UInt8()
length = UInt8()
data.readType((userId, channelId, flags, length))
if length & UInt8(0x80) == UInt8(0x80):
lengthP2 = UInt8()
data.readType(lengthP2)
length = UInt16Be(length.value & 0x7f << 8 | lengthP2.value)
#channel id doesn't match a requested layer
if not self._channelIdsRequest.has_key(channelId):
print "receive data for an unrequested layer"
return
#channel id math an unconnected layer
if not self._channelIdsRequest[channelId]:
print "receive data for an unconnected layer"
return
self._channelIds[channelId].recv(data)
def send(self, channelId, data):
'''
specific send function for channelId
@param data: message to send
'''
self._transport.send((self.writeMCSPDUHeader(DomainMCSPDU.SEND_DATA_REQUEST), self._userId, channelId, UInt8(0x70), UInt16Be(sizeof(data)) | UInt16Be(0x8000), data))
def writeDomainParams(self, maxChannels, maxUsers, maxTokens, maxPduSize):
'''
write a special domain param structure
use in connection sequence
@param maxChannels: number of mcs channel use
@param maxUsers: number of mcs user used (1)
@param maxTokens: unknown
@param maxPduSize: unknown
@return: domain param structure
'''
domainParam = (ber.writeInteger(maxChannels), ber.writeInteger(maxUsers), ber.writeInteger(maxTokens),
ber.writeInteger(1), ber.writeInteger(0), ber.writeInteger(1),
ber.writeInteger(maxPduSize), ber.writeInteger(2))
return (ber.writeUniversalTag(ber.Tag.BER_TAG_SEQUENCE, True), writeLength(sizeof(domainParam)), domainParam)
def writeMCSPDUHeader(self, mcsPdu, options = 0):
'''
write mcs pdu header
@param mcsPdu: pdu code
@param options: option contains in header
@return: UInt8
'''
return (mcsPdu << 2) | options
def readMCSPDUHeader(self, opcode, mcsPdu):
'''
read mcsPdu header and return options parameter
@param opcode: UInt8 opcode
@param mcsPdu: mcsPdu will be checked
@return: true if opcode is correct
'''
return (opcode >> 2) == mcsPdu
def readDomainParams(self, s):
'''
read domain params structure
@return: (max_channels, max_users, max_tokens, max_pdu_size)
'''
if not ber.readUniversalTag(s, ber.Tag.BER_TAG_SEQUENCE, True):
raise InvalidValue("bad BER tags")
ber.readLength(s)#length
max_channels = ber.readInteger(s)
max_users = ber.readInteger(s)
max_tokens = ber.readInteger(s)
ber.readInteger(s)
ber.readInteger(s)
ber.readInteger(s)
max_pdu_size = ber.readInteger(s)
ber.readInteger(s)
return (max_channels, max_users, max_tokens, max_pdu_size)