In post of Spring Security: Authentication and authorization for separated backend, I have show you how to use Spring Security to authenticate and authorize for separated backend of web application. With that way, the server side will store sessions for users, and send cookie to front-end. It means the server side saved the state of user.

In this post, I will show you how to implement Stateless authentication and authorization with Spring Security and JWT.

What is JWT

JWT is short for ‘JSON Web Tokens’. It provids a compact, URL-safe way to represent claims when we need transfer them between two parties. It use JSON to encode the claims and make us easy to read/write and produce.

I believe you have find the JWT will not do any things about Authentication, it just encode and decode the claims, then we can transfer it between 2 sides. We need implement the logic to verify the user’s information by ourselves.

Then, why we use it?

  • First, it has higher safety to a certain extent. It can prevent the token is falsified.

  • Second, it is self-contained and stateless. we need not store it on server side in most situations.

  • Third, it is a cross-language solution. If your application is implemented with Microservice model. It could work well for each service which implemented with different languages.

  • Forth, it supports token expiring. We can set a expiring time point for it, after that, the token will be invalid.

But there also has some points we need to be careful.

  • The JWT body is only be encoded with Base64, it is easy to be decoded. It means anybody can read the content of it if they want.

  • And also, JWT cannot take many messages in body, it will be too long and spend much time on transferring.

  • Last, we cannot make a JWT token to invalid state before it expired. It’s means, even if it has been removed from front-end, it still is valid at server side. If somebody copied it, they still can access the interfaces and fetch data from server.

How to integrate with Spring Security

Ok, let’s start to discuss how to use it with Spring Security.

First, please see following picture which is shown in previous post.

When we discuss how to customize the authentication process in that post, we extended the UsernamePasswordAuthenticationFilter and replaced it in SecurityFilterChain. The same thing is we also want to customize the authentication process by using JWT. The difference is we cannot get the username and password from path variables and body of request directly in Filter. The user’s information will be get after JWT decoded the token. When there is not token, we need call next Filter in SecurityFilterChain. So, we need not ‘form-login’ and only use filter of JWT to complete the job of authentication.

Before we start to write code, let’s see what should we do in JWT filter

With this flowchart, we can see we just need create UsernamePasswordAuthenticationToken object with user details and authorities, and save it into SecurityContextHolder when token exists. Then just call filterChain.doFilter() method.

With the knowlege of Spring Security: How Filter Chain works, we can know what are following filters after our customized filter. So, when we call filterChain.doFilter(), it will one by one to execute them until meet the FilterSecurityInterception which will check the authorities and authorize. If passed, the api which requested in request will be called.

Then, do you have know where we should implement the authentication logic? Yes, it should be in API method of controller. And, it must be an API which need not permissions because authorities is null in login request.

OK, let’s start write code.

The first class is an util, we will implement all operations of token in it.

JWTAuthenticationUtil.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package com.simplejourney.securityjwt.utils;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;

import javax.servlet.http.HttpServletRequest;
import java.util.Base64;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;


@Component
public class JWTAuthenticationUtil {
@Value("${jwt.secret}")
private String secret;

public static final long JWT_TOKEN_VALIDITY = 5 * 60 * 60;


public String getUsername(String token) {
return getClaimFromToken(token, Claims::getSubject);
}

public String generateToken(Map<String, Object> claims, UserDetails userDetails) {
if (null == claims) {
claims = new HashMap<>();
}
return doGenerateToken(claims, userDetails.getUsername());
}

public Boolean validateToken(String token, UserDetails userDetails) {
final String username =getClaimFromToken(token, Claims::getSubject);
return (username.equals(userDetails.getUsername()) && !isTokenExpired(token));
}

public Boolean isTokenExpired(String token) {
final Date expiration = getClaimFromToken(token, Claims::getExpiration);
return expiration.before(new Date());
}


private <T> T getClaimFromToken(String token, Function<Claims, T> claimsResolver) {
final Claims claims = getAllClaimsFromToken(token);
return claimsResolver.apply(claims);
}

private Claims getAllClaimsFromToken(String token) {
return Jwts.parser().setSigningKey(secret).parseClaimsJws(token).getBody();
}

private String doGenerateToken(Map<String, Object> claims, String subject) {
/*
iss: issuer
exp: expiration time
sub: subject
aud: audience
nbf: Not Before
iat: Issued At
jti: JWT ID
*/
return Jwts.builder()
.setClaims(claims)
.setSubject(subject) // sub
.setIssuedAt(new Date(System.currentTimeMillis())) // iat
.setExpiration(new Date(System.currentTimeMillis() + JWT_TOKEN_VALIDITY * 1000)) // exp
.signWith(SignatureAlgorithm.HS512, secret)
.compact();
}
}

