Skip to content

Commit

Permalink
Add multipart support for MockMvcTester
Browse files Browse the repository at this point in the history
File uploads with MockMvc require a separate
MockHttpServletRequestBuilder implementation. This commit applies the
same change to support AssertJ on this builder, but for the multipart
version.

Any request builder can now use `multipart()` to "down cast" to a
dedicated multipart request builder that contains the settings
configured thus far.

Closes gh-33027
  • Loading branch information
snicoll committed Jun 18, 2024
1 parent f2137c9 commit d76f37c
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.RequestBuilder;
import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder;
import org.springframework.test.web.servlet.request.AbstractMockMultipartHttpServletRequestBuilder;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.setup.DefaultMockMvcBuilder;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
Expand Down Expand Up @@ -389,8 +391,42 @@ private GenericHttpMessageConverter<Object> findJsonMessageConverter(
public final class MockMvcRequestBuilder extends AbstractMockHttpServletRequestBuilder<MockMvcRequestBuilder>
implements AssertProvider<MvcTestResultAssert> {

private final HttpMethod httpMethod;

private MockMvcRequestBuilder(HttpMethod httpMethod) {
super(httpMethod);
this.httpMethod = httpMethod;
}

/**
* Enable file upload support using multipart.
* @return a {@link MockMultipartMvcRequestBuilder} with the settings
* configured thus far
*/
public MockMultipartMvcRequestBuilder multipart() {
return new MockMultipartMvcRequestBuilder(this);
}

public MvcTestResult exchange() {
return perform(this);
}

@Override
public MvcTestResultAssert assertThat() {
return new MvcTestResultAssert(exchange(), MockMvcTester.this.jsonMessageConverter);
}
}

/**
* A builder for {@link MockMultipartHttpServletRequest} that supports AssertJ.
*/
public final class MockMultipartMvcRequestBuilder
extends AbstractMockMultipartHttpServletRequestBuilder<MockMultipartMvcRequestBuilder>
implements AssertProvider<MvcTestResultAssert> {

private MockMultipartMvcRequestBuilder(MockMvcRequestBuilder currentBuilder) {
super(currentBuilder.httpMethod);
merge(currentBuilder);
}

public MvcTestResult exchange() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.web.servlet.request;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import jakarta.servlet.ServletContext;
import jakarta.servlet.http.Part;

import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

/**
* Base builder for {@link MockMultipartHttpServletRequest}.
*
* @author Rossen Stoyanchev
* @author Arjen Poutsma
* @author Stephane Nicoll
* @since 6.2
* @param <B> a self reference to the builder type
*/
public abstract class AbstractMockMultipartHttpServletRequestBuilder<B extends AbstractMockMultipartHttpServletRequestBuilder<B>>
extends AbstractMockHttpServletRequestBuilder<B> {

private final List<MockMultipartFile> files = new ArrayList<>();

private final MultiValueMap<String, Part> parts = new LinkedMultiValueMap<>();


protected AbstractMockMultipartHttpServletRequestBuilder(HttpMethod httpMethod) {
super(httpMethod);
}

/**
* Add a new {@link MockMultipartFile} with the given content.
* @param name the name of the file
* @param content the content of the file
*/
public B file(String name, byte[] content) {
this.files.add(new MockMultipartFile(name, content));
return self();
}

/**
* Add the given {@link MockMultipartFile}.
* @param file the multipart file
*/
public B file(MockMultipartFile file) {
this.files.add(file);
return self();
}

/**
* Add {@link Part} components to the request.
* @param parts one or more parts to add
* @since 5.0
*/
public B part(Part... parts) {
Assert.notEmpty(parts, "'parts' must not be empty");
for (Part part : parts) {
this.parts.add(part.getName(), part);
}
return self();
}

@Override
public Object merge(@Nullable Object parent) {
if (parent == null) {
return this;
}
if (parent instanceof AbstractMockHttpServletRequestBuilder<?>) {
super.merge(parent);
if (parent instanceof AbstractMockMultipartHttpServletRequestBuilder<?> parentBuilder) {
this.files.addAll(parentBuilder.files);
parentBuilder.parts.keySet().forEach(name ->
this.parts.putIfAbsent(name, parentBuilder.parts.get(name)));
}
}
else {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
return this;
}

/**
* Create a new {@link MockMultipartHttpServletRequest} based on the
* supplied {@code ServletContext} and the {@code MockMultipartFiles}
* added to this builder.
*/
@Override
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
Charset defaultCharset = (request.getCharacterEncoding() != null ?
Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8);

this.files.forEach(request::addFile);
this.parts.values().stream().flatMap(Collection::stream).forEach(part -> {
request.addPart(part);
try {
String name = part.getName();
String filename = part.getSubmittedFileName();
InputStream is = part.getInputStream();
if (filename != null) {
request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is));
}
else {
InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset));
String value = FileCopyUtils.copyToString(reader);
request.addParameter(part.getName(), value);
}
}
catch (IOException ex) {
throw new IllegalStateException("Failed to read content for part " + part.getName(), ex);
}
});

