October 16, 2020

Spring Boot Testing Tutorial

Spring Boot Testing Tutorial – Part 1, in this article series, we are going to learn about Unit Testing Spring Boot application using Junit 5 and we will see how to use Mocking frameworks like Mockito.

This will be the part of the 3 part tutorial series which covers the following topics:

  • Unit Testing with Junit 5 and Mockito
  • Integration Tests using Test Containers
  • Testing REST APIs using MockMvc

Source code for Example Project

I am going to explain the above concepts by taking a complete project as an example. I am going to use the Reddit Clone Application which I built using Spring Boot and Angular, you can check out the source code of the tutorial here

I also created a written and video tutorial series where I show you how to build the application step by step, you can check it out here if you are interested.

You can follow along with this tutorial, by downloading the Source Code and starting writing tests with me. I will explain the overall application functionality, as we progress in this tutorial.

Source Code with Tests included

You can find the source code which includes Unit Tests at this URL: https://github.com/SaiUpadhyayula/spring-boot-testing-reddit-clone

If you are a visual learner like me, you can check out the Video Version of this tutorial below:

Unit Testing with Junit 5

We are going to write unit tests using the Junit5 library, a popular Unit Testing Library for Java applications, before starting to write the unit tests, let’s discuss What exactly is Unit Testing?

Unit Testing is a practice in the software development process, where you test the functionality of a component (in our case a Java class) in isolation, without depending on any external dependencies.

As I already mentioned before, we are going to use JUNIT 5 for writing Unit Tests in our application. We can install Junit 5 in your project by adding the below maven dependency to the pom.xml file.

<dependency>
    <groupId>org.junit.jupiter</groupId>
    <artifactId>junit-jupiter</artifactId>
    <version>5.6.2</version>
    <scope>test</scope>
</dependency>

We also need to make sure that the Spring Boot Starter Test dependency is also added to our pom.xml

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-test</artifactId>
    <scope>test</scope>
</dependency>

The example project I linked above already contains the Spring Boot Start Test dependency, but if you check the pom.xml of the spring-boot-starter-test library, you can see that it includes Junit 4 as a transitive dependency.

We can exclude this dependency by adding the below configuration to the spring-boot-starter-test dependency.

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
            <exclusions>
                <exclusion>
                    <groupId>junit</groupId>
                    <artifactId>junit</artifactId>
                </exclusion>
            </exclusions>
        </dependency>

Now we should only have Junit 5 dependency in our classpath.

Overview of Application Architecture

Here you can find the overview of the application architecture

Overview of Application Architecture

Spring Boot Application Architecture

Overview of Spring Boot Application Architecture

We have a 3 Tier Architecture with Controller, Service and Persistence Layer, we are going to cover each layer in our Tutorial Series.

As we are mainly emphasizing the Unit Testing we will take some example classes which are part of the Service Layer.

Now let’s start writing our first unit test.

Your first Unit Test

We will start off with writing Tests for the CommentServiceclass which looks like below:

CommentService.java

package com.programming.techie.springredditclone.service;

import com.programming.techie.springredditclone.dto.CommentsDto;
import com.programming.techie.springredditclone.exceptions.PostNotFoundException;
import com.programming.techie.springredditclone.exceptions.SpringRedditException;
import com.programming.techie.springredditclone.mapper.CommentMapper;
import com.programming.techie.springredditclone.model.Comment;
import com.programming.techie.springredditclone.model.NotificationEmail;
import com.programming.techie.springredditclone.model.Post;
import com.programming.techie.springredditclone.model.User;
import com.programming.techie.springredditclone.repository.CommentRepository;
import com.programming.techie.springredditclone.repository.PostRepository;
import com.programming.techie.springredditclone.repository.UserRepository;
import lombok.AllArgsConstructor;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.stereotype.Service;

import java.util.List;

import static java.util.stream.Collectors.toList;