Then we can implement filter for JWT

JWTRequestFilter.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
package com.simplejourney.securityjwt.config;

import com.simplejourney.securityjwt.services.JWTTokenManageService;
import com.simplejourney.securityjwt.services.impl.UserDetailsServiceImpl;
import com.simplejourney.securityjwt.utils.JWTAuthenticationUtil;
import io.jsonwebtoken.JwtException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

@Component
public class JWTRequestFilter extends OncePerRequestFilter {

@Autowired
private UserDetailsServiceImpl userDetailService;

@Autowired
private JWTAuthenticationUtil JWTAuthenticationUtil;

@Autowired
private JWTTokenManageService tokenManageService;

@Override
protected void doFilterInternal(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, FilterChain filterChain) throws ServletException, IOException {
String authorization = httpServletRequest.getHeader("Authorization");

String username = null;
String token = null;

if (null != authorization && authorization.startsWith("Bearer ")) {
token = authorization.substring(7);

try {
username = JWTAuthenticationUtil.getUsername(token);
} catch (IllegalArgumentException ex) {
System.err.println("IllegalArgumentException: ");
ex.printStackTrace();
} catch (JwtException ex) {
System.err.println("JwtException: ");
ex.printStackTrace();
}
} else {
System.err.println("Token not found or not start with 'Bearer'");
}

if (null != username && null == SecurityContextHolder.getContext().getAuthentication()) {
UserDetails userDetails = this.userDetailService.loadUserByUsername(username);

if (JWTAuthenticationUtil.validateToken(token, userDetails)) {
UsernamePasswordAuthenticationToken usernamePasswordAuthenticationToken = new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities());
usernamePasswordAuthenticationToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(httpServletRequest));
SecurityContextHolder.getContext().setAuthentication(usernamePasswordAuthenticationToken);
}
}

filterChain.doFilter(httpServletRequest, httpServletResponse);
}
}

Add WebSecurityConfig

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package com.simplejourney.securityjwt.config;

import com.simplejourney.securityjwt.services.impl.UserDetailsServiceImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.method.configuration.EnableGlobalMethodSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;

@Configuration
@EnableWebSecurity
@EnableGlobalMethodSecurity(prePostEnabled = true)
public class WebSecurityConfig extends WebSecurityConfigurerAdapter {
@Autowired
private UserDetailsServiceImpl userDetailsService;

@Autowired
private UserAccessDeniedHandler accessDeniedHandler;

@Autowired
private DemoAuthenticationEntryPoint authenticationEntryPoint;

@Autowired
private JWTRequestFilter requestFilter;

@Autowired
public void configureGlobal(AuthenticationManagerBuilder authenticationManagerBuilder) throws Exception {
authenticationManagerBuilder.userDetailsService(userDetailsService).passwordEncoder(passwordEncoder());
}

@Bean
public PasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}

@Bean
@Override
public AuthenticationManager authenticationManagerBean() throws Exception {
return super.authenticationManagerBean();
}