return request;
}

private Charset getCharsetOrDefault(Part part, Charset defaultCharset) {
if (part.getContentType() != null) {
MediaType mediaType = MediaType.parseMediaType(part.getContentType());
if (mediaType.getCharset() != null) {
return mediaType.getCharset();
}
}
return defaultCharset;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,22 @@

package org.springframework.test.web.servlet.request;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import jakarta.servlet.ServletContext;
import jakarta.servlet.http.Part;

import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

/**
* Default builder for {@link MockMultipartHttpServletRequest}.
*
* @author Rossen Stoyanchev
* @author Arjen Poutsma
* @author Stephane Nicoll
* @since 3.2
*/
public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServletRequestBuilder<MockMultipartHttpServletRequestBuilder> {

private final List<MockMultipartFile> files = new ArrayList<>();

private final MultiValueMap<String, Part> parts = new LinkedMultiValueMap<>();
public class MockMultipartHttpServletRequestBuilder
extends AbstractMockMultipartHttpServletRequestBuilder<MockMultipartHttpServletRequestBuilder> {


/**
Expand Down Expand Up @@ -98,101 +78,4 @@ public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServ
super.contentType(MediaType.MULTIPART_FORM_DATA);
}


/**
* Add a new {@link MockMultipartFile} with the given content.
* @param name the name of the file
* @param content the content of the file
*/
public MockMultipartHttpServletRequestBuilder file(String name, byte[] content) {
this.files.add(new MockMultipartFile(name, content));
return this;
}

/**
* Add the given {@link MockMultipartFile}.
* @param file the multipart file
*/
public MockMultipartHttpServletRequestBuilder file(MockMultipartFile file) {
this.files.add(file);
return this;
}

/**
* Add {@link Part} components to the request.
* @param parts one or more parts to add
* @since 5.0
*/
public MockMultipartHttpServletRequestBuilder part(Part... parts) {
Assert.notEmpty(parts, "'parts' must not be empty");
for (Part part : parts) {
this.parts.add(part.getName(), part);
}
return this;
}

@Override
public Object merge(@Nullable Object parent) {
if (parent == null) {
return this;
}
if (parent instanceof AbstractMockHttpServletRequestBuilder) {
super.merge(parent);
if (parent instanceof MockMultipartHttpServletRequestBuilder parentBuilder) {
this.files.addAll(parentBuilder.files);
parentBuilder.parts.keySet().forEach(name ->
this.parts.putIfAbsent(name, parentBuilder.parts.get(name)));
}
}
else {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
return this;
}

/**
* Create a new {@link MockMultipartHttpServletRequest} based on the
* supplied {@code ServletContext} and the {@code MockMultipartFiles}
* added to this builder.
*/
@Override
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
Charset defaultCharset = (request.getCharacterEncoding() != null ?
Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8);

this.files.forEach(request::addFile);
this.parts.values().stream().flatMap(Collection::stream).forEach(part -> {
request.addPart(part);
try {
String name = part.getName();
String filename = part.getSubmittedFileName();
InputStream is = part.getInputStream();
if (filename != null) {
request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is));
}
else {
InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset));
String value = FileCopyUtils.copyToString(reader);
request.addParameter(part.getName(), value);
}
}
catch (IOException ex) {
throw new IllegalStateException("Failed to read content for part " + part.getName(), ex);
}
});

return request;
}

private Charset getCharsetOrDefault(Part part, Charset defaultCharset) {
if (part.getContentType() != null) {
MediaType mediaType = MediaType.parseMediaType(part.getContentType());
if (mediaType.getCharset() != null) {
return mediaType.getCharset();
}
}
return defaultCharset;
}

}
Loading

0 comments on commit d76f37c

Please sign in to comment.