GooFit  v2.1.3
MetricTaker.cu
Go to the documentation of this file.
1 #include <goofit/PDFs/MetricTaker.h>
2 
3 #include <goofit/GlobalCudaDefines.h>
4 #include <goofit/PDFs/GooPdf.h>
5 #include <goofit/detail/ThrustOverride.h>
6 
7 #include <goofit/BinnedDataSet.h>
8 #include <goofit/Error.h>
9 #include <goofit/FitControl.h>
10 #include <goofit/Log.h>
11 #include <goofit/UnbinnedDataSet.h>
12 #include <goofit/Variable.h>
13 
14 #include <thrust/device_vector.h>
15 #include <thrust/host_vector.h>
16 #include <thrust/iterator/constant_iterator.h>
17 #include <thrust/iterator/zip_iterator.h>
18 #include <thrust/sequence.h>
19 #include <thrust/transform.h>
20 #include <thrust/transform_reduce.h>
21 
22 namespace GooFit {
23 
24 __device__ fptype MetricTaker::operator()(thrust::tuple<int, fptype *, int> t) const {
25  // Calculate event offset for this thread.
26  int eventIndex = thrust::get<0>(t);
27  int eventSize = thrust::get<2>(t);
28  fptype *eventAddress = thrust::get<1>(t) + (eventIndex * abs(eventSize));
29 
30  // Causes stack size to be statically undeterminable.
31  fptype ret = callFunction(eventAddress, functionIdx, parameters);
32 
33  // Notice assumption here! For unbinned fits the 'eventAddress' pointer won't be used
34  // in the metric, so it doesn't matter what it is. For binned fits it is assumed that
35  // the structure of the event is (obs1 obs2... binentry binvolume), so that the array
36  // passed to the metric consists of (binentry binvolume).
37  ret = (*(reinterpret_cast<device_metric_ptr>(device_function_table[metricIndex])))(
38  ret, eventAddress + (abs(eventSize) - 2), parameters);
39  return ret;
40 }
41 
42 #define MAX_NUM_OBSERVABLES 5
43 
44 //__device__ fptype binCenters[1024*MAX_NUM_OBSERVABLES];
45 
46 __device__ fptype MetricTaker::operator()(thrust::tuple<int, int, fptype *> t) const {
47  // Bin index, event size, base address [lower, upper,getNumBins]
48  __shared__ fptype binCenters[1024 * MAX_NUM_OBSERVABLES];
49 
50  int evtSize = thrust::get<1>(t);
51  int binNumber = thrust::get<0>(t);
52 
53  // To convert global bin number to (x,y,z...) coordinates: For each dimension, take the mod
54  // with the number of bins in that dimension. Then divide by the number of bins, in effect
55  // collapsing so the grid has one fewer dimension. Rinse and repeat.
56  unsigned int *indices = paramIndices + parameters;
57 
58  for(int i = 0; i < evtSize; ++i) {
59  fptype lowerBound = thrust::get<2>(t)[3 * i + 0];
60  fptype upperBound = thrust::get<2>(t)[3 * i + 1];
61  auto numBins = static_cast<int>(floor(thrust::get<2>(t)[3 * i + 2] + 0.5));
62  int localBin = binNumber % numBins;
63 
64  fptype x = upperBound - lowerBound;
65  x /= numBins;
66  x *= (localBin + 0.5);
67  x += lowerBound;
68  binCenters[indices[indices[0] + 2 + i] + THREADIDX * MAX_NUM_OBSERVABLES] = x;
69  binNumber /= numBins;
70  }
71 
72  // Causes stack size to be statically undeterminable.
73  fptype ret = callFunction(binCenters + THREADIDX * MAX_NUM_OBSERVABLES, functionIdx, parameters);
74  return ret;
75 }
76 
77 MetricTaker::MetricTaker(PdfBase *dat, void *dev_functionPtr)
78  : metricIndex(0)
79  , functionIdx(dat->getFunctionIndex())
80  , parameters(dat->getParameterIndex()) {
81  // std::cout << "MetricTaker constructor with " << functionIdx << std::endl;
82 
83  auto localPos = functionAddressToDeviceIndexMap.find(dev_functionPtr);
84 
85  if(localPos != functionAddressToDeviceIndexMap.end()) {
86  metricIndex = (*localPos).second;
87  } else {
88  metricIndex = num_device_functions;
89  host_function_table[num_device_functions] = dev_functionPtr;
90  functionAddressToDeviceIndexMap[dev_functionPtr] = num_device_functions;
91  num_device_functions++;
92  MEMCPY_TO_SYMBOL(device_function_table,
93  host_function_table,
94  num_device_functions * sizeof(void *),
95  0,
96  cudaMemcpyHostToDevice);
97  }
98 }
99 
100 MetricTaker::MetricTaker(int fIdx, int pIdx)
101  : metricIndex(0)
102  , functionIdx(fIdx)
103  , parameters(pIdx) {
104  // This constructor should only be used for binned evaluation, ie for integrals.
105 }
106 
107 } // namespace GooFit