Source code for ws4py.client

# -*- coding: utf-8 -*-
from base64 import b64encode
from hashlib import sha1
import os
import socket
import ssl

from ws4py import WS_KEY, WS_VERSION
from ws4py.exc import HandshakeError
from ws4py.websocket import WebSocket
from ws4py.compat import urlsplit, enc, dec

__all__ = ['WebSocketBaseClient']

[docs]class WebSocketBaseClient(WebSocket): def __init__(self, url, protocols, extensions): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) WebSocket.__init__(self, sock, protocols=protocols, extensions=extensions) self.stream.always_mask = True self.stream.expect_masking = False self.key = b64encode(os.urandom(16)) self.url = url self.host = None self.scheme = None self.port = None self.resource = None self._parse_url() # Adpated from: https://github.com/liris/websocket-client/blob/master/websocket.py#L105 def _parse_url(self): if ":" not in self.url: raise ValueError("Invalid URL: %s" % self.url) # Python 2.6.1 and below don't parse ws or wss urls properly. netloc is empty. # See: https://github.com/Lawouach/WebSocket-for-Python/issues/59 scheme, url = self.url.split(":", 1) parsed = urlsplit(url, scheme="http") if parsed.hostname: self.host = parsed.hostname else: raise ValueError("Invalid hostname from: %s", self.url) if parsed.port: self.port = parsed.port if scheme == "ws": if not self.port: self.port = 80 elif scheme == "wss": if not self.port: self.port = 443 else: raise ValueError("Invalid scheme: %s" % scheme) if parsed.path: resource = parsed.path else: resource = "/" if parsed.query: resource += "?" + parsed.query self.scheme = scheme self.resource = resource
[docs] def close(self, code=1000, reason=''): if not self.client_terminated: self.client_terminated = True self.sender(self.stream.close(code=code, reason=reason).single(mask=True))
[docs] def connect(self): #self.sock.settimeout(3) if self.scheme == "wss": # default port is now 443; upgrade self.sender to send ssl self.sock = ssl.wrap_socket(self.sock) self.sender = self.sock.sendall self.sock.connect((self.host, int(self.port))) self.sender(self.handshake_request) response = enc('') doubleCLRF = enc('\r\n\r\n') while True: bytes = self.sock.recv(128) if not bytes: break response += bytes if doubleCLRF in response: break if not response: self.close_connection() raise HandshakeError("Invalid response") headers, _, body = response.partition(doubleCLRF) response_line, _, headers = headers.partition(enc('\r\n')) try: self.process_response_line(response_line) self.protocols, self.extensions = self.process_handshake_header(headers) except HandshakeError: self.close_connection() raise self.handshake_ok() if body: self.process(body)
@property
[docs] def handshake_headers(self): headers = [ ('Host', self.host), ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), ('Sec-WebSocket-Key', dec(self.key)), ('Origin', self.url), ('Sec-WebSocket-Version', str(max(WS_VERSION))) ] if self.protocols: headers.append(('Sec-WebSocket-Protocol', ','.join(self.protocols))) return headers
@property
[docs] def handshake_request(self): headers = self.handshake_headers request = ["GET %s HTTP/1.1" % self.resource] for header, value in headers: request.append("%s: %s" % (header, value)) request.append('\r\n') return enc('\r\n'.join(request))
[docs] def process_response_line(self, response_line): protocol, code, status = response_line.split(enc(' '), 2) if code != enc('101'): raise HandshakeError("Invalid response status: %s %s" % (code, status))
[docs] def process_handshake_header(self, headers): protocols = [] extensions = [] headers = headers.strip() for header_line in headers.split(enc('\r\n')): header, value = header_line.split(enc(':'), 1) header = header.strip().lower() value = value.strip().lower() if header == 'upgrade' and value != 'websocket': raise HandshakeError("Invalid Upgrade header: %s" % value) elif header == 'connection' and value != 'upgrade': raise HandshakeError("Invalid Connection header: %s" % value) elif header == 'sec-websocket-accept': match = b64encode(sha1(enc(self.key + WS_KEY)).digest()) if value != match.lower(): raise HandshakeError("Invalid challenge response: %s" % value) elif header == 'sec-websocket-protocol': protocols = ','.join(value) elif header == 'sec-websocket-extensions': extensions = ','.join(value) return protocols, extensions