@Service
@AllArgsConstructor
public class CommentService {
    private static final String POST_URL = "";
    private final PostRepository postRepository;
    private final UserRepository userRepository;
    private final AuthService authService;
    private final CommentMapper commentMapper;
    private final CommentRepository commentRepository;
    private final MailContentBuilder mailContentBuilder;
    private final MailService mailService;

    public void save(CommentsDto commentsDto) {
        Post post = postRepository.findById(commentsDto.getPostId())
                .orElseThrow(() -> new PostNotFoundException(commentsDto.getPostId().toString()));
        Comment comment = commentMapper.map(commentsDto, post, authService.getCurrentUser());
        commentRepository.save(comment);

        String message = mailContentBuilder.build(authService.getCurrentUser() + " posted a comment on your post." + POST_URL);
        sendCommentNotification(message, post.getUser());
    }

    private void sendCommentNotification(String message, User user) {
        mailService.sendMail(new NotificationEmail(user.getUsername() + " Commented on your post", user.getEmail(), message));
    }

    public List<CommentsDto> getAllCommentsForPost(Long postId) {
        Post post = postRepository.findById(postId).orElseThrow(() -> new PostNotFoundException(postId.toString()));
        return commentRepository.findByPost(post)
                .stream()
                .map(commentMapper::mapToDto).collect(toList());
    }

    public List<CommentsDto> getAllCommentsForUser(String userName) {
        User user = userRepository.findByUsername(userName)
                .orElseThrow(() -> new UsernameNotFoundException(userName));
        return commentRepository.findAllByUser(user)
                .stream()
                .map(commentMapper::mapToDto)
                .collect(toList());
    }

    public boolean containsSwearWords(String comment) {
        if (comment.contains("shit")) {
            throw new SpringRedditException("Comments contains unacceptable language");
        }
        return true;
    }
}

This CommentService class is communicating with CommentRepository and CommentController classes which are part of the Persistence and Controller Layer respectively.

CommentRepository.java

package com.programming.techie.springredditclone.repository;

import com.programming.techie.springredditclone.model.Comment;
import com.programming.techie.springredditclone.model.Post;
import com.programming.techie.springredditclone.model.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;

import java.util.List;

@Repository
public interface CommentRepository extends JpaRepository<Comment, Long> {
    List<Comment> findByPost(Post post);

    List<Comment> findAllByUser(User user);
}

CommentsController.java

package com.programming.techie.springredditclone.controller;

import com.programming.techie.springredditclone.dto.CommentsDto;
import com.programming.techie.springredditclone.service.CommentService;
import lombok.AllArgsConstructor;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import java.util.List;

import static org.springframework.http.HttpStatus.CREATED;
import static org.springframework.http.HttpStatus.OK;

@RestController
@RequestMapping("/api/comments/")
@AllArgsConstructor
public class CommentsController {
    private final CommentService commentService;

    @PostMapping
    public ResponseEntity<Void> createComment(@RequestBody CommentsDto commentsDto) {
        commentService.save(commentsDto);
        return new ResponseEntity<>(CREATED);
    }

    @GetMapping("/by-post/{postId}")
    public ResponseEntity<List<CommentsDto>> getAllCommentsForPost(@PathVariable Long postId) {
        return ResponseEntity.status(OK)
                .body(commentService.getAllCommentsForPost(postId));
    }

    @GetMapping("/by-user/{userName}")
    public ResponseEntity<List<CommentsDto>> getAllCommentsForUser(@PathVariable String userName){
        return ResponseEntity.status(OK)
                .body(commentService.getAllCommentsForUser(userName));
    }

}

Let’s create a unit test for the CommentService class by creating a class called CommentServiceTest, we will concentrate on writing a Test for the method containsSwearWords(String)

CommentServiceTest.java

package com.programming.techie.springredditclone.service;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

public class CommentServiceTest {

