@@ -5,6 +5,8 @@ plugins {
5
5
6
6
def USE_NIGHTLY = System . getenv(' USE_LIBTORCH_NIGHTLY' )?. toBoolean()
7
7
8
+ def LIBTORCH_VERSION = USE_NIGHTLY ? " 1.5.0-SNAPSHOT" : " 1.4.0" ;
9
+
8
10
repositories {
9
11
jcenter()
10
12
@@ -14,20 +16,26 @@ repositories {
14
16
}
15
17
16
18
dependencies {
17
- if (USE_NIGHTLY ) {
18
- implementation ' org.pytorch:pytorch_java_only:1.5.0-SNAPSHOT'
19
- } else {
20
- implementation ' org.pytorch:pytorch_java_only:1.4.0'
21
- }
19
+ implementation " org.pytorch:pytorch_java_only:${ LIBTORCH_VERSION} "
22
20
}
23
21
24
22
def LIBTORCH_HOME = System . getenv(' LIBTORCH_HOME' )
25
23
if (! LIBTORCH_HOME ) {
26
24
throw new RuntimeException (' LIBTORCH_HOME not present in environment.' );
27
25
}
28
- if (! file(LIBTORCH_HOME ). isDirectory()) {
29
- throw new RuntimeException (' LIBTORCH_HOME does not refer to a directory.' );
26
+ def BUILD_VERSION_FILE = new File (LIBTORCH_HOME , " build-version" );
27
+ if (! BUILD_VERSION_FILE . isFile()) {
28
+ throw new RuntimeException (
29
+ " Cannot find ${ BUILD_VERSION_FILE} . " +
30
+ " Make sure LIBTORCH_HOME refers to the root of the libtorch distribution." );
30
31
}
32
+ def installedVersion = BUILD_VERSION_FILE . readLines();
33
+ def versionPattern = " ^" + java.util.regex.Pattern . quote(LIBTORCH_VERSION ) + " \\ b.*" ;
34
+ if (! USE_NIGHTLY && ! (installedVersion[0 ] ==~ versionPattern)) {
35
+ throw new RuntimeException (
36
+ " Found libtorch version ${ installedVersion} , but build.gradle expects ${ LIBTORCH_VERSION} ." );
37
+ }
38
+
31
39
32
40
application {
33
41
mainClassName = ' demo.App'
0 commit comments