@Override
protected void configure(HttpSecurity http) throws Exception {
http
// Disable Cross-Site Request Forgery protection
.csrf().disable()
// Enable Cross-Origin Resource Sharing
.cors().disable()


/**
* Resources which need not authorization
*/
.authorizeRequests()
.antMatchers("/hello", "/auth/login").permitAll()


/**
* Other resources need permissions
*/
.anyRequest().authenticated()

/**
* Exception Handling
*/
.and().exceptionHandling()
// set process when autentication failed
.authenticationEntryPoint(authenticationEntryPoint)
// set handler for access denied
.accessDeniedHandler(accessDeniedHandler)


/**
* Session
*/
.and().sessionManagement()
.sessionCreationPolicy(SessionCreationPolicy.STATELESS);

// add JWT filter before authentication
http.addFilterAt(requestFilter, UsernamePasswordAuthenticationFilter.class);
}
}

After above, we need implement authentication in login api.

AuthenticationController.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package com.simplejourney.securityjwt.controllers;

import com.simplejourney.securityjwt.dto.LoginData;
import com.simplejourney.securityjwt.services.JWTTokenManageService;
import com.simplejourney.securityjwt.services.impl.UserDetailsServiceImpl;
import com.simplejourney.securityjwt.utils.JWTAuthenticationUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.DisabledException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.web.bind.annotation.*;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

@RestController
@CrossOrigin
@RequestMapping("/auth")
public class AuthenticationController {
@Autowired
private AuthenticationManager authenticationManager;

@Autowired
private UserDetailsServiceImpl userDetailsService;

@Autowired
private JWTTokenManageService tokenManageService;

@Autowired
private JWTAuthenticationUtil JWTAuthenticationUtil;

@PostMapping("/login")
public ResponseEntity<String> login(HttpServletRequest request, @RequestBody LoginData data) throws Exception {
try {
this.authenticationManager.authenticate(new UsernamePasswordAuthenticationToken(data.getUsername(), data.getPassword()));
} catch (DisabledException ex) {
throw new Exception("USER_DISABLED", ex);
} catch (BadCredentialsException ex) {
throw new Exception("INVALID_CREDENTIALS", ex);
}

final UserDetails userDetails = this.userDetailsService.loadUserByUsername(data.getUsername());
final String token = JWTAuthenticationUtil.generateToken(null, userDetails);

return ResponseEntity.ok(token);
}
}

OK, what we need have been done. We can do something for test. (For more code of project need, please see demo project on Github)

For this project, we can use Swagger UI to do testing for APIs.

Add following dependencies to ‘build.gradle’

1
2
implementation 'io.springfox:springfox-swagger2:2.4.0'
implementation 'io.springfox:springfox-swagger-ui:2.4.0'

And add SwaggerUI class to same folder as DemoApplication

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package com.simplejourney.securityjwt;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import springfox.documentation.builders.ApiInfoBuilder;
import springfox.documentation.builders.ParameterBuilder;
import springfox.documentation.builders.PathSelectors;
import springfox.documentation.builders.RequestHandlerSelectors;
import springfox.documentation.schema.ModelRef;
import springfox.documentation.service.ApiInfo;
import springfox.documentation.service.Parameter;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spring.web.plugins.Docket;
import springfox.documentation.swagger2.annotations.EnableSwagger2;

import java.util.ArrayList;
import java.util.List;

@Configuration
@EnableSwagger2
public class SwaggerUI {
@Bean
public Docket createRestApi() {
ParameterBuilder tokenPar = new ParameterBuilder();
List<Parameter> pars = new ArrayList<>();

tokenPar.name("Authorization").description("Token").modelRef(new ModelRef("string")).parameterType("header").required(false).build();
pars.add(tokenPar.build());

return new Docket(DocumentationType.SWAGGER_2)
.select()
.apis(RequestHandlerSelectors.basePackage("com.simplejourney.securityjwt.controllers"))
.paths(PathSelectors.any())
.build()
.globalOperationParameters(pars)
.apiInfo(apiInfo());
}

private ApiInfo apiInfo() {
return new ApiInfoBuilder().title("Simple APIs")
.description("simple apis")
.version("1.0")
.build();
}
}