    @Test
    @DisplayName("Test Should Pass When Comment do not Contains Swear Words")
    public void shouldNotContainSwearWordsInsideComment() {
        CommentService commentService = new CommentService(null, null, null, null, null, null, null);
        assertFalse(commentService.containsSwearWords("This is a comment"));
    }
}

Let’s understand what is going on in this class:

  • So as you can see, first, we created a CommentServiceTest class, and inside the class, we declared a method which is annotated with @Test which indicates that the below method is a test.
  • We also have a @DisplayName annotation, this annotation helps us to write a meaningful description for our Test, as you can see with the name of the method shouldNotContainSwearWordsInsideComment() its pretty descriptive, but not easily readable.
  • As we are writing Tests for the CommentService class we need to first instantiate the class and as this is a Unit Test we are supposed to Test this class in isolation, that is the reason why we passed null as Constructor Parameter.
  • Next, we called the containsSwearWords() method and we are doing an assertion whether this method is returning our expected value or not.
  • The assertion we are doing with the help of Junit 5 class called Assertions

If you try to run this test, you can see that the Test should pass without any problem. We have our first passing Test 🎉🎉🥳🥳

Test Result for CommentService Test

A rule of thumb to remember when testing our code, is to make sure that the test we wrote actually fails when the behavior of the code changes, that is the main reason we are writing tests, to get the feedback immediately when we unintentionally changed the behavior of the method.

Let’s change the logic of the method to return true instead of false when a clean comment is passed in as input.

    public boolean containsSwearWords(String comment) {
        if (comment.contains("shit")) {
            throw new SpringRedditException("Comments contains unacceptable language");
        }
        return true;
    }

And if we run our test again, it should fail.

Failure Test Result for CommentServiceTest

Testing Negative Case

Now let’s go ahead and write a test which is supposed to throw an Exception, as a developer we tend to get stuck into and emphasize only on Happy Path Testing, but it’s also important to cover the negative cases.

When the submitted comment contains swear words, then the method should throw an Exception, let’s write a unit test to test this scenario:

    @Test
    @DisplayName("Should Throw Exception when Exception Contains Swear Words")
    public void shouldFailWhenCommentContainsSwearWords() {
        CommentService commentService = new CommentService(null, null, null, null, null, null, null);
        SpringRedditException exception = assertThrows(SpringRedditException.class, () -> {
            commentService.containsSwearWords("This is shitty comment");
        });
        assertTrue(exception.getMessage().contains("Comments contains unacceptable language"));
    }

So in this case we are expecting our method under test to throw a SpringRedditException with a message “Comment contains unacceptable language”

We can use the assertThrows() method from Assertions class to verify this behavior. This method returns the Exception as return value, then we can use the exception.getMessage() to retrieve the exception message.

Improved Assertions using AssertJ

Till now we used Junit 5 built-in Assertions class to make some basic assertions, but we can write more readable assertions using the AssertJ library, let’s add this dependency to our class and see how it helps us to write better tests.

<dependency>
    <groupId>org.assertj</groupId>
    <artifactId>assertj-core</artifactId>
    <version>3.17.2</version>
    <scope>test</scope>
</dependency>

Let’s see how we can improve our assertions and make use of the Fluent API provided by AssertJ

In our existing test, we are asserting for Boolean values using the assertTrue and assertFalse methods from Assertions class in Junit 5. We can replace them with AssertJ like below:

    @Test
    @DisplayName("Test Should Pass When Comment do not Contains Swear Words")
    public void shouldNotContainSwearWordsInsideComment() {
        CommentService commentService = new CommentService(null, null, null, null, null, null, null);
        assertThat(commentService.containsSwearWords("This is a comment")).isFalse();
    }

