Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 40 additions & 35 deletions python_multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,12 @@ class MultipartState(IntEnum):
# Mask for ASCII characters that can be http tokens.
# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
# and these: !#$%&'*+-.^_`|~
TOKEN_CHARS_SET = frozenset(
TOKEN_CHARS = (
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
b"abcdefghijklmnopqrstuvwxyz"
b"0123456789"
b"!#$%&'*+-.^_`|~")
TOKEN_CHARS_SET = frozenset(TOKEN_CHARS)
# fmt: on

DEFAULT_MAX_HEADER_COUNT = 8
Expand Down Expand Up @@ -647,8 +648,7 @@ def callback(
end: An integer that is passed to the data callback.
start: An integer that is passed to the data callback.
"""
on_name = "on_" + name
func = self.callbacks.get(on_name)
func = self.callbacks.get("on_" + name)
if func is None:
return
func = cast("Callable[..., Any]", func)
Expand All @@ -657,11 +657,8 @@ def callback(
# Don't do anything if we have start == end.
if start is not None and start == end:
return

self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end)
func(data, start, end)
else:
self.logger.debug("Calling %s with no data", on_name)
func()

def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None:
Expand Down Expand Up @@ -1078,6 +1075,7 @@ def write(self, data: bytes) -> int:
def _internal_write(self, data: bytes, length: int) -> int:
# Get values from locals.
boundary = self.boundary
boundary_length = len(boundary)

# Get our state, flags and index. These are persisted between calls to
# this function.
Expand Down Expand Up @@ -1128,7 +1126,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
# We need to use self.flags (and not flags) because we care about
# the state when we entered the loop.
lookbehind_len = -marked_index
if lookbehind_len <= len(boundary):
if lookbehind_len <= boundary_length:
self.callback(name, boundary, 0, lookbehind_len)
elif self.flags & FLAG_PART_BOUNDARY:
lookback = boundary + b"\r\n"
Expand Down Expand Up @@ -1173,7 +1171,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
elif state == MultipartState.START_BOUNDARY:
# Check to ensure that the last 2 characters in our boundary
# are CRLF.
if index == len(boundary) - 2:
if index == boundary_length - 2:
if c == HYPHEN:
# Potential empty message.
state = MultipartState.END_BOUNDARY
Expand All @@ -1185,7 +1183,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No

index += 1

elif index == len(boundary) - 2 + 1:
elif index == boundary_length - 1:
if c != LF:
msg = "Did not find LF at end of boundary (%d)" % (i,)
self.logger.warning(msg)
Expand Down Expand Up @@ -1247,31 +1245,38 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
i += 1
continue

# Increment our index in the header.
index += 1
# The field name runs until the colon; jump straight to it and
# validate the whole span at once instead of byte by byte.
colon = data.find(b":", i, length)
end = colon if colon != -1 else length
field = data[i:end]
if field.translate(None, TOKEN_CHARS):
Comment on lines +1252 to +1253

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Enforce header size before bulk-validating field names

For a malformed multipart part whose header field never contains : and is delivered in a large write() chunk, this slices and translates the entire remaining chunk before advance_header_size(end - i) runs below. That means the default max_header_size no longer bounds the work or temporary allocation for oversized header names; a request with megabytes of token characters in a header line will be scanned/copied in full instead of failing once the 4 KiB limit is crossed.

Useful? React with 👍 / 👎.

bad = next(b for b in field if b not in TOKEN_CHARS_SET)
bad_i = i + field.index(bad)
msg = "Found invalid character %r in header at %d" % (bad, bad_i)
self.logger.warning(msg)
raise MultipartParseError(msg, offset=bad_i)

# If we've reached a colon, we're done with this header.
if c == COLON:
advance_header_size()
index += end - i
if colon == -1:
# Field name continues into the next chunk.
advance_header_size(end - i)
i = length
else:
advance_header_size(end - i + 1)
# A 0-length header is an error.
if index == 1:
if index == 0:
msg = "Found 0-length header at %d" % (i,)
self.logger.warning(msg)
raise MultipartParseError(msg, offset=i)

# Call our callback with the header field.
i = colon
data_callback("header_field", i)

# Move to parsing the header value.
state = MultipartState.HEADER_VALUE_START

elif c not in TOKEN_CHARS_SET:
msg = "Found invalid character %r in header at %d" % (c, i)
self.logger.warning(msg)
raise MultipartParseError(msg, offset=i)
else:
advance_header_size()

elif state == MultipartState.HEADER_VALUE_START:
# Skip leading spaces.
if c == SPACE:
Expand All @@ -1287,15 +1292,19 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
i -= 1

elif state == MultipartState.HEADER_VALUE:
# If we've got a CR, we're nearly done our headers. Otherwise,
# we do nothing and just move past this character.
if c == CR:
# The value runs until the terminating CR; jump straight to it
# instead of inspecting every byte.
cr = data.find(b"\r", i, length)
end = cr if cr != -1 else length
advance_header_size(end - i)
if cr != -1:
i = cr
data_callback("header_value", i)
self.callback("header_end")
current_header_size = 0
state = MultipartState.HEADER_VALUE_ALMOST_DONE
else:
advance_header_size()
i = length

elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
# The last character should be a LF. If not, it's an error.
Expand Down Expand Up @@ -1338,17 +1347,13 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
# find part of a boundary, but it doesn't match fully.
prev_index = index

# Set up variables.
boundary_length = len(boundary)
data_length = length

# If our index is 0, we're starting a new part, so start our
# search.
if index == 0:
# The most common case is likely to be that the whole
# boundary is present in the buffer.
# Calling `find` is much faster than iterating here.
i0 = data.find(boundary, i, data_length)
i0 = data.find(boundary, i, length)
if i0 >= 0:
# We matched the whole boundary string.
index = boundary_length - 1
Expand All @@ -1360,9 +1365,9 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
# Since the length to be searched is limited to the
# boundary length, scan the tail for boundary[0] via
# bytes.find (C-level) to keep cost off the Python loop.
i = max(i, data_length - boundary_length)
j = data.find(boundary[:1], i, data_length - 1)
i = j if j >= 0 else data_length - 1
i = max(i, length - boundary_length)
j = data.find(boundary[:1], i, length - 1)
i = j if j >= 0 else length - 1

c = data[i]

Expand Down Expand Up @@ -1456,7 +1461,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
i -= 1

elif state == MultipartState.END_BOUNDARY:
if index == len(boundary) - 2 + 1:
if index == boundary_length - 1:
if c != HYPHEN:
msg = "Did not find - at end of boundary (%d)" % (i,)
self.logger.warning(msg)
Expand Down
Loading