I've had a programming problem that's been bugging me for a while. I'm a Java developer who's a fan of both
Spring and unit testing. So whenever I create a data access object (DAO) for persisting objects I like to unit test it to make sure that I've coded everything properly.
One of the special problems with database testing, especially for those cases where you share a relational database with others, is relying on data being present to make your tests pass. What happens if the data that made my tests run at 100% success is removed by someone else?
One solution is to use an isolated database with the identical schema that is completely under your control. This may not be practical for large schemas.
Another solution is to make your tests transactional: start the test, open a transaction, populate the database with data, run your tests, and roll them back. That way you're operating on the real, live schema. You seed the database with data you can rely on. Transactions wipe out your footprints and make it look like you never modified the database at all.
Spring provides base classes to accomplish exactly this: one for
JUnit version 4 and another for
TestNG.
I've become a fan of TestNG, but I'm unhappy to report that I couldn't make this ideal situation work for TestNG. I went back to the
Spring reference docs in frustration and started again with JUnit version 4. Section 8.3.7.4. "Transaction management" lays it out perfectly. My tests were 100% successful. If I stopped in a debugger and looked at the database, I could see the seed data rows. When the test was completed, the table rolled back to its undisturbed state, as expected.
It should have been as simple as exchanging a single JUnit annotation for its TestNG equivalent, but autowiring of beans wasn't working as it should. When I tried to inject the bean manually from the application context I had another problem. I'll have to dig into this a bit more to see if I can make TestNG work with Spring.
I used a simple model object Product:
package tutorial.model;
import java.io.Serializable;
import java.text.DecimalFormat;
public class Product implements Serializable
{
private Integer id;
private String name;
private double price;
private int quantity;
public Product(String name, double price, int quantity)
{
this(null, name, price, quantity);
}
public Product(Integer id, String name, double price, int quantity)
{
this.id = id;
this.name = name;
this.price = price;
this.quantity = quantity;
}
public Integer getId()
{
return id;
}
public void setId(Integer id)
{
this.id = id;
}
public String getName()
{
return name;
}
public void setName(String name)
{
this.name = name;
}
public double getPrice()
{
return price;
}
public void setPrice(double price)
{
this.price = price;
}
public int getQuantity()
{
return quantity;
}
public void setQuantity(int quantity)
{
this.quantity = quantity;
}
@Override
public boolean equals(Object o)
{
if (this == o)
{
return true;
}
if (o == null || getClass() != o.getClass())
{
return false;
}
Product product = (Product) o;
if (Double.compare(product.price, price) != 0)
{
return false;
}
if (quantity != product.quantity)
{
return false;
}
if (id != null ? !id.equals(product.id) : product.id != null)
{
return false;
}
if (name != null ? !name.equals(product.name) : product.name != null)
{
return false;
}
return true;
}
@Override
public int hashCode()
{
int result;
long temp;
result = id != null ? id.hashCode() : 0;
result = 31 * result + (name != null ? name.hashCode() : 0);
temp = price != +0.0d ? Double.doubleToLongBits(price) : 0L;
result = 31 * result + (int) (temp ^ (temp >>> 32));
result = 31 * result + quantity;
return result;
}
@Override
public String toString()
{
return "Product{" +
"id=" + id +
", name='" + name + '\'' +
", price=" + DecimalFormat.getNumberInstance().format(price) +
", quantity=" + quantity +
'}';
}
}
There's a ProductDao interface:
package tutorial.persistence;
import tutorial.model.Product;
import java.util.List;
public interface ProductDao
{
List<Product> find();
Product find(Integer id);
List<Product> find(String name);
List<Product> find(double minPrice, double maxPrice);
void save(Product product);
void update(Product product);
void delete(Product product);
void delete();
}
The ProductDaoImpl uses Spring JDBC:
package tutorial.persistence.jdbc;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.simple.SimpleJdbcDaoSupport;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.stereotype.Repository;
import tutorial.model.Product;
import tutorial.persistence.ProductDao;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Repository
public class ProductDaoImpl extends SimpleJdbcDaoSupport implements ProductDao
{
public static final String BASE_SELECT = "select id, name, price, quantity from product ";
public static final String FIND_ALL = BASE_SELECT + " order by id ";
public static final String FIND_BY_ID = BASE_SELECT + " where id = ? ";
public static final String FIND_BY_NAME = BASE_SELECT + " where name = ? ";
public static final String FIND_BY_PRICE_RANGE = BASE_SELECT + " where price between ? and ? ";
public static final String INSERT_SQL = "insert into product(name, price, quantity) values(?,?,?)";
public static final String UPDATE_SQL = "update product set name = ?, price = ?, quantity = ? where id = ?";
public static final String DELETE_ALL_SQL = "delete from product ";
public static final String DELETE_BY_ID = DELETE_ALL_SQL + " where id = ?";
public List<Product> find()
{
ProductRowMapper productRowMapper = new ProductRowMapper();
return getSimpleJdbcTemplate().query(FIND_ALL, productRowMapper);
}
public Product find(Integer id)
{
return this.getSimpleJdbcTemplate().queryForObject(FIND_BY_ID, new ProductRowMapper(), id);
}
public List<Product> find(String name)
{
return this.getSimpleJdbcTemplate().query(FIND_BY_NAME, new ProductRowMapper(), name);
}
public List<Product> find(double minPrice, double maxPrice)
{
return this.getSimpleJdbcTemplate().query(FIND_BY_PRICE_RANGE, new ProductRowMapper(), minPrice, maxPrice);
}
public void save(final Product product)
{
KeyHolder keyHolder = new GeneratedKeyHolder();
this.getJdbcTemplate().update(new PreparedStatementCreator()
{
public PreparedStatement createPreparedStatement(Connection connection) throws SQLException
{
PreparedStatement ps = connection.prepareStatement(INSERT_SQL, new String [] { "id" });
ps.setString(1, product.getName());
ps.setDouble(2, product.getPrice());
ps.setInt(3, product.getQuantity());
return ps;
}
}, keyHolder);
product.setId(keyHolder.getKey().intValue());
}
public void update(final Product product)
{
this.getSimpleJdbcTemplate().update(UPDATE_SQL, product.getName(), product.getPrice(), product.getQuantity(), product.getId());
}
public void delete(final Product product)
{
this.getSimpleJdbcTemplate().update(DELETE_BY_ID, product.getId());
}
public void delete()
{
this.getSimpleJdbcTemplate().update(DELETE_ALL_SQL);
}
private void update(String sql, final Product product)
{
Map parameters = new HashMap()
{{
put("id", product.getId());
put("name", product.getName());
put("price", product.getPrice());
put("quantity", product.getQuantity());
}};
this.getSimpleJdbcTemplate().update(sql, parameters);
}
}
The Spring application context uses the DataSourceTransactionManager:
<?xml version="1.0" encoding="UTF-8"?>
<beans xmlns="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:aop="http://www.springframework.org/schema/aop"
xmlns:tx="http://www.springframework.org/schema/tx"
xsi:schemaLocation="
http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans-2.5.xsd
http://www.springframework.org/schema/aop http://www.springframework.org/schema/aop/spring-aop-2.5.xsd
http://www.springframework.org/schema/tx http://www.springframework.org/schema/tx/spring-tx-2.5.xsd">
<tx:annotation-driven transaction-manager="txManager"/>
<bean id="dataSourceProperties" class="org.springframework.beans.factory.config.PreferencesPlaceholderConfigurer">
<property name="location" value="classpath:product-datasource.properties"/>
</bean>
<bean id="dataSource" class="org.springframework.jdbc.datasource.DriverManagerDataSource">
<property name="driverClassName" value="${datasource.driver}"/>
<property name="url" value="${datasource.url}"/>
<property name="username" value="${datasource.username}"/>
<property name="password" value="${datasource.password}"/>
</bean>
<bean id="productDao" class="tutorial.persistence.jdbc.ProductDaoImpl">
<property name="dataSource" ref="dataSource"/>
</bean>
<bean id="txManager" class="org.springframework.jdbc.datasource.DataSourceTransactionManager">
<property name="dataSource" ref="dataSource"/>
</bean>
</beans>
The Spring transactional JUnit 4 unit test has all the annotations from Chapter 8 of the reference manual:
package tutorial.persistence;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.test.annotation.Rollback;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.transaction.AfterTransaction;
import org.springframework.test.context.transaction.BeforeTransaction;
import org.springframework.test.context.transaction.TransactionConfiguration;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.beans.factory.annotation.Autowired;
import tutorial.model.Product;
import javax.annotation.Resource;
import java.util.List;
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(locations = { "file:resources/product-*.xml" })
@Transactional
@TransactionConfiguration(transactionManager="txManager", defaultRollback=true)
public class ProductDaoTest extends AbstractTransactionalJUnit4SpringContextTests
{
@Autowired
private ProductDao productDao;
Product [] testProducts =
{
new Product("Dell", 1000.0, 100),
new Product("HP", 2000.0, 200),
new Product("Cisco", 3000.0, 300),
new Product("Microsoft", 4000.0, 400),
};
private static final double TOLERANCE = 1.0E-8;
@BeforeTransaction
public void verifyInitialDatabaseState()
{
List<Product> products = this.productDao.find();
assert products != null && products.size() == 0;
}
@Before
public void populateDatabase()
{
for (Product product : testProducts)
{
productDao.save(product);
}
}
@Test
@Rollback(true)
public void testFindAll()
{
List<Product> actual = this.productDao.find();
assert actual != null && actual.size() == testProducts.length;
for (Product product : testProducts)
{
assert actual.contains(product);
}
}
@Test
@Rollback(true)
public void testById()
{
List<Product> actual = this.productDao.find();
assert actual != null && actual.size() == testProducts.length;
for (Product product : testProducts)
{
Product byId = productDao.find(product.getId());
assert byId.equals(product);
}
}
@Test
@Rollback(true)
public void testFindByName()
{
List<Product> actual = this.productDao.find();
assert actual != null && actual.size() == testProducts.length;
for (Product product : testProducts)
{
List<Product> byName = productDao.find(product.getName());
assert byName != null && byName.size() == 1 && byName.get(0).equals(product);
}
}
@Test
@Rollback(true)
public void testFindByPriceRange()
{
List<Product> actual = this.productDao.find();
assert actual != null && actual.size() == testProducts.length;
for (Product product : testProducts)
{
double minPrice = product.getPrice() - 10.0;
double maxPrice = product.getPrice() + 10.0;
List<Product> byPriceRange = productDao.find(minPrice, maxPrice);
assert byPriceRange != null && byPriceRange.size() == 1 && byPriceRange.get(0).equals(product);
}
}
@Test
@Rollback(true)
public void testUpdate()
{
List<Product> actual = this.productDao.find();
assert actual != null && actual.size() == testProducts.length;
double priceIncrease = 1000.0;
for (Product product : testProducts)
{
double oldPrice = product.getPrice();
product.setPrice(oldPrice + priceIncrease);
productDao.update(product);
Product byId = productDao.find(product.getId());
assert Math.abs(byId.getPrice() - (oldPrice+priceIncrease)) < TOLERANCE;
}
}
@Test
@Rollback(true)
public void testDelete()
{
List<Product> before = this.productDao.find();
assert before != null && before.size() == testProducts.length;
productDao.delete(testProducts[0]);
List<Product> after = this.productDao.find();
assert after != null && after.size() == (before.size()-1) && !after.contains(testProducts[0]);
}
@AfterTransaction
public void verifyFinalDatabaseState()
{
List<Product> products = this.productDao.find();
assert products != null && products.size() == 0;
}
}
Since all the transactional annotations are Spring, I thought that switching from JUnit 4 to TestNG would be as simple as the following three steps:
- Remove the @RunWith annotation calling the JUnit 4 runner
- Switch the base class
- Replace the JUnit 4 @Before annotation with its closest TestNG equivalent (@BeforeSuite)
Unfortunately, there's some autowiring magic that's lost in the translation. I get a NullPointerException for the ProductDao reference in the populateDatabase method. When I added code to inject the bean from the application context it failed as well.
If anyone has any advice that would get me off the dime with TestNG I'd appreciate hearing it. In the meantime, I know that Spring's transactional database tests work exactly as advertised with JUnit 4.