You may be thinking that’s not much of an improvement, but hang in there with me, let’s take the next example where we are asserting whether an Exception is thrown from our method or not.

    @Test
    @DisplayName("Should Throw Exception when Exception Contains Swear Words")
    public void shouldFailWhenCommentContainsSwearWords() {
        CommentService commentService = new CommentService(null, null, null, null, null, null, null);

        assertThatThrownBy(() -> {
            commentService.containsSwearWords("This is a shitty comment");
        }).isInstanceOf(SpringRedditException.class)
                .hasMessage("Comments contains unacceptable language");
    }

In the above example we can see that using the assertThatThrownBy() method we can get access to the methods isInstanceOf() and hasMessage() which made our tests more readable than the previous implementation.

Mocking the dependencies using Mockito

In this tutorial, as we are doing Unit Testing ( which by definitions means testing a functionality of a Unit in isolation) we cannot use any Spring Features, but if you have a look at our CommentService class it has dependencies to many Spring Components like you see in the below screenshot:

Dependencies for CommentService

We have 7 dependencies for this CommentService class, surely this class is doing a lot of things, and the implementation can be refactored, but that’s another discussion.

In this scenario, if you want to test the complete class, we have to Mock the dependencies used by the CommentService, to provide the mocks we can use the library Mockito

We can install Mockito in our project by adding the below dependency to our pom.xml file

        <dependency>
            <groupId>org.mockito</groupId>
            <artifactId>mockito-all</artifactId>
            <version>1.10.19</version>
            <scope>test</scope>
        </dependency>

Now let’s try to write a Test using Mockito by taking another class as an example PostService.java

package com.programming.techie.springredditclone.service;

import com.programming.techie.springredditclone.dto.PostRequest;
import com.programming.techie.springredditclone.dto.PostResponse;
import com.programming.techie.springredditclone.exceptions.PostNotFoundException;
import com.programming.techie.springredditclone.exceptions.SubredditNotFoundException;
import com.programming.techie.springredditclone.mapper.PostMapper;
import com.programming.techie.springredditclone.model.Post;
import com.programming.techie.springredditclone.model.Subreddit;
import com.programming.techie.springredditclone.model.User;
import com.programming.techie.springredditclone.repository.PostRepository;
import com.programming.techie.springredditclone.repository.SubredditRepository;
import com.programming.techie.springredditclone.repository.UserRepository;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.util.List;

import static java.util.stream.Collectors.toList;

@Service
@AllArgsConstructor
@Slf4j
@Transactional
public class PostService {

    private final PostRepository postRepository;
    private final SubredditRepository subredditRepository;
    private final UserRepository userRepository;
    private final AuthService authService;
    private final PostMapper postMapper;

    public void save(PostRequest postRequest) {
        Subreddit subreddit = subredditRepository.findByName(postRequest.getSubredditName())
                .orElseThrow(() -> new SubredditNotFoundException(postRequest.getSubredditName()));
        postRepository.save(postMapper.map(postRequest, subreddit, authService.getCurrentUser()));
    }

    @Transactional(readOnly = true)
    public PostResponse getPost(Long id) {
        Post post = postRepository.findById(id)
                .orElseThrow(() -> new PostNotFoundException(id.toString()));
        return postMapper.mapToDto(post);
    }

    @Transactional(readOnly = true)
    public List<PostResponse> getAllPosts() {
        return postRepository.findAll()
                .stream()
                .map(postMapper::mapToDto)
                .collect(toList());
    }

    @Transactional(readOnly = true)
    public List<PostResponse> getPostsBySubreddit(Long subredditId) {
        Subreddit subreddit = subredditRepository.findById(subredditId)
                .orElseThrow(() -> new SubredditNotFoundException(subredditId.toString()));
        List<Post> posts = postRepository.findAllBySubreddit(subreddit);
        return posts.stream().map(postMapper::mapToDto).collect(toList());
    }

    @Transactional(readOnly = true)
    public List<PostResponse> getPostsByUsername(String username) {
        User user = userRepository.findByUsername(username)
                .orElseThrow(() -> new UsernameNotFoundException(username));
        return postRepository.findByUser(user)
                .stream()
                .map(postMapper::mapToDto)
                .collect(toList());
    }
}