And add following configurations after ‘authorizeRequests()’ in WebSecurityConfig class

1
2
3
4
5
6
7
8
9
10
// all Swagger-UI accessing need not authorization
.antMatchers(
"/v2/api-docs",
"/swagger-resources",
"/swagger-resources/**",
"/configuration/ui",
"/configuration/security",
"/swagger-ui.html/**",
"/webjars/**"
).permitAll()

Then launch application, open web browser which you like and enter url

1
http://localhost:8080/swagger-ui.html

You can see the page looks like following

Now, expand ‘POST /auth/login’ pane and enter the username and password to ‘data’, and click ‘Try it out!’ button

If succeed, you can find the token in ‘Response Body’ section as following

This just a example for using Swagger UI, for other APIs, you can test them by yourself.

Make Token Invalid Immediately

In What is JWT section, we have know there are some points need to be careful when using JWT. One of them is JWT token cannot be invalid before it expired.

We can try it with Swagger UI of project with following case.

  • Step 1: Login with tom, and copy the token from response
  • Step 2: Logout
  • Step 3: Set ‘Authorization’ to ‘Bearer ‘ + ‘you copied token’ to ‘Get /book’ api, and send it

You will find you also can receive the book list.

But we may meet the requirements which we should make it invalid immediately. Such as we need kick user offline immediately. Then, how to make it true?

To do this, there is a solution is that we can use Ridis or database to cache an identifier and manage them. When a token is generated, we put the identifier to cache, and when it invalid or we need make it invalid, remove it from cache.

The next question is, how to generate the identifier? There are 2 data we can used.

The first is user’s IP address. To get it, we need add

1
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;

to NGINX,

And/Or

1
RequestHeader set X-Forwarded-Proto "http"

for Apache.

Then, we can get it from request by

1
request.getHeader("X-FORWARDED-FOR")

The second is ‘User-Agent’. It also exists in request headers.

1
request.getHeader("User-Agent")

And also, we can use both of them.

In my example, I use both of them and encoded with Base64. You also can encode them with MD5 or other hash algorithms.

Then we start to modify our code to add this solution.

First, we need create a service to manage identifiers in cache. In this post, we also use memory as cache.

JWTTokenManageService.java

1
2
3
4
5
6
7
package com.simplejourney.securityjwt.services;

public interface JWTTokenManageService {
void add(String ident, Object token);
void remove(String ident);
boolean has(String ident);
}

JWTTokenManageServiceImpl.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
package com.simplejourney.securityjwt.services.impl;

import com.simplejourney.securityjwt.services.JWTTokenManageService;
import org.springframework.stereotype.Service;

import java.util.HashMap;
import java.util.Map;

@Service
public class JWTTokenManageServiceImpl implements JWTTokenManageService {
private Map<String, Object> tokens = new HashMap<>();

public void add(String ident, Object token) {
this.tokens.put(ident, token);
}

public void remove(String ident) {
this.tokens.remove(ident);
}

public boolean has(String ident) {
return this.tokens.containsKey(ident);
}
}

Then we need add code to generate the identifier. So we add following methods in JWTAuthenticationUtil

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public String generateIdent(HttpServletRequest request) {
String identSource = String.format("%s;%s", getIPAddress(request), getUserAgent(request));
return new String(Base64.getEncoder().encode(identSource.getBytes()));
}

public String getIdent(String token) {
final Claims claims = getAllClaimsFromToken(token);
return String.valueOf(claims.get("ident"));
}

private String getIPAddress(HttpServletRequest request) {
String ip = "";
if (null != request) {
ip = request.getHeader("X-FORWARDED-FOR");
if (null == ip || ip.isEmpty()) {
ip = request.getRemoteAddr();
}
}
return ip;
}

