MtasSolrResultMerge.java

  1. package mtas.solr.handler.component.util;

  2. import java.io.IOException;
  3. import java.util.ArrayList;
  4. import java.util.HashMap;
  5. import java.util.Iterator;
  6. import java.util.Map.Entry;
  7. import java.util.SortedSet;

  8. import org.apache.commons.logging.Log;
  9. import org.apache.commons.logging.LogFactory;
  10. import org.apache.solr.common.util.NamedList;
  11. import org.apache.solr.common.util.SimpleOrderedMap;
  12. import org.apache.solr.handler.component.ResponseBuilder;
  13. import org.apache.solr.handler.component.ShardRequest;
  14. import org.apache.solr.handler.component.ShardResponse;

  15. import mtas.solr.handler.component.MtasSolrSearchComponent;

  16. /**
  17.  * The Class MtasSolrResultMerge.
  18.  */
  19. public class MtasSolrResultMerge {

  20.   /** The Constant log. */
  21.   private static final Log log = LogFactory.getLog(MtasSolrResultMerge.class);

  22.   /**
  23.    * Merge.
  24.    *
  25.    * @param rb the rb
  26.    */
  27.   @SuppressWarnings("unchecked")
  28.   public void merge(ResponseBuilder rb) {
  29.     if (rb.req.getParams().getBool(MtasSolrSearchComponent.PARAM_MTAS, false)) {
  30.       // mtas response
  31.       NamedList<Object> mtasResponse = null;
  32.       boolean newResponse = false;
  33.       try {
  34.         mtasResponse = (NamedList<Object>) rb.rsp.getValues().get(MtasSolrSearchComponent.NAME);        
  35.       } catch (ClassCastException e) {
  36.         log.debug(e);
  37.         mtasResponse = null;        
  38.       }
  39.       if (mtasResponse == null) {
  40.         newResponse = true;
  41.         mtasResponse = new SimpleOrderedMap<>();              
  42.       }

  43.       for (ShardRequest sreq : rb.finished) {
  44.         if (rb.stage == ResponseBuilder.STAGE_EXECUTE_QUERY) {
  45.           // merge stats
  46.           if (rb.req.getParams()
  47.               .getBool(MtasSolrComponentStats.PARAM_MTAS_STATS, false)) {
  48.             mergeNamedList(sreq, mtasResponse, MtasSolrComponentStats.NAME, null);
  49.           }
  50.           // merge group
  51.           if (rb.req.getParams()
  52.               .getBool(MtasSolrComponentGroup.PARAM_MTAS_GROUP, false)) {
  53.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentGroup.NAME, null, false);
  54.           }
  55.           // merge facet
  56.           if (rb.req.getParams()
  57.               .getBool(MtasSolrComponentFacet.PARAM_MTAS_FACET, false)) {
  58.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentFacet.NAME, null, false);
  59.           }
  60.           // merge collection
  61.           if (rb.req.getParams().getBool(
  62.               MtasSolrComponentCollection.PARAM_MTAS_COLLECTION, false)) {
  63.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentCollection.NAME, null, false);
  64.           }
  65.           // merge prefix
  66.           if (rb.req.getParams()
  67.               .getBool(MtasSolrComponentPrefix.PARAM_MTAS_PREFIX, false)) {
  68.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentPrefix.NAME, null, false);
  69.           }
  70.         } else if (rb.stage == MtasSolrSearchComponent.STAGE_COLLECTION_INIT) {
  71.           // merge collection
  72.           if (rb.req.getParams().getBool(
  73.               MtasSolrComponentCollection.PARAM_MTAS_COLLECTION, false)) {
  74.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentCollection.NAME, null, false);
  75.           }
  76.         } else if (rb.stage == MtasSolrSearchComponent.STAGE_TERMVECTOR_MISSING_KEY) {
  77.           // merge termvector
  78.           if (rb.req.getParams().getBool(
  79.               MtasSolrComponentTermvector.PARAM_MTAS_TERMVECTOR, false)) {
  80.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentTermvector.NAME, null, false);
  81.           }
  82.         } else if (rb.stage == MtasSolrSearchComponent.STAGE_LIST) {
  83.           // merge list
  84.           if (rb.req.getParams().getBool(MtasSolrComponentList.PARAM_MTAS_LIST,
  85.               false)) {
  86.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentList.NAME,
  87.                 ShardRequest.PURPOSE_PRIVATE, true);
  88.           }
  89.         } else if (rb.stage == ResponseBuilder.STAGE_GET_FIELDS) {
  90.           // merge document
  91.           if (rb.req.getParams()
  92.               .getBool(MtasSolrComponentDocument.PARAM_MTAS_DOCUMENT, false)) {
  93.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentDocument.NAME,
  94.                 ShardRequest.PURPOSE_PRIVATE, true);
  95.           }
  96.           // merge kwic
  97.           if (rb.req.getParams().getBool(MtasSolrComponentKwic.PARAM_MTAS_KWIC,
  98.               false)) {
  99.             mergeArrayList(sreq, mtasResponse, MtasSolrComponentKwic.NAME,
  100.                 ShardRequest.PURPOSE_PRIVATE, true);
  101.           }
  102.         }
  103.       }
  104.       if(newResponse && mtasResponse.size()>0) {
  105.         rb.rsp.add(MtasSolrSearchComponent.NAME, mtasResponse);
  106.       }
  107.     }
  108.   }

  109.   /**
  110.    * Merge named list.
  111.    *
  112.    * @param sreq the sreq
  113.    * @param mtasResponse the mtas response
  114.    * @param key the key
  115.    * @param preferredPurpose the preferred purpose
  116.    */
  117.   @SuppressWarnings("unchecked")
  118.   private void mergeNamedList(ShardRequest sreq, NamedList<Object> mtasResponse,
  119.       String key, Integer preferredPurpose) {
  120.     // create new response for key
  121.     NamedList<Object> mtasListResponse;
  122.     Object o = mtasResponse.get(key);
  123.     if (o instanceof NamedList) {
  124.       mtasListResponse = (NamedList<Object>) o;
  125.     } else {
  126.       mtasListResponse = new SimpleOrderedMap<>();
  127.       mtasResponse.removeAll(key);
  128.       mtasResponse.add(key, mtasListResponse);
  129.     }
  130.     // collect responses for each shard
  131.     HashMap<String, NamedList<Object>> mtasListShardResponses = new HashMap<>();
  132.     for (ShardResponse response : sreq.responses) {
  133.       // only continue if new shard or preferred purpose
  134.       if (mtasListShardResponses.containsKey(response.getShard())
  135.           && ((preferredPurpose == null)
  136.               || (sreq.purpose != preferredPurpose))) {
  137.         break;
  138.       }
  139.       // update
  140.       try {
  141.         NamedList<Object> result = response.getSolrResponse().getResponse();
  142.         NamedList<Object> data = (NamedList<Object>) result
  143.             .findRecursive("mtas", key);
  144.         if (data != null) {
  145.           mtasListShardResponses.put(response.getShard(),
  146.               MtasSolrResultUtil.decode(data));
  147.         }
  148.       } catch (ClassCastException e) {
  149.         log.debug(e);
  150.       }
  151.     }
  152.     try {
  153.       for (NamedList<Object> mtasListShardResponse : mtasListShardResponses
  154.           .values()) {
  155.         mergeResponsesNamedList(mtasListResponse, mtasListShardResponse);
  156.       }
  157.     } catch (IOException e) {
  158.       log.error(e);
  159.     }
  160.   }

  161.   /**
  162.    * Merge array list.
  163.    *
  164.    * @param sreq the sreq
  165.    * @param mtasResponse the mtas response
  166.    * @param key the key
  167.    * @param preferredPurpose the preferred purpose
  168.    * @param mergeAllShardResponses the merge all shard responses
  169.    */
  170.   @SuppressWarnings("unchecked")
  171.   private void mergeArrayList(ShardRequest sreq, NamedList<Object> mtasResponse,
  172.       String key, Integer preferredPurpose, boolean mergeAllShardResponses) {
  173.     // create new response for key
  174.     ArrayList<Object> mtasListResponse;
  175.     Object o = mtasResponse.get(key);
  176.     if (o instanceof ArrayList) {
  177.       mtasListResponse = (ArrayList<Object>) o;
  178.     } else {
  179.       mtasListResponse = new ArrayList<>();
  180.       mtasResponse.removeAll(key);
  181.       mtasResponse.add(key, mtasListResponse);
  182.     }
  183.     // collect responses for each shard
  184.     HashMap<String, ArrayList<Object>> mtasListShardResponses = new HashMap<>();
  185.     ArrayList<ArrayList<Object>> mtasListShardResponsesExtra = new ArrayList<>();
  186.     for (ShardResponse response : sreq.responses) {
  187.       // only continue if new shard or preferred purpose
  188.       if (mtasListShardResponses.containsKey(response.getShard())
  189.           && ((preferredPurpose == null)
  190.               || (sreq.purpose != preferredPurpose))) {
  191.         break;
  192.       }
  193.       // update
  194.       try {
  195.         NamedList<Object> result = response.getSolrResponse().getResponse();
  196.         ArrayList<Object> data = (ArrayList<Object>) result
  197.             .findRecursive("mtas", key);
  198.         if (data != null) {
  199.           if (mtasListShardResponses.containsKey(response.getShardAddress())) {
  200.             if (mergeAllShardResponses) {
  201.               mtasListShardResponsesExtra.add(data);
  202.             }
  203.           } else {
  204.             mtasListShardResponses.put(response.getShardAddress(), data);
  205.           }
  206.         }
  207.       } catch (ClassCastException e) {
  208.         log.error(e);
  209.       }
  210.     }

  211.     try {
  212.       for (ArrayList<Object> mtasListShardResponse : mtasListShardResponses
  213.           .values()) {
  214.         mergeResponsesArrayList(mtasListResponse, mtasListShardResponse);
  215.       }
  216.       for (ArrayList<Object> mtasListShardResponse : mtasListShardResponsesExtra) {
  217.         mergeResponsesArrayList(mtasListResponse, mtasListShardResponse);
  218.       }
  219.     } catch (IOException e) {
  220.       log.error(e);
  221.     }
  222.   }

  223.   /**
  224.    * Merge responses sorted set.
  225.    *
  226.    * @param originalList the original list
  227.    * @param shardList the shard list
  228.    */
  229.   private void mergeResponsesSortedSet(SortedSet<Object> originalList,
  230.       SortedSet<Object> shardList) {
  231.     for (Object item : shardList) {
  232.       originalList.add(item);
  233.     }
  234.   }

  235.   /**
  236.    * Merge responses array list.
  237.    *
  238.    * @param originalList the original list
  239.    * @param shardList the shard list
  240.    * @throws IOException Signals that an I/O exception has occurred.
  241.    */
  242.   @SuppressWarnings("unchecked")
  243.   private void mergeResponsesArrayList(ArrayList<Object> originalList,
  244.       ArrayList<Object> shardList) throws IOException {
  245.     // get keys from original
  246.     HashMap<String, Object> originalKeyList = new HashMap<>();
  247.     for (Object item : originalList) {
  248.       if (item instanceof NamedList<?>) {
  249.         NamedList<Object> itemList = (NamedList<Object>) item;
  250.         Object key = itemList.get("key");
  251.         if ((key != null) && (key instanceof String)) {
  252.           originalKeyList.put((String) key, item);
  253.         }
  254.       }
  255.     }
  256.     for (Object item : shardList) {
  257.       if (item instanceof NamedList<?>) {
  258.         NamedList<Object> itemList = (NamedList<Object>) item;
  259.         Object key = itemList.get("key");
  260.         // item with key
  261.         if ((key != null) && (key instanceof String)) {
  262.           // merge
  263.           if (originalKeyList.containsKey(key)) {
  264.             Object originalItem = originalKeyList.get(key);
  265.             if (originalItem.getClass().equals(item.getClass())) {
  266.               mergeResponsesNamedList((NamedList<Object>) originalItem,
  267.                   (NamedList<Object>) item);
  268.             } else {
  269.               // ignore?
  270.             }
  271.             // add
  272.           } else {
  273.             Object clonedItem = adjustablePartsCloned(item);
  274.             originalList.add(clonedItem);
  275.             originalKeyList.put((String) key, clonedItem);
  276.           }
  277.         } else {
  278.           originalList.add(item);
  279.         }
  280.       } else {
  281.         originalList.add(item);
  282.       }
  283.     }
  284.   }

  285.   /**
  286.    * Merge responses named list.
  287.    *
  288.    * @param mainResponse the main response
  289.    * @param shardResponse the shard response
  290.    * @throws IOException Signals that an I/O exception has occurred.
  291.    */
  292.   @SuppressWarnings({ "rawtypes", "unchecked" })
  293.   private void mergeResponsesNamedList(NamedList<Object> mainResponse,
  294.       NamedList<Object> shardResponse) throws IOException {
  295.     Iterator<Entry<String, Object>> it = shardResponse.iterator();
  296.     while (it.hasNext()) {
  297.       Entry<String, Object> entry = it.next();
  298.       String name = entry.getKey();
  299.       Object shardValue = entry.getValue();
  300.       int originalId = mainResponse.indexOf(name, 0);
  301.       if (originalId < 0) {
  302.         mainResponse.add(name, adjustablePartsCloned(shardValue));
  303.       } else {
  304.         Object original = mainResponse.getVal(originalId);
  305.         if (original == null) {
  306.           original = adjustablePartsCloned(shardValue);
  307.         } else if (shardValue != null
  308.             && original.getClass().equals(shardValue.getClass())) {
  309.           // merge ArrayList
  310.           if (original instanceof ArrayList) {
  311.             ArrayList originalList = (ArrayList) original;
  312.             ArrayList shardList = (ArrayList) shardValue;
  313.             mergeResponsesArrayList(originalList, shardList);
  314.             // merge Namedlist
  315.           } else if (original instanceof NamedList<?>) {
  316.             mergeResponsesNamedList((NamedList<Object>) original,
  317.                 (NamedList<Object>) shardValue);
  318.             // merge SortedSet
  319.           } else if (original instanceof SortedSet<?>) {
  320.             mergeResponsesSortedSet((SortedSet<Object>) original,
  321.                 (SortedSet<Object>) shardValue);
  322.           } else if (original instanceof MtasSolrMtasResult) {
  323.             MtasSolrMtasResult originalComponentResult = (MtasSolrMtasResult) original;
  324.             originalComponentResult.merge((MtasSolrMtasResult) shardValue);
  325.           } else if (original instanceof MtasSolrCollectionResult) {
  326.             MtasSolrCollectionResult originalComponentResult = (MtasSolrCollectionResult) original;
  327.             originalComponentResult
  328.                 .merge((MtasSolrCollectionResult) shardValue);
  329.           } else if (original instanceof String) {
  330.             // ignore?
  331.           } else if (original instanceof Integer) {
  332.             original = (Integer) original + ((Integer) shardValue);
  333.           } else if (original instanceof Long) {
  334.             original = (Long) original + ((Long) shardValue);
  335.           } else {
  336.             // ignore?
  337.           }
  338.           mainResponse.setVal(originalId, original);
  339.         } else {
  340.           // ignore?
  341.         }
  342.       }
  343.     }
  344.   }

  345.   /**
  346.    * Adjustable parts cloned.
  347.    *
  348.    * @param original the original
  349.    * @return the object
  350.    */
  351.   @SuppressWarnings({ "rawtypes", "unchecked" })
  352.   private Object adjustablePartsCloned(Object original) {
  353.     if (original instanceof NamedList) {
  354.       NamedList<Object> newObject = new SimpleOrderedMap();
  355.       NamedList<Object> originalObject = (NamedList<Object>) original;
  356.       for (int i = 0; i < originalObject.size(); i++) {
  357.         newObject.add(originalObject.getName(i),
  358.             adjustablePartsCloned(originalObject.getVal(i)));
  359.       }
  360.       return newObject;
  361.     } else if (original instanceof ArrayList) {
  362.       ArrayList<Object> newObject = new ArrayList<>();
  363.       ArrayList<Object> originalObject = (ArrayList<Object>) original;
  364.       for (int i = 0; i < originalObject.size(); i++) {
  365.         newObject.add(adjustablePartsCloned(originalObject.get(i)));
  366.       }
  367.       return newObject;
  368.     } else if (original instanceof Integer) {
  369.       return original;
  370.     }
  371.     return original;
  372.   }

  373. }