We are going to write a test to check the behavior of the getPost(Long) method

PostServiceTest.java

package com.programming.techie.springredditclone.service;

import com.programming.techie.springredditclone.dto.PostResponse;
import com.programming.techie.springredditclone.mapper.PostMapper;
import com.programming.techie.springredditclone.model.Post;
import com.programming.techie.springredditclone.repository.PostRepository;
import com.programming.techie.springredditclone.repository.SubredditRepository;
import com.programming.techie.springredditclone.repository.UserRepository;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import java.time.Instant;
import java.util.Optional;

class PostServiceTest {

    private PostRepository postRepository = Mockito.mock(PostRepository.class);
    private SubredditRepository subredditRepository = Mockito.mock(SubredditRepository.class);
    private UserRepository userRepository = Mockito.mock(UserRepository.class);
    private AuthService authService = Mockito.mock(AuthService.class);
    private PostMapper postMapper = Mockito.mock(PostMapper.class);

    @Test
    @DisplayName("Should Retrieve Post by Id")
    public void shouldFindPostById() {

        PostService postService = new PostService(postRepository, subredditRepository, userRepository, authService, postMapper);

        Post post = new Post(123L, "First Post", "http://url.site", "Test",
                0, null, Instant.now(), null);
        PostResponse expectedPostResponse = new PostResponse(123L, "First Post", "http://url.site", "Test",
                "Test User", "Test Subredit", 0, 0, "1 Hour Ago", false, false);

        Mockito.when(postRepository.findById(123L)).thenReturn(Optional.of(post));
        Mockito.when(postMapper.mapToDto(Mockito.any(Post.class))).thenReturn(expectedPostResponse);

        PostResponse actualPostResponse = postService.getPost(123L);

        Assertions.assertThat(actualPostResponse.getId()).isEqualTo(expectedPostResponse.getId());
        Assertions.assertThat(actualPostResponse.getPostName()).isEqualTo(expectedPostResponse.getPostName());
    }
}

Ok let’s go through what we are doing in this test:

  • The class under Test is PostService and it has dependencies to 5 Spring classes, to mock them out we are going to use the Mockito.mock() which will provide us the Mock Objects for our dependencies.
  • We are passing the mocked objects as constructor parameters, to create an instance of PostService.
  • As the dependencies are mocked, we have to define how the method calls to the mocked object should behave, we can do that by using the method call chain Mockito.when().thenReturn()

    Example:
    As part of the postService.getPost() method call, we are retrieving the post from the PostRepository using the postRepository.findById(), as we are using the mocked PostRepository object, we have to define what should be the expected output when making the above method call.
  • Finally, we are making assertions that the return value from the method is what we are expecting are not.

Using @Mock annotations

We are using Mockito.mock() method to create a Mock of an object, there is a more convenient way to do this in a declarative manner, by making use of the @Mock annotations.

To use the annotations, we have to add the mockito-junit-jupiter dependency to our project

<dependency>
    <groupId>org.mockito</groupId>
    <artifactId>mockito-junit-jupiter</artifactId>
    <scope>test</scope>
</dependency>

We can change our test like below after adding the annotation:

package com.programming.techie.springredditclone.service;

import com.programming.techie.springredditclone.dto.PostResponse;
import com.programming.techie.springredditclone.mapper.PostMapper;
import com.programming.techie.springredditclone.model.Post;
import com.programming.techie.springredditclone.repository.PostRepository;
import com.programming.techie.springredditclone.repository.SubredditRepository;
import com.programming.techie.springredditclone.repository.UserRepository;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;

import java.time.Instant;
import java.util.Optional;

@ExtendWith(MockitoExtension.class)
class PostServiceTest {

