diff --git a/swh/loader/cvs/cvsclient.py b/swh/loader/cvs/cvsclient.py --- a/swh/loader/cvs/cvsclient.py +++ b/swh/loader/cvs/cvsclient.py @@ -7,11 +7,11 @@ """ +import os.path +import re import socket import subprocess -import os.path import tempfile -import re from swh.loader.exception import NotFound @@ -20,12 +20,33 @@ EXAMPLE_PSERVER_URL = "pserver://user:password@cvs.example.com/cvsroot/repository" EXAMPLE_SSH_URL = "ssh://user@cvs.example.com/cvsroot/repository" -VALID_RESPONSES = ["ok", "error", "Valid-requests", "Checked-in", - "New-entry", "Checksum", "Copy-file", "Updated", "Created", - "Update-existing", "Merged", "Patched", "Rcs-diff", "Mode", - "Removed", "Remove-entry", "Template", "Notified", - "Module-expansion", "Wrapper-rcsOption", "M", "Mbinary", - "E", "F", "MT"] +VALID_RESPONSES = [ + "ok", + "error", + "Valid-requests", + "Checked-in", + "New-entry", + "Checksum", + "Copy-file", + "Updated", + "Created", + "Update-existing", + "Merged", + "Patched", + "Rcs-diff", + "Mode", + "Removed", + "Remove-entry", + "Template", + "Notified", + "Module-expansion", + "Wrapper-rcsOption", + "M", + "Mbinary", + "E", + "F", + "MT", +] # Trivially encode strings to protect them from innocent eyes (i.e., # inadvertent password compromises, like a network administrator @@ -36,7 +57,7 @@ def scramble_password(password): - s = ['A'] # scramble scheme version number + s = ["A"] # scramble scheme version number # fmt: off scramble_shifts = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, # noqa: E241 @@ -57,7 +78,7 @@ 243,233,253,240,194,250,191,155,142,137,245,235,163,242,178,152] # noqa: E241,E131,E501 # fmt: on for c in password: - s.append('%c' % scramble_shifts[ord(c)]) + s.append("%c" % scramble_shifts[ord(c)]) return "".join(s) @@ -65,23 +86,26 @@ pass -_re_kb_opt = re.compile(b'\/-kb\/') # noqa: W605 +_re_kb_opt = re.compile(b"\/-kb\/") # noqa: W605 class CVSClient: - def connect_pserver(self, hostname, port, auth): if port is None: port = CVS_PSERVER_PORT if auth is None: - raise NotFound("Username and password are required for " - "a pserver connection: %s" % EXAMPLE_PSERVER_URL) + raise NotFound( + "Username and password are required for " + "a pserver connection: %s" % EXAMPLE_PSERVER_URL + ) try: - user = auth.split(':')[0] - password = auth.split(':')[1] + user = auth.split(":")[0] + password = auth.split(":")[1] except IndexError: - raise NotFound("Username and password are required for " - "a pserver connection: %s" % EXAMPLE_PSERVER_URL) + raise NotFound( + "Username and password are required for " + "a pserver connection: %s" % EXAMPLE_PSERVER_URL + ) try: self.socket = socket.create_connection((hostname, port)) @@ -89,59 +113,65 @@ raise NotFound("Could not connect to %s:%s", hostname, port) scrambled_password = scramble_password(password) - request = "BEGIN AUTH REQUEST\n%s\n%s\n%s\nEND AUTH REQUEST\n" \ - % (self.cvsroot_path, user, scrambled_password) + request = "BEGIN AUTH REQUEST\n%s\n%s\n%s\nEND AUTH REQUEST\n" % ( + self.cvsroot_path, + user, + scrambled_password, + ) print("Request: %s\n" % request) - self.socket.sendall(request.encode('UTF-8')) + self.socket.sendall(request.encode("UTF-8")) response = self.conn_read_line() if response != b"I LOVE YOU\n": - raise NotFound("pserver authentication failed for %s:%s: %s" % - (hostname, port, response)) + raise NotFound( + "pserver authentication failed for %s:%s: %s" + % (hostname, port, response) + ) def connect_ssh(self, hostname, port, auth): - command = ['ssh'] + command = ["ssh"] if auth is not None: # Assume 'auth' contains only a user name. # We do not support password authentication with SSH since the # anoncvs user is usually granted access without a password. - command += ['-l' , '%s' % auth] + command += ["-l", "%s" % auth] if port is not None: - command += ['-p' , '%d' % port] + command += ["-p", "%d" % port] # accept new SSH hosts keys upon first use; changed host keys # will require intervention - command += ['-o', "StrictHostKeyChecking=accept-new"] + command += ["-o", "StrictHostKeyChecking=accept-new"] # disable interactive prompting - command += ['-o', "BatchMode=yes"] + command += ["-o", "BatchMode=yes"] # disable further option processing by adding '--' - command += ['--'] + command += ["--"] - command += ['%s' % hostname, 'cvs', 'server'] + command += ["%s" % hostname, "cvs", "server"] # use non-buffered I/O to match behaviour of self.socket - self.ssh = subprocess.Popen(command, - bufsize=0, - stdin=subprocess.PIPE, stdout=subprocess.PIPE) + self.ssh = subprocess.Popen( + command, bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) def connect_fake(self, hostname, port, auth): - command = ['cvs', 'server'] + command = ["cvs", "server"] # use non-buffered I/O to match behaviour of self.socket - self.ssh = subprocess.Popen(command, - bufsize=0, - stdin=subprocess.PIPE, stdout=subprocess.PIPE) + self.ssh = subprocess.Popen( + command, bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) def conn_read_line(self, require_newline=True): if len(self.linebuffer) != 0: return self.linebuffer.pop(0) - buf = b'' + buf = b"" idx = -1 while idx == -1: if len(buf) >= CVS_PROTOCOL_BUFFER_SIZE: if require_newline: - raise CVSProtocolError("Overlong response from " - "CVS server: %s" % buf) + raise CVSProtocolError( + "Overlong response from " "CVS server: %s" % buf + ) else: break if self.socket: @@ -152,9 +182,9 @@ raise Exception("No valid connection") if not buf: return None - idx = buf.rfind(b'\n') + idx = buf.rfind(b"\n") if idx != -1: - self.linebuffer = buf[:idx + 1].splitlines(keepends=True) + self.linebuffer = buf[: idx + 1].splitlines(keepends=True) else: if require_newline: raise CVSProtocolError("Invalid response from CVS server: %s" % buf) @@ -163,9 +193,9 @@ if len(self.incomplete_line) > 0: self.linebuffer[0] = self.incomplete_line + self.linebuffer[0] if idx != -1: - self.incomplete_line = buf[idx + 1:] + self.incomplete_line = buf[idx + 1 :] else: - self.incomplete_line = b'' + self.incomplete_line = b"" return self.linebuffer.pop(0) def conn_write(self, data): @@ -177,7 +207,7 @@ raise Exception("No valid connection") def conn_write_str(self, s): - return self.conn_write(s.encode('UTF-8')) + return self.conn_write(s.encode("UTF-8")) def conn_close(self): if self.socket: @@ -187,8 +217,9 @@ try: self.ssh.wait(timeout=10) except subprocess.TimeoutExpired as e: - raise subprocess.TimeoutExpired("Could not terminate " - "ssh program: %s" % e) + raise subprocess.TimeoutExpired( + "Could not terminate " "ssh program: %s" % e + ) def __init__(self, url): """ @@ -201,13 +232,13 @@ self.socket = None self.ssh = None self.linebuffer = list() - self.incomplete_line = b'' + self.incomplete_line = b"" - if url.scheme == 'pserver': + if url.scheme == "pserver": self.connect_pserver(url.host, url.port, url.auth) - elif url.scheme == 'ssh': + elif url.scheme == "ssh": self.connect_ssh(url.host, url.port, url.auth) - elif url.scheme == 'fake': + elif url.scheme == "fake": self.connect_fake(url.host, url.port, url.auth) else: raise NotFound("Invalid CVS origin URL '%s'" % url) @@ -215,16 +246,18 @@ # we should have a connection now assert self.socket or self.ssh - self.conn_write_str("Root %s\nValid-responses %s\nvalid-requests\n" - "UseUnchanged\n" % - (self.cvsroot_path, ' '.join(VALID_RESPONSES))) + self.conn_write_str( + "Root %s\nValid-responses %s\nvalid-requests\n" + "UseUnchanged\n" % (self.cvsroot_path, " ".join(VALID_RESPONSES)) + ) response = self.conn_read_line() if not response: raise CVSProtocolError("No response from CVS server") try: if response[0:15] != b"Valid-requests ": - raise CVSProtocolError("Invalid response from " - "CVS server: %s" % response) + raise CVSProtocolError( + "Invalid response from " "CVS server: %s" % response + ) except IndexError: raise CVSProtocolError("Invalid response from CVS server: %s" % response) response = self.conn_read_line() @@ -239,31 +272,32 @@ expect_error = False for line in fp.readlines(): if expect_error: - raise CVSProtocolError('CVS server error: %s' % line) - if line == b'ok\n': + raise CVSProtocolError("CVS server error: %s" % line) + if line == b"ok\n": break - elif line == b'M \n': + elif line == b"M \n": continue - elif line[0:2] == b'M ': + elif line[0:2] == b"M ": rlog_output.write(line[2:]) - elif line[0:8] == b'MT text ': + elif line[0:8] == b"MT text ": rlog_output.write(line[8:-1]) - elif line[0:8] == b'MT date ': + elif line[0:8] == b"MT date ": rlog_output.write(line[8:-1]) - elif line[0:10] == b'MT newline': + elif line[0:10] == b"MT newline": rlog_output.write(line[10:]) - elif line[0:7] == b'error ': + elif line[0:7] == b"error ": expect_error = True continue else: - raise CVSProtocolError('Bad CVS protocol response: %s' % line) + raise CVSProtocolError("Bad CVS protocol response: %s" % line) rlog_output.seek(0) return rlog_output def fetch_rlog(self): fp = tempfile.TemporaryFile() - self.conn_write_str("Global_option -q\nArgument --\nArgument %s\nrlog\n" % - self.cvs_module_name) + self.conn_write_str( + "Global_option -q\nArgument --\nArgument %s\nrlog\n" % self.cvs_module_name + ) while True: response = self.conn_read_line() if response is None: @@ -286,16 +320,20 @@ if dirname: self.conn_write_str("Directory %s\n%s\n" % (dirname, dirname)) filename = os.path.basename(path) - co_output = tempfile.NamedTemporaryFile(dir=dest_dir, delete=True, - prefix='cvsclient-checkout-%s-r%s-' % - (filename, rev)) + co_output = tempfile.NamedTemporaryFile( + dir=dest_dir, + delete=True, + prefix="cvsclient-checkout-%s-r%s-" % (filename, rev), + ) # TODO: cvs <= 1.10 servers expect to be given every Directory along the path. - self.conn_write_str("Directory %s\n%s\n" - "Global_option -q\n" - "Argument -r%s\n" - "Argument -kb\n" - "Argument --\nArgument %s\nco \n" % - (self.cvs_module_name, self.cvs_module_name, rev, path)) + self.conn_write_str( + "Directory %s\n%s\n" + "Global_option -q\n" + "Argument -r%s\n" + "Argument -kb\n" + "Argument --\nArgument %s\nco \n" + % (self.cvs_module_name, self.cvs_module_name, rev, path) + ) while True: if have_bytecount and bytecount > 0: response = self.conn_read_line(require_newline=False) @@ -304,14 +342,15 @@ co_output.write(response) bytecount -= len(response) if bytecount < 0: - raise CVSProtocolError("Overlong response from " - "CVS server: %s" % response) + raise CVSProtocolError( + "Overlong response from " "CVS server: %s" % response + ) continue else: response = self.conn_read_line() - if response[0:2] == b'E ': - raise CVSProtocolError('Error from CVS server: %s' % response) - if have_bytecount and bytecount == 0 and response == b'ok\n': + if response[0:2] == b"E ": + raise CVSProtocolError("Error from CVS server: %s" % response) + if have_bytecount and bytecount == 0 and response == b"ok\n": break if skip_line: skip_line = False @@ -320,34 +359,34 @@ try: bytecount = int(response[0:-1]) # strip trailing \n except ValueError: - raise CVSProtocolError('Bad CVS protocol response: %s' % response) + raise CVSProtocolError("Bad CVS protocol response: %s" % response) have_bytecount = True continue - elif response == b'M \n': + elif response == b"M \n": continue - elif response == b'MT +updated\n': + elif response == b"MT +updated\n": continue - elif response == b'MT -updated\n': + elif response == b"MT -updated\n": continue - elif response[0:9] == b'MT fname ': + elif response[0:9] == b"MT fname ": continue - elif response[0:8] == b'Created ': + elif response[0:8] == b"Created ": skip_line = True continue - elif response[0:1] == b'/' and _re_kb_opt.search(response): + elif response[0:1] == b"/" and _re_kb_opt.search(response): expect_modeline = True continue - elif expect_modeline and response[0:2] == b'u=': + elif expect_modeline and response[0:2] == b"u=": expect_modeline = False expect_bytecount = True continue - elif response[0:2] == b'M ': + elif response[0:2] == b"M ": continue - elif response[0:8] == b'MT text ': + elif response[0:8] == b"MT text ": continue - elif response[0:10] == b'MT newline': + elif response[0:10] == b"MT newline": continue else: - raise CVSProtocolError('Bad CVS protocol response: %s' % response) + raise CVSProtocolError("Bad CVS protocol response: %s" % response) co_output.seek(0) return co_output