View Javadoc
1   /*
2    *
3    * The DbUnit Database Testing Framework
4    * Copyright (C)2002-2005, DbUnit.org
5    *
6    * This library is free software; you can redistribute it and/or
7    * modify it under the terms of the GNU Lesser General Public
8    * License as published by the Free Software Foundation; either
9    * version 2.1 of the License, or (at your option) any later version.
10   *
11   * This library is distributed in the hope that it will be useful,
12   * but WITHOUT ANY WARRANTY; without even the implied warranty of
13   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14   * Lesser General Public License for more details.
15   *
16   * You should have received a copy of the GNU Lesser General Public
17   * License along with this library; if not, write to the Free Software
18   * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19   *
20   */
21  package org.dbunit.database;
22  
23  import java.sql.PreparedStatement;
24  import java.sql.ResultSet;
25  import java.sql.SQLException;
26  import java.util.*;
27  
28  import org.dbunit.database.search.ForeignKeyRelationshipEdge;
29  import org.dbunit.dataset.DataSetException;
30  import org.dbunit.dataset.IDataSet;
31  import org.dbunit.dataset.ITable;
32  import org.dbunit.dataset.ITableIterator;
33  import org.dbunit.dataset.ITableMetaData;
34  import org.dbunit.dataset.filter.AbstractTableFilter;
35  import org.dbunit.util.SQLHelper;
36  import org.slf4j.Logger;
37  import org.slf4j.LoggerFactory;
38  
39  /**
40   * Filter a table given a map of the allowed rows based on primary key values.<br>
41   * It uses a depth-first algorithm (although not recursive - it might be refactored
42   * in the future) to define which rows are allowed, as well which rows are necessary
43   * (and hence allowed) because of dependencies with the allowed rows.<br>
44   * <strong>NOTE:</strong> multi-column primary keys are not supported at the moment.
45   * TODO: test cases
46   * @author Felipe Leme (dbunit@felipeal.net)
47   * @author Last changed by: $Author$
48   * @version $Revision$ $Date$
49   * @since Sep 9, 2005
50   */
51  public class PrimaryKeyFilter extends AbstractTableFilter {
52  
53      private final IDatabaseConnection connection;
54  
55      private final PkTableMap allowedPKsPerTable;
56      private final PkTableMap allowedPKsInput;
57      private final PkTableMap pksToScanPerTable;
58  
59      private final boolean reverseScan;
60  
61      protected final Logger logger = LoggerFactory.getLogger(getClass());
62  
63      // cache the primary keys
64      private final Map pkColumnPerTable = new HashMap();
65  
66      private final Map fkEdgesPerTable = new HashMap();
67      private final Map fkReverseEdgesPerTable = new HashMap();
68  
69      // name of the tables, in reverse order of dependency
70      private final List tableNames = new ArrayList();
71  
72      /**
73       * Default constructor, it takes as input a map with desired rows in a final
74       * dataset; the filter will ensure that the rows necessary by these initial rows
75       * are also allowed (and so on...).
76       * @param connection database connection
77       * @param allowedPKs map of allowed rows, based on the primary keys (key is the name
78       * of a table; value is a Set with allowed primary keys for that table)
79       * @param reverseDependency flag indicating if the rows that depend on a row should
80       * also be allowed by the filter
81       */
82      public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) {
83          this.connection = connection;    
84          this.allowedPKsPerTable = new PkTableMap();    
85          this.allowedPKsInput = allowedPKs;
86          this.reverseScan = reverseDependency;
87  
88          // we need a deep copy here
89          this.pksToScanPerTable = new PkTableMap(allowedPKs);
90      }
91  
92      public void nodeAdded(Object node) {
93          this.tableNames.add( node );
94          if ( this.logger.isDebugEnabled() ) {
95              this.logger.debug("nodeAdded: " + node );
96          }
97      }
98  
99      public void edgeAdded(ForeignKeyRelationshipEdge edge) {
100         if ( this.logger.isDebugEnabled() ) {
101             this.logger.debug("edgeAdded: " + edge );
102         }
103         // first add it to the "direct edges"
104         String from = (String) edge.getFrom();
105         Set edges = (Set) this.fkEdgesPerTable.get(from);
106         if ( edges == null ) {
107             edges = new HashSet();
108             this.fkEdgesPerTable.put( from, edges );
109         }
110         if ( ! edges.contains(edge) ) {
111             edges.add(edge);
112         }
113 
114         // then add it to the "reverse edges"
115         String to = (String) edge.getTo();
116         edges = (Set) this.fkReverseEdgesPerTable.get(to);
117         if ( edges == null ) {
118             edges = new HashSet();
119             this.fkReverseEdgesPerTable.put(to, edges);
120         }
121         if ( ! edges.contains(edge) ) {
122             edges.add(edge);
123         }
124 
125         // finally, update the PKs cache
126         updatePkCache(to, edge);
127 
128     }
129 
130     /**
131      * @see AbstractTableFilter
132      */
133     public boolean isValidName(String tableName) throws DataSetException {
134         //    boolean isValid = this.allowedIds.containsKey(tableName);
135         //    return isValid;
136         return true;
137     }
138 
139     public ITableIterator iterator(IDataSet dataSet, boolean reversed)
140     throws DataSetException {
141         if ( this.logger.isDebugEnabled() ) {
142             this.logger.debug("Filter.iterator()" );
143         }
144         try {
145             searchPKs(dataSet);
146         } catch (SQLException e) {
147             throw new DataSetException( e );
148         }
149         return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet
150                 .iterator());
151     }
152 
153     private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
154         logger.debug("searchPKs(dataSet={}) - start", dataSet);
155 
156         int counter = 0;
157         while ( !this.pksToScanPerTable.isEmpty() ) {
158             counter ++;
159             if ( this.logger.isDebugEnabled() ) {
160                 this.logger.debug( "RUN # " + counter );
161             }
162 
163             for( int i=this.tableNames.size()-1; i>=0; i-- ) {
164                 String tableName = (String) this.tableNames.get(i);
165                 // TODO: support multi-column PKs
166                 String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
167                 Set tmpSet = this.pksToScanPerTable.get( tableName );
168                 if ( tmpSet != null && ! tmpSet.isEmpty() ) {
169                     Set pksToScan = new HashSet( tmpSet );
170                     if ( this.logger.isDebugEnabled() ) {
171                         this.logger.debug(  "before search: "+ tableName + "=>" + pksToScan );
172                     }
173                     scanPKs( tableName, pkColumn, pksToScan );
174                     scanReversePKs( tableName, pksToScan );
175                     allowPKs( tableName, pksToScan );
176                     removePKsToScan( tableName, pksToScan );
177                 } // if
178             } // for 
179             removeScannedTables();
180         } // while
181         if ( this.logger.isDebugEnabled() ) {
182             this.logger.debug( "Finished searchIds()" );
183         }
184     } 
185 
186     private void removeScannedTables() {
187         logger.debug("removeScannedTables() - start");
188         this.pksToScanPerTable.retainOnly(this.tableNames);
189     }
190 
191     private void allowPKs(String table, Set newAllowedPKs) {
192         logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs);
193 
194         // then, add the new IDs, but checking if it should be allowed to add them
195         Set forcedAllowedPKs = this.allowedPKsInput.get( table );
196         if( forcedAllowedPKs == null || forcedAllowedPKs.isEmpty() ) {
197             allowedPKsPerTable.addAll(table, newAllowedPKs );
198         } else {
199             for(Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext(); ) {
200                 Object id = iterator.next();
201                 if( forcedAllowedPKs.contains(id) ) {
202                     allowedPKsPerTable.add(table, id);
203                 } 
204                 else 
205                 {
206                     logger.debug("Discarding id {} of table {} as it was not included in the input!", id, table);
207                 }
208             }
209         }
210     }
211 
212     private void scanPKs( String table, String pkColumn, Set allowedIds ) throws SQLException {
213         if (logger.isDebugEnabled())
214         {
215             logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start",
216                     new Object[]{ table, pkColumn, allowedIds });
217         }
218 
219         Set fkEdges = (Set) this.fkEdgesPerTable.get( table );
220         if ( fkEdges == null || fkEdges.isEmpty() ) {
221             return;
222         }
223         // we need a temporary list as there is no warranty about the set order...
224         List fkTables = new ArrayList( fkEdges.size() );
225         final StringBuilder colsBuffer = new StringBuilder();
226         for(Iterator iterator = fkEdges.iterator(); iterator.hasNext(); ) {
227             ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
228             fkTables.add( edge.getTo() );
229             colsBuffer.append( edge.getFKColumn() );
230             if ( iterator.hasNext() ) {
231                 colsBuffer.append( ", " );
232             }
233         }
234         // NOTE: make sure the query below is compatible standard SQL
235         String sql = "SELECT " + colsBuffer + " FROM " + table + 
236         " WHERE " + pkColumn + " = ? ";
237         if ( this.logger.isDebugEnabled() ) {
238             this.logger.debug( "SQL: " + sql );
239         }
240 
241         scanPKs(table, sql, allowedIds, fkTables);
242     }
243 
244     private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException
245     {
246         PreparedStatement pstmt = null;
247         ResultSet rs = null;
248         try {
249             pstmt = this.connection.getConnection().prepareStatement( sql );
250             for(Iterator iterator = allowedIds.iterator(); iterator.hasNext(); ) {
251                 Object pk = iterator.next(); // id being scanned
252                 if( this.logger.isDebugEnabled() ) {
253                     this.logger.debug("Executing sql for ? = " + pk );
254                 }
255                 pstmt.setObject( 1, pk );
256                 rs = pstmt.executeQuery();
257                 while( rs.next() ) {
258                     for( int i=0; i<fkTables.size(); i++ ) {
259                         String newTable = (String) fkTables.get(i);
260                         Object fk = rs.getObject(i+1);
261                         if( fk != null ) {
262                             logger.debug("New ID: {}->{}", newTable, fk);
263                             addPKToScan( newTable, fk );
264                         } 
265                         else {
266                             logger.warn( "Found null FK for relationship {} =>{}", table, newTable );
267                         }
268                     }
269                 }
270             }
271         } catch (SQLException e) {
272             logger.error("scanPKs()", e);
273         }
274         finally {
275             // new in the finally block. has been in the catch only before
276             SQLHelper.close( rs, pstmt );
277         }
278     }
279 
280     private void scanReversePKs(String table, Set pksToScan) throws SQLException {
281         logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan);
282 
283         if ( ! this.reverseScan ) {
284             return; 
285         }
286         Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get( table );
287         if ( fkReverseEdges == null || fkReverseEdges.isEmpty() ) {
288             return;
289         }
290         Iterator iterator = fkReverseEdges.iterator();
291         while ( iterator.hasNext() ) {
292             ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
293             addReverseEdge( edge, pksToScan );
294         }
295     }
296 
297     private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
298         logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan);
299 
300         String fkTable = (String) edge.getFrom();
301         String fkColumn = edge.getFKColumn();
302         String pkColumn = getPKColumn( fkTable );
303         // NOTE: make sure the query below is compatible standard SQL
304         String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";
305 
306         PreparedStatement pstmt = null;
307         ResultSet rs = null;
308         try {
309             logger.debug("Preparing SQL query '{}'", sql);
310 
311             pstmt = this.connection.getConnection().prepareStatement(sql);
312             for(Iterator iterator = idsToScan.iterator(); iterator.hasNext(); ) {
313                 Object pk = iterator.next();
314                 if ( this.logger.isDebugEnabled() ) {
315                     this.logger.debug( "executing query '" + sql + "' for ? = " + pk );
316                 }
317                 pstmt.setObject( 1, pk );
318                 rs = pstmt.executeQuery();
319                 while( rs.next() ) {
320                     Object fk = rs.getObject(1);
321                     addPKToScan( fkTable, fk );
322                 }
323             } 
324         } finally {
325             SQLHelper.close( rs, pstmt );
326         }
327     }
328 
329     private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) {
330         logger.debug("updatePkCache(to={}, edge={}) - start", table, edge);
331 
332         Object pkTo = this.pkColumnPerTable.get(table);
333         if ( pkTo == null ) {
334             String pkColumn = edge.getPKColumn();
335             this.pkColumnPerTable.put( table, pkColumn );
336         }
337     }
338 
339     // TODO: support PKs with multiple values
340     private String getPKColumn( String table ) throws SQLException {
341         logger.debug("getPKColumn(table={}) - start", table);
342 
343         // Try to get the cached column
344         String pkColumn = (String) this.pkColumnPerTable.get( table );
345         if ( pkColumn == null ) {
346             // If the column has not been cached until now retrieve it from the database connection
347             pkColumn = SQLHelper.getPrimaryKeyColumn( this.connection.getConnection(), table );
348             this.pkColumnPerTable.put( table, pkColumn );
349         }
350         return pkColumn;
351     }
352 
353 
354     private void removePKsToScan(String table, Set ids) {
355         logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids);
356 
357         Set pksToScan = this.pksToScanPerTable.get(table);
358         if ( pksToScan != null ) {
359             if ( pksToScan == ids ) {   
360                 throw new RuntimeException( "INTERNAL ERROR on removeIdsToScan() for table " + table );
361             } else {
362                 pksToScan.removeAll( ids );
363             }
364         }    
365     }
366 
367     private void addPKToScan(String table, Object pk) {
368         logger.debug("addPKToScan(table={}, pk={}) - start", table, pk);
369 
370         // first, check if it wasn't added yet
371         if(this.allowedPKsPerTable.contains(table, pk)) {
372             if ( this.logger.isDebugEnabled() ) {
373                 this.logger.debug( "Discarding already scanned id=" + pk + " for table " + table );
374             }
375             return;
376         }
377 
378         this.pksToScanPerTable.add(table, pk);
379     }
380 
381     public String toString() {
382         final StringBuilder sb = new StringBuilder();
383         sb.append("tableNames=").append(tableNames);
384         sb.append(", allowedPKsInput=").append(allowedPKsInput);
385         sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable);
386         sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable);
387         sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable);
388         sb.append(", pkColumnPerTable=").append(pkColumnPerTable);
389         sb.append(", pksToScanPerTable=").append(pksToScanPerTable);
390         sb.append(", reverseScan=").append(reverseScan);
391         sb.append(", connection=").append(connection);
392         return sb.toString();
393     }
394 
395 
396     private class FilterIterator implements ITableIterator {
397 
398         private final ITableIterator _iterator;
399 
400         public FilterIterator(ITableIterator iterator) {
401 
402             _iterator = iterator;
403         }
404 
405         ////////////////////////////////////////////////////////////////////////////
406         // ITableIterator interface
407 
408         public boolean next() throws DataSetException {
409             if ( logger.isDebugEnabled() ) {
410                 logger.debug("Iterator.next()" );
411             }      
412             while (_iterator.next()) {
413                 if (accept(_iterator.getTableMetaData().getTableName())) {
414                     return true;
415                 }
416             }
417             return false;
418         }
419 
420         public ITableMetaData getTableMetaData() throws DataSetException {
421             if ( logger.isDebugEnabled() ) {
422                 logger.debug("Iterator.getTableMetaData()" );
423             }      
424             return _iterator.getTableMetaData();
425         }
426 
427         public ITable getTable() throws DataSetException {
428             if ( logger.isDebugEnabled() ) {
429                 logger.debug("Iterator.getTable()" );
430             }
431             ITable table = _iterator.getTable();
432             String tableName = table.getTableMetaData().getTableName();
433             Set allowedPKs = allowedPKsPerTable.get( tableName );
434             if ( allowedPKs != null ) {
435                 return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
436             }
437             return table;
438         }
439     }
440 
441     /**
442      * Map that associates a table with a set of primary key objects.
443      * 
444      * @author gommma (gommma AT users.sourceforge.net)
445      * @author Last changed by: $Author$
446      * @version $Revision$ $Date$
447      * @since 2.3.0
448      */
449     public static class PkTableMap
450     {
451         private final LinkedHashMap pksPerTable;
452         private final Logger logger = LoggerFactory.getLogger(PkTableMap.class);
453 
454         public PkTableMap()
455         {
456             this.pksPerTable = new LinkedHashMap();
457         }
458 
459         /**
460          * Copy constructor
461          * @param allowedPKs
462          */
463         public PkTableMap(PkTableMap allowedPKs) {
464             this.pksPerTable = new LinkedHashMap();
465             Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator();
466             while ( iterator.hasNext() ) {
467                 Map.Entry entry = (Map.Entry) iterator.next();
468                 String table = (String)entry.getKey();
469                 SortedSet pkObjectSet = (SortedSet) entry.getValue();
470                 SortedSet newSet = new TreeSet( pkObjectSet );
471                 this.pksPerTable.put( table, newSet );
472             }
473         }
474 
475         public int size() {
476             return pksPerTable.size();
477         }
478 
479         public boolean isEmpty() {
480             return pksPerTable.isEmpty();
481         }
482 
483         public boolean contains(String table, Object pkObject) {
484             Set pksPerTable = this.get(table);
485             return (pksPerTable != null && pksPerTable.contains(pkObject));
486         }
487 
488         public void remove(String tableName) {
489             this.pksPerTable.remove(tableName);
490         }
491 
492         public void put(String table, SortedSet pkObjects) {
493             this.pksPerTable.put(table, pkObjects);
494         }
495 
496         public void add(String tableName, Object pkObject) {
497             Set pksPerTable = getCreateIfNeeded(tableName);
498             pksPerTable.add(pkObject);
499         }
500 
501         public void addAll(String tableName, Set pkObjectsToAdd) {
502             Set pksPerTable = this.getCreateIfNeeded(tableName);
503             pksPerTable.addAll(pkObjectsToAdd);
504         }
505 
506         public SortedSet get(String tableName) {
507             return (SortedSet) this.pksPerTable.get(tableName);
508         }
509 
510         private SortedSet getCreateIfNeeded(String tableName){
511             SortedSet pksPerTable = this.get(tableName);
512             // Lazily create the set if it did not exist yet
513             if( pksPerTable == null ) {
514                 pksPerTable = new TreeSet();
515                 this.pksPerTable.put(tableName, pksPerTable);
516             }
517             return pksPerTable;
518         }
519 
520         public String[] getTableNames() {
521             return (String[]) this.pksPerTable.keySet().toArray(new String[0]);
522         }
523 
524         public void retainOnly(List tableNames) {
525 
526             List tablesToRemove = new ArrayList();
527             for(Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext(); ) {
528                 Map.Entry entry = (Map.Entry) iterator.next();
529                 String table = (String) entry.getKey();
530                 SortedSet pksToScan = (SortedSet) entry.getValue();
531                 boolean removeIt = pksToScan.isEmpty();
532 
533                 if ( ! tableNames.contains(table) ) {
534                     if ( this.logger.isWarnEnabled() ) {
535                         this.logger.warn("Discarding ids " + pksToScan + " of table " + table +
536                         "as this table has not been passed as input" );
537                     }
538                     removeIt = true;
539                 }
540                 if ( removeIt ) {
541                     tablesToRemove.add( table );
542                 }
543             }
544 
545             for(Iterator iterator = tablesToRemove.iterator(); iterator.hasNext(); ) {
546                 this.remove( (String)iterator.next() );
547             }
548         }
549         
550         
551         public String toString() {
552             final StringBuilder sb = new StringBuilder();
553             sb.append("pKsPerTable=").append(pksPerTable);
554             return sb.toString();
555         }
556 
557     }
558 }