    @Mock
    private PostRepository postRepository;
    @Mock
    private SubredditRepository subredditRepository;
    @Mock
    private UserRepository userRepository;
    @Mock
    private AuthService authService;
    @Mock
    private PostMapper postMapper;

    @Test
    @DisplayName("Should Retrieve Post by Id")
    public void shouldFindPostById() {

        PostService postService = new PostService(postRepository, subredditRepository, userRepository, authService, postMapper);

        Post post = new Post(123L, "First Post", "http://url.site", "Test",
                0, null, Instant.now(), null);
        PostResponse expectedPostResponse = new PostResponse(123L, "First Post", "http://url.site", "Test",
                "Test User", "Test Subredit", 0, 0, "1 Hour Ago", false, false);

        Mockito.when(postRepository.findById(123L)).thenReturn(Optional.of(post));
        Mockito.when(postMapper.mapToDto(Mockito.any(Post.class))).thenReturn(expectedPostResponse);

        PostResponse actualPostResponse = postService.getPost(123L);

        Assertions.assertThat(actualPostResponse.getId()).isEqualTo(expectedPostResponse.getId());
        Assertions.assertThat(actualPostResponse.getPostName()).isEqualTo(expectedPostResponse.getPostName());
    }
}

To enable to @Mock annotation we have to add the annotation @ExtendWith(MockitoExtension.class) to our class under test.

Verifying Method Invocations using Mockito

We can verify the method invocations in our logic are invoked or not using the Mockito.verify() method.

    @Test
    @DisplayName("Should Save Posts")
    public void shouldSavePosts() {
        PostService postService = new PostService(postRepository, subredditRepository, userRepository, authService, postMapper);

        User currentUser = new User(123L, "test user", "secret password", "user@email.com", Instant.now(), true);
        Subreddit subreddit = new Subreddit(123L, "First Subreddit", "Subreddit Description", emptyList(), Instant.now(), currentUser);
        Post post = new Post(123L, "First Post", "http://url.site", "Test",
                0, null, Instant.now(), null);
        PostRequest postRequest = new PostRequest(null, "First Subreddit", "First Post", "http://url.site", "Test");

        Mockito.when(subredditRepository.findByName("First Subreddit"))
                .thenReturn(Optional.of(subreddit));
        Mockito.when(postMapper.map(postRequest, subreddit, currentUser))
                .thenReturn(post);
        postService.save(postRequest);
        Mockito.verify(postRepository, Mockito.times(1)).save(ArgumentMatchers.any(Post.class));

    }

In this case we are writing the test for the postService.save() method, to verify whether the post is saved into the database or not, as we don’t have access to the database, the only way we can verify this behavior is by checking whether the postRepository.save() method is invoked by the test or not.

We can do that by using the Mockito.verify(postRepository, Mockito.times(1)) method, here we can use the Mockito.times() to specify how many times this method must be invoked when running the test.

To the save() method we are passing an instance of ArgumentMatchers to tell Mockito that this method should accept any object only of type Post

Capturing Method Arguments

In the above scenario, we can capture the arguments which are passed to the save() method and verify whether the object passed is according to our requirements or not.

We can do this using the ArgumentCaptor in Mockito.

In the below test, which is slight modified version of the earlier test we are capturing the arguments which are passed to the postRepository.save() method and doing the assertions that whether this is the same object as we intended or not.

 package com.programming.techie.springredditclone.service;

import com.programming.techie.springredditclone.dto.PostRequest;
import com.programming.techie.springredditclone.dto.PostResponse;
import com.programming.techie.springredditclone.mapper.PostMapper;
import com.programming.techie.springredditclone.model.Post;
import com.programming.techie.springredditclone.model.Subreddit;
import com.programming.techie.springredditclone.model.User;
import com.programming.techie.springredditclone.repository.PostRepository;
import com.programming.techie.springredditclone.repository.SubredditRepository;
import com.programming.techie.springredditclone.repository.UserRepository;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;

