diff --git a/README.MD b/README.MD index 37f76fd..24d4d27 100644 --- a/README.MD +++ b/README.MD @@ -1,6 +1,6 @@ # Mock in Bean -[@MockInBean](src/main/java/com/teketik/test/mockinbean/MockInBean.java) and [@SpyInBean](src/main/java/com/teketik/test/mockinbean/SpyInBean.java) are alternatives to @MockBean and @SpyBean for Spring Boot tests *(>= 2.2.0 including >= 3.X.X)*. +[@MockInBean](src/main/java/com/teketik/test/mockinbean/MockInBean.java) and [@SpyInBean](src/main/java/com/teketik/test/mockinbean/SpyInBean.java) are alternatives to @MockBean and @SpyBean for Spring Boot tests *(>= 2.6.15 including >= 3.X.X)*. They surgically replace a field value in a Spring Bean by a Mock/Spy for the duration of a test and set back the original value afterwards, leaving the Spring Context clean. diff --git a/pom.xml b/pom.xml index d55844c..a9c8690 100644 --- a/pom.xml +++ b/pom.xml @@ -36,7 +36,7 @@ org.springframework.boot spring-boot-dependencies - 2.2.0.RELEASE + 2.6.15 diff --git a/src/main/java/com/teketik/test/mockinbean/BeanFieldState.java b/src/main/java/com/teketik/test/mockinbean/BeanFieldState.java index f233eaa..20d867a 100644 --- a/src/main/java/com/teketik/test/mockinbean/BeanFieldState.java +++ b/src/main/java/com/teketik/test/mockinbean/BeanFieldState.java @@ -1,16 +1,20 @@ package com.teketik.test.mockinbean; import org.springframework.test.context.TestContext; +import org.springframework.util.ReflectionUtils; import java.lang.reflect.Field; class BeanFieldState extends FieldState { - private Object bean; + final Object bean; + + final Object originalValue; public BeanFieldState(Object bean, Field field, Object originalValue, Definition definition) { - super(field, originalValue, definition); + super(field, definition); this.bean = bean; + this.originalValue = originalValue; } @Override @@ -18,4 +22,13 @@ public Object resolveTarget(TestContext testContext) { return bean; } + public void rollback(TestContext testContext) { + final Object target = resolveTarget(testContext); + ReflectionUtils.setField(field, target, originalValue); + } + + public Object createMockOrSpy() { + return definition.create(originalValue); + } + } diff --git a/src/main/java/com/teketik/test/mockinbean/BeanUtils.java b/src/main/java/com/teketik/test/mockinbean/BeanUtils.java index fea7f96..388c368 100644 --- a/src/main/java/com/teketik/test/mockinbean/BeanUtils.java +++ b/src/main/java/com/teketik/test/mockinbean/BeanUtils.java @@ -1,5 +1,7 @@ package com.teketik.test.mockinbean; +import org.springframework.aop.TargetSource; +import org.springframework.aop.framework.Advised; import org.springframework.aop.framework.AopProxyUtils; import org.springframework.aop.support.AopUtils; import org.springframework.context.ApplicationContext; @@ -44,8 +46,8 @@ static T findBean(Class type, @Nullable String name, ApplicationContext a .findFirst() .orElseThrow(() -> new IllegalArgumentException("No beans of type " + type + " and name " + name)); } - return AopUtils.isAopProxy(beanOrProxy) - ? (T) AopProxyUtils.getSingletonTarget(beanOrProxy) + return AopUtils.isAopProxy(beanOrProxy) + ? (T) AopProxyUtils.getSingletonTarget(beanOrProxy) : beanOrProxy; } @@ -96,4 +98,27 @@ static Field findField(Class clazz, @Nullable String name, Class type) { return null; } + static @Nullable TargetSource getProxyTarget(Object candidate) { + try { + while (AopUtils.isAopProxy(candidate) && candidate instanceof Advised) { + Advised advised = (Advised) candidate; + TargetSource targetSource = advised.getTargetSource(); + + if (targetSource.isStatic()) { + Object target = targetSource.getTarget(); + + if (target == null || !AopUtils.isAopProxy(target)) { + return targetSource; + } + candidate = target; + } else { + return null; + } + } + } catch (Throwable ex) { + throw new IllegalStateException("Failed to unwrap proxied object", ex); + } + return null; + } + } diff --git a/src/main/java/com/teketik/test/mockinbean/FieldState.java b/src/main/java/com/teketik/test/mockinbean/FieldState.java index bd3ccb7..1e97e94 100644 --- a/src/main/java/com/teketik/test/mockinbean/FieldState.java +++ b/src/main/java/com/teketik/test/mockinbean/FieldState.java @@ -1,6 +1,5 @@ package com.teketik.test.mockinbean; -import org.springframework.lang.Nullable; import org.springframework.test.context.TestContext; import java.lang.reflect.Field; @@ -9,14 +8,10 @@ abstract class FieldState { final Field field; - @Nullable - final Object originalValue; - final Definition definition; - public FieldState(Field targetField, Object originalValue, Definition definition) { + public FieldState(Field targetField, Definition definition) { this.field = targetField; - this.originalValue = originalValue; this.definition = definition; } diff --git a/src/main/java/com/teketik/test/mockinbean/MockInBeanTestExecutionListener.java b/src/main/java/com/teketik/test/mockinbean/MockInBeanTestExecutionListener.java index 356b433..2b99ba0 100644 --- a/src/main/java/com/teketik/test/mockinbean/MockInBeanTestExecutionListener.java +++ b/src/main/java/com/teketik/test/mockinbean/MockInBeanTestExecutionListener.java @@ -3,6 +3,7 @@ import org.junit.jupiter.api.Nested; import org.mockito.Mock; import org.mockito.Spy; +import org.springframework.aop.TargetSource; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.test.context.TestContext; import org.springframework.test.context.TestExecutionListener; @@ -59,30 +60,22 @@ public void beforeTestClass(TestContext testContext) throws Exception { for (InBeanDefinition inBeanDefinition : definitionToInbeans.getValue()) { final Object inBean = BeanUtils.findBean(inBeanDefinition.clazz, inBeanDefinition.name, testContext.getApplicationContext()); beanField = BeanUtils.findField(inBean.getClass(), definition.getName(), mockOrSpyType); + Assert.notNull(beanField, "Cannot find any field for definition:" + definitionToInbeans.getKey()); beanField.setAccessible(true); - originalValues.add( - new BeanFieldState( - inBean, - beanField, - ReflectionUtils.getField( - beanField, - inBean - ), - definition - ) - ); + final Object beanFieldValue = ReflectionUtils.getField(beanField, inBean); + final TargetSource proxyTarget = BeanUtils.getProxyTarget(beanFieldValue); + final BeanFieldState beanFieldState; + if (proxyTarget != null) { + beanFieldState = new ProxiedBeanFieldState(inBean, beanField, beanFieldValue, proxyTarget, definition); + } else { + beanFieldState = new BeanFieldState(inBean, beanField, beanFieldValue, definition); + } + originalValues.add(beanFieldState); } - Assert.notNull(beanField, "Cannot find any field for definition:" + definitionToInbeans.getKey()); Assert.isTrue(visitedFields.add(beanField), beanField + " can only be mapped once, as a mock or a spy, not both!"); final Field testField = ReflectionUtils.findField(targetTestClass, definition.getName(), mockOrSpyType); testField.setAccessible(true); - originalValues.add( - new TestFieldState( - testField, - null, - definition - ) - ); + originalValues.add(new TestFieldState(testField, definition)); } testContext.setAttribute(ORIGINAL_VALUES_ATTRIBUTE_NAME, originalValues); super.beforeTestClass(testContext); @@ -100,10 +93,13 @@ public void beforeTestMethod(TestContext testContext) throws Exception { final Map spyTracker = new IdentityHashMap<>(); //First loop to setup all the mocks and spies fieldStates + .stream() + .filter(BeanFieldState.class::isInstance) + .map(BeanFieldState.class::cast) .forEach(fieldState -> { Object mockOrSpy = mockOrSpys.get(fieldState.definition); if (mockOrSpy == null) { - mockOrSpy = fieldState.definition.create(fieldState.originalValue); + mockOrSpy = fieldState.createMockOrSpy(); mockOrSpys.put(fieldState.definition, mockOrSpy); if (fieldState.definition instanceof SpyDefinition) { spyTracker.put(fieldState.originalValue, mockOrSpy); @@ -143,15 +139,10 @@ public void afterTestClass(TestContext testContext) throws Exception { return; } ((LinkedList) testContext.getAttribute(ORIGINAL_VALUES_ATTRIBUTE_NAME)) - .forEach(fieldValue -> { - if (fieldValue.originalValue != null) { - ReflectionUtils.setField( - fieldValue.field, - fieldValue.resolveTarget(testContext), - fieldValue.originalValue - ); - } - }); + .stream() + .filter(BeanFieldState.class::isInstance) + .map(BeanFieldState.class::cast) + .forEach(fieldState -> fieldState.rollback(testContext)); ROOT_TEST_CONTEXT_TRACKER.remove(testContext.getTestClass()); super.afterTestClass(testContext); } diff --git a/src/main/java/com/teketik/test/mockinbean/ProxiedBeanFieldState.java b/src/main/java/com/teketik/test/mockinbean/ProxiedBeanFieldState.java new file mode 100644 index 0000000..8dcb16a --- /dev/null +++ b/src/main/java/com/teketik/test/mockinbean/ProxiedBeanFieldState.java @@ -0,0 +1,43 @@ +package com.teketik.test.mockinbean; + +import org.springframework.aop.TargetSource; +import org.springframework.test.context.TestContext; +import org.springframework.test.util.ReflectionTestUtils; + +import java.lang.reflect.Field; + +/** + * Special kind of {@link BeanFieldState} handling proxied beans (like aspects).
+ * The mock is not injected into the field but into the target of its {@link TargetSource}. + * @author Antoine Meyer + * @see https://github.com/antoinemeyer/mock-in-bean/issues/23 + */ +class ProxiedBeanFieldState extends BeanFieldState { + + private static void setTargetSourceValue(TargetSource targetSource, Object value) { + ReflectionTestUtils.setField(targetSource, "target", value); + } + + final TargetSource proxyTargetSource; + + final Object proxyTargetOriginalValue; + + public ProxiedBeanFieldState(Object inBean, Field beanField, Object beanFieldValue, TargetSource proxyTargetSource, Definition definition) throws Exception { + super(inBean, beanField, beanFieldValue, definition); + this.proxyTargetSource = proxyTargetSource; + this.proxyTargetOriginalValue = proxyTargetSource.getTarget(); + } + + @Override + public void rollback(TestContext testContext) { + setTargetSourceValue(proxyTargetSource, proxyTargetOriginalValue); + } + + @Override + public Object createMockOrSpy() { + Object applicableMockOrSpy = definition.create(proxyTargetOriginalValue); + setTargetSourceValue(proxyTargetSource, applicableMockOrSpy); + return originalValue; //the 'mock or spy' to operate for proxied beans are the actual proxy + } + +} diff --git a/src/main/java/com/teketik/test/mockinbean/TestFieldState.java b/src/main/java/com/teketik/test/mockinbean/TestFieldState.java index 512f371..2dfbfb8 100644 --- a/src/main/java/com/teketik/test/mockinbean/TestFieldState.java +++ b/src/main/java/com/teketik/test/mockinbean/TestFieldState.java @@ -6,8 +6,8 @@ class TestFieldState extends FieldState { - TestFieldState(Field targetField, Object originalValue, Definition definition) { - super(targetField, originalValue, definition); + TestFieldState(Field targetField, Definition definition) { + super(targetField, definition); } @Override diff --git a/src/test/java/com/teketik/test/mockinbean/test/VerifyAdvisedSpyInBeanTest.java b/src/test/java/com/teketik/test/mockinbean/test/VerifyAdvisedSpyInBeanTest.java new file mode 100644 index 0000000..7d63dea --- /dev/null +++ b/src/test/java/com/teketik/test/mockinbean/test/VerifyAdvisedSpyInBeanTest.java @@ -0,0 +1,113 @@ +package com.teketik.test.mockinbean.test; + +import static org.mockito.Mockito.verify; + +import com.teketik.test.mockinbean.SpyInBean; +import com.teketik.test.mockinbean.test.VerifyAdvisedSpyInBeanTest.Config.AnAspect; +import com.teketik.test.mockinbean.test.VerifyAdvisedSpyInBeanTest.Config.LoggingService; +import com.teketik.test.mockinbean.test.VerifyAdvisedSpyInBeanTest.Config.ProviderService; + +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Before; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.ApplicationContext; +import org.springframework.core.Ordered; +import org.springframework.stereotype.Component; +import org.springframework.stereotype.Service; +import org.springframework.test.context.TestContext; +import org.springframework.test.context.TestExecutionListener; +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.test.context.TestExecutionListeners.MergeMode; +import org.springframework.test.util.ReflectionTestUtils; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Covering test case from https://github.com/inkassso/mock-in-bean-issue-23/blob/master/src/test/java/com/github/inkassso/mockinbean/issue23/service/BrokenLoggingServiceTest1_SpyInBean.java + */ +@TestExecutionListeners(value = {VerifyAdvisedSpyInBeanTest.class}, mergeMode = MergeMode.MERGE_WITH_DEFAULTS) +@SpringBootTest +public class VerifyAdvisedSpyInBeanTest implements TestExecutionListener, Ordered { + + @org.springframework.boot.test.context.TestConfiguration + static class Config { + + @Aspect + @Component + public class AnAspect { + + private final AtomicInteger invocationCounter = new AtomicInteger(); + + @Before("execution(* com.teketik.test.mockinbean.test.VerifyAdvisedSpyInBeanTest.Config.ProviderService.provideValue())") + public void run() { + invocationCounter.incrementAndGet(); + } + } + + @Service + public class ProviderService { + public String provideValue() { + return ""; + } + } + + @Component + public class LoggingService { + + @Autowired + private ProviderService providerService; + + public String logCurrentValue() { + return providerService.provideValue(); + } + } + } + + @Autowired + protected LoggingService loggingService; + + @Autowired + protected AnAspect anAspect; + + @SpyInBean(LoggingService.class) + private ProviderService providerService; + + @Test + void testAspectInvocation() { + int initialCounterValue = anAspect.invocationCounter.get(); + loggingService.logCurrentValue(); + Assertions.assertEquals(initialCounterValue + 1, anAspect.invocationCounter.get()); + verify(providerService).provideValue(); + Assertions.assertEquals(initialCounterValue + 2, anAspect.invocationCounter.get()); + } + + @Test + void testSpyAnswer() { + Mockito.doAnswer(i -> "value").when(providerService).provideValue(); + Assertions.assertEquals("value", loggingService.logCurrentValue()); + } + + @Override + public void afterTestClass(TestContext testContext) throws Exception { + final ApplicationContext applicationContext = testContext.getApplicationContext(); + + //ensure context clean + final Object loggingServiceBean = applicationContext.getBean(LoggingService.class); + final Object providerServiceInBean = ReflectionTestUtils.getField(loggingServiceBean, "providerService"); + Assertions.assertFalse(TestUtils.isMockOrSpy(providerServiceInBean)); + Assertions.assertSame(applicationContext.getBean(ProviderService.class), providerServiceInBean); + + //ensure aspect invoked + final AnAspect anAspect = applicationContext.getBean(AnAspect.class); + Assertions.assertEquals(4, anAspect.invocationCounter.get()); + } + + @Override + public int getOrder() { + return Integer.MAX_VALUE; + } +}