nxdt_host.py: add CLI mode.

This commit is contained in:
Pablo Curiel 2021-06-03 20:19:19 -04:00
parent c0e82b3686
commit 5bb5f0c858

View file

@ -31,6 +31,7 @@
# Under MacOS, use `brew install libusb` to install libusb via Homebrew. # Under MacOS, use `brew install libusb` to install libusb via Homebrew.
# Under Linux, you should be good to go from the start. If not, just use the package manager from your distro to install libusb. # Under Linux, you should be good to go from the start. If not, just use the package manager from your distro to install libusb.
import sys
import os import os
import platform import platform
import threading import threading
@ -39,6 +40,7 @@ import logging
import queue import queue
import shutil import shutil
import time import time
import datetime
import struct import struct
import usb.core import usb.core
import usb.util import usb.util
@ -52,6 +54,8 @@ from tqdm import tqdm
import base64 import base64
from argparse import ArgumentParser
# Scaling factors. # Scaling factors.
WINDOWS_SCALING_FACTOR = 96.0 WINDOWS_SCALING_FACTOR = 96.0
SCALE = 1.0 SCALE = 1.0
@ -61,7 +65,7 @@ WINDOW_WIDTH = 500
WINDOW_HEIGHT = 470 WINDOW_HEIGHT = 470
# Application version. # Application version.
APP_VERSION = '0.2' APP_VERSION = '0.3'
# Copyright year. # Copyright year.
COPYRIGHT_YEAR = '2021' COPYRIGHT_YEAR = '2021'
@ -114,6 +118,16 @@ USB_STATUS_UNSUPPORTED_ABI_VERSION = 6
USB_STATUS_MALFORMED_CMD = 7 USB_STATUS_MALFORMED_CMD = 7
USB_STATUS_HOST_IO_ERROR = 8 USB_STATUS_HOST_IO_ERROR = 8
# Script title.
SCRIPT_TITLE = "{} host script v{}".format(USB_DEV_PRODUCT, APP_VERSION)
# Copyright text.
now = datetime.datetime.now()
cur_year = now.year
COPYRIGHT_TEXT = "Copyright (c) {}".format(COPYRIGHT_YEAR)
if cur_year > int(COPYRIGHT_YEAR): COPYRIGHT_TEXT += "-{}".format(cur_year)
COPYRIGHT_TEXT += ", {}".format(USB_DEV_MANUFACTURER)
# Messages displayed as labels. # Messages displayed as labels.
SERVER_START_MSG = 'Please connect a Nintendo Switch console running {}.'.format(USB_DEV_PRODUCT) SERVER_START_MSG = 'Please connect a Nintendo Switch console running {}.'.format(USB_DEV_PRODUCT)
SERVER_STOP_MSG = 'Exit {} on your console or disconnect it at any time to stop the server.'.format(USB_DEV_PRODUCT) SERVER_STOP_MSG = 'Exit {} on your console or disconnect it at any time to stop the server.'.format(USB_DEV_PRODUCT)
@ -289,13 +303,17 @@ class LogQueueHandler(logging.Handler):
self.log_queue = log_queue self.log_queue = log_queue
def emit(self, record): def emit(self, record):
if g_cliMode:
msg = self.format(record)
print(msg)
else:
self.log_queue.put(record) self.log_queue.put(record)
# Reference: https://beenje.github.io/blog/posts/logging-to-a-tkinter-scrolledtext-widget. # Reference: https://beenje.github.io/blog/posts/logging-to-a-tkinter-scrolledtext-widget.
class LogConsole: class LogConsole:
def __init__(self, scrolled_text): def __init__(self, scrolled_text=None):
self.scrolled_text = scrolled_text self.scrolled_text = scrolled_text
self.frame = self.scrolled_text.winfo_toplevel() self.frame = (self.scrolled_text.winfo_toplevel() if self.scrolled_text else None)
# Create a logging handler using a queue. # Create a logging handler using a queue.
self.log_queue = queue.Queue() self.log_queue = queue.Queue()
@ -303,13 +321,14 @@ class LogConsole:
#formatter = logging.Formatter('[%(asctime)s] -> %(message)s') #formatter = logging.Formatter('[%(asctime)s] -> %(message)s')
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter('%(message)s')
self.queue_handler.setFormatter(formatter) self.queue_handler.setFormatter(formatter)
g_Logger.addHandler(self.queue_handler) g_logger.addHandler(self.queue_handler)
# Start polling messages from the queue. # Start polling messages from the queue.
self.frame.after(100, self.poll_log_queue) if self.frame: self.frame.after(100, self.poll_log_queue)
def display(self, record): def display(self, record):
msg = self.queue_handler.format(record) msg = self.queue_handler.format(record)
if self.scrolled_text:
self.scrolled_text.configure(state='normal') self.scrolled_text.configure(state='normal')
self.scrolled_text.insert(tk.END, msg + '\n', record.levelname) self.scrolled_text.insert(tk.END, msg + '\n', record.levelname)
self.scrolled_text.configure(state='disabled') self.scrolled_text.configure(state='disabled')
@ -325,15 +344,13 @@ class LogConsole:
else: else:
self.display(record) self.display(record)
self.frame.after(100, self.poll_log_queue) if self.frame: self.frame.after(100, self.poll_log_queue)
# Loosely based on tk.py from tqdm. # Loosely based on tk.py from tqdm.
class ProgressBarWindow: class ProgressBarWindow:
global g_tlb, g_taskbar global g_tlb, g_taskbar
def __init__(self, bar_format=None, tk_parent=None, window_title='', window_resize=False, window_protocol=None): def __init__(self, bar_format=None, tk_parent=None, window_title='', window_resize=False, window_protocol=None):
if tk_parent is None: raise Exception('`tk_parent` must be provided!')
self.n = 0 self.n = 0
self.total = 0 self.total = 0
self.divider = 1 self.divider = 1
@ -346,9 +363,15 @@ class ProgressBarWindow:
self.hwnd = 0 self.hwnd = 0
self.tk_parent = tk_parent self.tk_parent = tk_parent
self.tk_window = (tk.Toplevel(self.tk_parent) if self.tk_parent else None)
self.withdrawn = False
self.tk_text_var = None
self.tk_n_var = None
self.tk_pbar = None
self.tk_window = tk.Toplevel(self.tk_parent) self.pbar = None
if self.tk_window:
self.tk_window.withdraw() self.tk_window.withdraw()
self.withdrawn = True self.withdrawn = True
@ -369,7 +392,7 @@ class ProgressBarWindow:
self.tk_pbar.pack() self.tk_pbar.pack()
def __del__(self): def __del__(self):
self.tk_parent.after(0, self.tk_window.destroy) if self.tk_parent: self.tk_parent.after(0, self.tk_window.destroy)
def start(self, total, n=0, divider=1, prefix='', unit='B'): def start(self, total, n=0, divider=1, prefix='', unit='B'):
if (total <= 0) or (n < 0) or (divider < 1): raise Exception('Invalid arguments!') if (total <= 0) or (n < 0) or (divider < 1): raise Exception('Invalid arguments!')
@ -381,14 +404,18 @@ class ProgressBarWindow:
self.prefix = prefix self.prefix = prefix
self.unit = unit self.unit = unit
if self.tk_pbar:
self.tk_pbar.configure(maximum=self.total_div, mode='determinate') self.tk_pbar.configure(maximum=self.total_div, mode='determinate')
self.start_time = time.time() self.start_time = time.time()
else:
n_div = (float(self.n) / self.divider)
self.pbar = tqdm(initial=n_div, total=self.total_div, unit=self.unit, dynamic_ncols=True, desc=self.prefix, bar_format=self.bar_format)
def update(self, n): def update(self, n):
cur_n = (self.n + n) cur_n = (self.n + n)
if cur_n > self.total: return if cur_n > self.total: return
if self.tk_window:
cur_n_div = (float(cur_n) / self.divider) cur_n_div = (float(cur_n) / self.divider)
self.elapsed_time = (time.time() - self.start_time) self.elapsed_time = (time.time() - self.start_time)
@ -410,6 +437,9 @@ class ProgressBarWindow:
self.withdrawn = False self.withdrawn = False
if g_taskbar: g_taskbar.SetProgressValue(self.hwnd, cur_n, self.total) if g_taskbar: g_taskbar.SetProgressValue(self.hwnd, cur_n, self.total)
else:
n_div = (float(n) / self.divider)
self.pbar.update(n_div)
self.n = cur_n self.n = cur_n
@ -423,6 +453,7 @@ class ProgressBarWindow:
self.start_time = 0 self.start_time = 0
self.elapsed_time = 0 self.elapsed_time = 0
if self.tk_window:
if g_taskbar: if g_taskbar:
g_taskbar.SetProgressState(self.hwnd, g_tlb.TBPF_NOPROGRESS) g_taskbar.SetProgressState(self.hwnd, g_tlb.TBPF_NOPROGRESS)
g_taskbar.UnregisterTab(self.hwnd) g_taskbar.UnregisterTab(self.hwnd)
@ -433,10 +464,24 @@ class ProgressBarWindow:
self.withdrawn = True self.withdrawn = True
self.tk_pbar.configure(maximum=100, mode='indeterminate') self.tk_pbar.configure(maximum=100, mode='indeterminate')
else:
self.pbar.close()
self.pbar = None
print()
def set_prefix(self, prefix): def set_prefix(self, prefix):
self.prefix = prefix self.prefix = prefix
def utilsGetPath(path_arg, fallback_path, is_file, create=False):
path = os.path.abspath(os.path.expanduser(os.path.expandvars(path_arg if path_arg else fallback_path)))
if not is_file and create: os.makedirs(path, exist_ok=True)
if not os.path.exists(path) or (is_file and os.path.isdir(path)) or (not is_file and os.path.isfile(path)):
raise Exception("Error: '%s' points to an invalid file/directory." % (path))
return path
def utilsIsValueAlignedToEndpointPacketSize(value): def utilsIsValueAlignedToEndpointPacketSize(value):
return bool((value & (g_usbEpMaxPacketSize - 1)) == 0) return bool((value & (g_usbEpMaxPacketSize - 1)) == 0)
@ -473,9 +518,11 @@ def usbGetDeviceEndpoints():
usb_ep_out_lambda = lambda ep: usb.util.endpoint_direction(ep.bEndpointAddress) == usb.util.ENDPOINT_OUT usb_ep_out_lambda = lambda ep: usb.util.endpoint_direction(ep.bEndpointAddress) == usb.util.ENDPOINT_OUT
usb_version = 0 usb_version = 0
if g_cliMode: g_logger.info('Please connect a Nintendo Switch console running nxdumptool.')
while True: while True:
# Check if the user decided to stop the server. # Check if the user decided to stop the server.
if g_stopEvent.is_set(): if not g_cliMode and g_stopEvent.is_set():
g_stopEvent.clear() g_stopEvent.clear()
return False return False
@ -491,7 +538,7 @@ def usbGetDeviceEndpoints():
# Check if the product and manufacturer strings match the ones used by nxdumptool. # Check if the product and manufacturer strings match the ones used by nxdumptool.
#if (cur_dev.manufacturer != USB_DEV_MANUFACTURER) or (cur_dev.product != USB_DEV_PRODUCT): #if (cur_dev.manufacturer != USB_DEV_MANUFACTURER) or (cur_dev.product != USB_DEV_PRODUCT):
if cur_dev.manufacturer != USB_DEV_MANUFACTURER: if cur_dev.manufacturer != USB_DEV_MANUFACTURER:
g_Logger.error('Invalid manufacturer/product strings! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) g_logger.error('Invalid manufacturer/product strings! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address))
time.sleep(0.1) time.sleep(0.1)
continue continue
@ -510,7 +557,7 @@ def usbGetDeviceEndpoints():
g_usbEpOut = usb.util.find_descriptor(intf, custom_match=usb_ep_out_lambda) g_usbEpOut = usb.util.find_descriptor(intf, custom_match=usb_ep_out_lambda)
if (g_usbEpIn is None) or (g_usbEpOut is None): if (g_usbEpIn is None) or (g_usbEpOut is None):
g_Logger.error('Invalid endpoint addresses! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) g_logger.error('Invalid endpoint addresses! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address))
time.sleep(0.1) time.sleep(0.1)
continue continue
@ -520,8 +567,10 @@ def usbGetDeviceEndpoints():
break break
g_Logger.debug('Successfully retrieved USB endpoints! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) g_logger.debug('Successfully retrieved USB endpoints! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address))
g_Logger.debug('Max packet size: 0x%X (USB %u.%u).\n' % (g_usbEpMaxPacketSize, usb_version >> 8, (usb_version & 0xFF) >> 4)) g_logger.debug('Max packet size: 0x%X (USB %u.%u).\n' % (g_usbEpMaxPacketSize, usb_version >> 8, (usb_version & 0xFF) >> 4))
if g_cliMode: g_logger.info('Exit nxdumptool or disconnect your console at any time to close this script.')
return True return True
@ -531,9 +580,9 @@ def usbRead(size, timeout=-1):
try: try:
# Convert read data to a bytes object for easier handling. # Convert read data to a bytes object for easier handling.
rd = bytes(g_usbEpIn.read(size, timeout)) rd = bytes(g_usbEpIn.read(size, timeout))
except: except usb.core.USBError:
traceback.print_exc() if not g_cliMode: traceback.print_exc()
g_Logger.error('USB timeout triggered or console disconnected.') g_logger.error('\nUSB timeout triggered or console disconnected.')
return rd return rd
@ -542,9 +591,9 @@ def usbWrite(data, timeout=-1):
try: try:
wr = g_usbEpOut.write(data, timeout) wr = g_usbEpOut.write(data, timeout)
except: except usb.core.USBError:
traceback.print_exc() if not g_cliMode: traceback.print_exc()
g_Logger.error('USB timeout triggered or console disconnected.') g_logger.error('\nUSB timeout triggered or console disconnected.')
return wr return wr
@ -555,18 +604,19 @@ def usbSendStatus(code):
def usbHandleStartSession(cmd_block): def usbHandleStartSession(cmd_block):
global g_nxdtVersionMajor, g_nxdtVersionMinor, g_nxdtVersionMicro, g_nxdtAbiVersion, g_nxdtGitCommit global g_nxdtVersionMajor, g_nxdtVersionMinor, g_nxdtVersionMicro, g_nxdtAbiVersion, g_nxdtGitCommit
g_Logger.debug('Received StartSession (%02X) command.' % (USB_CMD_START_SESSION)) if g_cliMode: print()
g_logger.debug('Received StartSession (%02X) command.' % (USB_CMD_START_SESSION))
# Parse command block. # Parse command block.
(g_nxdtVersionMajor, g_nxdtVersionMinor, g_nxdtVersionMicro, g_nxdtAbiVersion, g_nxdtGitCommit) = struct.unpack_from('<BBBB8s', cmd_block, 0) (g_nxdtVersionMajor, g_nxdtVersionMinor, g_nxdtVersionMicro, g_nxdtAbiVersion, g_nxdtGitCommit) = struct.unpack_from('<BBBB8s', cmd_block, 0)
g_nxdtGitCommit = g_nxdtGitCommit.decode('utf-8').strip('\x00') g_nxdtGitCommit = g_nxdtGitCommit.decode('utf-8').strip('\x00')
# Print client info. # Print client info.
g_Logger.info('Client info: nxdumptool v%u.%u.%u, ABI v%u (commit %s).\n' % (g_nxdtVersionMajor, g_nxdtVersionMinor, g_nxdtVersionMicro, g_nxdtAbiVersion, g_nxdtGitCommit)) g_logger.info('Client info: nxdumptool v%u.%u.%u, ABI v%u (commit %s).\n' % (g_nxdtVersionMajor, g_nxdtVersionMinor, g_nxdtVersionMicro, g_nxdtAbiVersion, g_nxdtGitCommit))
# Check if we support this ABI version. # Check if we support this ABI version.
if g_nxdtAbiVersion != USB_ABI_VERSION: if g_nxdtAbiVersion != USB_ABI_VERSION:
g_Logger.error('Unsupported ABI version!') g_logger.error('Unsupported ABI version!')
return USB_STATUS_UNSUPPORTED_ABI_VERSION return USB_STATUS_UNSUPPORTED_ABI_VERSION
# Return status code # Return status code
@ -575,7 +625,8 @@ def usbHandleStartSession(cmd_block):
def usbHandleSendFileProperties(cmd_block): def usbHandleSendFileProperties(cmd_block):
global g_nspTransferMode, g_nspSize, g_nspHeaderSize, g_nspRemainingSize, g_nspFile, g_nspFilePath, g_outputDir, g_tkRoot, g_progressBarWindow global g_nspTransferMode, g_nspSize, g_nspHeaderSize, g_nspRemainingSize, g_nspFile, g_nspFilePath, g_outputDir, g_tkRoot, g_progressBarWindow
g_Logger.debug('Received SendFileProperties (%02X) command.' % (USB_CMD_SEND_FILE_PROPERTIES)) if g_cliMode and not g_nspTransferMode: print()
g_logger.debug('Received SendFileProperties (%02X) command.' % (USB_CMD_SEND_FILE_PROPERTIES))
# Parse command block. # Parse command block.
(file_size, filename_length, nsp_header_size, raw_filename) = struct.unpack_from('<QII{}s'.format(USB_FILE_PROPERTIES_MAX_NAME_LENGTH), cmd_block, 0) (file_size, filename_length, nsp_header_size, raw_filename) = struct.unpack_from('<QII{}s'.format(USB_FILE_PROPERTIES_MAX_NAME_LENGTH), cmd_block, 0)
@ -585,22 +636,22 @@ def usbHandleSendFileProperties(cmd_block):
dbg_str = ('File size: 0x%X | Filename length: 0x%X' % (file_size, filename_length)) dbg_str = ('File size: 0x%X | Filename length: 0x%X' % (file_size, filename_length))
if nsp_header_size > 0: dbg_str += (' | NSP header size: 0x%X' % (nsp_header_size)) if nsp_header_size > 0: dbg_str += (' | NSP header size: 0x%X' % (nsp_header_size))
dbg_str += '.' dbg_str += '.'
g_Logger.debug(dbg_str) g_logger.debug(dbg_str)
file_type_str = ('file' if (not g_nspTransferMode) else 'NSP file entry') file_type_str = ('file' if (not g_nspTransferMode) else 'NSP file entry')
g_Logger.info('Receiving %s: "%s".' % (file_type_str, filename)) if g_cliMode and not g_nspTransferMode: g_logger.info('Receiving %s: "%s".' % (file_type_str, filename))
# Perform integrity checks # Perform integrity checks
if (not g_nspTransferMode) and file_size and (nsp_header_size >= file_size): if (not g_nspTransferMode) and file_size and (nsp_header_size >= file_size):
g_Logger.error('NSP header size must be smaller than the full NSP size!\n') g_logger.error('NSP header size must be smaller than the full NSP size!\n')
return USB_STATUS_MALFORMED_CMD return USB_STATUS_MALFORMED_CMD
if g_nspTransferMode and nsp_header_size: if g_nspTransferMode and nsp_header_size:
g_Logger.error('Received non-zero NSP header size during NSP transfer mode!\n') g_logger.error('Received non-zero NSP header size during NSP transfer mode!\n')
return USB_STATUS_MALFORMED_CMD return USB_STATUS_MALFORMED_CMD
if (not filename_length) or (filename_length > USB_FILE_PROPERTIES_MAX_NAME_LENGTH): if (not filename_length) or (filename_length > USB_FILE_PROPERTIES_MAX_NAME_LENGTH):
g_Logger.error('Invalid filename length!\n') g_logger.error('Invalid filename length!\n')
return USB_STATUS_MALFORMED_CMD return USB_STATUS_MALFORMED_CMD
# Enable NSP transfer mode (if needed). # Enable NSP transfer mode (if needed).
@ -611,7 +662,7 @@ def usbHandleSendFileProperties(cmd_block):
g_nspRemainingSize = (file_size - nsp_header_size) g_nspRemainingSize = (file_size - nsp_header_size)
g_nspFile = None g_nspFile = None
g_nspFilePath = None g_nspFilePath = None
g_Logger.debug('NSP transfer mode enabled!\n') g_logger.debug('NSP transfer mode enabled!\n')
# Perform additional integrity checks and get a file object to work with. # Perform additional integrity checks and get a file object to work with.
if (not g_nspTransferMode) or (g_nspFile is None): if (not g_nspTransferMode) or (g_nspFile is None):
@ -627,14 +678,14 @@ def usbHandleSendFileProperties(cmd_block):
# Make sure the output filepath doesn't point to an existing directory. # Make sure the output filepath doesn't point to an existing directory.
if os.path.exists(fullpath) and (not os.path.isfile(fullpath)): if os.path.exists(fullpath) and (not os.path.isfile(fullpath)):
utilsResetNspInfo() utilsResetNspInfo()
g_Logger.error('Output filepath points to an existing directory! ("%s").\n' % (fullpath)) g_logger.error('Output filepath points to an existing directory! ("%s").\n' % (fullpath))
return USB_STATUS_HOST_IO_ERROR return USB_STATUS_HOST_IO_ERROR
# Make sure we have enough free space. # Make sure we have enough free space.
(total_space, used_space, free_space) = shutil.disk_usage(dirpath) (total_space, used_space, free_space) = shutil.disk_usage(dirpath)
if free_space <= file_size: if free_space <= file_size:
utilsResetNspInfo() utilsResetNspInfo()
g_Logger.error('Not enough free space available in output volume!\n') g_logger.error('Not enough free space available in output volume!\n')
return USB_STATUS_HOST_IO_ERROR return USB_STATUS_HOST_IO_ERROR
# Get file object. # Get file object.
@ -667,7 +718,7 @@ def usbHandleSendFileProperties(cmd_block):
usbSendStatus(USB_STATUS_SUCCESS) usbSendStatus(USB_STATUS_SUCCESS)
# Start data transfer stage. # Start data transfer stage.
g_Logger.debug('Data transfer started. Saving %s to: "%s".' % (file_type_str, fullpath)) g_logger.debug('Data transfer started. Saving %s to: "%s".' % (file_type_str, fullpath))
offset = 0 offset = 0
blksize = USB_TRANSFER_BLOCK_SIZE blksize = USB_TRANSFER_BLOCK_SIZE
@ -675,6 +726,9 @@ def usbHandleSendFileProperties(cmd_block):
# Check if we should use the progress bar window. # Check if we should use the progress bar window.
use_pbar = (((not g_nspTransferMode) and (file_size > USB_TRANSFER_THRESHOLD)) or (g_nspTransferMode and (g_nspSize > USB_TRANSFER_THRESHOLD))) use_pbar = (((not g_nspTransferMode) and (file_size > USB_TRANSFER_THRESHOLD)) or (g_nspTransferMode and (g_nspSize > USB_TRANSFER_THRESHOLD)))
if use_pbar: if use_pbar:
if g_cliMode:
prefix = ''
else:
idx = filename.rfind(os.path.sep) idx = filename.rfind(os.path.sep)
prefix_filename = (filename[idx+1:] if (idx >= 0) else filename) prefix_filename = (filename[idx+1:] if (idx >= 0) else filename)
@ -722,7 +776,7 @@ def usbHandleSendFileProperties(cmd_block):
# Read current chunk. # Read current chunk.
chunk = usbRead(rd_size, USB_TRANSFER_TIMEOUT) chunk = usbRead(rd_size, USB_TRANSFER_TIMEOUT)
if chunk is None: if chunk is None:
g_Logger.error('Failed to read 0x%X-byte long data chunk!' % (rd_size)) g_logger.error('Failed to read 0x%X-byte long data chunk!' % (rd_size))
# Cancel file transfer. # Cancel file transfer.
cancelTransfer() cancelTransfer()
@ -736,12 +790,12 @@ def usbHandleSendFileProperties(cmd_block):
if chunk_size == USB_CMD_HEADER_SIZE: if chunk_size == USB_CMD_HEADER_SIZE:
(magic, cmd_id, cmd_block_size) = struct.unpack_from('<4sII', chunk, 0) (magic, cmd_id, cmd_block_size) = struct.unpack_from('<4sII', chunk, 0)
if (magic == USB_MAGIC_WORD) and (cmd_id == USB_CMD_CANCEL_FILE_TRANSFER): if (magic == USB_MAGIC_WORD) and (cmd_id == USB_CMD_CANCEL_FILE_TRANSFER):
g_Logger.debug('\nReceived CancelFileTransfer (%02X) command.' % (USB_CMD_CANCEL_FILE_TRANSFER))
g_Logger.warning('Transfer cancelled.')
# Cancel file transfer. # Cancel file transfer.
cancelTransfer() cancelTransfer()
g_logger.debug('Received CancelFileTransfer (%02X) command.' % (USB_CMD_CANCEL_FILE_TRANSFER))
g_logger.warning('Transfer cancelled.')
# Let the command handler take care of sending the status response for us. # Let the command handler take care of sending the status response for us.
return USB_STATUS_SUCCESS return USB_STATUS_SUCCESS
@ -759,7 +813,7 @@ def usbHandleSendFileProperties(cmd_block):
if use_pbar: g_progressBarWindow.update(chunk_size) if use_pbar: g_progressBarWindow.update(chunk_size)
elapsed_time = round(time.time() - start_time) elapsed_time = round(time.time() - start_time)
g_Logger.debug('File transfer successfully completed in %s!\n' % (tqdm.format_interval(elapsed_time))) g_logger.debug('File transfer successfully completed in %s!\n' % (tqdm.format_interval(elapsed_time)))
# Close file handle (if needed). # Close file handle (if needed).
if not g_nspTransferMode: file.close() if not g_nspTransferMode: file.close()
@ -774,19 +828,19 @@ def usbHandleSendNspHeader(cmd_block):
nsp_header_size = len(cmd_block) nsp_header_size = len(cmd_block)
g_Logger.debug('Received SendNspHeader (%02X) command.' % (USB_CMD_SEND_NSP_HEADER)) g_logger.debug('Received SendNspHeader (%02X) command.' % (USB_CMD_SEND_NSP_HEADER))
# Integrity checks. # Integrity checks.
if not g_nspTransferMode: if not g_nspTransferMode:
g_Logger.error('Received NSP header out of NSP transfer mode!\n') g_logger.error('Received NSP header out of NSP transfer mode!\n')
return USB_STATUS_MALFORMED_CMD return USB_STATUS_MALFORMED_CMD
if g_nspRemainingSize: if g_nspRemainingSize:
g_Logger.error('Received NSP header before receiving all NSP data! (missing 0x%X byte[s]).\n' % (g_nspRemainingSize)) g_logger.error('Received NSP header before receiving all NSP data! (missing 0x%X byte[s]).\n' % (g_nspRemainingSize))
return USB_STATUS_MALFORMED_CMD return USB_STATUS_MALFORMED_CMD
if nsp_header_size != g_nspHeaderSize: if nsp_header_size != g_nspHeaderSize:
g_Logger.error('NSP header size mismatch! (0x%X != 0x%X).\n' % (nsp_header_size, g_nspHeaderSize)) g_logger.error('NSP header size mismatch! (0x%X != 0x%X).\n' % (nsp_header_size, g_nspHeaderSize))
return USB_STATUS_MALFORMED_CMD return USB_STATUS_MALFORMED_CMD
# Write NSP header. # Write NSP header.
@ -794,7 +848,7 @@ def usbHandleSendNspHeader(cmd_block):
g_nspFile.write(cmd_block) g_nspFile.write(cmd_block)
g_nspFile.close() g_nspFile.close()
g_Logger.debug('Successfully wrote 0x%X-byte long NSP header to "%s".\n' % (nsp_header_size, g_nspFilePath)) g_logger.debug('Successfully wrote 0x%X-byte long NSP header to "%s".\n' % (nsp_header_size, g_nspFilePath))
# Disable NSP transfer mode. # Disable NSP transfer mode.
utilsResetNspInfo() utilsResetNspInfo()
@ -802,7 +856,7 @@ def usbHandleSendNspHeader(cmd_block):
return USB_STATUS_SUCCESS return USB_STATUS_SUCCESS
def usbHandleEndSession(cmd_block): def usbHandleEndSession(cmd_block):
g_Logger.debug('Received EndSession (%02X) command.' % (USB_CMD_END_SESSION)) g_logger.debug('Received EndSession (%02X) command.' % (USB_CMD_END_SESSION))
return USB_STATUS_SUCCESS return USB_STATUS_SUCCESS
def usbCommandHandler(): def usbCommandHandler():
@ -816,10 +870,12 @@ def usbCommandHandler():
# Get device endpoints. # Get device endpoints.
if not usbGetDeviceEndpoints(): if not usbGetDeviceEndpoints():
# Update UI and return. if not g_cliMode:
# Update UI.
uiToggleElements(True) uiToggleElements(True)
return return
if not g_cliMode:
# Update UI. # Update UI.
g_tkCanvas.itemconfigure(g_tkTipMessage, state='normal', text=SERVER_STOP_MSG) g_tkCanvas.itemconfigure(g_tkTipMessage, state='normal', text=SERVER_STOP_MSG)
g_tkServerButton.configure(state='disabled') g_tkServerButton.configure(state='disabled')
@ -831,7 +887,7 @@ def usbCommandHandler():
# Read command header. # Read command header.
cmd_header = usbRead(USB_CMD_HEADER_SIZE) cmd_header = usbRead(USB_CMD_HEADER_SIZE)
if (cmd_header is None) or (len(cmd_header) != USB_CMD_HEADER_SIZE): if (cmd_header is None) or (len(cmd_header) != USB_CMD_HEADER_SIZE):
g_Logger.error('Failed to read 0x%X-byte long command header!' % (USB_CMD_HEADER_SIZE)) g_logger.error('Failed to read 0x%X-byte long command header!' % (USB_CMD_HEADER_SIZE))
break break
# Parse command header. # Parse command header.
@ -849,19 +905,19 @@ def usbCommandHandler():
cmd_block = usbRead(rd_size, USB_TRANSFER_TIMEOUT) cmd_block = usbRead(rd_size, USB_TRANSFER_TIMEOUT)
if (cmd_block is None) or (len(cmd_block) != cmd_block_size): if (cmd_block is None) or (len(cmd_block) != cmd_block_size):
g_Logger.error('Failed to read 0x%X-byte long command block for command ID %02X!' % (cmd_block_size, cmd_id)) g_logger.error('Failed to read 0x%X-byte long command block for command ID %02X!' % (cmd_block_size, cmd_id))
break break
# Verify magic word. # Verify magic word.
if magic != USB_MAGIC_WORD: if magic != USB_MAGIC_WORD:
g_Logger.error('Received command header with invalid magic word!\n') g_logger.error('Received command header with invalid magic word!\n')
usbSendStatus(USB_STATUS_INVALID_MAGIC_WORD) usbSendStatus(USB_STATUS_INVALID_MAGIC_WORD)
continue continue
# Get command handler function. # Get command handler function.
cmd_func = cmd_dict.get(cmd_id, None) cmd_func = cmd_dict.get(cmd_id, None)
if cmd_func is None: if cmd_func is None:
g_Logger.error('Received command header with unsupported ID %02X.\n' % (cmd_id)) g_logger.error('Received command header with unsupported ID %02X.\n' % (cmd_id))
usbSendStatus(USB_STATUS_UNSUPPORTED_CMD) usbSendStatus(USB_STATUS_UNSUPPORTED_CMD)
continue continue
@ -869,7 +925,7 @@ def usbCommandHandler():
if (cmd_id == USB_CMD_START_SESSION and cmd_block_size != USB_CMD_BLOCK_SIZE_START_SESSION) or \ if (cmd_id == USB_CMD_START_SESSION and cmd_block_size != USB_CMD_BLOCK_SIZE_START_SESSION) or \
(cmd_id == USB_CMD_SEND_FILE_PROPERTIES and cmd_block_size != USB_CMD_BLOCK_SIZE_SEND_FILE_PROPERTIES) or \ (cmd_id == USB_CMD_SEND_FILE_PROPERTIES and cmd_block_size != USB_CMD_BLOCK_SIZE_SEND_FILE_PROPERTIES) or \
(cmd_id == USB_CMD_SEND_NSP_HEADER and not cmd_block_size): (cmd_id == USB_CMD_SEND_NSP_HEADER and not cmd_block_size):
g_Logger.error('Invalid command block size for command ID %02X! (0x%X).\n' % (cmd_id, cmd_block_size)) g_logger.error('Invalid command block size for command ID %02X! (0x%X).\n' % (cmd_id, cmd_block_size))
usbSendStatus(USB_STATUS_MALFORMED_COMMAND) usbSendStatus(USB_STATUS_MALFORMED_COMMAND)
continue continue
@ -879,8 +935,9 @@ def usbCommandHandler():
if (status is None) or (not usbSendStatus(status)) or (cmd_id == USB_CMD_END_SESSION) or (status == USB_STATUS_UNSUPPORTED_ABI_VERSION): if (status is None) or (not usbSendStatus(status)) or (cmd_id == USB_CMD_END_SESSION) or (status == USB_STATUS_UNSUPPORTED_ABI_VERSION):
break break
g_Logger.info('\nStopping server.') g_logger.info('\nStopping server.')
if not g_cliMode:
# Update UI. # Update UI.
uiToggleElements(True) uiToggleElements(True)
@ -952,14 +1009,16 @@ def uiScaleMeasure(measure):
def uiInitialize(): def uiInitialize():
global SCALE global SCALE
global g_tkRoot, g_tkCanvas, g_tkDirText, g_tkChooseDirButton, g_tkServerButton, g_tkTipMessage, g_tkScrolledTextLog global g_tkRoot, g_tkCanvas, g_tkDirText, g_tkChooseDirButton, g_tkServerButton, g_tkTipMessage, g_tkScrolledTextLog
global g_tlb, g_taskbar, g_progressBarWindow global g_stopEvent, g_tlb, g_taskbar, g_progressBarWindow
# Setup thread event.
g_stopEvent = threading.Event()
# Enable high DPI scaling under Windows (if possible). # Enable high DPI scaling under Windows (if possible).
dpi_aware = False dpi_aware = False
if g_isWindowsVista: if g_isWindowsVista:
try: try:
import ctypes import ctypes
dpi_aware = (ctypes.windll.user32.SetProcessDPIAware() == 1) dpi_aware = (ctypes.windll.user32.SetProcessDPIAware() == 1)
if not dpi_aware: dpi_aware = (ctypes.windll.shcore.SetProcessDpiAwareness(1) == 0) if not dpi_aware: dpi_aware = (ctypes.windll.shcore.SetProcessDpiAwareness(1) == 0)
except: except:
@ -989,7 +1048,7 @@ def uiInitialize():
# Create root Tkinter object. # Create root Tkinter object.
g_tkRoot = tk.Tk() g_tkRoot = tk.Tk()
g_tkRoot.title("{} host app v{}".format(USB_DEV_PRODUCT, APP_VERSION)) g_tkRoot.title(SCRIPT_TITLE)
g_tkRoot.protocol('WM_DELETE_WINDOW', uiHandleExitProtocol) g_tkRoot.protocol('WM_DELETE_WINDOW', uiHandleExitProtocol)
g_tkRoot.resizable(False, False) g_tkRoot.resizable(False, False)
@ -1026,7 +1085,7 @@ def uiInitialize():
g_tkCanvas.create_text(uiScaleMeasure(60), uiScaleMeasure(30), text='Output directory:', anchor=tk.CENTER) g_tkCanvas.create_text(uiScaleMeasure(60), uiScaleMeasure(30), text='Output directory:', anchor=tk.CENTER)
g_tkDirText = tk.Text(g_tkRoot, height=1, width=45, font=font.nametofont('TkDefaultFont'), wrap='none', state='disabled', bg='#F0F0F0') g_tkDirText = tk.Text(g_tkRoot, height=1, width=45, font=font.nametofont('TkDefaultFont'), wrap='none', state='disabled', bg='#F0F0F0')
uiUpdateDirectoryField(DEFAULT_DIR) uiUpdateDirectoryField(g_outputDir)
g_tkCanvas.create_window(uiScaleMeasure(260), uiScaleMeasure(30), window=g_tkDirText, anchor=tk.CENTER) g_tkCanvas.create_window(uiScaleMeasure(260), uiScaleMeasure(30), window=g_tkDirText, anchor=tk.CENTER)
g_tkChooseDirButton = tk.Button(g_tkRoot, text='Choose', width=10, command=uiChooseDirectory) g_tkChooseDirButton = tk.Button(g_tkRoot, text='Choose', width=10, command=uiChooseDirectory)
@ -1046,7 +1105,7 @@ def uiInitialize():
g_tkScrolledTextLog.tag_config('CRITICAL', foreground='red', underline=1) g_tkScrolledTextLog.tag_config('CRITICAL', foreground='red', underline=1)
g_tkCanvas.create_window(uiScaleMeasure(WINDOW_WIDTH / 2), uiScaleMeasure(280), window=g_tkScrolledTextLog, anchor=tk.CENTER) g_tkCanvas.create_window(uiScaleMeasure(WINDOW_WIDTH / 2), uiScaleMeasure(280), window=g_tkScrolledTextLog, anchor=tk.CENTER)
g_tkCanvas.create_text(uiScaleMeasure(5), uiScaleMeasure(WINDOW_HEIGHT - 10), text="Copyright (c) {}, {}".format(COPYRIGHT_YEAR, USB_DEV_MANUFACTURER), anchor=tk.W) g_tkCanvas.create_text(uiScaleMeasure(5), uiScaleMeasure(WINDOW_HEIGHT - 10), text=COPYRIGHT_TEXT, anchor=tk.W)
# Initialize console logger. # Initialize console logger.
console = LogConsole(g_tkScrolledTextLog) console = LogConsole(g_tkScrolledTextLog)
@ -1059,22 +1118,38 @@ def uiInitialize():
g_tkRoot.lift() g_tkRoot.lift()
g_tkRoot.mainloop() g_tkRoot.mainloop()
def cliInitialize():
global g_progressBarWindow
# Initialize console logger.
console = LogConsole()
# Initialize progress bar window object.
bar_format = '{percentage:.2f}% |{bar}| {n:.2f}/{total:.2f} [{elapsed}<{remaining}, {rate_fmt}]'
g_progressBarWindow = ProgressBarWindow(bar_format)
# Print info.
g_logger.info('\n' + SCRIPT_TITLE + '. ' + COPYRIGHT_TEXT + '.')
g_logger.info('Output directory: "' + g_outputDir + '".\n')
# Start USB command handler directly.
usbCommandHandler()
def main(): def main():
global g_Logger, g_stopEvent, g_osType, g_osVersion, g_isWindows, g_isWindowsVista, g_isWindows7 global g_cliMode, g_outputDir, g_osType, g_osVersion, g_isWindows, g_isWindowsVista, g_isWindows7, g_logger
# Disable warnings. # Disable warnings.
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
# Setup logging mechanism. # Parse command line arguments.
logging.basicConfig(level=logging.INFO) parser = ArgumentParser(description=SCRIPT_TITLE + '. ' + COPYRIGHT_TEXT + '.')
g_Logger = logging.getLogger() parser.add_argument('-c', '--cli', required=False, action='store_true', help='Start the script in CLI mode.')
if len(g_Logger.handlers): parser.add_argument('-o', '--outdir', required=False, type=str, metavar='DIR', help='Path to output directory. Defaults to "' + DEFAULT_DIR + '".')
# Remove stderr output handler from logger. args = parser.parse_args()
log_stderr = g_Logger.handlers[0]
g_Logger.removeHandler(log_stderr)
# Setup thread event. # Update global flags.
g_stopEvent = threading.Event() g_cliMode = args.cli
g_outputDir = utilsGetPath(args.outdir, DEFAULT_DIR, False, True)
# Get OS information. # Get OS information.
g_osType = platform.system() g_osType = platform.system()
@ -1091,6 +1166,18 @@ def main():
g_isWindowsVista = (win_ver_major >= 6) g_isWindowsVista = (win_ver_major >= 6)
g_isWindows7 = (True if (win_ver_major > 6) else (win_ver_major == 6 and win_ver_minor > 0)) g_isWindows7 = (True if (win_ver_major > 6) else (win_ver_major == 6 and win_ver_minor > 0))
# Setup logging mechanism.
logging.basicConfig(level=logging.INFO)
g_logger = logging.getLogger()
if len(g_logger.handlers):
# Remove stderr output handler from logger.
log_stderr = g_logger.handlers[0]
g_logger.removeHandler(log_stderr)
if g_cliMode:
# Initialize CLI.
cliInitialize()
else:
# Initialize UI. # Initialize UI.
uiInitialize() uiInitialize()
@ -1098,4 +1185,9 @@ if __name__ == "__main__":
try: try:
main() main()
except KeyboardInterrupt: except KeyboardInterrupt:
pass if g_cliMode:
print('\nScript interrupted.')
try:
sys.exit(0)
except SystemExit:
os._exit(0)