import java.time.Instant;
import java.util.Optional;

import static java.util.Collections.emptyList;

@ExtendWith(MockitoExtension.class)
class PostServiceTest {

    @Mock
    private PostRepository postRepository;
    @Mock
    private SubredditRepository subredditRepository;
    @Mock
    private UserRepository userRepository;
    @Mock
    private AuthService authService;
    @Mock
    private PostMapper postMapper;

    @Captor
    private ArgumentCaptor<Post> postArgumentCaptor;

    @Test
    @DisplayName("Should Retrieve Post by Id")
    public void shouldFindPostById() {

        PostService postService = new PostService(postRepository, subredditRepository, userRepository, authService, postMapper);

        Post post = new Post(123L, "First Post", "http://url.site", "Test",
                0, null, Instant.now(), null);
        PostResponse expectedPostResponse = new PostResponse(123L, "First Post", "http://url.site", "Test",
                "Test User", "Test Subredit", 0, 0, "1 Hour Ago", false, false);

        Mockito.when(postRepository.findById(123L)).thenReturn(Optional.of(post));
        Mockito.when(postMapper.mapToDto(Mockito.any(Post.class))).thenReturn(expectedPostResponse);

        PostResponse actualPostResponse = postService.getPost(123L);

        Assertions.assertThat(actualPostResponse.getId()).isEqualTo(expectedPostResponse.getId());
        Assertions.assertThat(actualPostResponse.getPostName()).isEqualTo(expectedPostResponse.getPostName());
    }

    @Test
    @DisplayName("Should Save Posts")
    public void shouldSavePosts() {
        PostService postService = new PostService(postRepository, subredditRepository, userRepository, authService, postMapper);

        User currentUser = new User(123L, "test user", "secret password", "user@email.com", Instant.now(), true);
        Subreddit subreddit = new Subreddit(123L, "First Subreddit", "Subreddit Description", emptyList(), Instant.now(), currentUser);
        Post post = new Post(123L, "First Post", "http://url.site", "Test",
                0, null, Instant.now(), null);
        PostRequest postRequest = new PostRequest(null, "First Subreddit", "First Post", "http://url.site", "Test");

        Mockito.when(subredditRepository.findByName("First Subreddit"))
                .thenReturn(Optional.of(subreddit));
        Mockito.when(authService.getCurrentUser())
                .thenReturn(currentUser);
        Mockito.when(postMapper.map(postRequest, subreddit, currentUser))
                .thenReturn(post);

        postService.save(postRequest);
        Mockito.verify(postRepository, Mockito.times(1)).save(postArgumentCaptor.capture());

        Assertions.assertThat(postArgumentCaptor.getValue().getPostId()).isEqualTo(123L);
        Assertions.assertThat(postArgumentCaptor.getValue().getPostName()).isEqualTo("First Post");
    }
}

Improving our Tests by using JUnit Lifecycle Methods

So we have written some tests, now its time to try to refactor and improve them. If you observe the tests we wrote, we are instantiating the PostService each time inside a test, this is because, Junit creates a new instance of the Test class for each test in the class, so we cannot instantiate the PostService once and re-use it across whole Test class.

This leads to code duplication, and we can use the Lifecycle methods of Junit 5 to help us reduce this code duplication.

Here are some of the lifecycle annotations provided by Junit

  • @BeforeEach
    Using this annotation, we can define all the pre-processing logic required before running each test. In our example, we can move the instantiation of PostService to the method annotated with @BeforeEach
  • @AfterEach
    Using this annotation, we can define all the post-processing logic required after running each test.
  • @BeforeAll
    Using this annotation, we can define all the pre-processing logic required before running the Test class. We will see how to see this annotation in the future parts.
  • @AfterAll
    Using this annotation, we can define all the post-processing logic required after running the Test class. We will see how to see this annotation in the future parts.

Here is how our PostServiceTest class looks like after using the Lifecycle methods.