private String getUserAgent(HttpServletRequest request) {
String userAgent = "";
if (null != request) {
userAgent = request.getHeader("User-Agent");
}
return userAgent;
}

After that, we need modify the ‘login’ API to add the identifier to cache when the token is created

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@PostMapping("/login")
public ResponseEntity<String> login(HttpServletRequest request, @RequestBody LoginData data) throws Exception {
try {
this.authenticationManager.authenticate(new UsernamePasswordAuthenticationToken(data.getUsername(), data.getPassword()));
} catch (DisabledException ex) {
throw new Exception("USER_DISABLED", ex);
} catch (BadCredentialsException ex) {
throw new Exception("INVALID_CREDENTIALS", ex);
}

String ident = jwtAuthenticationUtil.generateIdent(request);
if (!tokenManageService.has(ident)) {
Map<String, Object> claims = new HashMap<>();
claims.put("ident", ident);

final UserDetails userDetails = this.userDetailsService.loadUserByUsername(data.getUsername());
final String token = jwtAuthenticationUtil.generateToken(claims, userDetails);

tokenManageService.add(ident, token);

return ResponseEntity.ok(token);
} else {
return new ResponseEntity<>(null, HttpStatus.CONFLICT);
}
}

And add logout API to remove it from cache when user logout

1
2
3
4
5
6
7
8
9
@GetMapping("/logout")
public ResponseEntity<String> logout(HttpServletRequest request, HttpServletResponse response) {
String ident = jwtAuthenticationUtil.generateIdent(request);
if (tokenManageService.has(ident)) {
tokenManageService.remove(ident);
}

return ResponseEntity.ok("Logout Succeed");
}

modify ‘doFilterInternal’ method of JWTRequestFilter to remove it when token is expired

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@Override
protected void doFilterInternal(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, FilterChain filterChain) throws ServletException, IOException {
String authorization = httpServletRequest.getHeader("Authorization");

String username = null;
String token = null;

if (null != authorization && authorization.startsWith("Bearer ")) {
token = authorization.substring(7);

try {
username = jwtAuthenticationUtil.getUsername(token);
} catch (IllegalArgumentException ex) {
System.err.println("IllegalArgumentException: ");
ex.printStackTrace();
} catch (JwtException ex) {
System.err.println("JwtException: ");
ex.printStackTrace();
}
} else {
System.err.println("Token not found or not start with 'Bearer'");
}

String ident = jwtAuthenticationUtil.generateIdent(httpServletRequest);
if (null != username) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
boolean hasIdent = tokenManageService.has(ident);
if (hasIdent && null == authentication) {
UserDetails userDetails = this.userDetailService.loadUserByUsername(username);

if (jwtAuthenticationUtil.isTokenExpired(token)) {
tokenManageService.remove(ident);
}

if (ident.equals(jwtAuthenticationUtil.getIdent(token)) && jwtAuthenticationUtil.validateToken(token, userDetails)) {
UsernamePasswordAuthenticationToken usernamePasswordAuthenticationToken = new UsernamePasswordAuthenticationToken(userDetails, null, userDetails.getAuthorities());
usernamePasswordAuthenticationToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(httpServletRequest));
SecurityContextHolder.getContext().setAuthentication(usernamePasswordAuthenticationToken);
}
}
}

filterChain.doFilter(httpServletRequest, httpServletResponse);
}

Done, build and re-launch application and try above test case again. The response after logout should be

1
2
3
4
5
6
7
{
"timestamp": "2020-05-15T05:53:51.257+0000",
"status": 401,
"error": "Unauthorized",
"message": "Unauthorized",
"path": "/book"
}

OK, all above are what I want to tell you in this post. If you have any questions or suggestions, please feel free to submit your comments to issue

The demo project you can download from Github

References