diff --git a/src/h2/stream.py b/src/h2/stream.py index 9b5bce78..6ff1d8dd 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -1364,15 +1364,26 @@ def _initialize_content_length(self, headers: Iterable[Header]) -> None: self._expected_content_length = 0 return + content_length = None + for n, v in headers: if n == b"content-length": try: - self._expected_content_length = int(v, 10) + parsed_content_length = int(v, 10) except ValueError as err: msg = f"Invalid content-length header: {v!r}" raise ProtocolError(msg) from err - return + if content_length is None: + content_length = parsed_content_length + elif parsed_content_length != content_length: + msg = "Conflicting content-length headers" + raise ProtocolError(msg) + + if content_length is None: + return + + self._expected_content_length = content_length def _track_content_length(self, length: int, end_stream: bool) -> None: """ diff --git a/tests/test_invalid_content_lengths.py b/tests/test_invalid_content_lengths.py index 39401ea2..430b21cf 100644 --- a/tests/test_invalid_content_lengths.py +++ b/tests/test_invalid_content_lengths.py @@ -19,18 +19,24 @@ class TestInvalidContentLengths: peer is not valid. """ - example_request_headers = [ + example_request_headers_without_content_length = [ (":authority", "example.com"), (":path", "/"), (":scheme", "https"), (":method", "POST"), + ] + example_request_headers = [ + *example_request_headers_without_content_length, ("content-length", "15"), ] - example_request_headers_bytes = [ + example_request_headers_bytes_without_content_length = [ (b":authority", b"example.com"), (b":path", b"/"), (b":scheme", b"https"), (b":method", b"POST"), + ] + example_request_headers_bytes = [ + *example_request_headers_bytes_without_content_length, (b"content-length", b"15"), ] example_response_headers = [ @@ -39,6 +45,75 @@ class TestInvalidContentLengths: ] server_config = h2.config.H2Configuration(client_side=False) + @pytest.mark.parametrize( + "request_headers", + [ + example_request_headers_without_content_length, + example_request_headers_bytes_without_content_length, + ], + ) + def test_duplicate_matching_content_lengths(self, frame_factory, request_headers) -> None: + """ + Remote peers sending duplicate matching content-length fields are + accepted. + """ + c = h2.connection.H2Connection(config=self.server_config) + c.initiate_connection() + c.receive_data(frame_factory.preamble()) + c.clear_outbound_data_buffer() + + headers = frame_factory.build_headers_frame( + headers=[ + *request_headers, + ("content-length", "15"), + ("content-length", "15"), + ], + ) + data = frame_factory.build_data_frame( + data=b"\x01"*15, + flags=["END_STREAM"], + ) + + events = c.receive_data(headers.serialize() + data.serialize()) + + assert isinstance(events[0], h2.events.RequestReceived) + assert isinstance(events[1], h2.events.DataReceived) + assert isinstance(events[2], h2.events.StreamEnded) + assert c.data_to_send() == b"" + + @pytest.mark.parametrize( + "request_headers", + [ + example_request_headers_without_content_length, + example_request_headers_bytes_without_content_length, + ], + ) + def test_duplicate_conflicting_content_lengths(self, frame_factory, request_headers) -> None: + """ + Remote peers sending duplicate conflicting content-length fields cause + Protocol Errors. + """ + c = h2.connection.H2Connection(config=self.server_config) + c.initiate_connection() + c.receive_data(frame_factory.preamble()) + c.clear_outbound_data_buffer() + + headers = frame_factory.build_headers_frame( + headers=[ + *request_headers, + ("content-length", "15"), + ("content-length", "16"), + ], + ) + with pytest.raises(h2.exceptions.ProtocolError): + c.receive_data(headers.serialize()) + + expected_frame = frame_factory.build_goaway_frame( + last_stream_id=1, + error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, + ) + assert c.data_to_send() == expected_frame.serialize() + @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) def test_too_much_data(self, frame_factory, request_headers) -> None: """