package com.programming.techie.springredditclone.service;

import com.programming.techie.springredditclone.dto.PostRequest;
import com.programming.techie.springredditclone.dto.PostResponse;
import com.programming.techie.springredditclone.mapper.PostMapper;
import com.programming.techie.springredditclone.model.Post;
import com.programming.techie.springredditclone.model.Subreddit;
import com.programming.techie.springredditclone.model.User;
import com.programming.techie.springredditclone.repository.PostRepository;
import com.programming.techie.springredditclone.repository.SubredditRepository;
import com.programming.techie.springredditclone.repository.UserRepository;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;

import java.time.Instant;
import java.util.Optional;

import static java.util.Collections.emptyList;

@ExtendWith(MockitoExtension.class)
class PostServiceTest {

    @Mock
    private PostRepository postRepository;
    @Mock
    private SubredditRepository subredditRepository;
    @Mock
    private UserRepository userRepository;
    @Mock
    private AuthService authService;
    @Mock
    private PostMapper postMapper;

    @Captor
    private ArgumentCaptor<Post> postArgumentCaptor;

    private PostService postService;

    @BeforeEach
    public void setup() {
        postService = new PostService(postRepository, subredditRepository, userRepository, authService, postMapper);
    }

    @Test
    @DisplayName("Should Retrieve Post by Id")
    public void shouldFindPostById() {
        Post post = new Post(123L, "First Post", "http://url.site", "Test",
                0, null, Instant.now(), null);
        PostResponse expectedPostResponse = new PostResponse(123L, "First Post", "http://url.site", "Test",
                "Test User", "Test Subredit", 0, 0, "1 Hour Ago", false, false);

        Mockito.when(postRepository.findById(123L)).thenReturn(Optional.of(post));
        Mockito.when(postMapper.mapToDto(Mockito.any(Post.class))).thenReturn(expectedPostResponse);

        PostResponse actualPostResponse = postService.getPost(123L);

        Assertions.assertThat(actualPostResponse.getId()).isEqualTo(expectedPostResponse.getId());
        Assertions.assertThat(actualPostResponse.getPostName()).isEqualTo(expectedPostResponse.getPostName());
    }

    @Test
    @DisplayName("Should Save Posts")
    public void shouldSavePosts() {
        User currentUser = new User(123L, "test user", "secret password", "user@email.com", Instant.now(), true);
        Subreddit subreddit = new Subreddit(123L, "First Subreddit", "Subreddit Description", emptyList(), Instant.now(), currentUser);
        Post post = new Post(123L, "First Post", "http://url.site", "Test",
                0, null, Instant.now(), null);
        PostRequest postRequest = new PostRequest(null, "First Subreddit", "First Post", "http://url.site", "Test");

        Mockito.when(subredditRepository.findByName("First Subreddit"))
                .thenReturn(Optional.of(subreddit));
        Mockito.when(authService.getCurrentUser())
                .thenReturn(currentUser);
        Mockito.when(postMapper.map(postRequest, subreddit, currentUser))
                .thenReturn(post);

        postService.save(postRequest);
        Mockito.verify(postRepository, Mockito.times(1)).save(postArgumentCaptor.capture());

        Assertions.assertThat(postArgumentCaptor.getValue().getPostId()).isEqualTo(123L);
        Assertions.assertThat(postArgumentCaptor.getValue().getPostName()).isEqualTo("First Post");
    }
}

What’s next ?

So this is the end of the first part of the tutorial, in the next part, we will be concentrating on writing Integration Tests using Spring Boot with the help of Test Containers.

I hope this article was helpful to you,

About the author 

Sai Upadhyayula

  1. The article was quite helpful and enjoyed it by looking through pictures and describing each line of code which I was curious about Thank you so much,

Comments are closed.

{"email":"Email address invalid","url":"Website address invalid","required":"Required field missing"}

Subscribe now to get the latest updates!