Skip to content

Commit 8e94ade

Browse files
committed
Divide broadcast
1 parent a414cb9 commit 8e94ade

File tree

7 files changed

+33
-24
lines changed

7 files changed

+33
-24
lines changed
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
extern "C" {
2-
__global__ void Run(int n, double* __restrict left, double* __restrict right, double* __restrict output) {
2+
__global__ void Run(int n, double* __restrict left, double* __restrict right, double* __restrict output, int rightIsScalar) {
33
int i = blockIdx.x*blockDim.x + threadIdx.x;
4-
if (i < n) output[i] = left[i] / right[i];
4+
if (i < n) {
5+
if (rightIsScalar == 1) {
6+
output[i] = left[i] / right[0];
7+
} else {
8+
output[i] = left[i] / right[i];
9+
}
10+
}
511
}
612
}

src/ConvNetSharp.Volume.GPU/Double/Volume.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ public override void ConvolutionGradient(Volume<double> filters, Volume<double>
406406

407407
public override void Divide(Volume<double> other, Volume<double> result)
408408
{
409-
_kernelLoader.RunKernel("div", this, other, result);
409+
_kernelLoader.RunKernel("div", this, other, result, other.Shape.TotalLength == 1 ? 1 : 0);
410410
}
411411

412412
public override void Dropout(double dropProbability, Volume<double> result)
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
extern "C" {
2-
__global__ void Run(int n, float* __restrict left, float* __restrict right, float* __restrict output) {
2+
__global__ void Run(int n, float* __restrict left, float* __restrict right, float* __restrict output, int rightIsScalar) {
33
int i = blockIdx.x*blockDim.x + threadIdx.x;
4-
if (i < n) output[i] = left[i] / right[i];
4+
if (i < n) {
5+
if (rightIsScalar == 1) {
6+
output[i] = left[i] / right[0];
7+
} else {
8+
output[i] = left[i] / right[i];
9+
}
10+
}
511
}
612
}

src/ConvNetSharp.Volume.GPU/Single/Volume.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ public override void ConvolutionGradient(Volume<float> filters, Volume<float> ou
414414

415415
public override void Divide(Volume<float> other, Volume<float> result)
416416
{
417-
_kernelLoader.RunKernel("div", this, other, result);
417+
_kernelLoader.RunKernel("div", this, other, result, other.Shape.TotalLength == 1 ? 1 : 0);
418418
}
419419

420420
public override void Dropout(float dropProbability, Volume<float> result)

src/ConvNetSharp.Volume.Tests/VolumeTests.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ public void Div1DInPlace()
6161
AssertNumber.AreEqual(0.3, left.Get(2), 1e-6);
6262
}
6363

64+
[TestMethod]
65+
public void Div1DBroadcast()
66+
{
67+
var left = NewVolume(new[] { 1.0, 2.0, 3.0 }, new Shape(3));
68+
var right = NewVolume(new[] { 2.0 }, new Shape(1));
69+
var result = BuilderInstance<T>.Volume.SameAs(new Shape(3));
70+
71+
left.Divide(right, result);
72+
AssertNumber.AreEqual(0.5, result.Get(0));
73+
AssertNumber.AreEqual(1.0, result.Get(1));
74+
AssertNumber.AreEqual(1.5, result.Get(2));
75+
}
76+
6477
[TestMethod]
6578
public void Add2D()
6679
{

src/ConvNetSharp.Volume/Double/Volume.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,7 @@ public override void ConvolutionGradient(Volume<double> filters, Volume<double>
255255

256256
public override void Divide(Volume<double> other, Volume<double> result)
257257
{
258-
if (this.Shape.Equals(other.Shape))
259-
{
260-
this.Storage.Map((left, right) => left / right, other.Storage, result.Storage);
261-
}
262-
else
263-
{
264-
//Todo: broadcast
265-
throw new NotImplementedException();
266-
}
258+
this.Storage.MapEx((left, right) => left / right, other.Storage, result.Storage);
267259
}
268260

269261
public override void Dropout(double dropProbability, Volume<double> result)

src/ConvNetSharp.Volume/Single/Volume.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,7 @@ public override void ConvolutionGradient(Volume<float> filters, Volume<float> ou
255255

256256
public override void Divide(Volume<float> other, Volume<float> result)
257257
{
258-
if (this.Shape.Equals(other.Shape))
259-
{
260-
this.Storage.Map((left, right) => left / right, other.Storage, result.Storage);
261-
}
262-
else
263-
{
264-
//Todo: broadcast
265-
throw new NotImplementedException();
266-
}
258+
this.Storage.MapEx((left, right) => left / right, other.Storage, result.Storage);
267259
}
268260

269261
public override void Dropout(float dropProbability, Volume<float> result)

0 commit comments

Comments
 (0)